From e5b7cfd7fdfbe5b0942ed46184159e35a491588c Mon Sep 17 00:00:00 2001 From: liuluobin Date: Wed, 23 Dec 2020 14:33:33 +0800 Subject: [PATCH] Intercept the ConvergenceWarning from sklearn in some cases. --- .../model_attacks/white_box/mnist_attack_pgd.py | 4 ++-- mindarmour/privacy/evaluation/attacker.py | 22 +++++++++++++--------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/examples/model_security/model_attacks/white_box/mnist_attack_pgd.py b/examples/model_security/model_attacks/white_box/mnist_attack_pgd.py index 1aff061..6a245d5 100644 --- a/examples/model_security/model_attacks/white_box/mnist_attack_pgd.py +++ b/examples/model_security/model_attacks/white_box/mnist_attack_pgd.py @@ -74,10 +74,10 @@ def test_projected_gradient_descent_method(): # attacking loss = SoftmaxCrossEntropyWithLogits(sparse=True) attack = ProjectedGradientDescent(net, eps=0.3, loss_fn=loss) - start_time = time.clock() + start_time = time.process_time() adv_data = attack.batch_generate(np.concatenate(test_images), true_labels, batch_size=32) - stop_time = time.clock() + stop_time = time.process_time() np.save('./adv_data', adv_data) pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy() # rescale predict confidences into (0, 1). diff --git a/mindarmour/privacy/evaluation/attacker.py b/mindarmour/privacy/evaluation/attacker.py index 5733a83..2d798db 100644 --- a/mindarmour/privacy/evaluation/attacker.py +++ b/mindarmour/privacy/evaluation/attacker.py @@ -14,12 +14,15 @@ """ Attacker of Membership Inference. """ +import warnings + from sklearn.neighbors import KNeighborsClassifier from sklearn.linear_model import LogisticRegression from sklearn.neural_network import MLPClassifier from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import GridSearchCV from sklearn.model_selection import RandomizedSearchCV +from sklearn.exceptions import ConvergenceWarning from mindarmour.utils.logger import LogUtil @@ -141,15 +144,16 @@ def get_attack_model(features, labels, config, n_jobs=-1): >>> attack_model = get_attack_model(features, labels, config) """ method = str.lower(config["method"]) - - if method == "knn": - return _attack_knn(features, labels, config["params"], n_jobs) - if method == "lr": - return _attack_lr(features, labels, config["params"], n_jobs) - if method == "mlp": - return _attack_mlpc(features, labels, config["params"], n_jobs) - if method == "rf": - return _attack_rf(features, labels, config["params"], n_jobs) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=ConvergenceWarning) + if method == "knn": + return _attack_knn(features, labels, config["params"], n_jobs) + if method == "lr": + return _attack_lr(features, labels, config["params"], n_jobs) + if method == "mlp": + return _attack_mlpc(features, labels, config["params"], n_jobs) + if method == "rf": + return _attack_rf(features, labels, config["params"], n_jobs) msg = "Method {} is not supported.".format(config["method"]) LOGGER.error(TAG, msg)