Browse Source

!165 Fix some issues for Suppress Privacy - modified on 2021.2.3

From: @itcomee
Reviewed-by: 
Signed-off-by:
tags/v1.2.1
mindspore-ci-bot Gitee 4 years ago
parent
commit
feb920c1b8
5 changed files with 38 additions and 37 deletions
  1. +4
    -8
      examples/privacy/sup_privacy/sup_privacy.py
  2. +1
    -1
      examples/privacy/sup_privacy/sup_privacy_config.py
  3. +28
    -23
      mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py
  4. +1
    -1
      mindarmour/privacy/sup_privacy/train/model.py
  5. +4
    -4
      tests/ut/python/privacy/sup_privacy/test_model_train.py

+ 4
- 8
examples/privacy/sup_privacy/sup_privacy.py View File

@@ -21,7 +21,6 @@ from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig
from mindspore.train.callback import LossMonitor
from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
@@ -91,16 +90,16 @@ def mnist_suppress_train(epoch_size=10, start_epoch=3, lr=0.05, samples=10000, m
"""

networks_l5 = LeNet5()
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train",
suppress_ctrl_instance = SuppressPrivacyFactory().create(networks_l5,
masklayers,
policy="local_train",
end_epoch=epoch_size,
batch_num=(int)(samples/cfg.batch_size),
start_epoch=start_epoch,
mask_times=mask_times,
networks=networks_l5,
lr=lr,
sparse_end=sparse_thd,
sparse_start=sparse_start,
mask_layers=masklayers)
sparse_start=sparse_start)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.SGD(networks_l5.trainable_params(), lr)
config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size),
@@ -130,9 +129,6 @@ def mnist_suppress_train(epoch_size=10, start_epoch=3, lr=0.05, samples=10000, m
dataset_sink_mode=False)

print("============== Starting SUPP Testing ==============")
ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(networks_l5, param_dict)
ds_eval = generate_mnist_dataset(os.path.join(mnist_path, 'test'),
batch_size=cfg.batch_size)
acc = model_instance.eval(ds_eval, dataset_sink_mode=False)


+ 1
- 1
examples/privacy/sup_privacy/sup_privacy_config.py View File

@@ -20,7 +20,7 @@ from easydict import EasyDict as edict

mnist_cfg = edict({
'num_classes': 10, # the number of classes of model's output
'epoch_size': 1, # training epochs
'epoch_size': 10, # training epochs
'batch_size': 32, # batch size for training
'image_height': 32, # the height of training samples
'image_width': 32, # the width of training samples


+ 28
- 23
mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py View File

@@ -35,20 +35,20 @@ class SuppressPrivacyFactory:
pass

@staticmethod
def create(policy="local_train", end_epoch=10, batch_num=2, start_epoch=3, mask_times=100, networks=None,
lr=0.05, sparse_end=0.60, sparse_start=0.0, mask_layers=None):
def create(networks, mask_layers, policy="local_train", end_epoch=10, batch_num=20, start_epoch=3,
mask_times=500, lr=0.10, sparse_end=0.90, sparse_start=0.0):
"""
Args:
policy (str): Training policy for suppress privacy training. "local_train" means local training.
end_epoch (int): The last epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 .
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size .
start_epoch (int): The first epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 .
mask_times (int): The num of suppress operations.
networks (Cell): The training network.
lr (float): Learning rate.
sparse_end (float): The sparsity to reach, 0.0 <= sparse_start < sparse_end < 1.0 .
sparse_start (float): The sparsity to start, 0.0 <= sparse_start < sparse_end < 1.0 .
mask_layers (list): Description of the training network layers that need to be suppressed.
policy (str): Training policy for suppress privacy training. Default: "local_train", means local training.
end_epoch (int): The last epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 10.
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size. Default: 20.
start_epoch (int): The first epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 3.
mask_times (int): The num of suppress operations. Default: 500.
lr (Union[float, int]): Learning rate, 0 < lr <= 0.5. Default: 0.10.
sparse_end (float): The sparsity to reach, 0.0<=sparse_start<sparse_end<1.0. Default: 0.90.
sparse_start (float): The sparsity to start, 0.0<=sparse_start<sparse_end<1.0. Default: 0.0.

Returns:
SuppressCtrl, class of Suppress Privavy Mechanism.
@@ -84,8 +84,8 @@ class SuppressPrivacyFactory:
dataset_sink_mode=False)
"""
if policy == "local_train":
return SuppressCtrl(networks, end_epoch, batch_num, start_epoch, mask_times, lr, sparse_end,
sparse_start, mask_layers)
return SuppressCtrl(networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr,
sparse_end, sparse_start)
msg = "Only local training is supported now, federal training will be supported " \
"in the future. But got {}.".format(policy)
LOGGER.error(TAG, msg)
@@ -95,6 +95,7 @@ class SuppressCtrl(Cell):
"""
Args:
networks (Cell): The training network.
mask_layers (list): Description of those layers that need to be suppressed.
end_epoch (int): The last epoch in suppress operations.
batch_num (int): The num of grad operation in an epoch.
mask_start_epoch (int): The first epoch in suppress operations.
@@ -102,14 +103,12 @@ class SuppressCtrl(Cell):
lr (Union[float, int]): Learning rate.
sparse_end (Union[float, int]): The sparsity to reach.
sparse_start (float): The sparsity to start.
mask_layers (list): Description of those layers that need to be suppressed.
"""
def __init__(self, networks, end_epoch, batch_num, mask_start_epoch=3, mask_times=500, lr=0.05,
sparse_end=0.60,
sparse_start=0.0,
mask_layers=None):
def __init__(self, networks, mask_layers, end_epoch, batch_num, mask_start_epoch, mask_times, lr,
sparse_end, sparse_start):
super(SuppressCtrl, self).__init__()
self.networks = check_param_type('networks', networks, Cell)
self.mask_layers = check_param_type('mask_layers', mask_layers, list)
self.mask_end_epoch = check_int_positive('end_epoch', end_epoch)
self.batch_num = check_int_positive('batch_num', batch_num)
self.mask_start_epoch = check_int_positive('mask_start_epoch', mask_start_epoch)
@@ -117,7 +116,6 @@ class SuppressCtrl(Cell):
self.lr = check_value_positive('lr', lr)
self.sparse_end = check_value_non_negative('sparse_end', sparse_end)
self.sparse_start = check_value_non_negative('sparse_start', sparse_start)
self.mask_layers = check_param_type('mask_layers', mask_layers, list)

self.weight_lower_bound = 0.005 # all network weight will be larger than this value
self.sparse_vibra = 0.02 # the sparsity may have certain range of variations
@@ -137,13 +135,19 @@ class SuppressCtrl(Cell):
self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation
self.mask_initialized = False # flag means the initialization is done

if self.lr > 0.5:
msg = "learning rate should be smaller than 0.5, but got {}".format(self.lr)
LOGGER.error(TAG, msg)
raise ValueError(msg)

if self.mask_start_epoch > self.mask_end_epoch:
msg = "start_epoch error: {}".format(self.mask_start_epoch)
msg = "start_epoch should not be greater than end_epoch, but got start_epoch and end_epoch are: " \
"{}, {}".format(self.mask_start_epoch, self.mask_end_epoch)
LOGGER.error(TAG, msg)
raise ValueError(msg)

if self.mask_end_epoch > 100:
msg = "end_epoch error: {}".format(self.mask_end_epoch)
msg = "The end_epoch should be smaller than 100, but got {}".format(self.mask_end_epoch)
LOGGER.error(TAG, msg)
raise ValueError(msg)

@@ -152,13 +156,14 @@ class SuppressCtrl(Cell):
LOGGER.error(TAG, msg)
raise ValueError(msg)

if self.sparse_end > 1.00 or self.sparse_end <= 0:
msg = "sparse_end error: {}".format(self.sparse_end)
if self.sparse_end >= 1.00 or self.sparse_end <= 0:
msg = "sparse_end should be in range (0, 1), but got {}".format(self.sparse_end)
LOGGER.error(TAG, msg)
raise ValueError(msg)

if self.sparse_start >= self.sparse_end:
msg = "sparse_start error: {}".format(self.sparse_start)
msg = "sparse_start should be smaller than sparse_end, but got sparse_start and sparse_end are: " \
"{}, {}".format(self.sparse_start, self.sparse_end)
LOGGER.error(TAG, msg)
raise ValueError(msg)



+ 1
- 1
mindarmour/privacy/sup_privacy/train/model.py View File

@@ -96,7 +96,7 @@ class SuppressModel(Model):
"""

def __init__(self,
network=None,
network,
**kwargs):

check_param_type('networks', network, Cell)


+ 4
- 4
tests/ut/python/privacy/sup_privacy/test_model_train.py View File

@@ -56,16 +56,16 @@ def test_suppress_model_with_pynative_mode():
lr = 0.01
masklayers_lenet5 = []
masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, False, -1))
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train",
suppress_ctrl_instance = SuppressPrivacyFactory().create(networks_l5,
masklayers_lenet5,
policy="local_train",
end_epoch=epochs,
batch_num=batch_num,
start_epoch=1,
mask_times=mask_times,
networks=networks_l5,
lr=lr,
sparse_end=0.50,
sparse_start=0.0,
mask_layers=masklayers_lenet5)
sparse_start=0.0)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.SGD(networks_l5.trainable_params(), lr)
model_instance = SuppressModel(


Loading…
Cancel
Save