diff --git a/mindarmour/privacy/sup_privacy/mask_monitor/masker.py b/mindarmour/privacy/sup_privacy/mask_monitor/masker.py index fd54ac2..0f67b12 100644 --- a/mindarmour/privacy/sup_privacy/mask_monitor/masker.py +++ b/mindarmour/privacy/sup_privacy/mask_monitor/masker.py @@ -26,10 +26,6 @@ TAG = 'suppress masker' class SuppressMasker(Callback): """ Args: - args (Union[int, float, numpy.ndarray, list, str]): Parameters - used for creating a suppress privacy monitor. - kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword - parameters used for creating a suppress privacy monitor. model (SuppressModel): SuppressModel instance. suppress_ctrl (SuppressCtrl): SuppressCtrl instance. diff --git a/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py b/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py index b3ae7b8..2e6407b 100644 --- a/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py +++ b/mindarmour/privacy/sup_privacy/sup_ctrl/conctrl.py @@ -107,6 +107,37 @@ class SuppressCtrl(Cell): lr (Union[float, int]): Learning rate. sparse_end (float): The sparsity to reach. sparse_start (Union[float, int]): The sparsity to start. + + Examples: + >>> networks_l5 = LeNet5() + >>> masklayers = [] + >>> masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) + >>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=networks_l5, + >>> mask_layers=masklayers, + >>> policy="local_train", + >>> end_epoch=10, + >>> batch_num=(int)(10000/cfg.batch_size), + >>> start_epoch=3, + >>> mask_times=1000, + >>> lr=lr, + >>> sparse_end=0.90, + >>> sparse_start=0.0) + >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + >>> net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) + >>> config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), + >>> keep_checkpoint_max=10) + >>> model_instance = SuppressModel(network=networks_l5, + >>> loss_fn=net_loss, + >>> optimizer=net_opt, + >>> metrics={"Accuracy": Accuracy()}) + >>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) + >>> ds_train = generate_mnist_dataset("./MNIST_unzip/train", + >>> batch_size=cfg.batch_size, repeat_size=1, samples=samples) + >>> ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + >>> directory="./trained_ckpt_file/", + >>> config=config_ck) + >>> model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], + >>> dataset_sink_mode=False) """ def __init__(self, networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr, sparse_end, sparse_start): @@ -740,9 +771,13 @@ class MaskLayerDes: If False, the weights of this layer won't be clipped. If parameter num is greater than 100000, is_lower_clip has no effect. min_num (int): The number of weights left that not be suppressed. - If min_num is smaller than (parameter num*SupperssCtrl.sparse_end), min_num has no effect. - upper_bound (Union[float, int]): Max abs value of weight in this layer, default: 1.20. - If parameter num is greater than 100000, upper_bound has no effect. + If min_num is smaller than (parameter num*SupperssCtrl.sparse_end), min_num has not effect. + upper_bound (Union[float, int]): max abs value of weight in this layer, default: 1.20. + If parameter num is greater than 100000, upper_bound has not effect. + + Examples: + >>> masklayers = [] + >>> masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) """ def __init__(self, layer_name, grad_idx, is_add_noise, is_lower_clip, min_num, upper_bound=1.20): self.layer_name = check_param_type('layer_name', layer_name, str) diff --git a/mindarmour/privacy/sup_privacy/train/model.py b/mindarmour/privacy/sup_privacy/train/model.py index a6513b8..b9b49bd 100644 --- a/mindarmour/privacy/sup_privacy/train/model.py +++ b/mindarmour/privacy/sup_privacy/train/model.py @@ -62,7 +62,6 @@ class SuppressModel(Model): network (Cell): The training network. loss_fn (Cell): Computes softmax cross entropy between logits and labels. optimizer (Optimizer): optimizer instance. - metrics (Union[dict, set]): Calculates the accuracy for classification and multilabel data. kwargs: Keyword parameters used for creating a suppress model. Examples: