Browse Source

suppress privacy model, refer to "Deep Leakage from Gradients"

https://arxiv.org/abs/1906.08935
tags/v1.2.1
itcomee 4 years ago
parent
commit
c13cd9391a
14 changed files with 1416 additions and 18 deletions
  1. +36
    -15
      examples/privacy/README.md
  2. +0
    -0
      examples/privacy/sup_privacy/__init__.py
  3. +154
    -0
      examples/privacy/sup_privacy/sup_privacy.py
  4. +32
    -0
      examples/privacy/sup_privacy/sup_privacy_config.py
  5. +27
    -0
      mindarmour/privacy/sup_privacy/__init__.py
  6. +0
    -0
      mindarmour/privacy/sup_privacy/mask_monitor/__init__.py
  7. +98
    -0
      mindarmour/privacy/sup_privacy/mask_monitor/masker.py
  8. +0
    -0
      mindarmour/privacy/sup_privacy/sup_ctrl/__init__.py
  9. +640
    -0
      mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py
  10. +0
    -0
      mindarmour/privacy/sup_privacy/train/__init__.py
  11. +325
    -0
      mindarmour/privacy/sup_privacy/train/model.py
  12. +3
    -3
      tests/ut/python/privacy/__init__.py
  13. +16
    -0
      tests/ut/python/privacy/sup_privacy/__init__.py
  14. +85
    -0
      tests/ut/python/privacy/sup_privacy/test_model_train.py

+ 36
- 15
examples/privacy/README.md View File

@@ -1,33 +1,54 @@
# Application demos of privacy stealing and privacy protection

## Introduction

Although machine learning could obtain a generic model based on training data, it has been proved that the trained
model may disclose the information of training data (such as the membership inference attack). Differential
privacy training
is an effective
method proposed
to overcome this problem, in which Gaussian noise is added while training. There are mainly three parts for
differential privacy(DP) training: noise-generating mechanism, DP optimizer and DP monitor. We have implemented
a novel noise-generating mechanisms: adaptive decay noise mechanism. DP
monitor is used to compute the privacy budget while training.
model may disclose the information of training data (such as the membership inference attack).
Differential privacy training is an effective method proposed to overcome this problem, in which Gaussian noise is
added while training. There are mainly three parts for differential privacy(DP) training: noise-generating
mechanism, DP optimizer and DP monitor. We have implemented a novel noise-generating mechanisms: adaptive decay
noise mechanism. DP monitor is used to compute the privacy budget while training.
Suppress Privacy training is a novel method to protect privacy distinct from the noise addition method
(such as DP), in which the negligible model parameter is removed gradually to achieve a better balance between
accuracy and privacy.

## 1. Adaptive decay DP training

With adaptive decay mechanism, the magnitude of the Gaussian noise would be decayed as the training step grows, which
resulting a stable convergence.

```sh
$ cd examples/privacy/diff_privacy
$ python lenet5_dp_ada_gaussian.py
cd examples/privacy/diff_privacy
python lenet5_dp_ada_gaussian.py
```

## 2. Adaptive norm clip training

With adaptive norm clip mechanism, the norm clip of the gradients would be changed according to the norm values of
them, which can adjust the ratio of noise and original gradients.

```sh
$ cd examples/privacy/diff_privacy
$ python lenet5_dp.py
cd examples/privacy/diff_privacy
python lenet5_dp.py
```

## 3. Membership inference evaluation

By this evaluation method, we could judge whether a sample is belongs to training dataset or not.

```sh
cd examples/privacy/membership_inference_attack
python train.py --data_path home_path_to_cifar100 --ckpt_path ./
python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt
```

## 4. suppress privacy training

With suppress privacy mechanism, the values of some trainable parameters (such as conv layers and fully connected
layers) are set to zero as the training step grows, which can
achieve a better balance between accuracy and privacy

```sh
$ cd examples/privacy/membership_inference_attack
$ python train.py --data_path home_path_to_cifar100 --ckpt_path ./
$ python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt
cd examples/privacy/sup_privacy
python sup_privacy.py
```

+ 0
- 0
examples/privacy/sup_privacy/__init__.py View File


+ 154
- 0
examples/privacy/sup_privacy/sup_privacy.py View File

