Browse Source

Intercept the ConvergenceWarning from sklearn in some cases.

tags/v1.2.1
liuluobin 4 years ago
parent
commit
e5b7cfd7fd
2 changed files with 15 additions and 11 deletions
  1. +2
    -2
      examples/model_security/model_attacks/white_box/mnist_attack_pgd.py
  2. +13
    -9
      mindarmour/privacy/evaluation/attacker.py

+ 2
- 2
examples/model_security/model_attacks/white_box/mnist_attack_pgd.py View File

@@ -74,10 +74,10 @@ def test_projected_gradient_descent_method():
# attacking # attacking
loss = SoftmaxCrossEntropyWithLogits(sparse=True) loss = SoftmaxCrossEntropyWithLogits(sparse=True)
attack = ProjectedGradientDescent(net, eps=0.3, loss_fn=loss) attack = ProjectedGradientDescent(net, eps=0.3, loss_fn=loss)
start_time = time.clock()
start_time = time.process_time()
adv_data = attack.batch_generate(np.concatenate(test_images), adv_data = attack.batch_generate(np.concatenate(test_images),
true_labels, batch_size=32) true_labels, batch_size=32)
stop_time = time.clock()
stop_time = time.process_time()
np.save('./adv_data', adv_data) np.save('./adv_data', adv_data)
pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy() pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy()
# rescale predict confidences into (0, 1). # rescale predict confidences into (0, 1).


+ 13
- 9
mindarmour/privacy/evaluation/attacker.py View File

@@ -14,12 +14,15 @@
""" """
Attacker of Membership Inference. Attacker of Membership Inference.
""" """
import warnings

from sklearn.neighbors import KNeighborsClassifier from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV from sklearn.model_selection import RandomizedSearchCV
from sklearn.exceptions import ConvergenceWarning


from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil


@@ -141,15 +144,16 @@ def get_attack_model(features, labels, config, n_jobs=-1):
>>> attack_model = get_attack_model(features, labels, config) >>> attack_model = get_attack_model(features, labels, config)
""" """
method = str.lower(config["method"]) method = str.lower(config["method"])

if method == "knn":
return _attack_knn(features, labels, config["params"], n_jobs)
if method == "lr":
return _attack_lr(features, labels, config["params"], n_jobs)
if method == "mlp":
return _attack_mlpc(features, labels, config["params"], n_jobs)
if method == "rf":
return _attack_rf(features, labels, config["params"], n_jobs)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=ConvergenceWarning)
if method == "knn":
return _attack_knn(features, labels, config["params"], n_jobs)
if method == "lr":
return _attack_lr(features, labels, config["params"], n_jobs)
if method == "mlp":
return _attack_mlpc(features, labels, config["params"], n_jobs)
if method == "rf":
return _attack_rf(features, labels, config["params"], n_jobs)


msg = "Method {} is not supported.".format(config["method"]) msg = "Method {} is not supported.".format(config["method"])
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)


Loading…
Cancel
Save