diff --git a/example/mnist_demo/dp_ada_sgd_graph_config.py b/example/mnist_demo/dp_ada_sgd_graph_config.py new file mode 100644 index 0000000..82e6b13 --- /dev/null +++ b/example/mnist_demo/dp_ada_sgd_graph_config.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py +""" + +from easydict import EasyDict as edict + +mnist_cfg = edict({ + 'num_classes': 10, # the number of classes of model's output + 'lr': 0.001, # the learning rate of model's optimizer + 'momentum': 0.9, # the momentum value of model's optimizer + 'epoch_size': 20, # training epochs + 'batch_size': 256, # batch size for training + 'image_height': 32, # the height of training samples + 'image_width': 32, # the width of training samples + 'save_checkpoint_steps': 234, # the interval steps for saving checkpoint file of the model + 'keep_checkpoint_max': 10, # the maximum number of checkpoint files would be saved + 'device_target': 'Ascend', # device used + 'data_path': './MNIST_unzip', # the path of training and testing data set + 'dataset_sink_mode': False, # whether deliver all training data to device one time + 'micro_batches': 16, # the number of small batches split from an original batch + 'norm_bound': 1.0, # the clip bound of the gradients of model's training parameters + 'initial_noise_multiplier': 0.05, # the initial multiplication coefficient of the noise added to training + # parameters' gradients + 'decay_policy': 'Step', + 'noise_mechanisms': 'AdaGaussian', # the method of adding noise in gradients while training + 'optimizer': 'SGD' # the base optimizer used for Differential privacy training +}) diff --git a/example/mnist_demo/lenet5_dp_ada_sgd_graph.py b/example/mnist_demo/lenet5_dp_ada_sgd_graph.py new file mode 100644 index 0000000..f72e190 --- /dev/null +++ b/example/mnist_demo/lenet5_dp_ada_sgd_graph.py @@ -0,0 +1,150 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Training example of adaGaussian-mechanism differential privacy. +""" +import os + +import mindspore.nn as nn +from mindspore import context +from mindspore.train.callback import ModelCheckpoint +from mindspore.train.callback import CheckpointConfig +from mindspore.train.callback import LossMonitor +from mindspore.nn.metrics import Accuracy +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as CV +import mindspore.dataset.transforms.c_transforms as C +from mindspore.dataset.transforms.vision import Inter +import mindspore.common.dtype as mstype + +from mindarmour.diff_privacy import DPModel +from mindarmour.diff_privacy import PrivacyMonitorFactory +from mindarmour.diff_privacy import NoiseMechanismsFactory +from mindarmour.utils.logger import LogUtil +from lenet5_net import LeNet5 +from dp_ada_sgd_graph_config import mnist_cfg as cfg + +LOGGER = LogUtil.get_instance() +LOGGER.set_level('INFO') +TAG = 'Lenet5_train' + + +def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, + num_parallel_workers=1, sparse=True): + """ + create dataset for training or testing + """ + # define dataset + ds1 = ds.MnistDataset(data_path) + + # define operation parameters + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width), + interpolation=Inter.LINEAR) + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + # apply map operations on images + if not sparse: + one_hot_enco = C.OneHot(10) + ds1 = ds1.map(input_columns="label", operations=one_hot_enco, + num_parallel_workers=num_parallel_workers) + type_cast_op = C.TypeCast(mstype.float32) + ds1 = ds1.map(input_columns="label", operations=type_cast_op, + num_parallel_workers=num_parallel_workers) + ds1 = ds1.map(input_columns="image", operations=resize_op, + num_parallel_workers=num_parallel_workers) + ds1 = ds1.map(input_columns="image", operations=rescale_op, + num_parallel_workers=num_parallel_workers) + ds1 = ds1.map(input_columns="image", operations=hwc2chw_op, + num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + buffer_size = 10000 + ds1 = ds1.shuffle(buffer_size=buffer_size) + ds1 = ds1.batch(batch_size, drop_remainder=True) + ds1 = ds1.repeat(repeat_size) + + return ds1 + + +if __name__ == "__main__": + # This configure can run both in pynative mode and graph mode + context.set_context(mode=context.GRAPH_MODE, + device_target=cfg.device_target) + network = LeNet5() + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, + reduction="mean") + config_ck = CheckpointConfig( + save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + directory='./trained_ckpt_file/', + config=config_ck) + + # get training dataset + ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"), + cfg.batch_size) + + if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: + raise ValueError( + "Number of micro_batches should divide evenly batch_size") + # Create a factory class of DP noise mechanisms, this method is adding noise + # in gradients while training. Initial_noise_multiplier is suggested to be + # greater than 1.0, otherwise the privacy budget would be huge, which means + # that the privacy protection effect is weak. Mechanisms can be 'Gaussian' + # or 'AdaGaussian', in which noise would be decayed with 'AdaGaussian' + # mechanism while be constant with 'Gaussian' mechanism. + noise_mech = NoiseMechanismsFactory().create(cfg.noise_mechanisms, + norm_bound=cfg.norm_bound, + initial_noise_multiplier=cfg.initial_noise_multiplier, + decay_policy=cfg.decay_policy) + + net_opt = nn.SGD(params=network.trainable_params(), + learning_rate=cfg.lr, momentum=cfg.momentum) + # Create a monitor for DP training. The function of the monitor is to + # compute and print the privacy budget(eps and delta) while training. + rdp_monitor = PrivacyMonitorFactory.create('rdp', + num_samples=60000, + batch_size=cfg.batch_size, + initial_noise_multiplier=cfg.initial_noise_multiplier, + per_print_times=234) + # Create the DP model for training. + model = DPModel(micro_batches=cfg.micro_batches, + norm_bound=cfg.norm_bound, + noise_mech=noise_mech, + network=network, + loss_fn=net_loss, + optimizer=net_opt, + metrics={"Accuracy": Accuracy()}) + + LOGGER.info(TAG, "============== Starting Training ==============") + model.train(cfg['epoch_size'], ds_train, + callbacks=[ckpoint_cb, LossMonitor(), rdp_monitor], + dataset_sink_mode=cfg.dataset_sink_mode) + + LOGGER.info(TAG, "============== Starting Testing ==============") + ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-' + str(cfg.epoch_size) + '_234.ckpt' + param_dict = load_checkpoint(ckpt_file_name) + load_param_into_net(network, param_dict) + ds_eval = generate_mnist_dataset(os.path.join(cfg.data_path, 'test'), + batch_size=cfg.batch_size) + acc = model.eval(ds_eval, dataset_sink_mode=False) + LOGGER.info(TAG, "============== Accuracy: %s ==============", acc)