diff --git a/mindarmour/privacy/evaluation/membership_inference.py b/mindarmour/privacy/evaluation/membership_inference.py index d9289a4..b48ca58 100644 --- a/mindarmour/privacy/evaluation/membership_inference.py +++ b/mindarmour/privacy/evaluation/membership_inference.py @@ -71,24 +71,24 @@ def _eval_info(pred, truth, option): raise ValueError(msg) -def _softmax_cross_entropy(logits, labels): +def _softmax_cross_entropy(logits, labels, epsilon=1e-12): """ Calculate the SoftmaxCrossEntropy result between logits and labels. Args: logits (numpy.ndarray): Numpy array of shape(N, C). labels (numpy.ndarray): Numpy array of shape(N, ). + epsilon (float): The calculated value of softmax will be clipped into [epsilon, 1 - epsilon]. Default: 1e-12. Returns: numpy.ndarray: numpy array of shape(N, ), containing loss value for each vector in logits. """ labels = np.eye(logits.shape[1])[labels].astype(np.int32) - logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) - loss = -1*np.sum(labels*np.log(logits), axis=1) - nan_index = np.isnan(loss) - if np.any(nan_index): - loss[nan_index] = 0 + exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True)) + predictions = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) + predictions = np.clip(predictions, epsilon, 1.0 - epsilon) + loss = -1 * np.sum(labels*np.log(predictions), axis=-1) return loss diff --git a/tests/ut/python/privacy/evaluation/test_membership_inference.py b/tests/ut/python/privacy/evaluation/test_membership_inference.py index 8ff0a09..8863790 100644 --- a/tests/ut/python/privacy/evaluation/test_membership_inference.py +++ b/tests/ut/python/privacy/evaluation/test_membership_inference.py @@ -69,7 +69,7 @@ def test_membership_inference_object_train(): loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(network=net, loss_fn=loss, optimizer=opt) - inference_model = MembershipInference(model, -1) + inference_model = MembershipInference(model, 2) assert isinstance(inference_model, MembershipInference) config = [{