@@ -0,0 +1,154 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training example of suppress-based privacy.
"""
import os
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig
from mindspore.train.callback import LossMonitor
from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision.utils import Inter
import mindspore.common.dtype as mstype

from examples.common.networks.lenet5.lenet5_net import LeNet5

from sup_privacy_config import mnist_cfg as cfg
from mindarmour.privacy.sup_privacy import SuppressModel
from mindarmour.privacy.sup_privacy import SuppressMasker
from mindarmour.privacy.sup_privacy import SuppressPrivacyFactory
from mindarmour.privacy.sup_privacy import MaskLayerDes

from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
LOGGER.set_level('INFO')
TAG = 'Lenet5_Suppress_train'


def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, samples=None, num_parallel_workers=1, sparse=True):
"""
create dataset for training or testing
"""
# define dataset
ds1 = ds.MnistDataset(data_path, num_samples=samples)

# define operation parameters
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0

# define map operations
resize_op = CV.Resize((resize_height, resize_width),
interpolation=Inter.LINEAR)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)

# apply map operations on images
if not sparse:
one_hot_enco = C.OneHot(10)
ds1 = ds1.map(input_columns="label", operations=one_hot_enco, num_parallel_workers=num_parallel_workers)
type_cast_op = C.TypeCast(mstype.float32)
ds1 = ds1.map(input_columns="label", operations=type_cast_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=resize_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=rescale_op,
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=hwc2chw_op,
num_parallel_workers=num_parallel_workers)

# apply DatasetOps
buffer_size = 10000
ds1 = ds1.shuffle(buffer_size=buffer_size)
ds1 = ds1.batch(batch_size, drop_remainder=True)
ds1 = ds1.repeat(repeat_size)

return ds1

def mnist_suppress_train(epoch_size=10, start_epoch=3, lr=0.05, samples=10000, mask_times=1000,
sparse_thd=0.90, sparse_start=0.0, masklayers=None):
"""
local train by suppress-based privacy
"""

networks_l5 = LeNet5()
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train",
end_epoch=epoch_size,
batch_num=(int)(samples/cfg.batch_size),
start_epoch=start_epoch,
mask_times=mask_times,
networks=networks_l5,
lr=lr,
sparse_end=sparse_thd,
sparse_start=sparse_start,
mask_layers=masklayers)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.SGD(networks_l5.trainable_params(), lr)
config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size),
keep_checkpoint_max=10)

# Create the SuppressModel model for training.
model_instance = SuppressModel(network=networks_l5,
loss_fn=net_loss,
optimizer=net_opt,
metrics={"Accuracy": Accuracy()})
model_instance.link_suppress_ctrl(suppress_ctrl_instance)

# Create a Masker for Suppress training. The function of the Masker is to
# enforce suppress operation while training.
suppress_masker = SuppressMasker(model=model_instance, suppress_ctrl=suppress_ctrl_instance)

mnist_path = "./MNIST_unzip/" #"../../MNIST_unzip/"
ds_train = generate_mnist_dataset(os.path.join(mnist_path, "train"),
batch_size=cfg.batch_size, repeat_size=1, samples=samples)

ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory="./trained_ckpt_file/",
config=config_ck)

print("============== Starting SUPP Training ==============")
model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker],
dataset_sink_mode=False)

print("============== Starting SUPP Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(networks_l5, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(mnist_path, 'test'),
batch_size=cfg.batch_size)
acc = model_instance.eval(ds_eval, dataset_sink_mode=False)
print("============== SUPP Accuracy: %s ==============", acc)

if __name__ == "__main__":
# This configure can run in pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target)

masklayers_lenet5 = [] # determine which layer should be masked

masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, True, 10))
masklayers_lenet5.append(MaskLayerDes("conv2.weight", False, True, 150))
masklayers_lenet5.append(MaskLayerDes("fc1.weight", True, False, -1))
masklayers_lenet5.append(MaskLayerDes("fc2.weight", True, False, -1))
masklayers_lenet5.append(MaskLayerDes("fc3.weight", True, False, 50))

# do suppreess privacy train, with stronger privacy protection and better performance than Differential Privacy
mnist_suppress_train(10, 3, 0.10, 60000, 1000, 0.95, 0.0, masklayers=masklayers_lenet5) # used

+ 32
- 0
examples/privacy/sup_privacy/sup_privacy_config.py View File

@@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in sup_privacy.py
"""

from easydict import EasyDict as edict

mnist_cfg = edict({
'num_classes': 10, # the number of classes of model's output
'epoch_size': 1, # training epochs
'batch_size': 32, # batch size for training
'image_height': 32, # the height of training samples
'image_width': 32, # the width of training samples
'save_checkpoint_steps': 1875, # the interval steps for saving checkpoint file of the model
'keep_checkpoint_max': 10, # the maximum number of checkpoint files would be saved
'device_target': 'Ascend', # device used
'data_path': './MNIST_unzip', # the path of training and testing data set
'dataset_sink_mode': False, # whether deliver all training data to device one time
})

+ 27
- 0
mindarmour/privacy/sup_privacy/__init__.py View File

@@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provides Suppress Privacy feature to protect user privacy.
"""
from .mask_monitor.masker import SuppressMasker
from .train.model import SuppressModel
from .sup_ctrl.conctrl import SuppressPrivacyFactory
from .sup_ctrl.conctrl import SuppressCtrl
from .sup_ctrl.conctrl import MaskLayerDes

__all__ = ['SuppressMasker',
'SuppressModel',
'SuppressPrivacyFactory',
'SuppressCtrl',
'MaskLayerDes']

+ 0
- 0
mindarmour/privacy/sup_privacy/mask_monitor/__init__.py View File


+ 98
- 0
mindarmour/privacy/sup_privacy/mask_monitor/masker.py View File

@@ -0,0 +1,98 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Masker module of suppress-based privacy..
"""
from mindspore.train.callback import Callback
from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_param_type
from mindarmour.privacy.sup_privacy.train.model import SuppressModel
from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl

LOGGER = LogUtil.get_instance()
TAG = 'suppress masker'

class SuppressMasker(Callback):
"""
Args:
args (Union[int, float, numpy.ndarray, list, str]): Parameters
used for creating a suppress privacy monitor.
kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword
parameters used for creating a suppress privacy monitor.
model (SuppressModel): SuppressModel instance.
suppress_ctrl (SuppressCtrl): SuppressCtrl instance.

Examples:
networks_l5 = LeNet5()
masklayers = []
masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10))
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train",
end_epoch=10,
batch_num=(int)(10000/cfg.batch_size),
start_epoch=3,
mask_times=100,
networks=networks_l5,
lr=lr,
sparse_end=0.90,
sparse_start=0.0,
mask_layers=masklayers)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0)
config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10)
model_instance = SuppressModel(network=networks_l5,
loss_fn=net_loss,
optimizer=net_opt,
metrics={"Accuracy": Accuracy()})
model_instance.link_suppress_ctrl(suppress_ctrl_instance)
ds_train = generate_mnist_dataset("./MNIST_unzip/train",
batch_size=cfg.batch_size, repeat_size=1, samples=samples)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory="./trained_ckpt_file/",
config=config_ck)
model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker],
dataset_sink_mode=False)
"""

def __init__(self, model=None, suppress_ctrl=None):

super(SuppressMasker, self).__init__()

self._model = check_param_type('model', model, SuppressModel)
self._suppress_ctrl = check_param_type('suppress_ctrl', suppress_ctrl, SuppressCtrl)

def step_end(self, run_context):
"""
Update mask matrix tensor used for SuppressModel instance.

Args:
run_context (RunContext): Include some information of the model.
"""
cb_params = run_context.original_args()
cur_step = cb_params.cur_step_num
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1

