Browse Source

Intercept the ConvergenceWarning from sklearn in some cases.

tags/v1.2.1
liuluobin 4 years ago
parent
commit
e5b7cfd7fd
2 changed files with 15 additions and 11 deletions
  1. +2
    -2
      examples/model_security/model_attacks/white_box/mnist_attack_pgd.py
  2. +13
    -9
      mindarmour/privacy/evaluation/attacker.py

+ 2
- 2
examples/model_security/model_attacks/white_box/mnist_attack_pgd.py View File

@@ -74,10 +74,10 @@ def test_projected_gradient_descent_method():
# attacking
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
attack = ProjectedGradientDescent(net, eps=0.3, loss_fn=loss)
start_time = time.clock()
start_time = time.process_time()
adv_data = attack.batch_generate(np.concatenate(test_images),
true_labels, batch_size=32)
stop_time = time.clock()
stop_time = time.process_time()
np.save('./adv_data', adv_data)
pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy()
# rescale predict confidences into (0, 1).


+ 13
- 9
mindarmour/privacy/evaluation/attacker.py View File

@@ -14,12 +14,15 @@
"""
Attacker of Membership Inference.
"""
import warnings

from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV
from sklearn.exceptions import ConvergenceWarning

from mindarmour.utils.logger import LogUtil

@@ -141,15 +144,16 @@ def get_attack_model(features, labels, config, n_jobs=-1):
>>> attack_model = get_attack_model(features, labels, config)
"""
method = str.lower(config["method"])

if method == "knn":
return _attack_knn(features, labels, config["params"], n_jobs)
if method == "lr":
return _attack_lr(features, labels, config["params"], n_jobs)
if method == "mlp":
return _attack_mlpc(features, labels, config["params"], n_jobs)
if method == "rf":
return _attack_rf(features, labels, config["params"], n_jobs)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=ConvergenceWarning)
if method == "knn":
return _attack_knn(features, labels, config["params"], n_jobs)
if method == "lr":
return _attack_lr(features, labels, config["params"], n_jobs)
if method == "mlp":
return _attack_mlpc(features, labels, config["params"], n_jobs)
if method == "rf":
return _attack_rf(features, labels, config["params"], n_jobs)

msg = "Method {} is not supported.".format(config["method"])
LOGGER.error(TAG, msg)


Loading…
Cancel
Save