|
|
@@ -11,13 +11,13 @@ import numpy as np |
|
|
|
from ..core.tensor.utils import make_shape_tuple |
|
|
|
from ..tensor import Tensor |
|
|
|
from .elemwise import abs, equal, exp, log, maximum, pow, relu |
|
|
|
from .nn import indexing_one_hot |
|
|
|
from .nn import indexing_one_hot, logsigmoid, logsoftmax |
|
|
|
from .tensor import where |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
"l1_loss", |
|
|
|
"square_loss", |
|
|
|
"cross_entropy_with_softmax", |
|
|
|
"cross_entropy", |
|
|
|
"binary_cross_entropy", |
|
|
|
"hinge_loss", |
|
|
|
] |
|
|
@@ -120,10 +120,16 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor: |
|
|
|
return (diff ** 2).mean() |
|
|
|
|
|
|
|
|
|
|
|
def cross_entropy_with_softmax( |
|
|
|
pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0 |
|
|
|
def cross_entropy( |
|
|
|
pred: Tensor, |
|
|
|
label: Tensor, |
|
|
|
axis: int = 1, |
|
|
|
with_logits: bool = True, |
|
|
|
label_smooth: float = 0, |
|
|
|
) -> Tensor: |
|
|
|
r"""Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`. |
|
|
|
r"""Compute the multi-class cross entropy loss (using logits by default). |
|
|
|
|
|
|
|
By default, prediction is assumed to be logits, whose softmax gives probabilities. |
|
|
|
|
|
|
|
It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`. |
|
|
|
|
|
|
@@ -137,6 +143,7 @@ def cross_entropy_with_softmax( |
|
|
|
:param pred: input tensor representing the predicted probability. |
|
|
|
:param label: input tensor representing the classification label. |
|
|
|
:param axis: an axis along which softmax will be applied. Default: 1 |
|
|
|
:param with_logits: whether to apply softmax first. Default: True |
|
|
|
:param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0 |
|
|
|
:return: loss value. |
|
|
|
|
|
|
@@ -150,9 +157,9 @@ def cross_entropy_with_softmax( |
|
|
|
|
|
|
|
data_shape = (1, 2) |
|
|
|
label_shape = (1, ) |
|
|
|
pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape)) |
|
|
|
pred = tensor(np.array([0, 0], dtype=np.float32).reshape(data_shape)) |
|
|
|
label = tensor(np.ones(label_shape, dtype=np.int32)) |
|
|
|
loss = F.cross_entropy_with_softmax(pred, label) |
|
|
|
loss = F.cross_entropy(pred, label) |
|
|
|
print(loss.numpy()) |
|
|
|
|
|
|
|
Outputs: |
|
|
@@ -170,26 +177,43 @@ def cross_entropy_with_softmax( |
|
|
|
) |
|
|
|
|
|
|
|
num_classes = pred.shape[axis] |
|
|
|
no_label_smooth = ( |
|
|
|
label_smooth is None or type(label_smooth) in (int, float) and label_smooth == 0 |
|
|
|
) |
|
|
|
|
|
|
|
if not with_logits: |
|
|
|
if no_label_smooth: |
|
|
|
return -log(indexing_one_hot(pred, label, axis)).mean() |
|
|
|
pred = log(pred) |
|
|
|
return ( |
|
|
|
label_smooth * pred.mean() |
|
|
|
- (1 - label_smooth) * indexing_one_hot(pred, label, axis).mean() |
|
|
|
) |
|
|
|
|
|
|
|
# Denominator of the softmax |
|
|
|
offset = pred.max(axis=axis, keepdims=True).detach() |
|
|
|
offset = pred.detach().max(axis=axis, keepdims=True) |
|
|
|
pred = pred - offset |
|
|
|
down = exp(pred).sum(axis=axis, keepdims=True) |
|
|
|
down = log(exp(pred).sum(axis=axis, keepdims=True)) |
|
|
|
|
|
|
|
up = indexing_one_hot(pred, label, axis) |
|
|
|
|
|
|
|
if label_smooth != 0: |
|
|
|
if not no_label_smooth: |
|
|
|
factor = label_smooth / num_classes |
|
|
|
up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor |
|
|
|
|
|
|
|
return (log(down) - up).mean() |
|
|
|
return (down - up).mean() |
|
|
|
|
|
|
|
|
|
|
|
def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: |
|
|
|
r"""Function that measures the Binary Cross Entropy between the target and the prediction. |
|
|
|
def binary_cross_entropy( |
|
|
|
pred: Tensor, label: Tensor, with_logits: bool = True |
|
|
|
) -> Tensor: |
|
|
|
r"""Compute the binary cross entropy loss (using logits by default). |
|
|
|
|
|
|
|
By default, prediction is assumed to be logits, whose sigmoid gives probabilities. |
|
|
|
|
|
|
|
:param pred: `(N, *)`, where `*` means any number of additional dimensions. |
|
|
|
:param label: `(N, *)`, same shape as the input. |
|
|
|
:param with_logits: bool, whether to apply sigmoid first. Default: True |
|
|
|
:return: loss value. |
|
|
|
|
|
|
|
Examples: |
|
|
@@ -200,7 +224,7 @@ def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: |
|
|
|
from megengine import tensor |
|
|
|
import megengine.functional as F |
|
|
|
|
|
|
|
pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(1, 2)) |
|
|
|
pred = tensor(np.array([0, 0], dtype=np.float32).reshape(1, 2)) |
|
|
|
label = tensor(np.ones((1, 2), dtype=np.float32)) |
|
|
|
loss = F.binary_cross_entropy(pred, label) |
|
|
|
print(loss.numpy()) |
|
|
@@ -212,7 +236,11 @@ def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: |
|
|
|
[0.6931] |
|
|
|
|
|
|
|
""" |
|
|
|
return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean() |
|
|
|
if not with_logits: |
|
|
|
return -(label * log(pred) + (1 - label) * log(1 - pred)).mean() |
|
|
|
# logsigmoid(pred) and logsigmoid(-pred) has common sub-expression |
|
|
|
# hopefully the backend would optimize this |
|
|
|
return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred)).mean() |
|
|
|
|
|
|
|
|
|
|
|
def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: |
|
|
|