if self._suppress_ctrl is not None and self._model.network_end is not None:
self._suppress_ctrl.update_status(cb_params.cur_epoch_num, cur_step, cur_step_in_epoch)

if not self._suppress_ctrl.mask_initialized:
raise ValueError("Not initialize network!")
if self._suppress_ctrl.to_do_mask:
self._suppress_ctrl.update_mask(self._suppress_ctrl.networks, cur_step)
LOGGER.info(TAG, "suppress update")
elif not self._suppress_ctrl.to_do_mask and self._suppress_ctrl.mask_started:
self._suppress_ctrl.reset_zeros()
if cur_step_in_epoch % 100 == 1:
self._suppress_ctrl.calc_theoretical_sparse_for_conv()
_, _, _ = self._suppress_ctrl.calc_actual_sparse_for_conv(
self._suppress_ctrl.networks)

+ 0
- 0
mindarmour/privacy/sup_privacy/sup_ctrl/__init__.py View File


+ 640
- 0
mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py View File

@@ -0,0 +1,640 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
control function of suppress-based privacy.
"""
import math
import numpy as np

from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.nn import Cell

from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_int_positive, check_value_positive, \
check_value_non_negative, check_param_type
LOGGER = LogUtil.get_instance()
TAG = 'Suppression training.'


class SuppressPrivacyFactory:
""" Factory class of SuppressCtrl mechanisms"""
def __init__(self):
pass

@staticmethod
def create(policy="local_train", end_epoch=10, batch_num=2, start_epoch=3, mask_times=100, networks=None,
lr=0.05, sparse_end=0.60, sparse_start=0.0, mask_layers=None):
"""
Args:
policy (str): Training policy for suppress privacy training. "local_train" means local training.
end_epoch (int): The last epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 .
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size .
start_epoch (int): The first epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 .
mask_times (int): The num of suppress operations.
networks (Cell): The training network.
lr (float): Learning rate.
sparse_end (float): The sparsity to reach, 0.0 <= sparse_start < sparse_end < 1.0 .
sparse_start (float): The sparsity to start, 0.0 <= sparse_start < sparse_end < 1.0 .
mask_layers (list): Description of the training network layers that need to be suppressed.

Returns:
SuppressCtrl, class of Suppress Privavy Mechanism.

