From 5b2f0129cbe4ddafbfc28b6df3b2ee9fc2d8b934 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 10 Oct 2020 19:16:33 +0800 Subject: [PATCH] feat(mge): rename cross_entropy_with_softmax -> cross_entropy GitOrigin-RevId: 9435c3260a773d3734a938ff2ae62ed89d24b36c --- imperative/python/megengine/functional/loss.py | 58 ++++++++++++++++------ .../python/test/integration/test_converge.py | 2 +- .../python/test/integration/test_correctness.py | 2 +- .../python/test/integration/test_dp_correctness.py | 2 +- .../python/test/integration/test_trace_dump.py | 4 +- .../python/test/unit/functional/test_functional.py | 13 ++++- .../python/test/unit/functional/test_loss.py | 27 ++++++++-- 7 files changed, 82 insertions(+), 26 deletions(-) diff --git a/imperative/python/megengine/functional/loss.py b/imperative/python/megengine/functional/loss.py index dd69c6b2..e5e1b6a5 100644 --- a/imperative/python/megengine/functional/loss.py +++ b/imperative/python/megengine/functional/loss.py @@ -11,13 +11,13 @@ import numpy as np from ..core.tensor.utils import make_shape_tuple from ..tensor import Tensor from .elemwise import abs, equal, exp, log, maximum, pow, relu -from .nn import indexing_one_hot +from .nn import indexing_one_hot, logsigmoid, logsoftmax from .tensor import where __all__ = [ "l1_loss", "square_loss", - "cross_entropy_with_softmax", + "cross_entropy", "binary_cross_entropy", "hinge_loss", ] @@ -120,10 +120,16 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor: return (diff ** 2).mean() -def cross_entropy_with_softmax( - pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0 +def cross_entropy( + pred: Tensor, + label: Tensor, + axis: int = 1, + with_logits: bool = True, + label_smooth: float = 0, ) -> Tensor: - r"""Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`. + r"""Compute the multi-class cross entropy loss (using logits by default). + + By default, prediction is assumed to be logits, whose softmax gives probabilities. It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`. @@ -137,6 +143,7 @@ def cross_entropy_with_softmax( :param pred: input tensor representing the predicted probability. :param label: input tensor representing the classification label. :param axis: an axis along which softmax will be applied. Default: 1 + :param with_logits: whether to apply softmax first. Default: True :param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0 :return: loss value. @@ -150,9 +157,9 @@ def cross_entropy_with_softmax( data_shape = (1, 2) label_shape = (1, ) - pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape)) + pred = tensor(np.array([0, 0], dtype=np.float32).reshape(data_shape)) label = tensor(np.ones(label_shape, dtype=np.int32)) - loss = F.cross_entropy_with_softmax(pred, label) + loss = F.cross_entropy(pred, label) print(loss.numpy()) Outputs: @@ -170,26 +177,43 @@ def cross_entropy_with_softmax( ) num_classes = pred.shape[axis] + no_label_smooth = ( + label_smooth is None or type(label_smooth) in (int, float) and label_smooth == 0 + ) + + if not with_logits: + if no_label_smooth: + return -log(indexing_one_hot(pred, label, axis)).mean() + pred = log(pred) + return ( + label_smooth * pred.mean() + - (1 - label_smooth) * indexing_one_hot(pred, label, axis).mean() + ) # Denominator of the softmax - offset = pred.max(axis=axis, keepdims=True).detach() + offset = pred.detach().max(axis=axis, keepdims=True) pred = pred - offset - down = exp(pred).sum(axis=axis, keepdims=True) + down = log(exp(pred).sum(axis=axis, keepdims=True)) up = indexing_one_hot(pred, label, axis) - if label_smooth != 0: + if not no_label_smooth: factor = label_smooth / num_classes up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor - return (log(down) - up).mean() + return (down - up).mean() -def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: - r"""Function that measures the Binary Cross Entropy between the target and the prediction. +def binary_cross_entropy( + pred: Tensor, label: Tensor, with_logits: bool = True +) -> Tensor: + r"""Compute the binary cross entropy loss (using logits by default). + + By default, prediction is assumed to be logits, whose sigmoid gives probabilities. :param pred: `(N, *)`, where `*` means any number of additional dimensions. :param label: `(N, *)`, same shape as the input. + :param with_logits: bool, whether to apply sigmoid first. Default: True :return: loss value. Examples: @@ -200,7 +224,7 @@ def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: from megengine import tensor import megengine.functional as F - pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(1, 2)) + pred = tensor(np.array([0, 0], dtype=np.float32).reshape(1, 2)) label = tensor(np.ones((1, 2), dtype=np.float32)) loss = F.binary_cross_entropy(pred, label) print(loss.numpy()) @@ -212,7 +236,11 @@ def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: [0.6931] """ - return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean() + if not with_logits: + return -(label * log(pred) + (1 - label) * log(1 - pred)).mean() + # logsigmoid(pred) and logsigmoid(-pred) has common sub-expression + # hopefully the backend would optimize this + return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred)).mean() def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: diff --git a/imperative/python/test/integration/test_converge.py b/imperative/python/test/integration/test_converge.py index 1beded21..0815ca63 100644 --- a/imperative/python/test/integration/test_converge.py +++ b/imperative/python/test/integration/test_converge.py @@ -80,7 +80,7 @@ def test_training_converge(): def train(data, label): with gm: pred = net(data) - loss = F.cross_entropy_with_softmax(pred, label) + loss = F.cross_entropy(pred, label) gm.backward(loss) return loss diff --git a/imperative/python/test/integration/test_correctness.py b/imperative/python/test/integration/test_correctness.py index 67e69dcf..9dd31cc0 100644 --- a/imperative/python/test/integration/test_correctness.py +++ b/imperative/python/test/integration/test_correctness.py @@ -92,7 +92,7 @@ class MnistNet(Module): def train(data, label, net, opt, gm): with gm: pred = net(data) - loss = F.cross_entropy_with_softmax(pred, label) + loss = F.cross_entropy(pred, label) gm.backward(loss) return loss diff --git a/imperative/python/test/integration/test_dp_correctness.py b/imperative/python/test/integration/test_dp_correctness.py index fcd8eaf2..c4774272 100644 --- a/imperative/python/test/integration/test_dp_correctness.py +++ b/imperative/python/test/integration/test_dp_correctness.py @@ -98,7 +98,7 @@ def train(data, label, net, opt, gm): opt.clear_grad() with gm: pred = net(data) - loss = F.cross_entropy_with_softmax(pred, label) + loss = F.cross_entropy(pred, label) gm.backward(loss) opt.step() return loss diff --git a/imperative/python/test/integration/test_trace_dump.py b/imperative/python/test/integration/test_trace_dump.py index f6e27398..742803b4 100644 --- a/imperative/python/test/integration/test_trace_dump.py +++ b/imperative/python/test/integration/test_trace_dump.py @@ -72,7 +72,7 @@ def test_xornet_trace_dump(): with gm: net.train() pred = net(data) - loss = F.cross_entropy_with_softmax(pred, label) + loss = F.cross_entropy(pred, label) gm.backward(loss) return pred, loss @@ -80,7 +80,7 @@ def test_xornet_trace_dump(): def val_fun(data, label): net.eval() pred = net(data) - loss = F.cross_entropy_with_softmax(pred, label) + loss = F.cross_entropy(pred, label) return pred, loss @trace(symbolic=True, capture_as_const=True) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 9165c38b..09ccdc3b 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import itertools +from functools import partial import numpy as np import pytest @@ -303,12 +304,12 @@ def test_binary_cross_entropy(): np.testing.assert_allclose(x.numpy(), y, atol=5e-4) np.random.seed(123) - data1 = sigmoid(np.random.uniform(size=data1_shape).astype(np.float32)) + data1 = np.random.uniform(size=data1_shape).astype(np.float32) label1 = np.random.uniform(size=label1_shape).astype(np.float32) expect1 = np.array([0.6361], dtype=np.float32) np.random.seed(123) - data2 = sigmoid(np.random.uniform(size=data2_shape).astype(np.float32)) + data2 = np.random.uniform(size=data2_shape).astype(np.float32) label2 = np.random.uniform(size=label2_shape).astype(np.float32) expect2 = np.array([0.6750], dtype=np.float32) @@ -318,6 +319,14 @@ def test_binary_cross_entropy(): ] opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn) + cases = [ + {"input": [sigmoid(data1), label1], "output": expect1,}, + {"input": [sigmoid(data2), label2], "output": expect2,}, + ] + opr_test( + cases, partial(F.binary_cross_entropy, with_logits=False), compare_fn=compare_fn + ) + def test_hinge_loss(): np.random.seed(123) diff --git a/imperative/python/test/unit/functional/test_loss.py b/imperative/python/test/unit/functional/test_loss.py index 50b9a7dc..8464a3a2 100644 --- a/imperative/python/test/unit/functional/test_loss.py +++ b/imperative/python/test/unit/functional/test_loss.py @@ -12,15 +12,34 @@ import megengine.functional as F from megengine import tensor -def test_cross_entropy_with_softmax(): +def test_cross_entropy_with_logits(): data = tensor([1, 100]).astype(np.float32).reshape((1, 2)) label = tensor([1]).astype(np.int32) - loss = F.cross_entropy_with_softmax(data, label) + loss = F.cross_entropy(data, label) np.testing.assert_allclose(loss.numpy(), 0.0) label = tensor([0]).astype(np.int32) - loss = F.cross_entropy_with_softmax(data, label) + loss = F.cross_entropy(data, label) np.testing.assert_allclose(loss.numpy(), 100 - 1) label = np.array([1]) - loss = F.cross_entropy_with_softmax(data, label) + loss = F.cross_entropy(data, label) np.testing.assert_allclose(loss.numpy(), 0.0) + + +def test_cross_entropy(): + def softmax(x): + x = np.exp(x) + x /= x.sum(1, keepdims=True) + return x + + def ref(x, y): + return np.mean([-np.log(x[i, y[i]]) for i in range(len(y))]) + + x = (np.random.rand(5, 10) - 0.5) * 4 + y = np.random.randint(10, size=(5,)) + for i in range(len(x)): + x[i, y[i]] += np.random.rand() * 2 + x = softmax(x) + l_ref = ref(x, y) + l = F.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False) + np.testing.assert_allclose(l.numpy(), l_ref)