|
|
@@ -28,7 +28,6 @@ from mindarmour.utils._check_param import check_int_positive, check_value_positi |
|
|
|
LOGGER = LogUtil.get_instance() |
|
|
|
TAG = 'Suppression training.' |
|
|
|
|
|
|
|
|
|
|
|
class SuppressPrivacyFactory: |
|
|
|
""" Factory class of SuppressCtrl mechanisms""" |
|
|
|
def __init__(self): |
|
|
@@ -36,7 +35,7 @@ class SuppressPrivacyFactory: |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def create(networks, mask_layers, policy="local_train", end_epoch=10, batch_num=20, start_epoch=3, |
|
|
|
mask_times=500, lr=0.10, sparse_end=0.90, sparse_start=0.0): |
|
|
|
mask_times=1000, lr=0.10, sparse_end=0.90, sparse_start=0.0): |
|
|
|
""" |
|
|
|
Args: |
|
|
|
networks (Cell): The training network. |
|
|
@@ -45,49 +44,50 @@ class SuppressPrivacyFactory: |
|
|
|
end_epoch (int): The last epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 10. |
|
|
|
batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size. Default: 20. |
|
|
|
start_epoch (int): The first epoch in suppress operations, 0<start_epoch<=end_epoch<=100. Default: 3. |
|
|
|
mask_times (int): The num of suppress operations. Default: 500. |
|
|
|
mask_times (int): The num of suppress operations. Default: 1000. |
|
|
|
lr (Union[float, int]): Learning rate, 0 < lr <= 0.5. Default: 0.10. |
|
|
|
sparse_end (float): The sparsity to reach, 0.0<=sparse_start<sparse_end<1.0. Default: 0.90. |
|
|
|
sparse_start (float): The sparsity to start, 0.0<=sparse_start<sparse_end<1.0. Default: 0.0. |
|
|
|
sparse_end (Union[float, int]): The sparsity to reach, 0.0<=sparse_start<sparse_end<1.0. Default: 0.90. |
|
|
|
sparse_start (Union[float, int]): The sparsity to start, 0.0<=sparse_start<sparse_end<1.0. Default: 0.0. |
|
|
|
|
|
|
|
Returns: |
|
|
|
SuppressCtrl, class of Suppress Privavy Mechanism. |
|
|
|
|
|
|
|
Examples: |
|
|
|
networks_l5 = LeNet5() |
|
|
|
masklayers = [] |
|
|
|
masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) |
|
|
|
suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", |
|
|
|
end_epoch=10, |
|
|
|
batch_num=(int)(10000/cfg.batch_size), |
|
|
|
start_epoch=3, |
|
|
|
mask_times=100, |
|
|
|
networks=networks_l5, |
|
|
|
lr=lr, |
|
|
|
sparse_end=0.90, |
|
|
|
sparse_start=0.0, |
|
|
|
mask_layers=masklayers) |
|
|
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") |
|
|
|
net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) |
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) |
|
|
|
model_instance = SuppressModel(network=networks_l5, |
|
|
|
loss_fn=net_loss, |
|
|
|
optimizer=net_opt, |
|
|
|
metrics={"Accuracy": Accuracy()}) |
|
|
|
model_instance.link_suppress_ctrl(suppress_ctrl_instance) |
|
|
|
ds_train = generate_mnist_dataset("./MNIST_unzip/train", |
|
|
|
batch_size=cfg.batch_size, repeat_size=1, samples=samples) |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", |
|
|
|
directory="./trained_ckpt_file/", |
|
|
|
config=config_ck) |
|
|
|
model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], |
|
|
|
dataset_sink_mode=False) |
|
|
|
>>> networks_l5 = LeNet5() |
|
|
|
>>> masklayers = [] |
|
|
|
>>> masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) |
|
|
|
>>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=networks_l5, |
|
|
|
>>> mask_layers=masklayers, |
|
|
|
>>> policy="local_train", |
|
|
|
>>> end_epoch=10, |
|
|
|
>>> batch_num=(int)(10000/cfg.batch_size), |
|
|
|
>>> start_epoch=3, |
|
|
|
>>> mask_times=1000, |
|
|
|
>>> lr=lr, |
|
|
|
>>> sparse_end=0.90, |
|
|
|
>>> sparse_start=0.0) |
|
|
|
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") |
|
|
|
>>> net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) |
|
|
|
>>> config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), |
|
|
|
>>> keep_checkpoint_max=10) |
|
|
|
>>> model_instance = SuppressModel(network=networks_l5, |
|
|
|
>>> loss_fn=net_loss, |
|
|
|
>>> optimizer=net_opt, |
|
|
|
>>> metrics={"Accuracy": Accuracy()}) |
|
|
|
>>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) |
|
|
|
>>> ds_train = generate_mnist_dataset("./MNIST_unzip/train", |
|
|
|
>>> batch_size=cfg.batch_size, repeat_size=1, samples=samples) |
|
|
|
>>> ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", |
|
|
|
>>> directory="./trained_ckpt_file/", |
|
|
|
>>> config=config_ck) |
|
|
|
>>> model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], |
|
|
|
>>> dataset_sink_mode=False) |
|
|
|
""" |
|
|
|
check_param_type('policy', policy, str) |
|
|
|
if policy == "local_train": |
|
|
|
return SuppressCtrl(networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr, |
|
|
|
sparse_end, sparse_start) |
|
|
|
msg = "Only local training is supported now, federal training will be supported " \ |
|
|
|
"in the future. But got {}.".format(policy) |
|
|
|
msg = "Only local training is supported now, but got {}.".format(policy) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
@@ -98,20 +98,20 @@ class SuppressCtrl(Cell): |
|
|
|
mask_layers (list): Description of those layers that need to be suppressed. |
|
|
|
end_epoch (int): The last epoch in suppress operations. |
|
|
|
batch_num (int): The num of grad operation in an epoch. |
|
|
|
mask_start_epoch (int): The first epoch in suppress operations. |
|
|
|
start_epoch (int): The first epoch in suppress operations. |
|
|
|
mask_times (int): The num of suppress operations. |
|
|
|
lr (Union[float, int]): Learning rate. |
|
|
|
sparse_end (Union[float, int]): The sparsity to reach. |
|
|
|
sparse_start (float): The sparsity to start. |
|
|
|
sparse_start (Union[float, int]): The sparsity to start. |
|
|
|
""" |
|
|
|
def __init__(self, networks, mask_layers, end_epoch, batch_num, mask_start_epoch, mask_times, lr, |
|
|
|
def __init__(self, networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr, |
|
|
|
sparse_end, sparse_start): |
|
|
|
super(SuppressCtrl, self).__init__() |
|
|
|
self.networks = check_param_type('networks', networks, Cell) |
|
|
|
self.mask_layers = check_param_type('mask_layers', mask_layers, list) |
|
|
|
self.mask_end_epoch = check_int_positive('end_epoch', end_epoch) |
|
|
|
self.batch_num = check_int_positive('batch_num', batch_num) |
|
|
|
self.mask_start_epoch = check_int_positive('mask_start_epoch', mask_start_epoch) |
|
|
|
self.mask_start_epoch = check_int_positive('start_epoch', start_epoch) |
|
|
|
self.mask_times = check_int_positive('mask_times', mask_times) |
|
|
|
self.lr = check_value_positive('lr', lr) |
|
|
|
self.sparse_end = check_value_non_negative('sparse_end', sparse_end) |
|
|
@@ -131,12 +131,12 @@ class SuppressCtrl(Cell): |
|
|
|
self.mask_start_step = 0 # suppress operation is actually started at this step |
|
|
|
self.mask_prev_step = 0 # previous suppress operation is done at this step |
|
|
|
self.cur_sparse = 0.0 # current sparsity to which one suppress will get |
|
|
|
self.mask_all_steps = (self.mask_end_epoch-mask_start_epoch+1)*batch_num # the amount of step contained in all suppress operation |
|
|
|
self.mask_all_steps = (end_epoch - start_epoch + 1)*batch_num # the amount of step contained in all suppress operation |
|
|
|
self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation |
|
|
|
self.mask_initialized = False # flag means the initialization is done |
|
|
|
|
|
|
|
if self.lr > 0.5: |
|
|
|
msg = "learning rate should be smaller than 0.5, but got {}".format(self.lr) |
|
|
|
msg = "learning rate should not be greater than 0.5, but got {}".format(self.lr) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
@@ -151,11 +151,18 @@ class SuppressCtrl(Cell): |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
if self.mask_step_interval < 0: |
|
|
|
msg = "step_interval error: {}".format(self.mask_step_interval) |
|
|
|
if self.mask_step_interval <= 0: |
|
|
|
msg = "step_interval should be greater than 0, but got {}".format(self.mask_step_interval) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
if self.mask_step_interval <= 10 or self.mask_step_interval >= 20: |
|
|
|
msg = "mask_interval should be greater than 10, smaller than 20, but got {}".format(self.mask_step_interval) |
|
|
|
msg += "\n Precision of trained model may be poor !!! " |
|
|
|
msg += "\n please modify epoch_start, epoch_end and batch_num !" |
|
|
|
msg += "\n mask_interval = (epoch_end-epoch_start+1)*batch_num, batch_num = samples/batch_size" |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
|
|
|
|
if self.sparse_end >= 1.00 or self.sparse_end <= 0: |
|
|
|
msg = "sparse_end should be in range (0, 1), but got {}".format(self.sparse_end) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
@@ -171,7 +178,7 @@ class SuppressCtrl(Cell): |
|
|
|
mask_layer_id = 0 |
|
|
|
for one_mask_layer in mask_layers: |
|
|
|
if not isinstance(one_mask_layer, MaskLayerDes): |
|
|
|
msg = "mask_layer instance error!" |
|
|
|
msg = "mask_layers should be a list of MaskLayerDes, but got a {}".format(type(one_mask_layer)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
layer_name = one_mask_layer.layer_name |
|
|
@@ -530,6 +537,12 @@ class SuppressCtrl(Cell): |
|
|
|
msg = "{} sparse fact={} ".format(layer_name, sparse) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
|
|
|
|
def print_paras(self): |
|
|
|
msg = "paras: start_epoch:{}, end_epoch:{}, batch_num:{}, interval:{}, lr:{}, sparse_end:{}, sparse_start:{}" \ |
|
|
|
.format(self.mask_start_epoch, self.mask_end_epoch, self.batch_num, self.mask_step_interval, |
|
|
|
self.lr, self.sparse_end, self.sparse_start) |
|
|
|
LOGGER.info(TAG, msg) |
|
|
|
|
|
|
|
def get_one_mask_layer(mask_layers, layer_name): |
|
|
|
""" |
|
|
|
Returns the layer definitions that need to be suppressed. |
|
|
|