Examples:
networks_l5 = LeNet5()
masklayers = []
masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10))
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train",
end_epoch=10,
batch_num=(int)(10000/cfg.batch_size),
start_epoch=3,
mask_times=100,
networks=networks_l5,
lr=lr,
sparse_end=0.90,
sparse_start=0.0,
mask_layers=masklayers)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0)
config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10)
model_instance = SuppressModel(network=networks_l5,
loss_fn=net_loss,
optimizer=net_opt,
metrics={"Accuracy": Accuracy()})
model_instance.link_suppress_ctrl(suppress_ctrl_instance)
ds_train = generate_mnist_dataset("./MNIST_unzip/train",
batch_size=cfg.batch_size, repeat_size=1, samples=samples)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory="./trained_ckpt_file/",
config=config_ck)
model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker],
dataset_sink_mode=False)
"""
if policy == "local_train":
return SuppressCtrl(networks, end_epoch, batch_num, start_epoch, mask_times, lr, sparse_end,
sparse_start, mask_layers)
msg = "Only local training is supported now, federal training will be supported " \
"in the future. But got {}.".format(policy)
LOGGER.error(TAG, msg)
raise ValueError(msg)

class SuppressCtrl(Cell):
"""
Args:
networks (Cell): The training network.
end_epoch (int): The last epoch in suppress operations.
batch_num (int): The num of grad operation in an epoch.
mask_start_epoch (int): The first epoch in suppress operations.
mask_times (int): The num of suppress operations.
lr (Union[float, int]): Learning rate.
sparse_end (Union[float, int]): The sparsity to reach.
sparse_start (float): The sparsity to start.
mask_layers (list): Description of those layers that need to be suppressed.
"""
def __init__(self, networks, end_epoch, batch_num, mask_start_epoch=3, mask_times=500, lr=0.05,
sparse_end=0.60,
sparse_start=0.0,
mask_layers=None):
super(SuppressCtrl, self).__init__()
self.networks = check_param_type('networks', networks, Cell)
self.mask_end_epoch = check_int_positive('end_epoch', end_epoch)
self.batch_num = check_int_positive('batch_num', batch_num)
self.mask_start_epoch = check_int_positive('mask_start_epoch', mask_start_epoch)
self.mask_times = check_int_positive('mask_times', mask_times)
self.lr = check_value_positive('lr', lr)
self.sparse_end = check_value_non_negative('sparse_end', sparse_end)
self.sparse_start = check_value_non_negative('sparse_start', sparse_start)
self.mask_layers = check_param_type('mask_layers', mask_layers, list)

self.weight_lower_bound = 0.005 # all network weight will be larger than this value
self.sparse_vibra = 0.02 # the sparsity may have certain range of variations
self.sparse_valid_max_weight = 0.20 # if max network weight is less than this value, suppress operation stop temporarily
self.add_noise_thd = 0.50 # if network weight is more than this value, noise is forced
self.noise_volume = 0.01 # noise volume 0.01
self.base_ground_thd = 0.0000001 # if network weight is less than this value, will be considered as 0
self.model = None # SuppressModel instance
self.grads_mask_list = [] # list for Grad Mask Matrix tensor
self.de_weight_mask_list = [] # list for weight Mask Matrix tensor
self.to_do_mask = False # the flag means suppress operation is toggled immediately
self.mask_started = False # the flag means suppress operation has been started
self.mask_start_step = 0 # suppress operation is actually started at this step
self.mask_prev_step = 0 # previous suppress operation is done at this step
self.cur_sparse = 0.0 # current sparsity to which one suppress will get
self.mask_all_steps = (self.mask_end_epoch-mask_start_epoch+1)*batch_num # the amount of step contained in all suppress operation
self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation
self.mask_initialized = False # flag means the initialization is done

if self.mask_start_epoch > self.mask_end_epoch:
msg = "start_epoch error: {}".format(self.mask_start_epoch)
LOGGER.error(TAG, msg)
raise ValueError(msg)

if self.mask_end_epoch > 100:
msg = "end_epoch error: {}".format(self.mask_end_epoch)
LOGGER.error(TAG, msg)
raise ValueError(msg)

if self.mask_step_interval < 0:
msg = "step_interval error: {}".format(self.mask_step_interval)
LOGGER.error(TAG, msg)
raise ValueError(msg)

if self.sparse_end > 1.00 or self.sparse_end <= 0:
msg = "sparse_end error: {}".format(self.sparse_end)
LOGGER.error(TAG, msg)
raise ValueError(msg)

if self.sparse_start >= self.sparse_end:
msg = "sparse_start error: {}".format(self.sparse_start)
LOGGER.error(TAG, msg)
raise ValueError(msg)

if mask_layers is not None:
mask_layer_id = 0
for one_mask_layer in mask_layers:
if not isinstance(one_mask_layer, MaskLayerDes):
msg = "mask_layer instance error!"
LOGGER.error(TAG, msg)
raise ValueError(msg)
layer_name = one_mask_layer.layer_name
mask_layer_id2 = 0
for one_mask_layer_2 in mask_layers:
if mask_layer_id != mask_layer_id2 and layer_name in one_mask_layer_2.layer_name:
msg = "mask_layers repeat item : {} in {} and {}".format(layer_name,
mask_layer_id,
mask_layer_id2)
LOGGER.error(TAG, msg)
raise ValueError(msg)
mask_layer_id2 = mask_layer_id2 + 1
mask_layer_id = mask_layer_id + 1

if networks is not None:
m = 0
for layer in networks.get_parameters(expand=True):
one_mask_layer = None
if mask_layers is not None:
one_mask_layer = get_one_mask_layer(mask_layers, layer.name)
if one_mask_layer is not None and not one_mask_layer.inited:
one_mask_layer.inited = True
shape = P.Shape()(layer)
mul_mask_array = np.ones(shape, dtype=np.float32)
grad_mask_cell = GradMaskInCell(mul_mask_array,
one_mask_layer.is_add_noise,
one_mask_layer.is_lower_clip,
one_mask_layer.min_num,
one_mask_layer.upper_bound)
grad_mask_cell.mask_able = True
self.grads_mask_list.append(grad_mask_cell)
add_mask_array = np.zeros(shape, dtype=np.float32)

de_weight_cell = DeWeightInCell(add_mask_array)
de_weight_cell.mask_able = True
self.de_weight_mask_list.append(de_weight_cell)
msg = "do mask {}, {}".format(m, one_mask_layer.layer_name)
LOGGER.info(TAG, msg)
elif one_mask_layer is not None and one_mask_layer.inited:
msg = "repeated match masked setting {}=>{}.".format(one_mask_layer.layer_name, layer.name)
LOGGER.error(TAG, msg)
raise ValueError(msg)
else:
shape = np.shape([1])
mul_mask_array = np.ones(shape, dtype=np.float32)
grad_mask_cell = GradMaskInCell(mul_mask_array, False, False, -1)
grad_mask_cell.mask_able = False

self.grads_mask_list.append(grad_mask_cell)
add_mask_array = np.zeros(shape, dtype=np.float32)
de_weight_cell = DeWeightInCell(add_mask_array)
de_weight_cell.mask_able = False
self.de_weight_mask_list.append(de_weight_cell)
m += 1
self.mask_initialized = True
msg = "init SuppressCtrl by networks"
LOGGER.info(TAG, msg)
msg = "complete init mask for lenet5.step_interval: {}".format(self.mask_step_interval)
LOGGER.info(TAG, msg)

for one_mask_layer in mask_layers:
if not one_mask_layer.inited:
msg = "can't match this mask layer: {} ".format(one_mask_layer.layer_name)
LOGGER.error(TAG, msg)
raise ValueError(msg)

def update_status(self, cur_epoch, cur_step, cur_step_in_epoch):
"""
Update the suppress operation status.

Args:
cur_epoch (int): Current epoch of the whole training process.
cur_step (int): Current step of the whole training process.
cur_step_in_epoch (int): Current step of the current epoch.
"""
if not self.mask_initialized:
self.mask_started = False
elif (self.mask_start_epoch <= cur_epoch <= self.mask_end_epoch) or self.mask_started:
if not self.mask_started:
self.mask_started = True
self.mask_start_step = cur_step
if cur_step >= (self.mask_prev_step + self.mask_step_interval):
self.mask_prev_step = cur_step
self.to_do_mask = True
# execute the last suppression operation
elif cur_epoch == self.mask_end_epoch and cur_step_in_epoch == self.batch_num-2:
self.mask_prev_step = cur_step
self.to_do_mask = True
else:
self.to_do_mask = False
else:
self.to_do_mask = False
self.mask_started = False

def update_mask(self, networks, cur_step):
"""
Update add mask arrays and multiply mask arrays of network layers.

