|
|
@@ -19,10 +19,11 @@ import numpy as np |
|
|
|
|
|
|
|
import mindspore as ms |
|
|
|
from mindspore.train import Model |
|
|
|
from mindspore.dataset.engine import Dataset |
|
|
|
import mindspore.nn as nn |
|
|
|
import mindspore.context as context |
|
|
|
from mindspore import Tensor |
|
|
|
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model |
|
|
|
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model, method_list |
|
|
|
|
|
|
|
def _eval_info(pred, truth, option): |
|
|
|
""" |
|
|
@@ -89,7 +90,7 @@ class MembershipInference: |
|
|
|
|
|
|
|
def __init__(self, model): |
|
|
|
if not isinstance(model, Model): |
|
|
|
raise TypeError("Type of model must be {}, but got {}.".format(type(Model), type(model))) |
|
|
|
raise TypeError("Type of parameter 'model' must be Model, but got {}.".format(type(model))) |
|
|
|
self.model = model |
|
|
|
self.attack_list = [] |
|
|
|
|
|
|
@@ -104,8 +105,24 @@ class MembershipInference: |
|
|
|
attack_config (list): Parameter setting for the attack model. |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: If the method in attack_config is not in ["LR", "KNN", "RF", "MLPC"]. |
|
|
|
KeyError: If each config in attack_config doesn't have keys {"method", "params"} |
|
|
|
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. |
|
|
|
""" |
|
|
|
if not isinstance(dataset_train, Dataset): |
|
|
|
raise TypeError("Type of parameter 'dataset_train' must be Dataset, " |
|
|
|
"but got {}".format(type(dataset_train))) |
|
|
|
|
|
|
|
if not isinstance(dataset_test, Dataset): |
|
|
|
raise TypeError("Type of parameter 'test_train' must be Dataset, " |
|
|
|
"but got {}".format(type(dataset_train))) |
|
|
|
|
|
|
|
for config in attack_config: |
|
|
|
if {"params", "method"} != set(config.keys()): |
|
|
|
raise KeyError("Each config in attack_config must have keys 'method' and 'params', " |
|
|
|
"but your key value is {}.".format(set(config.keys()))) |
|
|
|
if str.lower(config["method"]) not in method_list: |
|
|
|
raise ValueError("Method {} is not support.".format(config["method"])) |
|
|
|
|
|
|
|
features, labels = self._transform(dataset_train, dataset_test) |
|
|
|
for config in attack_config: |
|
|
|
self.attack_list.append(get_attack_model(features, labels, config)) |
|
|
@@ -124,6 +141,24 @@ class MembershipInference: |
|
|
|
Returns: |
|
|
|
list, Each element contains an evaluation indicator for the attack model. |
|
|
|
""" |
|
|
|
if not isinstance(dataset_train, Dataset): |
|
|
|
raise TypeError("Type of parameter 'dataset_train' must be Dataset, " |
|
|
|
"but got {}".format(type(dataset_train))) |
|
|
|
|
|
|
|
if not isinstance(dataset_test, Dataset): |
|
|
|
raise TypeError("Type of parameter 'test_train' must be Dataset, " |
|
|
|
"but got {}".format(type(dataset_train))) |
|
|
|
|
|
|
|
if not isinstance(metrics, (list, tuple)): |
|
|
|
raise TypeError("Type of parameter 'config' must be Union[list, tuple], but got " |
|
|
|
"{}.".format(type(metrics))) |
|
|
|
|
|
|
|
metrics = set(metrics) |
|
|
|
metrics_list = {"precision", "accruacy", "recall"} |
|
|
|
if metrics > metrics_list: |
|
|
|
raise ValueError("Element in 'metrics' must be in {}, but got " |
|
|
|
"{}.".format(metrics_list, metrics)) |
|
|
|
|
|
|
|
result = [] |
|
|
|
features, labels = self._transform(dataset_train, dataset_test) |
|
|
|
for attacker in self.attack_list: |
|
|
|