|
|
@@ -24,6 +24,11 @@ import mindspore.nn as nn |
|
|
|
import mindspore.context as context |
|
|
|
from mindspore import Tensor |
|
|
|
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model |
|
|
|
from mindarmour.utils.logger import LogUtil |
|
|
|
|
|
|
|
LOGGER = LogUtil.get_instance() |
|
|
|
TAG = "MembershipInference" |
|
|
|
|
|
|
|
|
|
|
|
def _eval_info(pred, truth, option): |
|
|
|
""" |
|
|
@@ -43,7 +48,9 @@ def _eval_info(pred, truth, option): |
|
|
|
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"]. |
|
|
|
""" |
|
|
|
if pred.size == 0 or truth.size == 0: |
|
|
|
raise ValueError("Size of pred or truth is 0.") |
|
|
|
msg = "Size of pred or truth is 0." |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
if option == "accuracy": |
|
|
|
count = np.sum(pred == truth) |
|
|
@@ -59,7 +66,9 @@ def _eval_info(pred, truth, option): |
|
|
|
return -1 |
|
|
|
return count / np.sum(truth) |
|
|
|
|
|
|
|
raise ValueError("The metric value {} is undefined.".format(option)) |
|
|
|
msg = "The metric value {} is undefined.".format(option) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
|
|
|
class MembershipInference: |
|
|
@@ -91,7 +100,10 @@ class MembershipInference: |
|
|
|
|
|
|
|
def __init__(self, model): |
|
|
|
if not isinstance(model, Model): |
|
|
|
raise TypeError("Type of parameter 'model' must be Model, but got {}.".format(type(model))) |
|
|
|
msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
self.model = model |
|
|
|
self.method_list = ["knn", "lr", "mlp", "rf"] |
|
|
|
self.attack_list = [] |
|
|
@@ -117,26 +129,34 @@ class MembershipInference: |
|
|
|
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. |
|
|
|
""" |
|
|
|
if not isinstance(dataset_train, Dataset): |
|
|
|
raise TypeError("Type of parameter 'dataset_train' must be Dataset, " |
|
|
|
"but got {}".format(type(dataset_train))) |
|
|
|
msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
if not isinstance(dataset_test, Dataset): |
|
|
|
raise TypeError("Type of parameter 'test_train' must be Dataset, " |
|
|
|
"but got {}".format(type(dataset_train))) |
|
|
|
msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
if not isinstance(attack_config, list): |
|
|
|
raise TypeError("Type of parameter 'attack_config' must be list, " |
|
|
|
"but got {}.".format(type(attack_config))) |
|
|
|
msg = "Type of parameter 'attack_config' must be list, but got {}.".format(type(attack_config)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
for config in attack_config: |
|
|
|
if not isinstance(config, dict): |
|
|
|
raise TypeError("Type of each config in 'attack_config' must be dict, " |
|
|
|
"but got {}.".format(type(config))) |
|
|
|
msg = "Type of each config in 'attack_config' must be dict, but got {}.".format(type(config)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
if {"params", "method"} != set(config.keys()): |
|
|
|
raise KeyError("Each config in attack_config must have keys 'method' and 'params', " |
|
|
|
"but your key value is {}.".format(set(config.keys()))) |
|
|
|
msg = "Each config in attack_config must have keys 'method' and 'params'," \ |
|
|
|
"but your key value is {}.".format(set(config.keys())) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise KeyError(msg) |
|
|
|
if str.lower(config["method"]) not in self.method_list: |
|
|
|
raise ValueError("Method {} is not support.".format(config["method"])) |
|
|
|
msg = "Method {} is not support.".format(config["method"]) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
features, labels = self._transform(dataset_train, dataset_test) |
|
|
|
for config in attack_config: |
|
|
@@ -157,22 +177,26 @@ class MembershipInference: |
|
|
|
list, Each element contains an evaluation indicator for the attack model. |
|
|
|
""" |
|
|
|
if not isinstance(dataset_train, Dataset): |
|
|
|
raise TypeError("Type of parameter 'dataset_train' must be Dataset, " |
|
|
|
"but got {}".format(type(dataset_train))) |
|
|
|
msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
if not isinstance(dataset_test, Dataset): |
|
|
|
raise TypeError("Type of parameter 'test_train' must be Dataset, " |
|
|
|
"but got {}".format(type(dataset_train))) |
|
|
|
msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
if not isinstance(metrics, (list, tuple)): |
|
|
|
raise TypeError("Type of parameter 'config' must be Union[list, tuple], but got " |
|
|
|
"{}.".format(type(metrics))) |
|
|
|
msg = "Type of parameter 'config' must be Union[list, tuple], but got {}.".format(type(metrics)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
metrics = set(metrics) |
|
|
|
metrics_list = {"precision", "accruacy", "recall"} |
|
|
|
if metrics > metrics_list: |
|
|
|
raise ValueError("Element in 'metrics' must be in {}, but got " |
|
|
|
"{}.".format(metrics_list, metrics)) |
|
|
|
msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
result = [] |
|
|
|
features, labels = self._transform(dataset_train, dataset_test) |
|
|
@@ -221,8 +245,10 @@ class MembershipInference: |
|
|
|
- numpy.ndarray, Labels for each sample, Shape is (N,). |
|
|
|
""" |
|
|
|
if context.get_context("device_target") != "Ascend": |
|
|
|
raise RuntimeError("The target device must be Ascend, " |
|
|
|
"but current is {}.".format(context.get_context("device_target"))) |
|
|
|
msg = "The target device must be Ascend, " \ |
|
|
|
"but current is {}.".format(context.get_context("device_target")) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise RuntimeError(msg) |
|
|
|
loss_logits = np.array([]) |
|
|
|
for batch in dataset_x.create_dict_iterator(): |
|
|
|
batch_data = Tensor(batch['image'], ms.float32) |
|
|
@@ -243,5 +269,7 @@ class MembershipInference: |
|
|
|
elif label == 0: |
|
|
|
labels = np.zeros(len(loss_logits), np.int32) |
|
|
|
else: |
|
|
|
raise ValueError("The value of label must be 0 or 1, but got {}.".format(label)) |
|
|
|
msg = "The value of label must be 0 or 1, but got {}.".format(label) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
return loss_logits, labels |