Args:
networks (Cell): The training network.
cur_step (int): Current epoch of the whole training process.
"""
if self.sparse_end <= 0.0:
return

self.cur_sparse = self.sparse_end +\
(self.sparse_start - self.sparse_end)*\
math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3)
m = 0
for layer in networks.get_parameters(expand=True):
if self.grads_mask_list[m].mask_able:
weight_array = layer.data.asnumpy()
weight_avg = np.mean(weight_array)
weight_array_flat = weight_array.flatten()
weight_array_flat_abs = np.abs(weight_array_flat)
weight_abs_avg = np.mean(weight_array_flat_abs)
weight_array_flat_abs.sort()
len_array = weight_array.size
weight_abs_max = np.max(weight_array_flat_abs)
if m == 0 and weight_abs_max < self.sparse_valid_max_weight:
msg = "give up this masking .."
LOGGER.info(TAG, msg)
return
if self.grads_mask_list[m].min_num > 0:
sparse_weight_thd, _, actual_stop_pos = self.calc_sparse_thd(weight_array_flat_abs,
self.cur_sparse, m)
else:
actual_stop_pos = int(len_array * self.cur_sparse)
sparse_weight_thd = weight_array_flat_abs[actual_stop_pos]

self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, m)

msg = "{} len={}, sparse={}, current sparse thd={}, max={}, avg={}, avg_abs={} \n".format(
layer.name, len_array, actual_stop_pos/len_array, sparse_weight_thd,
weight_abs_max, weight_avg, weight_abs_avg)
LOGGER.info(TAG, msg)
m = m + 1

def update_mask_layer(self, weight_array_flat, sparse_weight_thd, sparse_stop_pos, weight_abs_max, layer_index):
"""
Update add mask arrays and multiply mask arrays of one single layer.

Args:
weight_array (numpy.ndarray): The weight array of layer's parameters.
sparse_weight_thd (float): The weight threshold of sparse operation.
sparse_stop_pos (int): The maximum number of elements to be suppressed.
weight_abs_max (float): The maximum absolute value of weights.
layer_index (int): The index of target layer.
"""
grad_mask_cell = self.grads_mask_list[layer_index]
mul_mask_array_flat = grad_mask_cell.mul_mask_array_flat
de_weight_cell = self.de_weight_mask_list[layer_index]
add_mask_array_flat = de_weight_cell.add_mask_array_flat
min_num = grad_mask_cell.min_num
is_add_noise = grad_mask_cell.is_add_noise
is_lower_clip = grad_mask_cell.is_lower_clip
upper_bound = grad_mask_cell.upper_bound

if not self.grads_mask_list[layer_index].mask_able:
return
m = 0
n = 0
p = 0
q = 0
# add noise on weights if not masking or clipping.
weight_noise_bound = min(self.add_noise_thd, max(self.noise_volume*10, weight_abs_max*0.75))
for i in range(0, weight_array_flat.size):
if abs(weight_array_flat[i]) <= sparse_weight_thd:
if m < weight_array_flat.size - min_num and m < sparse_stop_pos:
# to mask
mul_mask_array_flat[i] = 0.0
add_mask_array_flat[i] = weight_array_flat[i] / self.lr
m = m + 1
else:
# not mask
if weight_array_flat[i] > 0.0:
add_mask_array_flat[i] = (weight_array_flat[i] - self.weight_lower_bound) / self.lr
else:
add_mask_array_flat[i] = (weight_array_flat[i] + self.weight_lower_bound) / self.lr
p = p + 1
elif is_lower_clip and abs(weight_array_flat[i]) <= \
self.weight_lower_bound and sparse_weight_thd > self.weight_lower_bound*0.5:
# not mask
mul_mask_array_flat[i] = 1.0
if weight_array_flat[i] > 0.0:
add_mask_array_flat[i] = (weight_array_flat[i] - self.weight_lower_bound) / self.lr
else:
add_mask_array_flat[i] = (weight_array_flat[i] + self.weight_lower_bound) / self.lr
p = p + 1
elif abs(weight_array_flat[i]) > upper_bound:
mul_mask_array_flat[i] = 1.0
if weight_array_flat[i] > 0.0:
add_mask_array_flat[i] = (weight_array_flat[i] - upper_bound) / self.lr
else:
add_mask_array_flat[i] = (weight_array_flat[i] + upper_bound) / self.lr
n = n + 1
else:
# not mask
mul_mask_array_flat[i] = 1.0
if is_add_noise and abs(weight_array_flat[i]) > weight_noise_bound > 0.0:
# add noise
add_mask_array_flat[i] = np.random.uniform(-self.noise_volume, self.noise_volume) / self.lr
q = q + 1
else:
add_mask_array_flat[i] = 0.0

grad_mask_cell.update()
de_weight_cell.update()
msg = "Dimension of mask tensor is {}D, which located in the {}-th layer of the network. \n The number of " \
"suppressed elements, max-clip elements, min-clip elements and noised elements are {}, {}, {}, {}"\
.format(len(grad_mask_cell.mul_mask_array_shape), layer_index, m, n, p, q)
LOGGER.info(TAG, msg)

def calc_sparse_thd(self, array_flat, sparse_value, layer_index):
"""
Calculate the suppression threshold of one weight array.

Args:
array_flat (numpy.ndarray): The flattened weight array.
sparse_value (float): The target sparse value of weight array.

Returns:
- float, the sparse threshold of this array.

- int, the number of weight elements to be suppressed.

- int, the larger number of weight elements to be suppressed.
"""
size = len(array_flat)
sparse_max_thd = 1.0 - min(self.grads_mask_list[layer_index].min_num, size) / size
pos = int(size*min(sparse_max_thd, sparse_value))
thd = array_flat[pos]
farther_stop_pos = int(size*min(sparse_max_thd, max(0, sparse_value + self.sparse_vibra / 2.0)))
return thd, pos, farther_stop_pos

def reset_zeros(self):
"""
Set add mask arrays to be zero.
"""
for de_weight_cell in self.de_weight_mask_list:
de_weight_cell.reset_zeros()

def calc_theoretical_sparse_for_conv(self):
"""
Compute actually sparsity of mask matrix for conv1 layer and conv2 layer.
"""
array_mul_mask_flat_conv1 = self.grads_mask_list[0].mul_mask_array_flat
array_mul_mask_flat_conv2 = self.grads_mask_list[1].mul_mask_array_flat
sparse = 0.0
sparse_value_1 = 0.0
sparse_value_2 = 0.0
full = 0.0
full_conv1 = 0.0
full_conv2 = 0.0
for i in range(0, array_mul_mask_flat_conv1.size):
full += 1.0
full_conv1 += 1.0
if array_mul_mask_flat_conv1[i] <= 0.0:
sparse += 1.0
sparse_value_1 += 1.0

for i in range(0, array_mul_mask_flat_conv2.size):
full = full + 1.0
full_conv2 = full_conv2 + 1.0
if array_mul_mask_flat_conv2[i] <= 0.0:
sparse = sparse + 1.0
sparse_value_2 += 1.0
sparse = sparse/full
sparse_value_1 = sparse_value_1/full_conv1
sparse_value_2 = sparse_value_2/full_conv2
msg = "conv sparse mask={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2)
LOGGER.info(TAG, msg)
return sparse, sparse_value_1, sparse_value_2

def calc_actual_sparse_for_conv(self, networks):
"""
Compute actually sparsity of network for conv1 layer and conv2 layer.

