diff --git a/mindarmour/privacy/evaluation/_check_config.py b/mindarmour/privacy/evaluation/_check_config.py index 2c65f61..48bd6f8 100644 --- a/mindarmour/privacy/evaluation/_check_config.py +++ b/mindarmour/privacy/evaluation/_check_config.py @@ -15,6 +15,8 @@ Verify attack config """ +import numpy as np + from mindarmour.utils._check_param import check_param_type from mindarmour.utils.logger import LogUtil @@ -138,7 +140,6 @@ _VALID_CONFIG_CHECKLIST = { "min_impurity_decrease": [_is_non_negative_float], "min_impurity_split": [{None}, _is_positive_float], "bootstrap": [{True, False}], - "oob_scroe": [{True, False}], "n_jobs": [_is_positive_int, {None}], "random_state": None, "verbose": [_is_non_negative_int], @@ -182,7 +183,7 @@ def _check_config(attack_config, config_checklist): for param_key in params.keys(): param_value = params[param_key] candidate_values = config_checklist[method][param_key] - check_param_type('param_value', param_value, list) + check_param_type('param_value', param_value, (list, tuple, np.ndarray)) if candidate_values is None: continue diff --git a/mindarmour/privacy/evaluation/membership_inference.py b/mindarmour/privacy/evaluation/membership_inference.py index 3d4f88a..d395cb7 100644 --- a/mindarmour/privacy/evaluation/membership_inference.py +++ b/mindarmour/privacy/evaluation/membership_inference.py @@ -39,7 +39,7 @@ def _eval_info(pred, truth, option): Args: pred (numpy.ndarray): Predictions for each sample. truth (numpy.ndarray): Ground truth for each sample. - option(str): Type of evaluation indicators; Possible + option (str): Type of evaluation indicators; Possible values are 'precision', 'accuracy' and 'recall'. Returns: @@ -77,7 +77,7 @@ def _softmax_cross_entropy(logits, labels): Args: logits (numpy.ndarray): Numpy array of shape(N, C). - labels (numpy.ndarray): Numpy array of shape(N, ) + labels (numpy.ndarray): Numpy array of shape(N, ). Returns: numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits. @@ -136,7 +136,7 @@ class MembershipInference: def train(self, dataset_train, dataset_test, attack_config): """ - Depending on the configuration, use the input data set to train the attack model. + Depending on the configuration, use the input dataset to train the attack model. Save the attack model to self._attack_list. Args: @@ -148,10 +148,13 @@ class MembershipInference: The support methods are knn, lr, mlp and rf, and the params of each method must within the range of changeable parameters. Tips of params implement can be found in - "https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". + https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html + https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html + https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html + https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPRegressor.html Raises: - KeyError: If any config in attack_config doesn't have keys {"method", "params"} + KeyError: If any config in attack_config doesn't have keys {"method", "params"}. NameError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. """ check_param_type("dataset_train", dataset_train, Dataset)