|
|
@@ -39,7 +39,7 @@ def _eval_info(pred, truth, option): |
|
|
|
Args: |
|
|
|
pred (numpy.ndarray): Predictions for each sample. |
|
|
|
truth (numpy.ndarray): Ground truth for each sample. |
|
|
|
option(str): Type of evaluation indicators; Possible |
|
|
|
option (str): Type of evaluation indicators; Possible |
|
|
|
values are 'precision', 'accuracy' and 'recall'. |
|
|
|
|
|
|
|
Returns: |
|
|
@@ -77,7 +77,7 @@ def _softmax_cross_entropy(logits, labels): |
|
|
|
|
|
|
|
Args: |
|
|
|
logits (numpy.ndarray): Numpy array of shape(N, C). |
|
|
|
labels (numpy.ndarray): Numpy array of shape(N, ) |
|
|
|
labels (numpy.ndarray): Numpy array of shape(N, ). |
|
|
|
|
|
|
|
Returns: |
|
|
|
numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits. |
|
|
@@ -136,7 +136,7 @@ class MembershipInference: |
|
|
|
|
|
|
|
def train(self, dataset_train, dataset_test, attack_config): |
|
|
|
""" |
|
|
|
Depending on the configuration, use the input data set to train the attack model. |
|
|
|
Depending on the configuration, use the input dataset to train the attack model. |
|
|
|
Save the attack model to self._attack_list. |
|
|
|
|
|
|
|
Args: |
|
|
@@ -148,10 +148,13 @@ class MembershipInference: |
|
|
|
The support methods are knn, lr, mlp and rf, and the params of each method |
|
|
|
must within the range of changeable parameters. Tips of params implement |
|
|
|
can be found in |
|
|
|
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". |
|
|
|
https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html |
|
|
|
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html |
|
|
|
https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html |
|
|
|
https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPRegressor.html |
|
|
|
|
|
|
|
Raises: |
|
|
|
KeyError: If any config in attack_config doesn't have keys {"method", "params"} |
|
|
|
KeyError: If any config in attack_config doesn't have keys {"method", "params"}. |
|
|
|
NameError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. |
|
|
|
""" |
|
|
|
check_param_type("dataset_train", dataset_train, Dataset) |
|
|
|