Args:
networks (Cell): The training network.
"""
sparse = 0.0
sparse_value_1 = 0.0
sparse_value_2 = 0.0
full = 0.0
full_conv1 = 0.0
full_conv2 = 0.0

array_cur_conv1 = np.ones(np.shape([1]), dtype=np.float32)
array_cur_conv2 = np.ones(np.shape([1]), dtype=np.float32)
for layer in networks.get_parameters(expand=True):
if "conv1.weight" in layer.name:
array_cur_conv1 = layer.data.asnumpy()
if "conv2.weight" in layer.name:
array_cur_conv2 = layer.data.asnumpy()

array_mul_mask_flat_conv1 = array_cur_conv1.flatten()
array_mul_mask_flat_conv2 = array_cur_conv2.flatten()

for i in range(0, array_mul_mask_flat_conv1.size):
full += 1.0
full_conv1 += 1.0
if abs(array_mul_mask_flat_conv1[i]) <= self.base_ground_thd:
sparse += 1.0
sparse_value_1 += 1.0

for i in range(0, array_mul_mask_flat_conv2.size):
full = full + 1.0
full_conv2 = full_conv2 + 1.0
if abs(array_mul_mask_flat_conv2[i]) <= self.base_ground_thd:
sparse = sparse + 1.0
sparse_value_2 += 1.0

sparse = sparse / full
sparse_value_1 = sparse_value_1 / full_conv1
sparse_value_2 = sparse_value_2 / full_conv2
msg = "conv sparse fact={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2)
LOGGER.info(TAG, msg)
return sparse, sparse_value_1, sparse_value_2

def calc_actual_sparse_for_fc1(self, networks):
self.calc_actual_sparse_for_layer(networks, "fc1.weight")

def calc_actual_sparse_for_layer(self, networks, layer_name):
"""
Compute actually sparsity of one network layer

Args:
networks (Cell): The training network.
layer_name (str): The name of target layer.
"""
check_param_type('networks', networks, Cell)
check_param_type('layer_name', layer_name, str)

sparse = 0.0
full = 0.0

array_cur = None
for layer in networks.get_parameters(expand=True):
if layer_name in layer.name:
array_cur = layer.data.asnumpy()

if array_cur is None:
msg = "no such layer to calc sparse: {} ".format(layer_name)
LOGGER.info(TAG, msg)
return

array_cur_flat = array_cur.flatten()

for i in range(0, array_cur_flat.size):
full += 1.0
if abs(array_cur_flat[i]) <= self.base_ground_thd:
sparse += 1.0

sparse = sparse / full
msg = "{} sparse fact={} ".format(layer_name, sparse)
LOGGER.info(TAG, msg)

def get_one_mask_layer(mask_layers, layer_name):
"""
Returns the layer definitions that need to be suppressed.

Args:
mask_layers (list): The layers that need to be suppressed.
layer_name (str): The name of target layer.

Returns:
Union[MaskLayerDes, None], the layer definitions that need to be suppressed.
"""
for each_mask_layer in mask_layers:
if each_mask_layer.layer_name in layer_name:
return each_mask_layer
return None

class MaskLayerDes:
"""
Describe the layer that need to be suppressed.

Args:
layer_name (str): Layer name, get the name of one layer as following:
for layer in networks.get_parameters(expand=True):
if layer.name == "conv": ...
is_add_noise (bool): If True, the weight of this layer can add noise.
If False, the weight of this layer can not add noise.
is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value.
If False, the weights of this layer won't be clipped.
min_num (int): The number of weights left that not be suppressed, which need to be greater than 0.
upper_bound (float): max value of weight in this layer, default value is 1.20 .
"""
def __init__(self, layer_name, is_add_noise, is_lower_clip, min_num, upper_bound=1.20):
self.layer_name = check_param_type('layer_name', layer_name, str)
self.is_add_noise = check_param_type('is_add_noise', is_add_noise, bool)
self.is_lower_clip = check_param_type('is_lower_clip', is_lower_clip, bool)
self.min_num = check_param_type('min_num', min_num, int)
self.upper_bound = check_value_positive('upper_bound', upper_bound)
self.inited = False

class GradMaskInCell(Cell):
"""
Define the mask matrix for gradients masking.

Args:
array (numpy.ndarray): The mask array.
is_add_noise (bool): If True, the weight of this layer can add noise.
If False, the weight of this layer can not add noise.
is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value.
If False, the weights of this layer won't be clipped.
min_num (int): The number of weights left that not be suppressed, which need to be greater than 0.
upper_bound (float): max value of weight in this layer, default value is 1.20
"""
def __init__(self, array, is_add_noise, is_lower_clip, min_num, upper_bound=1.20):
super(GradMaskInCell, self).__init__()
self.mul_mask_array_shape = array.shape
mul_mask_array = array.copy()
self.mul_mask_array_flat = mul_mask_array.flatten()
self.mul_mask_tensor = Tensor(array, mstype.float32)
self.mask_able = False
self.is_add_noise = is_add_noise
self.is_lower_clip = is_lower_clip
self.min_num = min_num
self.upper_bound = check_value_positive('upper_bound', upper_bound)

def construct(self):
"""
Return the mask matrix for optimization.
"""
return self.mask_able, self.mul_mask_tensor

def update(self):
"""
Update the mask tensor.
"""
self.mul_mask_tensor = Tensor(self.mul_mask_array_flat.reshape(self.mul_mask_array_shape), mstype.float32)

class DeWeightInCell(Cell):
"""
Define the mask matrix for de-weight masking.

