From c13cd9391a612f8e1f46feb60c3a91c4cd7ed2ed Mon Sep 17 00:00:00 2001 From: itcomee Date: Sun, 29 Nov 2020 00:24:47 +0800 Subject: [PATCH] suppress privacy model, refer to "Deep Leakage from Gradients" https://arxiv.org/abs/1906.08935 --- examples/privacy/README.md | 51 +- examples/privacy/sup_privacy/__init__.py | 0 examples/privacy/sup_privacy/sup_privacy.py | 154 +++++ examples/privacy/sup_privacy/sup_privacy_config.py | 32 ++ mindarmour/privacy/sup_privacy/__init__.py | 27 + .../privacy/sup_privacy/mask_monitor/__init__.py | 0 .../privacy/sup_privacy/mask_monitor/masker.py | 98 ++++ .../privacy/sup_privacy/sup_ctrl/__init__.py | 0 mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py | 640 +++++++++++++++++++++ mindarmour/privacy/sup_privacy/train/__init__.py | 0 mindarmour/privacy/sup_privacy/train/model.py | 325 +++++++++++ tests/ut/python/privacy/__init__.py | 6 +- tests/ut/python/privacy/sup_privacy/__init__.py | 16 + .../python/privacy/sup_privacy/test_model_train.py | 85 +++ 14 files changed, 1416 insertions(+), 18 deletions(-) create mode 100644 examples/privacy/sup_privacy/__init__.py create mode 100644 examples/privacy/sup_privacy/sup_privacy.py create mode 100644 examples/privacy/sup_privacy/sup_privacy_config.py create mode 100644 mindarmour/privacy/sup_privacy/__init__.py create mode 100644 mindarmour/privacy/sup_privacy/mask_monitor/__init__.py create mode 100644 mindarmour/privacy/sup_privacy/mask_monitor/masker.py create mode 100644 mindarmour/privacy/sup_privacy/sup_ctrl/__init__.py create mode 100644 mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py create mode 100644 mindarmour/privacy/sup_privacy/train/__init__.py create mode 100644 mindarmour/privacy/sup_privacy/train/model.py create mode 100644 tests/ut/python/privacy/sup_privacy/__init__.py create mode 100644 tests/ut/python/privacy/sup_privacy/test_model_train.py diff --git a/examples/privacy/README.md b/examples/privacy/README.md index be94d59..837d4b5 100644 --- a/examples/privacy/README.md +++ b/examples/privacy/README.md @@ -1,33 +1,54 @@ # Application demos of privacy stealing and privacy protection + ## Introduction + Although machine learning could obtain a generic model based on training data, it has been proved that the trained - model may disclose the information of training data (such as the membership inference attack). Differential - privacy training - is an effective - method proposed - to overcome this problem, in which Gaussian noise is added while training. There are mainly three parts for - differential privacy(DP) training: noise-generating mechanism, DP optimizer and DP monitor. We have implemented - a novel noise-generating mechanisms: adaptive decay noise mechanism. DP - monitor is used to compute the privacy budget while training. + model may disclose the information of training data (such as the membership inference attack). +Differential privacy training is an effective method proposed to overcome this problem, in which Gaussian noise is + added while training. There are mainly three parts for differential privacy(DP) training: noise-generating + mechanism, DP optimizer and DP monitor. We have implemented a novel noise-generating mechanisms: adaptive decay + noise mechanism. DP monitor is used to compute the privacy budget while training. +Suppress Privacy training is a novel method to protect privacy distinct from the noise addition method + (such as DP), in which the negligible model parameter is removed gradually to achieve a better balance between + accuracy and privacy. ## 1. Adaptive decay DP training + With adaptive decay mechanism, the magnitude of the Gaussian noise would be decayed as the training step grows, which resulting a stable convergence. + ```sh -$ cd examples/privacy/diff_privacy -$ python lenet5_dp_ada_gaussian.py +cd examples/privacy/diff_privacy +python lenet5_dp_ada_gaussian.py ``` + ## 2. Adaptive norm clip training + With adaptive norm clip mechanism, the norm clip of the gradients would be changed according to the norm values of them, which can adjust the ratio of noise and original gradients. + ```sh -$ cd examples/privacy/diff_privacy -$ python lenet5_dp.py +cd examples/privacy/diff_privacy +python lenet5_dp.py ``` + ## 3. Membership inference evaluation + By this evaluation method, we could judge whether a sample is belongs to training dataset or not. + +```sh +cd examples/privacy/membership_inference_attack +python train.py --data_path home_path_to_cifar100 --ckpt_path ./ +python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt +``` + +## 4. suppress privacy training + +With suppress privacy mechanism, the values of some trainable parameters (such as conv layers and fully connected + layers) are set to zero as the training step grows, which can + achieve a better balance between accuracy and privacy + ```sh -$ cd examples/privacy/membership_inference_attack -$ python train.py --data_path home_path_to_cifar100 --ckpt_path ./ -$ python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt +cd examples/privacy/sup_privacy +python sup_privacy.py ``` diff --git a/examples/privacy/sup_privacy/__init__.py b/examples/privacy/sup_privacy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/privacy/sup_privacy/sup_privacy.py b/examples/privacy/sup_privacy/sup_privacy.py new file mode 100644 index 0000000..cfe0466 --- /dev/null +++ b/examples/privacy/sup_privacy/sup_privacy.py @@ -0,0 +1,154 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Training example of suppress-based privacy. +""" +import os +import mindspore.nn as nn +from mindspore import context +from mindspore.train.callback import ModelCheckpoint +from mindspore.train.callback import CheckpointConfig +from mindspore.train.callback import LossMonitor +from mindspore.nn.metrics import Accuracy +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset as ds +import mindspore.dataset.vision.c_transforms as CV +import mindspore.dataset.transforms.c_transforms as C +from mindspore.dataset.vision.utils import Inter +import mindspore.common.dtype as mstype + +from examples.common.networks.lenet5.lenet5_net import LeNet5 + +from sup_privacy_config import mnist_cfg as cfg +from mindarmour.privacy.sup_privacy import SuppressModel +from mindarmour.privacy.sup_privacy import SuppressMasker +from mindarmour.privacy.sup_privacy import SuppressPrivacyFactory +from mindarmour.privacy.sup_privacy import MaskLayerDes + +from mindarmour.utils.logger import LogUtil + +LOGGER = LogUtil.get_instance() +LOGGER.set_level('INFO') +TAG = 'Lenet5_Suppress_train' + + +def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, samples=None, num_parallel_workers=1, sparse=True): + """ + create dataset for training or testing + """ + # define dataset + ds1 = ds.MnistDataset(data_path, num_samples=samples) + + # define operation parameters + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width), + interpolation=Inter.LINEAR) + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + # apply map operations on images + if not sparse: + one_hot_enco = C.OneHot(10) + ds1 = ds1.map(input_columns="label", operations=one_hot_enco, num_parallel_workers=num_parallel_workers) + type_cast_op = C.TypeCast(mstype.float32) + ds1 = ds1.map(input_columns="label", operations=type_cast_op, + num_parallel_workers=num_parallel_workers) + ds1 = ds1.map(input_columns="image", operations=resize_op, + num_parallel_workers=num_parallel_workers) + ds1 = ds1.map(input_columns="image", operations=rescale_op, + num_parallel_workers=num_parallel_workers) + ds1 = ds1.map(input_columns="image", operations=hwc2chw_op, + num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + buffer_size = 10000 + ds1 = ds1.shuffle(buffer_size=buffer_size) + ds1 = ds1.batch(batch_size, drop_remainder=True) + ds1 = ds1.repeat(repeat_size) + + return ds1 + +def mnist_suppress_train(epoch_size=10, start_epoch=3, lr=0.05, samples=10000, mask_times=1000, + sparse_thd=0.90, sparse_start=0.0, masklayers=None): + """ + local train by suppress-based privacy + """ + + networks_l5 = LeNet5() + suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", + end_epoch=epoch_size, + batch_num=(int)(samples/cfg.batch_size), + start_epoch=start_epoch, + mask_times=mask_times, + networks=networks_l5, + lr=lr, + sparse_end=sparse_thd, + sparse_start=sparse_start, + mask_layers=masklayers) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.SGD(networks_l5.trainable_params(), lr) + config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), + keep_checkpoint_max=10) + + # Create the SuppressModel model for training. + model_instance = SuppressModel(network=networks_l5, + loss_fn=net_loss, + optimizer=net_opt, + metrics={"Accuracy": Accuracy()}) + model_instance.link_suppress_ctrl(suppress_ctrl_instance) + + # Create a Masker for Suppress training. The function of the Masker is to + # enforce suppress operation while training. + suppress_masker = SuppressMasker(model=model_instance, suppress_ctrl=suppress_ctrl_instance) + + mnist_path = "./MNIST_unzip/" #"../../MNIST_unzip/" + ds_train = generate_mnist_dataset(os.path.join(mnist_path, "train"), + batch_size=cfg.batch_size, repeat_size=1, samples=samples) + + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + directory="./trained_ckpt_file/", + config=config_ck) + + print("============== Starting SUPP Training ==============") + model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], + dataset_sink_mode=False) + + print("============== Starting SUPP Testing ==============") + ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + param_dict = load_checkpoint(ckpt_file_name) + load_param_into_net(networks_l5, param_dict) + ds_eval = generate_mnist_dataset(os.path.join(mnist_path, 'test'), + batch_size=cfg.batch_size) + acc = model_instance.eval(ds_eval, dataset_sink_mode=False) + print("============== SUPP Accuracy: %s ==============", acc) + +if __name__ == "__main__": + # This configure can run in pynative mode + context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) + + masklayers_lenet5 = [] # determine which layer should be masked + + masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, True, 10)) + masklayers_lenet5.append(MaskLayerDes("conv2.weight", False, True, 150)) + masklayers_lenet5.append(MaskLayerDes("fc1.weight", True, False, -1)) + masklayers_lenet5.append(MaskLayerDes("fc2.weight", True, False, -1)) + masklayers_lenet5.append(MaskLayerDes("fc3.weight", True, False, 50)) + + # do suppreess privacy train, with stronger privacy protection and better performance than Differential Privacy + mnist_suppress_train(10, 3, 0.10, 60000, 1000, 0.95, 0.0, masklayers=masklayers_lenet5) # used diff --git a/examples/privacy/sup_privacy/sup_privacy_config.py b/examples/privacy/sup_privacy/sup_privacy_config.py new file mode 100644 index 0000000..88c9df2 --- /dev/null +++ b/examples/privacy/sup_privacy/sup_privacy_config.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in sup_privacy.py +""" + +from easydict import EasyDict as edict + +mnist_cfg = edict({ + 'num_classes': 10, # the number of classes of model's output + 'epoch_size': 1, # training epochs + 'batch_size': 32, # batch size for training + 'image_height': 32, # the height of training samples + 'image_width': 32, # the width of training samples + 'save_checkpoint_steps': 1875, # the interval steps for saving checkpoint file of the model + 'keep_checkpoint_max': 10, # the maximum number of checkpoint files would be saved + 'device_target': 'Ascend', # device used + 'data_path': './MNIST_unzip', # the path of training and testing data set + 'dataset_sink_mode': False, # whether deliver all training data to device one time +}) diff --git a/mindarmour/privacy/sup_privacy/__init__.py b/mindarmour/privacy/sup_privacy/__init__.py new file mode 100644 index 0000000..ce07347 --- /dev/null +++ b/mindarmour/privacy/sup_privacy/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module provides Suppress Privacy feature to protect user privacy. +""" +from .mask_monitor.masker import SuppressMasker +from .train.model import SuppressModel +from .sup_ctrl.conctrl import SuppressPrivacyFactory +from .sup_ctrl.conctrl import SuppressCtrl +from .sup_ctrl.conctrl import MaskLayerDes + +__all__ = ['SuppressMasker', + 'SuppressModel', + 'SuppressPrivacyFactory', + 'SuppressCtrl', + 'MaskLayerDes'] diff --git a/mindarmour/privacy/sup_privacy/mask_monitor/__init__.py b/mindarmour/privacy/sup_privacy/mask_monitor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mindarmour/privacy/sup_privacy/mask_monitor/masker.py b/mindarmour/privacy/sup_privacy/mask_monitor/masker.py new file mode 100644 index 0000000..0a3a4e7 --- /dev/null +++ b/mindarmour/privacy/sup_privacy/mask_monitor/masker.py @@ -0,0 +1,98 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Masker module of suppress-based privacy.. +""" +from mindspore.train.callback import Callback +from mindarmour.utils.logger import LogUtil +from mindarmour.utils._check_param import check_param_type +from mindarmour.privacy.sup_privacy.train.model import SuppressModel +from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl + +LOGGER = LogUtil.get_instance() +TAG = 'suppress masker' + +class SuppressMasker(Callback): + """ + Args: + args (Union[int, float, numpy.ndarray, list, str]): Parameters + used for creating a suppress privacy monitor. + kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword + parameters used for creating a suppress privacy monitor. + model (SuppressModel): SuppressModel instance. + suppress_ctrl (SuppressCtrl): SuppressCtrl instance. + + Examples: + networks_l5 = LeNet5() + masklayers = [] + masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) + suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", + end_epoch=10, + batch_num=(int)(10000/cfg.batch_size), + start_epoch=3, + mask_times=100, + networks=networks_l5, + lr=lr, + sparse_end=0.90, + sparse_start=0.0, + mask_layers=masklayers) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) + config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) + model_instance = SuppressModel(network=networks_l5, + loss_fn=net_loss, + optimizer=net_opt, + metrics={"Accuracy": Accuracy()}) + model_instance.link_suppress_ctrl(suppress_ctrl_instance) + ds_train = generate_mnist_dataset("./MNIST_unzip/train", + batch_size=cfg.batch_size, repeat_size=1, samples=samples) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + directory="./trained_ckpt_file/", + config=config_ck) + model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], + dataset_sink_mode=False) + """ + + def __init__(self, model=None, suppress_ctrl=None): + + super(SuppressMasker, self).__init__() + + self._model = check_param_type('model', model, SuppressModel) + self._suppress_ctrl = check_param_type('suppress_ctrl', suppress_ctrl, SuppressCtrl) + + def step_end(self, run_context): + """ + Update mask matrix tensor used for SuppressModel instance. + + Args: + run_context (RunContext): Include some information of the model. + """ + cb_params = run_context.original_args() + cur_step = cb_params.cur_step_num + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if self._suppress_ctrl is not None and self._model.network_end is not None: + self._suppress_ctrl.update_status(cb_params.cur_epoch_num, cur_step, cur_step_in_epoch) + + if not self._suppress_ctrl.mask_initialized: + raise ValueError("Not initialize network!") + if self._suppress_ctrl.to_do_mask: + self._suppress_ctrl.update_mask(self._suppress_ctrl.networks, cur_step) + LOGGER.info(TAG, "suppress update") + elif not self._suppress_ctrl.to_do_mask and self._suppress_ctrl.mask_started: + self._suppress_ctrl.reset_zeros() + if cur_step_in_epoch % 100 == 1: + self._suppress_ctrl.calc_theoretical_sparse_for_conv() + _, _, _ = self._suppress_ctrl.calc_actual_sparse_for_conv( + self._suppress_ctrl.networks) diff --git a/mindarmour/privacy/sup_privacy/sup_ctrl/__init__.py b/mindarmour/privacy/sup_privacy/sup_ctrl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py b/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py new file mode 100644 index 0000000..7bfe98f --- /dev/null +++ b/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py @@ -0,0 +1,640 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +control function of suppress-based privacy. +""" +import math +import numpy as np + +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from mindspore.nn import Cell + +from mindarmour.utils.logger import LogUtil +from mindarmour.utils._check_param import check_int_positive, check_value_positive, \ + check_value_non_negative, check_param_type +LOGGER = LogUtil.get_instance() +TAG = 'Suppression training.' + + +class SuppressPrivacyFactory: + """ Factory class of SuppressCtrl mechanisms""" + def __init__(self): + pass + + @staticmethod + def create(policy="local_train", end_epoch=10, batch_num=2, start_epoch=3, mask_times=100, networks=None, + lr=0.05, sparse_end=0.60, sparse_start=0.0, mask_layers=None): + """ + Args: + policy (str): Training policy for suppress privacy training. "local_train" means local training. + end_epoch (int): The last epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 . + batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size . + start_epoch (int): The first epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 . + mask_times (int): The num of suppress operations. + networks (Cell): The training network. + lr (float): Learning rate. + sparse_end (float): The sparsity to reach, 0.0 <= sparse_start < sparse_end < 1.0 . + sparse_start (float): The sparsity to start, 0.0 <= sparse_start < sparse_end < 1.0 . + mask_layers (list): Description of the training network layers that need to be suppressed. + + Returns: + SuppressCtrl, class of Suppress Privavy Mechanism. + + Examples: + networks_l5 = LeNet5() + masklayers = [] + masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) + suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", + end_epoch=10, + batch_num=(int)(10000/cfg.batch_size), + start_epoch=3, + mask_times=100, + networks=networks_l5, + lr=lr, + sparse_end=0.90, + sparse_start=0.0, + mask_layers=masklayers) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) + config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) + model_instance = SuppressModel(network=networks_l5, + loss_fn=net_loss, + optimizer=net_opt, + metrics={"Accuracy": Accuracy()}) + model_instance.link_suppress_ctrl(suppress_ctrl_instance) + ds_train = generate_mnist_dataset("./MNIST_unzip/train", + batch_size=cfg.batch_size, repeat_size=1, samples=samples) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + directory="./trained_ckpt_file/", + config=config_ck) + model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], + dataset_sink_mode=False) + """ + if policy == "local_train": + return SuppressCtrl(networks, end_epoch, batch_num, start_epoch, mask_times, lr, sparse_end, + sparse_start, mask_layers) + msg = "Only local training is supported now, federal training will be supported " \ + "in the future. But got {}.".format(policy) + LOGGER.error(TAG, msg) + raise ValueError(msg) + +class SuppressCtrl(Cell): + """ + Args: + networks (Cell): The training network. + end_epoch (int): The last epoch in suppress operations. + batch_num (int): The num of grad operation in an epoch. + mask_start_epoch (int): The first epoch in suppress operations. + mask_times (int): The num of suppress operations. + lr (Union[float, int]): Learning rate. + sparse_end (Union[float, int]): The sparsity to reach. + sparse_start (float): The sparsity to start. + mask_layers (list): Description of those layers that need to be suppressed. + """ + def __init__(self, networks, end_epoch, batch_num, mask_start_epoch=3, mask_times=500, lr=0.05, + sparse_end=0.60, + sparse_start=0.0, + mask_layers=None): + super(SuppressCtrl, self).__init__() + self.networks = check_param_type('networks', networks, Cell) + self.mask_end_epoch = check_int_positive('end_epoch', end_epoch) + self.batch_num = check_int_positive('batch_num', batch_num) + self.mask_start_epoch = check_int_positive('mask_start_epoch', mask_start_epoch) + self.mask_times = check_int_positive('mask_times', mask_times) + self.lr = check_value_positive('lr', lr) + self.sparse_end = check_value_non_negative('sparse_end', sparse_end) + self.sparse_start = check_value_non_negative('sparse_start', sparse_start) + self.mask_layers = check_param_type('mask_layers', mask_layers, list) + + self.weight_lower_bound = 0.005 # all network weight will be larger than this value + self.sparse_vibra = 0.02 # the sparsity may have certain range of variations + self.sparse_valid_max_weight = 0.20 # if max network weight is less than this value, suppress operation stop temporarily + self.add_noise_thd = 0.50 # if network weight is more than this value, noise is forced + self.noise_volume = 0.01 # noise volume 0.01 + self.base_ground_thd = 0.0000001 # if network weight is less than this value, will be considered as 0 + self.model = None # SuppressModel instance + self.grads_mask_list = [] # list for Grad Mask Matrix tensor + self.de_weight_mask_list = [] # list for weight Mask Matrix tensor + self.to_do_mask = False # the flag means suppress operation is toggled immediately + self.mask_started = False # the flag means suppress operation has been started + self.mask_start_step = 0 # suppress operation is actually started at this step + self.mask_prev_step = 0 # previous suppress operation is done at this step + self.cur_sparse = 0.0 # current sparsity to which one suppress will get + self.mask_all_steps = (self.mask_end_epoch-mask_start_epoch+1)*batch_num # the amount of step contained in all suppress operation + self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation + self.mask_initialized = False # flag means the initialization is done + + if self.mask_start_epoch > self.mask_end_epoch: + msg = "start_epoch error: {}".format(self.mask_start_epoch) + LOGGER.error(TAG, msg) + raise ValueError(msg) + + if self.mask_end_epoch > 100: + msg = "end_epoch error: {}".format(self.mask_end_epoch) + LOGGER.error(TAG, msg) + raise ValueError(msg) + + if self.mask_step_interval < 0: + msg = "step_interval error: {}".format(self.mask_step_interval) + LOGGER.error(TAG, msg) + raise ValueError(msg) + + if self.sparse_end > 1.00 or self.sparse_end <= 0: + msg = "sparse_end error: {}".format(self.sparse_end) + LOGGER.error(TAG, msg) + raise ValueError(msg) + + if self.sparse_start >= self.sparse_end: + msg = "sparse_start error: {}".format(self.sparse_start) + LOGGER.error(TAG, msg) + raise ValueError(msg) + + if mask_layers is not None: + mask_layer_id = 0 + for one_mask_layer in mask_layers: + if not isinstance(one_mask_layer, MaskLayerDes): + msg = "mask_layer instance error!" + LOGGER.error(TAG, msg) + raise ValueError(msg) + layer_name = one_mask_layer.layer_name + mask_layer_id2 = 0 + for one_mask_layer_2 in mask_layers: + if mask_layer_id != mask_layer_id2 and layer_name in one_mask_layer_2.layer_name: + msg = "mask_layers repeat item : {} in {} and {}".format(layer_name, + mask_layer_id, + mask_layer_id2) + LOGGER.error(TAG, msg) + raise ValueError(msg) + mask_layer_id2 = mask_layer_id2 + 1 + mask_layer_id = mask_layer_id + 1 + + if networks is not None: + m = 0 + for layer in networks.get_parameters(expand=True): + one_mask_layer = None + if mask_layers is not None: + one_mask_layer = get_one_mask_layer(mask_layers, layer.name) + if one_mask_layer is not None and not one_mask_layer.inited: + one_mask_layer.inited = True + shape = P.Shape()(layer) + mul_mask_array = np.ones(shape, dtype=np.float32) + grad_mask_cell = GradMaskInCell(mul_mask_array, + one_mask_layer.is_add_noise, + one_mask_layer.is_lower_clip, + one_mask_layer.min_num, + one_mask_layer.upper_bound) + grad_mask_cell.mask_able = True + self.grads_mask_list.append(grad_mask_cell) + add_mask_array = np.zeros(shape, dtype=np.float32) + + de_weight_cell = DeWeightInCell(add_mask_array) + de_weight_cell.mask_able = True + self.de_weight_mask_list.append(de_weight_cell) + msg = "do mask {}, {}".format(m, one_mask_layer.layer_name) + LOGGER.info(TAG, msg) + elif one_mask_layer is not None and one_mask_layer.inited: + msg = "repeated match masked setting {}=>{}.".format(one_mask_layer.layer_name, layer.name) + LOGGER.error(TAG, msg) + raise ValueError(msg) + else: + shape = np.shape([1]) + mul_mask_array = np.ones(shape, dtype=np.float32) + grad_mask_cell = GradMaskInCell(mul_mask_array, False, False, -1) + grad_mask_cell.mask_able = False + + self.grads_mask_list.append(grad_mask_cell) + add_mask_array = np.zeros(shape, dtype=np.float32) + de_weight_cell = DeWeightInCell(add_mask_array) + de_weight_cell.mask_able = False + self.de_weight_mask_list.append(de_weight_cell) + m += 1 + self.mask_initialized = True + msg = "init SuppressCtrl by networks" + LOGGER.info(TAG, msg) + msg = "complete init mask for lenet5.step_interval: {}".format(self.mask_step_interval) + LOGGER.info(TAG, msg) + + for one_mask_layer in mask_layers: + if not one_mask_layer.inited: + msg = "can't match this mask layer: {} ".format(one_mask_layer.layer_name) + LOGGER.error(TAG, msg) + raise ValueError(msg) + + def update_status(self, cur_epoch, cur_step, cur_step_in_epoch): + """ + Update the suppress operation status. + + Args: + cur_epoch (int): Current epoch of the whole training process. + cur_step (int): Current step of the whole training process. + cur_step_in_epoch (int): Current step of the current epoch. + """ + if not self.mask_initialized: + self.mask_started = False + elif (self.mask_start_epoch <= cur_epoch <= self.mask_end_epoch) or self.mask_started: + if not self.mask_started: + self.mask_started = True + self.mask_start_step = cur_step + if cur_step >= (self.mask_prev_step + self.mask_step_interval): + self.mask_prev_step = cur_step + self.to_do_mask = True + # execute the last suppression operation + elif cur_epoch == self.mask_end_epoch and cur_step_in_epoch == self.batch_num-2: + self.mask_prev_step = cur_step + self.to_do_mask = True + else: + self.to_do_mask = False + else: + self.to_do_mask = False + self.mask_started = False + + def update_mask(self, networks, cur_step): + """ + Update add mask arrays and multiply mask arrays of network layers. + + Args: + networks (Cell): The training network. + cur_step (int): Current epoch of the whole training process. + """ + if self.sparse_end <= 0.0: + return + + self.cur_sparse = self.sparse_end +\ + (self.sparse_start - self.sparse_end)*\ + math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3) + m = 0 + for layer in networks.get_parameters(expand=True): + if self.grads_mask_list[m].mask_able: + weight_array = layer.data.asnumpy() + weight_avg = np.mean(weight_array) + weight_array_flat = weight_array.flatten() + weight_array_flat_abs = np.abs(weight_array_flat) + weight_abs_avg = np.mean(weight_array_flat_abs) + weight_array_flat_abs.sort() + len_array = weight_array.size + weight_abs_max = np.max(weight_array_flat_abs) + if m == 0 and weight_abs_max < self.sparse_valid_max_weight: + msg = "give up this masking .." + LOGGER.info(TAG, msg) + return + if self.grads_mask_list[m].min_num > 0: + sparse_weight_thd, _, actual_stop_pos = self.calc_sparse_thd(weight_array_flat_abs, + self.cur_sparse, m) + else: + actual_stop_pos = int(len_array * self.cur_sparse) + sparse_weight_thd = weight_array_flat_abs[actual_stop_pos] + + self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, m) + + msg = "{} len={}, sparse={}, current sparse thd={}, max={}, avg={}, avg_abs={} \n".format( + layer.name, len_array, actual_stop_pos/len_array, sparse_weight_thd, + weight_abs_max, weight_avg, weight_abs_avg) + LOGGER.info(TAG, msg) + m = m + 1 + + def update_mask_layer(self, weight_array_flat, sparse_weight_thd, sparse_stop_pos, weight_abs_max, layer_index): + """ + Update add mask arrays and multiply mask arrays of one single layer. + + Args: + weight_array (numpy.ndarray): The weight array of layer's parameters. + sparse_weight_thd (float): The weight threshold of sparse operation. + sparse_stop_pos (int): The maximum number of elements to be suppressed. + weight_abs_max (float): The maximum absolute value of weights. + layer_index (int): The index of target layer. + """ + grad_mask_cell = self.grads_mask_list[layer_index] + mul_mask_array_flat = grad_mask_cell.mul_mask_array_flat + de_weight_cell = self.de_weight_mask_list[layer_index] + add_mask_array_flat = de_weight_cell.add_mask_array_flat + min_num = grad_mask_cell.min_num + is_add_noise = grad_mask_cell.is_add_noise + is_lower_clip = grad_mask_cell.is_lower_clip + upper_bound = grad_mask_cell.upper_bound + + if not self.grads_mask_list[layer_index].mask_able: + return + m = 0 + n = 0 + p = 0 + q = 0 + # add noise on weights if not masking or clipping. + weight_noise_bound = min(self.add_noise_thd, max(self.noise_volume*10, weight_abs_max*0.75)) + for i in range(0, weight_array_flat.size): + if abs(weight_array_flat[i]) <= sparse_weight_thd: + if m < weight_array_flat.size - min_num and m < sparse_stop_pos: + # to mask + mul_mask_array_flat[i] = 0.0 + add_mask_array_flat[i] = weight_array_flat[i] / self.lr + m = m + 1 + else: + # not mask + if weight_array_flat[i] > 0.0: + add_mask_array_flat[i] = (weight_array_flat[i] - self.weight_lower_bound) / self.lr + else: + add_mask_array_flat[i] = (weight_array_flat[i] + self.weight_lower_bound) / self.lr + p = p + 1 + elif is_lower_clip and abs(weight_array_flat[i]) <= \ + self.weight_lower_bound and sparse_weight_thd > self.weight_lower_bound*0.5: + # not mask + mul_mask_array_flat[i] = 1.0 + if weight_array_flat[i] > 0.0: + add_mask_array_flat[i] = (weight_array_flat[i] - self.weight_lower_bound) / self.lr + else: + add_mask_array_flat[i] = (weight_array_flat[i] + self.weight_lower_bound) / self.lr + p = p + 1 + elif abs(weight_array_flat[i]) > upper_bound: + mul_mask_array_flat[i] = 1.0 + if weight_array_flat[i] > 0.0: + add_mask_array_flat[i] = (weight_array_flat[i] - upper_bound) / self.lr + else: + add_mask_array_flat[i] = (weight_array_flat[i] + upper_bound) / self.lr + n = n + 1 + else: + # not mask + mul_mask_array_flat[i] = 1.0 + if is_add_noise and abs(weight_array_flat[i]) > weight_noise_bound > 0.0: + # add noise + add_mask_array_flat[i] = np.random.uniform(-self.noise_volume, self.noise_volume) / self.lr + q = q + 1 + else: + add_mask_array_flat[i] = 0.0 + + grad_mask_cell.update() + de_weight_cell.update() + msg = "Dimension of mask tensor is {}D, which located in the {}-th layer of the network. \n The number of " \ + "suppressed elements, max-clip elements, min-clip elements and noised elements are {}, {}, {}, {}"\ + .format(len(grad_mask_cell.mul_mask_array_shape), layer_index, m, n, p, q) + LOGGER.info(TAG, msg) + + def calc_sparse_thd(self, array_flat, sparse_value, layer_index): + """ + Calculate the suppression threshold of one weight array. + + Args: + array_flat (numpy.ndarray): The flattened weight array. + sparse_value (float): The target sparse value of weight array. + + Returns: + - float, the sparse threshold of this array. + + - int, the number of weight elements to be suppressed. + + - int, the larger number of weight elements to be suppressed. + """ + size = len(array_flat) + sparse_max_thd = 1.0 - min(self.grads_mask_list[layer_index].min_num, size) / size + pos = int(size*min(sparse_max_thd, sparse_value)) + thd = array_flat[pos] + farther_stop_pos = int(size*min(sparse_max_thd, max(0, sparse_value + self.sparse_vibra / 2.0))) + return thd, pos, farther_stop_pos + + def reset_zeros(self): + """ + Set add mask arrays to be zero. + """ + for de_weight_cell in self.de_weight_mask_list: + de_weight_cell.reset_zeros() + + def calc_theoretical_sparse_for_conv(self): + """ + Compute actually sparsity of mask matrix for conv1 layer and conv2 layer. + """ + array_mul_mask_flat_conv1 = self.grads_mask_list[0].mul_mask_array_flat + array_mul_mask_flat_conv2 = self.grads_mask_list[1].mul_mask_array_flat + sparse = 0.0 + sparse_value_1 = 0.0 + sparse_value_2 = 0.0 + full = 0.0 + full_conv1 = 0.0 + full_conv2 = 0.0 + for i in range(0, array_mul_mask_flat_conv1.size): + full += 1.0 + full_conv1 += 1.0 + if array_mul_mask_flat_conv1[i] <= 0.0: + sparse += 1.0 + sparse_value_1 += 1.0 + + for i in range(0, array_mul_mask_flat_conv2.size): + full = full + 1.0 + full_conv2 = full_conv2 + 1.0 + if array_mul_mask_flat_conv2[i] <= 0.0: + sparse = sparse + 1.0 + sparse_value_2 += 1.0 + sparse = sparse/full + sparse_value_1 = sparse_value_1/full_conv1 + sparse_value_2 = sparse_value_2/full_conv2 + msg = "conv sparse mask={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2) + LOGGER.info(TAG, msg) + return sparse, sparse_value_1, sparse_value_2 + + def calc_actual_sparse_for_conv(self, networks): + """ + Compute actually sparsity of network for conv1 layer and conv2 layer. + + Args: + networks (Cell): The training network. + """ + sparse = 0.0 + sparse_value_1 = 0.0 + sparse_value_2 = 0.0 + full = 0.0 + full_conv1 = 0.0 + full_conv2 = 0.0 + + array_cur_conv1 = np.ones(np.shape([1]), dtype=np.float32) + array_cur_conv2 = np.ones(np.shape([1]), dtype=np.float32) + for layer in networks.get_parameters(expand=True): + if "conv1.weight" in layer.name: + array_cur_conv1 = layer.data.asnumpy() + if "conv2.weight" in layer.name: + array_cur_conv2 = layer.data.asnumpy() + + array_mul_mask_flat_conv1 = array_cur_conv1.flatten() + array_mul_mask_flat_conv2 = array_cur_conv2.flatten() + + for i in range(0, array_mul_mask_flat_conv1.size): + full += 1.0 + full_conv1 += 1.0 + if abs(array_mul_mask_flat_conv1[i]) <= self.base_ground_thd: + sparse += 1.0 + sparse_value_1 += 1.0 + + for i in range(0, array_mul_mask_flat_conv2.size): + full = full + 1.0 + full_conv2 = full_conv2 + 1.0 + if abs(array_mul_mask_flat_conv2[i]) <= self.base_ground_thd: + sparse = sparse + 1.0 + sparse_value_2 += 1.0 + + sparse = sparse / full + sparse_value_1 = sparse_value_1 / full_conv1 + sparse_value_2 = sparse_value_2 / full_conv2 + msg = "conv sparse fact={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2) + LOGGER.info(TAG, msg) + return sparse, sparse_value_1, sparse_value_2 + + def calc_actual_sparse_for_fc1(self, networks): + self.calc_actual_sparse_for_layer(networks, "fc1.weight") + + def calc_actual_sparse_for_layer(self, networks, layer_name): + """ + Compute actually sparsity of one network layer + + Args: + networks (Cell): The training network. + layer_name (str): The name of target layer. + """ + check_param_type('networks', networks, Cell) + check_param_type('layer_name', layer_name, str) + + sparse = 0.0 + full = 0.0 + + array_cur = None + for layer in networks.get_parameters(expand=True): + if layer_name in layer.name: + array_cur = layer.data.asnumpy() + + if array_cur is None: + msg = "no such layer to calc sparse: {} ".format(layer_name) + LOGGER.info(TAG, msg) + return + + array_cur_flat = array_cur.flatten() + + for i in range(0, array_cur_flat.size): + full += 1.0 + if abs(array_cur_flat[i]) <= self.base_ground_thd: + sparse += 1.0 + + sparse = sparse / full + msg = "{} sparse fact={} ".format(layer_name, sparse) + LOGGER.info(TAG, msg) + +def get_one_mask_layer(mask_layers, layer_name): + """ + Returns the layer definitions that need to be suppressed. + + Args: + mask_layers (list): The layers that need to be suppressed. + layer_name (str): The name of target layer. + + Returns: + Union[MaskLayerDes, None], the layer definitions that need to be suppressed. + """ + for each_mask_layer in mask_layers: + if each_mask_layer.layer_name in layer_name: + return each_mask_layer + return None + +class MaskLayerDes: + """ + Describe the layer that need to be suppressed. + + Args: + layer_name (str): Layer name, get the name of one layer as following: + for layer in networks.get_parameters(expand=True): + if layer.name == "conv": ... + is_add_noise (bool): If True, the weight of this layer can add noise. + If False, the weight of this layer can not add noise. + is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value. + If False, the weights of this layer won't be clipped. + min_num (int): The number of weights left that not be suppressed, which need to be greater than 0. + upper_bound (float): max value of weight in this layer, default value is 1.20 . + """ + def __init__(self, layer_name, is_add_noise, is_lower_clip, min_num, upper_bound=1.20): + self.layer_name = check_param_type('layer_name', layer_name, str) + self.is_add_noise = check_param_type('is_add_noise', is_add_noise, bool) + self.is_lower_clip = check_param_type('is_lower_clip', is_lower_clip, bool) + self.min_num = check_param_type('min_num', min_num, int) + self.upper_bound = check_value_positive('upper_bound', upper_bound) + self.inited = False + +class GradMaskInCell(Cell): + """ + Define the mask matrix for gradients masking. + + Args: + array (numpy.ndarray): The mask array. + is_add_noise (bool): If True, the weight of this layer can add noise. + If False, the weight of this layer can not add noise. + is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value. + If False, the weights of this layer won't be clipped. + min_num (int): The number of weights left that not be suppressed, which need to be greater than 0. + upper_bound (float): max value of weight in this layer, default value is 1.20 + """ + def __init__(self, array, is_add_noise, is_lower_clip, min_num, upper_bound=1.20): + super(GradMaskInCell, self).__init__() + self.mul_mask_array_shape = array.shape + mul_mask_array = array.copy() + self.mul_mask_array_flat = mul_mask_array.flatten() + self.mul_mask_tensor = Tensor(array, mstype.float32) + self.mask_able = False + self.is_add_noise = is_add_noise + self.is_lower_clip = is_lower_clip + self.min_num = min_num + self.upper_bound = check_value_positive('upper_bound', upper_bound) + + def construct(self): + """ + Return the mask matrix for optimization. + """ + return self.mask_able, self.mul_mask_tensor + + def update(self): + """ + Update the mask tensor. + """ + self.mul_mask_tensor = Tensor(self.mul_mask_array_flat.reshape(self.mul_mask_array_shape), mstype.float32) + +class DeWeightInCell(Cell): + """ + Define the mask matrix for de-weight masking. + + Args: + array (numpy.ndarray): The mask array. + """ + def __init__(self, array): + super(DeWeightInCell, self).__init__() + self.add_mask_array_shape = array.shape + add_mask_array = array.copy() + self.add_mask_array_flat = add_mask_array.flatten() + self.add_mask_tensor = Tensor(array, mstype.float32) + self.mask_able = False + self.zero_mask_tensor = Tensor(np.zeros(array.shape, np.float32), mstype.float32) + self.just_update = -1.0 + + def construct(self): + """ + Return the mask matrix for optimization. + """ + if self.just_update > 0.0: + return self.mask_able, self.add_mask_tensor + return self.mask_able, self.zero_mask_tensor + + def update(self): + """ + Update the mask tensor. + """ + self.just_update = 1.0 + self.add_mask_tensor = Tensor(self.add_mask_array_flat.reshape(self.add_mask_array_shape), mstype.float32) + + def reset_zeros(self): + """ + Make the de-weight operation expired. + """ + self.just_update = -1.0 diff --git a/mindarmour/privacy/sup_privacy/train/__init__.py b/mindarmour/privacy/sup_privacy/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mindarmour/privacy/sup_privacy/train/model.py b/mindarmour/privacy/sup_privacy/train/model.py new file mode 100644 index 0000000..eedf793 --- /dev/null +++ b/mindarmour/privacy/sup_privacy/train/model.py @@ -0,0 +1,325 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +suppress-basd privacy model. +""" +from easydict import EasyDict as edict + +from mindspore.train.model import Model +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel +from mindspore.train.amp import _config_level +from mindspore.common import dtype as mstype +from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell +from mindspore.parallel._utils import _get_parallel_mode +from mindspore.train.model import ParallelMode +from mindspore.train.amp import _do_keep_batchnorm_fp32 +from mindspore.train.amp import _add_loss_network +from mindspore import nn +from mindspore import context +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.parallel._utils import _get_gradients_mean +from mindspore.parallel._utils import _get_device_num +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.nn import Cell +from mindarmour.utils._check_param import check_param_type +from mindarmour.utils.logger import LogUtil +from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl + +LOGGER = LogUtil.get_instance() +TAG = 'Mask model' + +GRADIENT_CLIP_TYPE = 1 +_grad_scale = C.MultitypeFuncGraph("grad_scale") +_reciprocal = P.Reciprocal() + + +@_grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + """ grad scaling """ + return grad*F.cast(_reciprocal(scale), F.dtype(grad)) + + +class SuppressModel(Model): + """ + This class is overload mindspore.train.model.Model. + + Args: + network (Cell): The training network. + loss_fn (Cell): Computes softmax cross entropy between logits and labels. + optimizer (Optimizer): optimizer instance. + metrics (Union[dict, set]): Calculates the accuracy for classification and multilabel data. + kwargs: Keyword parameters used for creating a suppress model. + + Examples: + networks_l5 = LeNet5() + masklayers = [] + masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) + suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", + end_epoch=10, + batch_num=(int)(10000/cfg.batch_size), + start_epoch=3, + mask_times=100, + networks=networks_l5, + lr=lr, + sparse_end=0.90, + sparse_start=0.0, + mask_layers=masklayers) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) + config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) + model_instance = SuppressModel(network=networks_l5, + loss_fn=net_loss, + optimizer=net_opt, + metrics={"Accuracy": Accuracy()}) + model_instance.link_suppress_ctrl(suppress_ctrl_instance) + ds_train = generate_mnist_dataset("./MNIST_unzip/train", + batch_size=cfg.batch_size, repeat_size=1, samples=samples) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + directory="./trained_ckpt_file/", + config=config_ck) + model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], + dataset_sink_mode=False) + """ + + def __init__(self, + network=None, + **kwargs): + + check_param_type('networks', network, Cell) + + self.network_end = None + self._train_one_step = None + + super(SuppressModel, self).__init__(network, **kwargs) + + def link_suppress_ctrl(self, suppress_pri_ctrl): + """ + Link self and SuppressCtrl instance. + + Args: + suppress_pri_ctrl (SuppressCtrl): SuppressCtrl instance. + """ + check_param_type('suppress_pri_ctrl', suppress_pri_ctrl, Cell) + if not isinstance(suppress_pri_ctrl, SuppressCtrl): + msg = "SuppressCtrl instance error!" + LOGGER.error(TAG, msg) + raise ValueError(msg) + + suppress_pri_ctrl.model = self + if self._train_one_step is not None: + self._train_one_step.link_suppress_ctrl(suppress_pri_ctrl) + + def _build_train_network(self): + """Build train network""" + network = self._network + + ms_mode = context.get_context("mode") + if ms_mode != context.PYNATIVE_MODE: + raise ValueError("Only PYNATIVE_MODE is supported for suppress privacy now.") + + if self._optimizer: + network = self._amp_build_train_network(network, + self._optimizer, + self._loss_fn, + level=self._amp_level, + keep_batchnorm_fp32=self._keep_bn_fp32) + else: + raise ValueError("_optimizer is none") + + self._train_one_step = network + + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, + ParallelMode.AUTO_PARALLEL): + network.set_auto_parallel() + + self.network_end = self._train_one_step.network + return network + + def _amp_build_train_network(self, network, optimizer, loss_fn=None, + level='O0', **kwargs): + """ + Build the mixed precision training cell automatically. + + Args: + network (Cell): Definition of the network. + loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, + the `network` should have the loss inside. Default: None. + optimizer (Optimizer): Optimizer to update the Parameter. + level (str): Supports [O0, O2]. Default: "O0". + - O0: Do not change. + - O2: Cast network to float16, keep batchnorm and `loss_fn` + (if set) run in float32, using dynamic loss scale. + cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` + or `mstype.float32`. If set to `mstype.float16`, use `float16` + mode to train. If set, overwrite the level setting. + keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, + overwrite the level setting. + loss_scale_manager (Union[None, LossScaleManager]): If None, not + scale the loss, or else scale the loss by LossScaleManager. + If set, overwrite the level setting. + """ + validator.check_value_type('network', network, nn.Cell, None) + validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) + validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None) + self._check_kwargs(kwargs) + config = dict(_config_level[level], **kwargs) + config = edict(config) + + if config.cast_model_type == mstype.float16: + network.to_float(mstype.float16) + + if config.keep_batchnorm_fp32: + _do_keep_batchnorm_fp32(network) + + if loss_fn: + network = _add_loss_network(network, loss_fn, + config.cast_model_type) + + if _get_parallel_mode() in ( + ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + network = _VirtualDatasetCell(network) + + loss_scale = 1.0 + if config.loss_scale_manager is not None: + print("----model config have loss scale manager !") + network = TrainOneStepCell(network, optimizer, sens=loss_scale).set_train() + return network + + +class _TupleAdd(nn.Cell): + """ + Add two tuple of data. + """ + def __init__(self): + super(_TupleAdd, self).__init__() + self.add = P.TensorAdd() + self.hyper_map = C.HyperMap() + + def construct(self, input1, input2): + """Add two tuple of data.""" + out = self.hyper_map(self.add, input1, input2) + return out + + +class _TupleMul(nn.Cell): + """ + Mul two tuple of data. + """ + def __init__(self): + super(_TupleMul, self).__init__() + self.mul = P.Mul() + self.hyper_map = C.HyperMap() + + def construct(self, input1, input2): + """Add two tuple of data.""" + out = self.hyper_map(self.mul, input1, input2) + #print(out) + return out + +# come from nn.cell_wrapper.TrainOneStepCell +class TrainOneStepCell(Cell): + r""" + Network training package class. + + Wraps the network with an optimizer. The resulting Cell be trained with input data and label. + Backward graph will be created in the construct function to do parameter updating. Different + parallel modes are available to run the training. + + Args: + network (Cell): The training network. + optimizer (Cell): Optimizer for updating the weights. + sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. + + Inputs: + - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + + Outputs: + Tensor, a scalar Tensor with shape :math:`()`. + """ + def __init__(self, network, optimizer, sens=1.0): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.network.add_flags(defer_inline=True) + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation(get_by_list=True, sens_param=True) # for mindspore 0.7x + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + self._tuple_add = _TupleAdd() + self._tuple_mul = _TupleMul() + parallel_mode = _get_parallel_mode() + if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): + self.reducer_flag = True + if self.reducer_flag: + mean = _get_gradients_mean() # for mindspore 0.7x + degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + self.do_privacy = False + self.grad_mask_tup = () # tuple containing grad_mask(cell) + self.de_weight_tup = () # tuple containing de_weight(cell) + self._suppress_pri_ctrl = None + + def link_suppress_ctrl(self, suppress_pri_ctrl): + """ + Set Suppress Mask for grad_mask_tup and de_weight_tup. + + Args: + suppress_pri_ctrl (SuppressCtrl): SuppressCtrl instance. + """ + self._suppress_pri_ctrl = suppress_pri_ctrl + if self._suppress_pri_ctrl.grads_mask_list: + for grad_mask_cell in self._suppress_pri_ctrl.grads_mask_list: + self.grad_mask_tup += (grad_mask_cell,) + self.do_privacy = True + for de_weight_cell in self._suppress_pri_ctrl.de_weight_mask_list: + self.de_weight_tup += (de_weight_cell,) + else: + self.do_privacy = False + + def construct(self, data, label): + """ + Construct a compute flow. + """ + weights = self.weights + loss = self.network(data, label) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(data, label, sens) + + new_grads = () + m = 0 + for grad in grads: + if self.do_privacy and self._suppress_pri_ctrl.mask_started: + enable_mask, grad_mask = self.grad_mask_tup[m]() + enable_de_weight, de_weight_array = self.de_weight_tup[m]() + + if enable_mask and enable_de_weight: + grad_n = self._tuple_add(de_weight_array, self._tuple_mul(grad, grad_mask)) + new_grads = new_grads + (grad_n,) + else: + new_grads = new_grads + (grad,) + else: + new_grads = new_grads + (grad,) + m = m + 1 + + if self.reducer_flag: + new_grads = self.grad_reducer(new_grads) + + return F.depend(loss, self.optimizer(new_grads)) diff --git a/tests/ut/python/privacy/__init__.py b/tests/ut/python/privacy/__init__.py index d57ca64..7d52212 100644 --- a/tests/ut/python/privacy/__init__.py +++ b/tests/ut/python/privacy/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This package includes unit tests for differential-privacy training and -privacy breach estimation. +This package includes unit tests for differential-privacy training, + suppress-privacy training and privacy breach estimation. """ diff --git a/tests/ut/python/privacy/sup_privacy/__init__.py b/tests/ut/python/privacy/sup_privacy/__init__.py new file mode 100644 index 0000000..549de01 --- /dev/null +++ b/tests/ut/python/privacy/sup_privacy/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This package includes unit tests for suppress-privacy training. +""" diff --git a/tests/ut/python/privacy/sup_privacy/test_model_train.py b/tests/ut/python/privacy/sup_privacy/test_model_train.py new file mode 100644 index 0000000..7c8b275 --- /dev/null +++ b/tests/ut/python/privacy/sup_privacy/test_model_train.py @@ -0,0 +1,85 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Suppress Privacy model test. +""" +import pytest +import numpy as np + +from mindspore import nn +from mindspore import context +from mindspore.train.callback import ModelCheckpoint +from mindspore.train.callback import CheckpointConfig +from mindspore.train.callback import LossMonitor +from mindspore.nn.metrics import Accuracy +import mindspore.dataset as ds + +from ut.python.utils.mock_net import Net as LeNet5 + +from mindarmour.privacy.sup_privacy import SuppressModel +from mindarmour.privacy.sup_privacy import SuppressMasker +from mindarmour.privacy.sup_privacy import SuppressPrivacyFactory +from mindarmour.privacy.sup_privacy import MaskLayerDes + +def dataset_generator(batch_size, batches): + """mock training data.""" + data = np.random.random((batches*batch_size, 1, 32, 32)).astype( + np.float32) + label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) + for i in range(batches): + yield data[i*batch_size:(i + 1)*batch_size],\ + label[i*batch_size:(i + 1)*batch_size] + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_suppress_model_with_pynative_mode(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + networks_l5 = LeNet5() + epochs = 5 + batch_num = 10 + batch_size = 32 + mask_times = 10 + lr = 0.01 + masklayers_lenet5 = [] + masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, False, -1)) + suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", + end_epoch=epochs, + batch_num=batch_num, + start_epoch=1, + mask_times=mask_times, + networks=networks_l5, + lr=lr, + sparse_end=0.50, + sparse_start=0.0, + mask_layers=masklayers_lenet5) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.SGD(networks_l5.trainable_params(), lr) + model_instance = SuppressModel( + network=networks_l5, + loss_fn=net_loss, + optimizer=net_opt, + metrics={"Accuracy": Accuracy()}) + model_instance.link_suppress_ctrl(suppress_ctrl_instance) + suppress_masker = SuppressMasker(model=model_instance, suppress_ctrl=suppress_ctrl_instance) + config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + directory="./trained_ckpt_file/", + config=config_ck) + ds_train = ds.GeneratorDataset(dataset_generator(batch_size, batch_num), ['data', 'label']) + + model_instance.train(epochs, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], + dataset_sink_mode=False)