|
|
@@ -14,12 +14,15 @@ |
|
|
|
""" |
|
|
|
Attacker of Membership Inference. |
|
|
|
""" |
|
|
|
import warnings |
|
|
|
|
|
|
|
from sklearn.neighbors import KNeighborsClassifier |
|
|
|
from sklearn.linear_model import LogisticRegression |
|
|
|
from sklearn.neural_network import MLPClassifier |
|
|
|
from sklearn.ensemble import RandomForestClassifier |
|
|
|
from sklearn.model_selection import GridSearchCV |
|
|
|
from sklearn.model_selection import RandomizedSearchCV |
|
|
|
from sklearn.exceptions import ConvergenceWarning |
|
|
|
|
|
|
|
from mindarmour.utils.logger import LogUtil |
|
|
|
|
|
|
@@ -141,15 +144,16 @@ def get_attack_model(features, labels, config, n_jobs=-1): |
|
|
|
>>> attack_model = get_attack_model(features, labels, config) |
|
|
|
""" |
|
|
|
method = str.lower(config["method"]) |
|
|
|
|
|
|
|
if method == "knn": |
|
|
|
return _attack_knn(features, labels, config["params"], n_jobs) |
|
|
|
if method == "lr": |
|
|
|
return _attack_lr(features, labels, config["params"], n_jobs) |
|
|
|
if method == "mlp": |
|
|
|
return _attack_mlpc(features, labels, config["params"], n_jobs) |
|
|
|
if method == "rf": |
|
|
|
return _attack_rf(features, labels, config["params"], n_jobs) |
|
|
|
with warnings.catch_warnings(): |
|
|
|
warnings.filterwarnings('ignore', category=ConvergenceWarning) |
|
|
|
if method == "knn": |
|
|
|
return _attack_knn(features, labels, config["params"], n_jobs) |
|
|
|
if method == "lr": |
|
|
|
return _attack_lr(features, labels, config["params"], n_jobs) |
|
|
|
if method == "mlp": |
|
|
|
return _attack_mlpc(features, labels, config["params"], n_jobs) |
|
|
|
if method == "rf": |
|
|
|
return _attack_rf(features, labels, config["params"], n_jobs) |
|
|
|
|
|
|
|
msg = "Method {} is not supported.".format(config["method"]) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|