Args:
array (numpy.ndarray): The mask array.
"""
def __init__(self, array):
super(DeWeightInCell, self).__init__()
self.add_mask_array_shape = array.shape
add_mask_array = array.copy()
self.add_mask_array_flat = add_mask_array.flatten()
self.add_mask_tensor = Tensor(array, mstype.float32)
self.mask_able = False
self.zero_mask_tensor = Tensor(np.zeros(array.shape, np.float32), mstype.float32)
self.just_update = -1.0

def construct(self):
"""
Return the mask matrix for optimization.
"""
if self.just_update > 0.0:
return self.mask_able, self.add_mask_tensor
return self.mask_able, self.zero_mask_tensor

def update(self):
"""
Update the mask tensor.
"""
self.just_update = 1.0
self.add_mask_tensor = Tensor(self.add_mask_array_flat.reshape(self.add_mask_array_shape), mstype.float32)

def reset_zeros(self):
"""
Make the de-weight operation expired.
"""
self.just_update = -1.0

+ 0
- 0
mindarmour/privacy/sup_privacy/train/__init__.py View File


+ 325
- 0
mindarmour/privacy/sup_privacy/train/model.py View File

@@ -0,0 +1,325 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
suppress-basd privacy model.
"""
from easydict import EasyDict as edict

from mindspore.train.model import Model
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.train.amp import _config_level
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.train.model import ParallelMode
from mindspore.train.amp import _do_keep_batchnorm_fp32
from mindspore.train.amp import _add_loss_network
from mindspore import nn
from mindspore import context
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.parallel._utils import _get_gradients_mean
from mindspore.parallel._utils import _get_device_num
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.nn import Cell
from mindarmour.utils._check_param import check_param_type
from mindarmour.utils.logger import LogUtil
from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl

LOGGER = LogUtil.get_instance()
TAG = 'Mask model'

GRADIENT_CLIP_TYPE = 1
_grad_scale = C.MultitypeFuncGraph("grad_scale")
_reciprocal = P.Reciprocal()


@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
""" grad scaling """
return grad*F.cast(_reciprocal(scale), F.dtype(grad))


class SuppressModel(Model):
"""
This class is overload mindspore.train.model.Model.

Args:
network (Cell): The training network.
loss_fn (Cell): Computes softmax cross entropy between logits and labels.
optimizer (Optimizer): optimizer instance.
metrics (Union[dict, set]): Calculates the accuracy for classification and multilabel data.
kwargs: Keyword parameters used for creating a suppress model.

Examples:
networks_l5 = LeNet5()
masklayers = []
masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10))
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train",
end_epoch=10,
batch_num=(int)(10000/cfg.batch_size),
start_epoch=3,
mask_times=100,
networks=networks_l5,
lr=lr,
sparse_end=0.90,
sparse_start=0.0,
mask_layers=masklayers)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0)
config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10)
model_instance = SuppressModel(network=networks_l5,
loss_fn=net_loss,
optimizer=net_opt,
metrics={"Accuracy": Accuracy()})
model_instance.link_suppress_ctrl(suppress_ctrl_instance)
ds_train = generate_mnist_dataset("./MNIST_unzip/train",
batch_size=cfg.batch_size, repeat_size=1, samples=samples)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory="./trained_ckpt_file/",
config=config_ck)
model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker],
dataset_sink_mode=False)
"""

def __init__(self,
network=None,
**kwargs):

check_param_type('networks', network, Cell)

self.network_end = None
self._train_one_step = None

super(SuppressModel, self).__init__(network, **kwargs)

def link_suppress_ctrl(self, suppress_pri_ctrl):
"""
Link self and SuppressCtrl instance.

Args:
suppress_pri_ctrl (SuppressCtrl): SuppressCtrl instance.
"""
check_param_type('suppress_pri_ctrl', suppress_pri_ctrl, Cell)
if not isinstance(suppress_pri_ctrl, SuppressCtrl):
msg = "SuppressCtrl instance error!"
LOGGER.error(TAG, msg)
raise ValueError(msg)

suppress_pri_ctrl.model = self
if self._train_one_step is not None:
self._train_one_step.link_suppress_ctrl(suppress_pri_ctrl)

def _build_train_network(self):
"""Build train network"""
network = self._network

ms_mode = context.get_context("mode")
if ms_mode != context.PYNATIVE_MODE:
raise ValueError("Only PYNATIVE_MODE is supported for suppress privacy now.")

if self._optimizer:
network = self._amp_build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
raise ValueError("_optimizer is none")

self._train_one_step = network

if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()

self.network_end = self._train_one_step.network
return network

def _amp_build_train_network(self, network, optimizer, loss_fn=None,
level='O0', **kwargs):
"""
Build the mixed precision training cell automatically.

Args:
network (Cell): Definition of the network.
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None,
the `network` should have the loss inside. Default: None.
optimizer (Optimizer): Optimizer to update the Parameter.
level (str): Supports [O0, O2]. Default: "O0".
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm and `loss_fn`
(if set) run in float32, using dynamic loss scale.
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16`
or `mstype.float32`. If set to `mstype.float16`, use `float16`
mode to train. If set, overwrite the level setting.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set,
overwrite the level setting.
loss_scale_manager (Union[None, LossScaleManager]): If None, not
scale the loss, or else scale the loss by LossScaleManager.
If set, overwrite the level setting.
"""
validator.check_value_type('network', network, nn.Cell, None)
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
self._check_kwargs(kwargs)
config = dict(_config_level[level], **kwargs)
config = edict(config)

