|
|
@@ -20,8 +20,6 @@ import numpy as np |
|
|
|
import mindspore as ms |
|
|
|
from mindspore.train import Model |
|
|
|
from mindspore.dataset.engine import Dataset |
|
|
|
import mindspore.nn as nn |
|
|
|
import mindspore.context as context |
|
|
|
from mindspore import Tensor |
|
|
|
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model |
|
|
|
from mindarmour.utils.logger import LogUtil |
|
|
@@ -71,6 +69,22 @@ def _eval_info(pred, truth, option): |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
|
|
|
def _softmax_cross_entropy(logits, labels): |
|
|
|
""" |
|
|
|
Calculate the SoftmaxCrossEntropy result between logits and labels. |
|
|
|
|
|
|
|
Args: |
|
|
|
logits (numpy.ndarray): Numpy array of shape(N, C). |
|
|
|
labels (numpy.ndarray): Numpy array of shape(N, ) |
|
|
|
|
|
|
|
Returns: |
|
|
|
numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits. |
|
|
|
""" |
|
|
|
labels = np.eye(logits.shape[1])[labels].astype(np.int32) |
|
|
|
logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) |
|
|
|
return -1*np.sum(labels*np.log(logits), axis=1) |
|
|
|
|
|
|
|
|
|
|
|
class MembershipInference: |
|
|
|
""" |
|
|
|
Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. |
|
|
@@ -192,8 +206,8 @@ class MembershipInference: |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
metrics = set(metrics) |
|
|
|
metrics_list = {"precision", "accruacy", "recall"} |
|
|
|
if metrics > metrics_list: |
|
|
|
metrics_list = {"precision", "accuracy", "recall"} |
|
|
|
if not metrics <= metrics_list: |
|
|
|
msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
@@ -244,19 +258,12 @@ class MembershipInference: |
|
|
|
N is the number of sample. C = 1 + dim(logits). |
|
|
|
- numpy.ndarray, Labels for each sample, Shape is (N,). |
|
|
|
""" |
|
|
|
if context.get_context("device_target") != "Ascend": |
|
|
|
msg = "The target device must be Ascend, " \ |
|
|
|
"but current is {}.".format(context.get_context("device_target")) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise RuntimeError(msg) |
|
|
|
loss_logits = np.array([]) |
|
|
|
for batch in dataset_x.create_dict_iterator(): |
|
|
|
batch_data = Tensor(batch['image'], ms.float32) |
|
|
|
batch_labels = Tensor(batch['label'], ms.int32) |
|
|
|
batch_logits = self.model.predict(batch_data) |
|
|
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction=None) |
|
|
|
batch_loss = loss(batch_logits, batch_labels).asnumpy() |
|
|
|
batch_logits = batch_logits.asnumpy() |
|
|
|
batch_labels = batch['label'].astype(np.int32) |
|
|
|
batch_logits = self.model.predict(batch_data).asnumpy() |
|
|
|
batch_loss = _softmax_cross_entropy(batch_logits, batch_labels) |
|
|
|
|
|
|
|
batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) |
|
|
|
if loss_logits.size == 0: |
|
|
|