From 04bd152e283480ad99e10ba56894fa8d98fc1310 Mon Sep 17 00:00:00 2001 From: itcomee Date: Wed, 3 Feb 2021 19:45:35 +0800 Subject: [PATCH] suppress based privacy model, 2021.2.3 --- examples/privacy/sup_privacy/sup_privacy.py | 12 ++--- examples/privacy/sup_privacy/sup_privacy_config.py | 2 +- mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py | 51 ++++++++++++---------- mindarmour/privacy/sup_privacy/train/model.py | 2 +- .../python/privacy/sup_privacy/test_model_train.py | 8 ++-- 5 files changed, 38 insertions(+), 37 deletions(-) diff --git a/examples/privacy/sup_privacy/sup_privacy.py b/examples/privacy/sup_privacy/sup_privacy.py index cfe0466..47a4fae 100644 --- a/examples/privacy/sup_privacy/sup_privacy.py +++ b/examples/privacy/sup_privacy/sup_privacy.py @@ -21,7 +21,6 @@ from mindspore.train.callback import ModelCheckpoint from mindspore.train.callback import CheckpointConfig from mindspore.train.callback import LossMonitor from mindspore.nn.metrics import Accuracy -from mindspore.train.serialization import load_checkpoint, load_param_into_net import mindspore.dataset as ds import mindspore.dataset.vision.c_transforms as CV import mindspore.dataset.transforms.c_transforms as C @@ -91,16 +90,16 @@ def mnist_suppress_train(epoch_size=10, start_epoch=3, lr=0.05, samples=10000, m """ networks_l5 = LeNet5() - suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", + suppress_ctrl_instance = SuppressPrivacyFactory().create(networks_l5, + masklayers, + policy="local_train", end_epoch=epoch_size, batch_num=(int)(samples/cfg.batch_size), start_epoch=start_epoch, mask_times=mask_times, - networks=networks_l5, lr=lr, sparse_end=sparse_thd, - sparse_start=sparse_start, - mask_layers=masklayers) + sparse_start=sparse_start) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.SGD(networks_l5.trainable_params(), lr) config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), @@ -130,9 +129,6 @@ def mnist_suppress_train(epoch_size=10, start_epoch=3, lr=0.05, samples=10000, m dataset_sink_mode=False) print("============== Starting SUPP Testing ==============") - ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' - param_dict = load_checkpoint(ckpt_file_name) - load_param_into_net(networks_l5, param_dict) ds_eval = generate_mnist_dataset(os.path.join(mnist_path, 'test'), batch_size=cfg.batch_size) acc = model_instance.eval(ds_eval, dataset_sink_mode=False) diff --git a/examples/privacy/sup_privacy/sup_privacy_config.py b/examples/privacy/sup_privacy/sup_privacy_config.py index 88c9df2..e011554 100644 --- a/examples/privacy/sup_privacy/sup_privacy_config.py +++ b/examples/privacy/sup_privacy/sup_privacy_config.py @@ -20,7 +20,7 @@ from easydict import EasyDict as edict mnist_cfg = edict({ 'num_classes': 10, # the number of classes of model's output - 'epoch_size': 1, # training epochs + 'epoch_size': 10, # training epochs 'batch_size': 32, # batch size for training 'image_height': 32, # the height of training samples 'image_width': 32, # the width of training samples diff --git a/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py b/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py index 7bfe98f..6fede64 100644 --- a/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py +++ b/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py @@ -35,20 +35,20 @@ class SuppressPrivacyFactory: pass @staticmethod - def create(policy="local_train", end_epoch=10, batch_num=2, start_epoch=3, mask_times=100, networks=None, - lr=0.05, sparse_end=0.60, sparse_start=0.0, mask_layers=None): + def create(networks, mask_layers, policy="local_train", end_epoch=10, batch_num=20, start_epoch=3, + mask_times=500, lr=0.10, sparse_end=0.90, sparse_start=0.0): """ Args: - policy (str): Training policy for suppress privacy training. "local_train" means local training. - end_epoch (int): The last epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 . - batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size . - start_epoch (int): The first epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 . - mask_times (int): The num of suppress operations. networks (Cell): The training network. - lr (float): Learning rate. - sparse_end (float): The sparsity to reach, 0.0 <= sparse_start < sparse_end < 1.0 . - sparse_start (float): The sparsity to start, 0.0 <= sparse_start < sparse_end < 1.0 . mask_layers (list): Description of the training network layers that need to be suppressed. + policy (str): Training policy for suppress privacy training. Default: "local_train", means local training. + end_epoch (int): The last epoch in suppress operations, 0 0.5: + msg = "learning rate should be smaller than 0.5, but got {}".format(self.lr) + LOGGER.error(TAG, msg) + raise ValueError(msg) + if self.mask_start_epoch > self.mask_end_epoch: - msg = "start_epoch error: {}".format(self.mask_start_epoch) + msg = "start_epoch should not be greater than end_epoch, but got start_epoch and end_epoch are: " \ + "{}, {}".format(self.mask_start_epoch, self.mask_end_epoch) LOGGER.error(TAG, msg) raise ValueError(msg) if self.mask_end_epoch > 100: - msg = "end_epoch error: {}".format(self.mask_end_epoch) + msg = "The end_epoch should be smaller than 100, but got {}".format(self.mask_end_epoch) LOGGER.error(TAG, msg) raise ValueError(msg) @@ -152,13 +156,14 @@ class SuppressCtrl(Cell): LOGGER.error(TAG, msg) raise ValueError(msg) - if self.sparse_end > 1.00 or self.sparse_end <= 0: - msg = "sparse_end error: {}".format(self.sparse_end) + if self.sparse_end >= 1.00 or self.sparse_end <= 0: + msg = "sparse_end should be in range (0, 1), but got {}".format(self.sparse_end) LOGGER.error(TAG, msg) raise ValueError(msg) if self.sparse_start >= self.sparse_end: - msg = "sparse_start error: {}".format(self.sparse_start) + msg = "sparse_start should be smaller than sparse_end, but got sparse_start and sparse_end are: " \ + "{}, {}".format(self.sparse_start, self.sparse_end) LOGGER.error(TAG, msg) raise ValueError(msg) diff --git a/mindarmour/privacy/sup_privacy/train/model.py b/mindarmour/privacy/sup_privacy/train/model.py index eedf793..8ac19f3 100644 --- a/mindarmour/privacy/sup_privacy/train/model.py +++ b/mindarmour/privacy/sup_privacy/train/model.py @@ -96,7 +96,7 @@ class SuppressModel(Model): """ def __init__(self, - network=None, + network, **kwargs): check_param_type('networks', network, Cell) diff --git a/tests/ut/python/privacy/sup_privacy/test_model_train.py b/tests/ut/python/privacy/sup_privacy/test_model_train.py index 7c8b275..976f037 100644 --- a/tests/ut/python/privacy/sup_privacy/test_model_train.py +++ b/tests/ut/python/privacy/sup_privacy/test_model_train.py @@ -56,16 +56,16 @@ def test_suppress_model_with_pynative_mode(): lr = 0.01 masklayers_lenet5 = [] masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, False, -1)) - suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", + suppress_ctrl_instance = SuppressPrivacyFactory().create(networks_l5, + masklayers_lenet5, + policy="local_train", end_epoch=epochs, batch_num=batch_num, start_epoch=1, mask_times=mask_times, - networks=networks_l5, lr=lr, sparse_end=0.50, - sparse_start=0.0, - mask_layers=masklayers_lenet5) + sparse_start=0.0) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.SGD(networks_l5.trainable_params(), lr) model_instance = SuppressModel(