Browse Source

!116 fixed config verify issue.

Merge pull request !116 from liuluobin/master
tags/v1.2.1
mindspore-ci-bot Gitee 4 years ago
parent
commit
af578fa91c
2 changed files with 11 additions and 7 deletions
  1. +3
    -2
      mindarmour/privacy/evaluation/_check_config.py
  2. +8
    -5
      mindarmour/privacy/evaluation/membership_inference.py

+ 3
- 2
mindarmour/privacy/evaluation/_check_config.py View File

@@ -15,6 +15,8 @@
Verify attack config
"""

import numpy as np

from mindarmour.utils._check_param import check_param_type
from mindarmour.utils.logger import LogUtil

@@ -138,7 +140,6 @@ _VALID_CONFIG_CHECKLIST = {
"min_impurity_decrease": [_is_non_negative_float],
"min_impurity_split": [{None}, _is_positive_float],
"bootstrap": [{True, False}],
"oob_scroe": [{True, False}],
"n_jobs": [_is_positive_int, {None}],
"random_state": None,
"verbose": [_is_non_negative_int],
@@ -182,7 +183,7 @@ def _check_config(attack_config, config_checklist):
for param_key in params.keys():
param_value = params[param_key]
candidate_values = config_checklist[method][param_key]
check_param_type('param_value', param_value, list)
check_param_type('param_value', param_value, (list, tuple, np.ndarray))

if candidate_values is None:
continue


+ 8
- 5
mindarmour/privacy/evaluation/membership_inference.py View File

@@ -39,7 +39,7 @@ def _eval_info(pred, truth, option):
Args:
pred (numpy.ndarray): Predictions for each sample.
truth (numpy.ndarray): Ground truth for each sample.
option(str): Type of evaluation indicators; Possible
option (str): Type of evaluation indicators; Possible
values are 'precision', 'accuracy' and 'recall'.

Returns:
@@ -77,7 +77,7 @@ def _softmax_cross_entropy(logits, labels):

Args:
logits (numpy.ndarray): Numpy array of shape(N, C).
labels (numpy.ndarray): Numpy array of shape(N, )
labels (numpy.ndarray): Numpy array of shape(N, ).

Returns:
numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits.
@@ -136,7 +136,7 @@ class MembershipInference:

def train(self, dataset_train, dataset_test, attack_config):
"""
Depending on the configuration, use the input data set to train the attack model.
Depending on the configuration, use the input dataset to train the attack model.
Save the attack model to self._attack_list.

Args:
@@ -148,10 +148,13 @@ class MembershipInference:
The support methods are knn, lr, mlp and rf, and the params of each method
must within the range of changeable parameters. Tips of params implement
can be found in
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html".
https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html
https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPRegressor.html

Raises:
KeyError: If any config in attack_config doesn't have keys {"method", "params"}
KeyError: If any config in attack_config doesn't have keys {"method", "params"}.
NameError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"].
"""
check_param_type("dataset_train", dataset_train, Dataset)


Loading…
Cancel
Save