Browse Source

feat(mge): rename cross_entropy_with_softmax -> cross_entropy

GitOrigin-RevId: 9435c3260a
release-1.1
Megvii Engine Team 4 years ago
parent
commit
5b2f0129cb
7 changed files with 82 additions and 26 deletions
  1. +43
    -15
      imperative/python/megengine/functional/loss.py
  2. +1
    -1
      imperative/python/test/integration/test_converge.py
  3. +1
    -1
      imperative/python/test/integration/test_correctness.py
  4. +1
    -1
      imperative/python/test/integration/test_dp_correctness.py
  5. +2
    -2
      imperative/python/test/integration/test_trace_dump.py
  6. +11
    -2
      imperative/python/test/unit/functional/test_functional.py
  7. +23
    -4
      imperative/python/test/unit/functional/test_loss.py

+ 43
- 15
imperative/python/megengine/functional/loss.py View File

@@ -11,13 +11,13 @@ import numpy as np
from ..core.tensor.utils import make_shape_tuple from ..core.tensor.utils import make_shape_tuple
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import abs, equal, exp, log, maximum, pow, relu from .elemwise import abs, equal, exp, log, maximum, pow, relu
from .nn import indexing_one_hot
from .nn import indexing_one_hot, logsigmoid, logsoftmax
from .tensor import where from .tensor import where


__all__ = [ __all__ = [
"l1_loss", "l1_loss",
"square_loss", "square_loss",
"cross_entropy_with_softmax",
"cross_entropy",
"binary_cross_entropy", "binary_cross_entropy",
"hinge_loss", "hinge_loss",
] ]
@@ -120,10 +120,16 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor:
return (diff ** 2).mean() return (diff ** 2).mean()




def cross_entropy_with_softmax(
pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0
def cross_entropy(
pred: Tensor,
label: Tensor,
axis: int = 1,
with_logits: bool = True,
label_smooth: float = 0,
) -> Tensor: ) -> Tensor:
r"""Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`.
r"""Compute the multi-class cross entropy loss (using logits by default).

By default, prediction is assumed to be logits, whose softmax gives probabilities.


It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`. It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`.


@@ -137,6 +143,7 @@ def cross_entropy_with_softmax(
:param pred: input tensor representing the predicted probability. :param pred: input tensor representing the predicted probability.
:param label: input tensor representing the classification label. :param label: input tensor representing the classification label.
:param axis: an axis along which softmax will be applied. Default: 1 :param axis: an axis along which softmax will be applied. Default: 1
:param with_logits: whether to apply softmax first. Default: True
:param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0 :param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0
:return: loss value. :return: loss value.


@@ -150,9 +157,9 @@ def cross_entropy_with_softmax(


data_shape = (1, 2) data_shape = (1, 2)
label_shape = (1, ) label_shape = (1, )
pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape))
pred = tensor(np.array([0, 0], dtype=np.float32).reshape(data_shape))
label = tensor(np.ones(label_shape, dtype=np.int32)) label = tensor(np.ones(label_shape, dtype=np.int32))
loss = F.cross_entropy_with_softmax(pred, label)
loss = F.cross_entropy(pred, label)
print(loss.numpy()) print(loss.numpy())


Outputs: Outputs:
@@ -170,26 +177,43 @@ def cross_entropy_with_softmax(
) )


num_classes = pred.shape[axis] num_classes = pred.shape[axis]
no_label_smooth = (
label_smooth is None or type(label_smooth) in (int, float) and label_smooth == 0
)

if not with_logits:
if no_label_smooth:
return -log(indexing_one_hot(pred, label, axis)).mean()
pred = log(pred)
return (
label_smooth * pred.mean()
- (1 - label_smooth) * indexing_one_hot(pred, label, axis).mean()
)


# Denominator of the softmax # Denominator of the softmax
offset = pred.max(axis=axis, keepdims=True).detach()
offset = pred.detach().max(axis=axis, keepdims=True)
pred = pred - offset pred = pred - offset
down = exp(pred).sum(axis=axis, keepdims=True)
down = log(exp(pred).sum(axis=axis, keepdims=True))


up = indexing_one_hot(pred, label, axis) up = indexing_one_hot(pred, label, axis)


if label_smooth != 0:
if not no_label_smooth:
factor = label_smooth / num_classes factor = label_smooth / num_classes
up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor


return (log(down) - up).mean()
return (down - up).mean()




def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
r"""Function that measures the Binary Cross Entropy between the target and the prediction.
def binary_cross_entropy(
pred: Tensor, label: Tensor, with_logits: bool = True
) -> Tensor:
r"""Compute the binary cross entropy loss (using logits by default).

By default, prediction is assumed to be logits, whose sigmoid gives probabilities.


:param pred: `(N, *)`, where `*` means any number of additional dimensions. :param pred: `(N, *)`, where `*` means any number of additional dimensions.
:param label: `(N, *)`, same shape as the input. :param label: `(N, *)`, same shape as the input.
:param with_logits: bool, whether to apply sigmoid first. Default: True
:return: loss value. :return: loss value.


Examples: Examples:
@@ -200,7 +224,7 @@ def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
from megengine import tensor from megengine import tensor
import megengine.functional as F import megengine.functional as F


pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(1, 2))
pred = tensor(np.array([0, 0], dtype=np.float32).reshape(1, 2))
label = tensor(np.ones((1, 2), dtype=np.float32)) label = tensor(np.ones((1, 2), dtype=np.float32))
loss = F.binary_cross_entropy(pred, label) loss = F.binary_cross_entropy(pred, label)
print(loss.numpy()) print(loss.numpy())
@@ -212,7 +236,11 @@ def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
[0.6931] [0.6931]


