diff --git a/examples/privacy/sup_privacy/sup_privacy.py b/examples/privacy/sup_privacy/sup_privacy.py index 47a4fae..44a1000 100644 --- a/examples/privacy/sup_privacy/sup_privacy.py +++ b/examples/privacy/sup_privacy/sup_privacy.py @@ -134,6 +134,7 @@ def mnist_suppress_train(epoch_size=10, start_epoch=3, lr=0.05, samples=10000, m acc = model_instance.eval(ds_eval, dataset_sink_mode=False) print("============== SUPP Accuracy: %s ==============", acc) + suppress_ctrl_instance.print_paras() if __name__ == "__main__": # This configure can run in pynative mode context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) diff --git a/examples/privacy/sup_privacy/sup_privacy_config.py b/examples/privacy/sup_privacy/sup_privacy_config.py index e011554..50daffd 100644 --- a/examples/privacy/sup_privacy/sup_privacy_config.py +++ b/examples/privacy/sup_privacy/sup_privacy_config.py @@ -20,13 +20,9 @@ from easydict import EasyDict as edict mnist_cfg = edict({ 'num_classes': 10, # the number of classes of model's output - 'epoch_size': 10, # training epochs 'batch_size': 32, # batch size for training 'image_height': 32, # the height of training samples 'image_width': 32, # the width of training samples - 'save_checkpoint_steps': 1875, # the interval steps for saving checkpoint file of the model 'keep_checkpoint_max': 10, # the maximum number of checkpoint files would be saved 'device_target': 'Ascend', # device used - 'data_path': './MNIST_unzip', # the path of training and testing data set - 'dataset_sink_mode': False, # whether deliver all training data to device one time }) diff --git a/mindarmour/privacy/sup_privacy/mask_monitor/masker.py b/mindarmour/privacy/sup_privacy/mask_monitor/masker.py index 0a3a4e7..89c5045 100644 --- a/mindarmour/privacy/sup_privacy/mask_monitor/masker.py +++ b/mindarmour/privacy/sup_privacy/mask_monitor/masker.py @@ -34,37 +34,37 @@ class SuppressMasker(Callback): suppress_ctrl (SuppressCtrl): SuppressCtrl instance. Examples: - networks_l5 = LeNet5() - masklayers = [] - masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) - suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", - end_epoch=10, - batch_num=(int)(10000/cfg.batch_size), - start_epoch=3, - mask_times=100, - networks=networks_l5, - lr=lr, - sparse_end=0.90, - sparse_start=0.0, - mask_layers=masklayers) - net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) - config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) - model_instance = SuppressModel(network=networks_l5, - loss_fn=net_loss, - optimizer=net_opt, - metrics={"Accuracy": Accuracy()}) - model_instance.link_suppress_ctrl(suppress_ctrl_instance) - ds_train = generate_mnist_dataset("./MNIST_unzip/train", - batch_size=cfg.batch_size, repeat_size=1, samples=samples) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", - directory="./trained_ckpt_file/", - config=config_ck) - model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], - dataset_sink_mode=False) + >>> networks_l5 = LeNet5() + >>> masklayers = [] + >>> masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) + >>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=networks_l5, + >>> mask_layers=masklayers, + >>> policy="local_train", + >>> end_epoch=10, + >>> batch_num=(int)(10000/cfg.batch_size), + >>> start_epoch=3, + >>> mask_times=1000, + >>> lr=lr, + >>> sparse_end=0.90, + >>> sparse_start=0.0) + >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + >>> net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) + >>> config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) + >>> model_instance = SuppressModel(network=networks_l5, + >>> loss_fn=net_loss, + >>> optimizer=net_opt, + >>> metrics={"Accuracy": Accuracy()}) + >>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) + >>> ds_train = generate_mnist_dataset("./MNIST_unzip/train", + >>> batch_size=cfg.batch_size, repeat_size=1, samples=samples) + >>> ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + >>> directory="./trained_ckpt_file/", + >>> config=config_ck) + >>> model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], + >>> dataset_sink_mode=False) """ - def __init__(self, model=None, suppress_ctrl=None): + def __init__(self, model, suppress_ctrl): super(SuppressMasker, self).__init__() diff --git a/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py b/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py index 6fede64..2acb518 100644 --- a/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py +++ b/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py @@ -28,7 +28,6 @@ from mindarmour.utils._check_param import check_int_positive, check_value_positi LOGGER = LogUtil.get_instance() TAG = 'Suppression training.' - class SuppressPrivacyFactory: """ Factory class of SuppressCtrl mechanisms""" def __init__(self): @@ -36,7 +35,7 @@ class SuppressPrivacyFactory: @staticmethod def create(networks, mask_layers, policy="local_train", end_epoch=10, batch_num=20, start_epoch=3, - mask_times=500, lr=0.10, sparse_end=0.90, sparse_start=0.0): + mask_times=1000, lr=0.10, sparse_end=0.90, sparse_start=0.0): """ Args: networks (Cell): The training network. @@ -45,49 +44,50 @@ class SuppressPrivacyFactory: end_epoch (int): The last epoch in suppress operations, 0>> networks_l5 = LeNet5() + >>> masklayers = [] + >>> masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) + >>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=networks_l5, + >>> mask_layers=masklayers, + >>> policy="local_train", + >>> end_epoch=10, + >>> batch_num=(int)(10000/cfg.batch_size), + >>> start_epoch=3, + >>> mask_times=1000, + >>> lr=lr, + >>> sparse_end=0.90, + >>> sparse_start=0.0) + >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + >>> net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) + >>> config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), + >>> keep_checkpoint_max=10) + >>> model_instance = SuppressModel(network=networks_l5, + >>> loss_fn=net_loss, + >>> optimizer=net_opt, + >>> metrics={"Accuracy": Accuracy()}) + >>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) + >>> ds_train = generate_mnist_dataset("./MNIST_unzip/train", + >>> batch_size=cfg.batch_size, repeat_size=1, samples=samples) + >>> ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + >>> directory="./trained_ckpt_file/", + >>> config=config_ck) + >>> model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], + >>> dataset_sink_mode=False) """ + check_param_type('policy', policy, str) if policy == "local_train": return SuppressCtrl(networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr, sparse_end, sparse_start) - msg = "Only local training is supported now, federal training will be supported " \ - "in the future. But got {}.".format(policy) + msg = "Only local training is supported now, but got {}.".format(policy) LOGGER.error(TAG, msg) raise ValueError(msg) @@ -98,20 +98,20 @@ class SuppressCtrl(Cell): mask_layers (list): Description of those layers that need to be suppressed. end_epoch (int): The last epoch in suppress operations. batch_num (int): The num of grad operation in an epoch. - mask_start_epoch (int): The first epoch in suppress operations. + start_epoch (int): The first epoch in suppress operations. mask_times (int): The num of suppress operations. lr (Union[float, int]): Learning rate. sparse_end (Union[float, int]): The sparsity to reach. - sparse_start (float): The sparsity to start. + sparse_start (Union[float, int]): The sparsity to start. """ - def __init__(self, networks, mask_layers, end_epoch, batch_num, mask_start_epoch, mask_times, lr, + def __init__(self, networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr, sparse_end, sparse_start): super(SuppressCtrl, self).__init__() self.networks = check_param_type('networks', networks, Cell) self.mask_layers = check_param_type('mask_layers', mask_layers, list) self.mask_end_epoch = check_int_positive('end_epoch', end_epoch) self.batch_num = check_int_positive('batch_num', batch_num) - self.mask_start_epoch = check_int_positive('mask_start_epoch', mask_start_epoch) + self.mask_start_epoch = check_int_positive('start_epoch', start_epoch) self.mask_times = check_int_positive('mask_times', mask_times) self.lr = check_value_positive('lr', lr) self.sparse_end = check_value_non_negative('sparse_end', sparse_end) @@ -131,12 +131,12 @@ class SuppressCtrl(Cell): self.mask_start_step = 0 # suppress operation is actually started at this step self.mask_prev_step = 0 # previous suppress operation is done at this step self.cur_sparse = 0.0 # current sparsity to which one suppress will get - self.mask_all_steps = (self.mask_end_epoch-mask_start_epoch+1)*batch_num # the amount of step contained in all suppress operation + self.mask_all_steps = (end_epoch - start_epoch + 1)*batch_num # the amount of step contained in all suppress operation self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation self.mask_initialized = False # flag means the initialization is done if self.lr > 0.5: - msg = "learning rate should be smaller than 0.5, but got {}".format(self.lr) + msg = "learning rate should not be greater than 0.5, but got {}".format(self.lr) LOGGER.error(TAG, msg) raise ValueError(msg) @@ -151,11 +151,18 @@ class SuppressCtrl(Cell): LOGGER.error(TAG, msg) raise ValueError(msg) - if self.mask_step_interval < 0: - msg = "step_interval error: {}".format(self.mask_step_interval) + if self.mask_step_interval <= 0: + msg = "step_interval should be greater than 0, but got {}".format(self.mask_step_interval) LOGGER.error(TAG, msg) raise ValueError(msg) + if self.mask_step_interval <= 10 or self.mask_step_interval >= 20: + msg = "mask_interval should be greater than 10, smaller than 20, but got {}".format(self.mask_step_interval) + msg += "\n Precision of trained model may be poor !!! " + msg += "\n please modify epoch_start, epoch_end and batch_num !" + msg += "\n mask_interval = (epoch_end-epoch_start+1)*batch_num, batch_num = samples/batch_size" + LOGGER.info(TAG, msg) + if self.sparse_end >= 1.00 or self.sparse_end <= 0: msg = "sparse_end should be in range (0, 1), but got {}".format(self.sparse_end) LOGGER.error(TAG, msg) @@ -171,7 +178,7 @@ class SuppressCtrl(Cell): mask_layer_id = 0 for one_mask_layer in mask_layers: if not isinstance(one_mask_layer, MaskLayerDes): - msg = "mask_layer instance error!" + msg = "mask_layers should be a list of MaskLayerDes, but got a {}".format(type(one_mask_layer)) LOGGER.error(TAG, msg) raise ValueError(msg) layer_name = one_mask_layer.layer_name @@ -530,6 +537,12 @@ class SuppressCtrl(Cell): msg = "{} sparse fact={} ".format(layer_name, sparse) LOGGER.info(TAG, msg) + def print_paras(self): + msg = "paras: start_epoch:{}, end_epoch:{}, batch_num:{}, interval:{}, lr:{}, sparse_end:{}, sparse_start:{}" \ + .format(self.mask_start_epoch, self.mask_end_epoch, self.batch_num, self.mask_step_interval, + self.lr, self.sparse_end, self.sparse_start) + LOGGER.info(TAG, msg) + def get_one_mask_layer(mask_layers, layer_name): """ Returns the layer definitions that need to be suppressed. diff --git a/mindarmour/privacy/sup_privacy/train/model.py b/mindarmour/privacy/sup_privacy/train/model.py index 7b28b5c..f42ac84 100644 --- a/mindarmour/privacy/sup_privacy/train/model.py +++ b/mindarmour/privacy/sup_privacy/train/model.py @@ -65,41 +65,41 @@ class SuppressModel(Model): kwargs: Keyword parameters used for creating a suppress model. Examples: - networks_l5 = LeNet5() - masklayers = [] - masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) - suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", - end_epoch=10, - batch_num=(int)(10000/cfg.batch_size), - start_epoch=3, - mask_times=100, - networks=networks_l5, - lr=lr, - sparse_end=0.90, - sparse_start=0.0, - mask_layers=masklayers) - net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) - config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) - model_instance = SuppressModel(network=networks_l5, - loss_fn=net_loss, - optimizer=net_opt, - metrics={"Accuracy": Accuracy()}) - model_instance.link_suppress_ctrl(suppress_ctrl_instance) - ds_train = generate_mnist_dataset("./MNIST_unzip/train", - batch_size=cfg.batch_size, repeat_size=1, samples=samples) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", - directory="./trained_ckpt_file/", - config=config_ck) - model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], - dataset_sink_mode=False) + >>> networks_l5 = LeNet5() + >>> masklayers = [] + >>> masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) + >>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=networks_l5, + >>> mask_layers=masklayers, + >>> policy="local_train", + >>> end_epoch=10, + >>> batch_num=(int)(10000/cfg.batch_size), + >>> start_epoch=3, + >>> mask_times=1000, + >>> lr=lr, + >>> sparse_end=0.90, + >>> sparse_start=0.0) + >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + >>> net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) + >>> config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) + >>> model_instance = SuppressModel(network=networks_l5, + >>> loss_fn=net_loss, + >>> optimizer=net_opt, + >>> metrics={"Accuracy": Accuracy()}) + >>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) + >>> ds_train = generate_mnist_dataset("./MNIST_unzip/train", + >>> batch_size=cfg.batch_size, repeat_size=1, samples=samples) + >>> ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + >>> directory="./trained_ckpt_file/", + >>> config=config_ck) + >>> model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], + >>> dataset_sink_mode=False) """ def __init__(self, network, **kwargs): - check_param_type('networks', network, Cell) + check_param_type('network', network, Cell) self.network_end = None self._train_one_step = None