if config.cast_model_type == mstype.float16:
network.to_float(mstype.float16)

if config.keep_batchnorm_fp32:
_do_keep_batchnorm_fp32(network)

if loss_fn:
network = _add_loss_network(network, loss_fn,
config.cast_model_type)

if _get_parallel_mode() in (
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network = _VirtualDatasetCell(network)

loss_scale = 1.0
if config.loss_scale_manager is not None:
print("----model config have loss scale manager !")
network = TrainOneStepCell(network, optimizer, sens=loss_scale).set_train()
return network


class _TupleAdd(nn.Cell):
"""
Add two tuple of data.
"""
def __init__(self):
super(_TupleAdd, self).__init__()
self.add = P.TensorAdd()
self.hyper_map = C.HyperMap()

def construct(self, input1, input2):
"""Add two tuple of data."""
out = self.hyper_map(self.add, input1, input2)
return out


class _TupleMul(nn.Cell):
"""
Mul two tuple of data.
"""
def __init__(self):
super(_TupleMul, self).__init__()
self.mul = P.Mul()
self.hyper_map = C.HyperMap()

def construct(self, input1, input2):
"""Add two tuple of data."""
out = self.hyper_map(self.mul, input1, input2)
#print(out)
return out

# come from nn.cell_wrapper.TrainOneStepCell
class TrainOneStepCell(Cell):
r"""
Network training package class.

Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
Backward graph will be created in the construct function to do parameter updating. Different
parallel modes are available to run the training.

Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.

Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.

Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True) # for mindspore 0.7x
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self._tuple_add = _TupleAdd()
self._tuple_mul = _TupleMul()
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean() # for mindspore 0.7x
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)

self.do_privacy = False
self.grad_mask_tup = () # tuple containing grad_mask(cell)
self.de_weight_tup = () # tuple containing de_weight(cell)
self._suppress_pri_ctrl = None

def link_suppress_ctrl(self, suppress_pri_ctrl):
"""
Set Suppress Mask for grad_mask_tup and de_weight_tup.

Args:
suppress_pri_ctrl (SuppressCtrl): SuppressCtrl instance.
"""
self._suppress_pri_ctrl = suppress_pri_ctrl
if self._suppress_pri_ctrl.grads_mask_list:
for grad_mask_cell in self._suppress_pri_ctrl.grads_mask_list:
self.grad_mask_tup += (grad_mask_cell,)
self.do_privacy = True
for de_weight_cell in self._suppress_pri_ctrl.de_weight_mask_list:
self.de_weight_tup += (de_weight_cell,)
else:
self.do_privacy = False

def construct(self, data, label):
"""
Construct a compute flow.
"""
weights = self.weights
loss = self.network(data, label)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(data, label, sens)

new_grads = ()
m = 0
for grad in grads:
if self.do_privacy and self._suppress_pri_ctrl.mask_started:
enable_mask, grad_mask = self.grad_mask_tup[m]()
enable_de_weight, de_weight_array = self.de_weight_tup[m]()

if enable_mask and enable_de_weight:
grad_n = self._tuple_add(de_weight_array, self._tuple_mul(grad, grad_mask))
new_grads = new_grads + (grad_n,)
else:
new_grads = new_grads + (grad,)
else:
new_grads = new_grads + (grad,)
m = m + 1

if self.reducer_flag:
new_grads = self.grad_reducer(new_grads)

return F.depend(loss, self.optimizer(new_grads))

+ 3
- 3
tests/ut/python/privacy/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This package includes unit tests for differential-privacy training and
privacy breach estimation.
This package includes unit tests for differential-privacy training,
suppress-privacy training and privacy breach estimation.
"""

+ 16
- 0
tests/ut/python/privacy/sup_privacy/__init__.py View File

@@ -0,0 +1,16 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This package includes unit tests for suppress-privacy training.
"""

+ 85
- 0
tests/ut/python/privacy/sup_privacy/test_model_train.py View File

@@ -0,0 +1,85 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Suppress Privacy model test.
"""
import pytest
import numpy as np

from mindspore import nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig
from mindspore.train.callback import LossMonitor
from mindspore.nn.metrics import Accuracy
import mindspore.dataset as ds

from ut.python.utils.mock_net import Net as LeNet5

from mindarmour.privacy.sup_privacy import SuppressModel
from mindarmour.privacy.sup_privacy import SuppressMasker
from mindarmour.privacy.sup_privacy import SuppressPrivacyFactory
from mindarmour.privacy.sup_privacy import MaskLayerDes

def dataset_generator(batch_size, batches):
"""mock training data."""
data = np.random.random((batches*batch_size, 1, 32, 32)).astype(
np.float32)
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32)
for i in range(batches):
yield data[i*batch_size:(i + 1)*batch_size],\
label[i*batch_size:(i + 1)*batch_size]

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_suppress_model_with_pynative_mode():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
networks_l5 = LeNet5()
epochs = 5
batch_num = 10
batch_size = 32
mask_times = 10
lr = 0.01
masklayers_lenet5 = []
masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, False, -1))
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train",
end_epoch=epochs,
batch_num=batch_num,
start_epoch=1,
mask_times=mask_times,
networks=networks_l5,
lr=lr,
sparse_end=0.50,
sparse_start=0.0,
mask_layers=masklayers_lenet5)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.SGD(networks_l5.trainable_params(), lr)
model_instance = SuppressModel(
network=networks_l5,
loss_fn=net_loss,
optimizer=net_opt,
metrics={"Accuracy": Accuracy()})
model_instance.link_suppress_ctrl(suppress_ctrl_instance)
suppress_masker = SuppressMasker(model=model_instance, suppress_ctrl=suppress_ctrl_instance)
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory="./trained_ckpt_file/",
config=config_ck)
ds_train = ds.GeneratorDataset(dataset_generator(batch_size, batch_num), ['data', 'label'])

model_instance.train(epochs, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker],
dataset_sink_mode=False)

Loading…
Cancel
Save