""" """
return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean()
if not with_logits:
return -(label * log(pred) + (1 - label) * log(1 - pred)).mean()
# logsigmoid(pred) and logsigmoid(-pred) has common sub-expression
# hopefully the backend would optimize this
return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred)).mean()




def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:


+ 1
- 1
imperative/python/test/integration/test_converge.py View File

@@ -80,7 +80,7 @@ def test_training_converge():
def train(data, label): def train(data, label):
with gm: with gm:
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
loss = F.cross_entropy(pred, label)
gm.backward(loss) gm.backward(loss)
return loss return loss




+ 1
- 1
imperative/python/test/integration/test_correctness.py View File

@@ -92,7 +92,7 @@ class MnistNet(Module):
def train(data, label, net, opt, gm): def train(data, label, net, opt, gm):
with gm: with gm:
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
loss = F.cross_entropy(pred, label)
gm.backward(loss) gm.backward(loss)
return loss return loss




+ 1
- 1
imperative/python/test/integration/test_dp_correctness.py View File

@@ -98,7 +98,7 @@ def train(data, label, net, opt, gm):
opt.clear_grad() opt.clear_grad()
with gm: with gm:
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
loss = F.cross_entropy(pred, label)
gm.backward(loss) gm.backward(loss)
opt.step() opt.step()
return loss return loss


+ 2
- 2
imperative/python/test/integration/test_trace_dump.py View File

@@ -72,7 +72,7 @@ def test_xornet_trace_dump():
with gm: with gm:
net.train() net.train()
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
loss = F.cross_entropy(pred, label)
gm.backward(loss) gm.backward(loss)
return pred, loss return pred, loss


@@ -80,7 +80,7 @@ def test_xornet_trace_dump():
def val_fun(data, label): def val_fun(data, label):
net.eval() net.eval()
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
loss = F.cross_entropy(pred, label)
return pred, loss return pred, loss


@trace(symbolic=True, capture_as_const=True) @trace(symbolic=True, capture_as_const=True)


+ 11
- 2
imperative/python/test/unit/functional/test_functional.py View File

@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import itertools import itertools
from functools import partial


import numpy as np import numpy as np
import pytest import pytest
@@ -303,12 +304,12 @@ def test_binary_cross_entropy():
np.testing.assert_allclose(x.numpy(), y, atol=5e-4) np.testing.assert_allclose(x.numpy(), y, atol=5e-4)


np.random.seed(123) np.random.seed(123)
data1 = sigmoid(np.random.uniform(size=data1_shape).astype(np.float32))
data1 = np.random.uniform(size=data1_shape).astype(np.float32)
label1 = np.random.uniform(size=label1_shape).astype(np.float32) label1 = np.random.uniform(size=label1_shape).astype(np.float32)
expect1 = np.array([0.6361], dtype=np.float32) expect1 = np.array([0.6361], dtype=np.float32)


np.random.seed(123) np.random.seed(123)
data2 = sigmoid(np.random.uniform(size=data2_shape).astype(np.float32))
data2 = np.random.uniform(size=data2_shape).astype(np.float32)
label2 = np.random.uniform(size=label2_shape).astype(np.float32) label2 = np.random.uniform(size=label2_shape).astype(np.float32)
expect2 = np.array([0.6750], dtype=np.float32) expect2 = np.array([0.6750], dtype=np.float32)


@@ -318,6 +319,14 @@ def test_binary_cross_entropy():
] ]
opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn) opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn)


cases = [
{"input": [sigmoid(data1), label1], "output": expect1,},
{"input": [sigmoid(data2), label2], "output": expect2,},
]
opr_test(
cases, partial(F.binary_cross_entropy, with_logits=False), compare_fn=compare_fn
)



def test_hinge_loss(): def test_hinge_loss():
np.random.seed(123) np.random.seed(123)


+ 23
- 4
imperative/python/test/unit/functional/test_loss.py View File

@@ -12,15 +12,34 @@ import megengine.functional as F
from megengine import tensor from megengine import tensor




def test_cross_entropy_with_softmax():
def test_cross_entropy_with_logits():
data = tensor([1, 100]).astype(np.float32).reshape((1, 2)) data = tensor([1, 100]).astype(np.float32).reshape((1, 2))
label = tensor([1]).astype(np.int32) label = tensor([1]).astype(np.int32)
loss = F.cross_entropy_with_softmax(data, label)
loss = F.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 0.0) np.testing.assert_allclose(loss.numpy(), 0.0)
label = tensor([0]).astype(np.int32) label = tensor([0]).astype(np.int32)
loss = F.cross_entropy_with_softmax(data, label)
loss = F.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 100 - 1) np.testing.assert_allclose(loss.numpy(), 100 - 1)


label = np.array([1]) label = np.array([1])
loss = F.cross_entropy_with_softmax(data, label)
loss = F.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 0.0) np.testing.assert_allclose(loss.numpy(), 0.0)


def test_cross_entropy():
def softmax(x):
x = np.exp(x)
x /= x.sum(1, keepdims=True)
return x

def ref(x, y):
return np.mean([-np.log(x[i, y[i]]) for i in range(len(y))])

x = (np.random.rand(5, 10) - 0.5) * 4
y = np.random.randint(10, size=(5,))
for i in range(len(x)):
x[i, y[i]] += np.random.rand() * 2
x = softmax(x)
l_ref = ref(x, y)
l = F.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False)
np.testing.assert_allclose(l.numpy(), l_ref)

Loading…
Cancel
Save