From ce15e7811e4927d8dcebb08cabe1dd7ef6a9cdf1 Mon Sep 17 00:00:00 2001 From: liuluobin Date: Mon, 24 Aug 2020 15:58:03 +0800 Subject: [PATCH] Append the parameter verification of class MembershipInference. --- example/membership_inference_demo/train.py | 5 +-- mindarmour/diff_privacy/evaluation/attacker.py | 10 ++++-- .../evaluation/membership_inference.py | 41 ++++++++++++++++++++-- 3 files changed, 48 insertions(+), 8 deletions(-) mode change 100644 => 100755 mindarmour/diff_privacy/evaluation/attacker.py mode change 100644 => 100755 mindarmour/diff_privacy/evaluation/membership_inference.py diff --git a/example/membership_inference_demo/train.py b/example/membership_inference_demo/train.py index 944da0b..f711448 100644 --- a/example/membership_inference_demo/train.py +++ b/example/membership_inference_demo/train.py @@ -27,7 +27,7 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore import context from mindspore.nn.optim.momentum import Momentum -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.model import Model from mindspore.train.serialization import load_param_into_net, load_checkpoint from mindarmour.utils import LogUtil @@ -187,12 +187,13 @@ if __name__ == '__main__': amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) # checkpoint save + callbacks = [LossMonitor()] if args.rank_save_ckpt_flag: ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval*args.steps_per_epoch, keep_checkpoint_max=args.ckpt_save_max) ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.outputs_dir, prefix='{}'.format(args.rank)) - callbacks = ckpt_cb + callbacks.append(ckpt_cb) model.train(args.max_epoch, dataset, callbacks=callbacks) diff --git a/mindarmour/diff_privacy/evaluation/attacker.py b/mindarmour/diff_privacy/evaluation/attacker.py old mode 100644 new mode 100755 index f4c1c24..80722d1 --- a/mindarmour/diff_privacy/evaluation/attacker.py +++ b/mindarmour/diff_privacy/evaluation/attacker.py @@ -22,6 +22,9 @@ from sklearn.model_selection import GridSearchCV from sklearn.model_selection import RandomizedSearchCV +method_list = ["lr", "knn", "rf", "mlp"] + + def _attack_knn(features, labels, param_grid): """ Train and return a KNN model. @@ -119,12 +122,13 @@ def get_attack_model(features, labels, config): sklearn.BaseEstimator, trained model specify by config["method"]. """ method = str.lower(config["method"]) + if method == "knn": return _attack_knn(features, labels, config["params"]) - if method in ["lr", "logitic regression"]: + if method == "lr": return _attack_lr(features, labels, config["params"]) if method == "mlp": return _attack_mlpc(features, labels, config["params"]) - if method in ["rf", "random forest"]: + if method == "rf": return _attack_rf(features, labels, config["params"]) - raise ValueError("Method {} is not support.".format(config["method"])) + return None diff --git a/mindarmour/diff_privacy/evaluation/membership_inference.py b/mindarmour/diff_privacy/evaluation/membership_inference.py old mode 100644 new mode 100755 index c0f802e..4ff0ce0 --- a/mindarmour/diff_privacy/evaluation/membership_inference.py +++ b/mindarmour/diff_privacy/evaluation/membership_inference.py @@ -19,10 +19,11 @@ import numpy as np import mindspore as ms from mindspore.train import Model +from mindspore.dataset.engine import Dataset import mindspore.nn as nn import mindspore.context as context from mindspore import Tensor -from mindarmour.diff_privacy.evaluation.attacker import get_attack_model +from mindarmour.diff_privacy.evaluation.attacker import get_attack_model, method_list def _eval_info(pred, truth, option): """ @@ -89,7 +90,7 @@ class MembershipInference: def __init__(self, model): if not isinstance(model, Model): - raise TypeError("Type of model must be {}, but got {}.".format(type(Model), type(model))) + raise TypeError("Type of parameter 'model' must be Model, but got {}.".format(type(model))) self.model = model self.attack_list = [] @@ -104,8 +105,24 @@ class MembershipInference: attack_config (list): Parameter setting for the attack model. Raises: - ValueError: If the method in attack_config is not in ["LR", "KNN", "RF", "MLPC"]. + KeyError: If each config in attack_config doesn't have keys {"method", "params"} + ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. """ + if not isinstance(dataset_train, Dataset): + raise TypeError("Type of parameter 'dataset_train' must be Dataset, " + "but got {}".format(type(dataset_train))) + + if not isinstance(dataset_test, Dataset): + raise TypeError("Type of parameter 'test_train' must be Dataset, " + "but got {}".format(type(dataset_train))) + + for config in attack_config: + if {"params", "method"} != set(config.keys()): + raise KeyError("Each config in attack_config must have keys 'method' and 'params', " + "but your key value is {}.".format(set(config.keys()))) + if str.lower(config["method"]) not in method_list: + raise ValueError("Method {} is not support.".format(config["method"])) + features, labels = self._transform(dataset_train, dataset_test) for config in attack_config: self.attack_list.append(get_attack_model(features, labels, config)) @@ -124,6 +141,24 @@ class MembershipInference: Returns: list, Each element contains an evaluation indicator for the attack model. """ + if not isinstance(dataset_train, Dataset): + raise TypeError("Type of parameter 'dataset_train' must be Dataset, " + "but got {}".format(type(dataset_train))) + + if not isinstance(dataset_test, Dataset): + raise TypeError("Type of parameter 'test_train' must be Dataset, " + "but got {}".format(type(dataset_train))) + + if not isinstance(metrics, (list, tuple)): + raise TypeError("Type of parameter 'config' must be Union[list, tuple], but got " + "{}.".format(type(metrics))) + + metrics = set(metrics) + metrics_list = {"precision", "accruacy", "recall"} + if metrics > metrics_list: + raise ValueError("Element in 'metrics' must be in {}, but got " + "{}.".format(metrics_list, metrics)) + result = [] features, labels = self._transform(dataset_train, dataset_test) for attacker in self.attack_list: