- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import numpy as np
- import pytest
-
- import megengine.functional as F
- from megengine import tensor
-
-
- def test_cross_entropy_with_logits():
- data = tensor([[0, 50], [0, -150]]).astype(np.float32)
- label = tensor([1, 0]).astype(np.int32)
- loss = F.nn.cross_entropy(data, label)
- np.testing.assert_allclose(loss.numpy(), 0.0)
- label = tensor([0, 1]).astype(np.int32)
- loss = F.nn.cross_entropy(data, label)
- np.testing.assert_allclose(loss.numpy(), 100)
-
- label = np.array([1, 0])
- loss = F.nn.cross_entropy(data, label)
- np.testing.assert_allclose(loss.numpy(), 0.0)
-
-
- def test_cross_entropy():
- def softmax(x):
- x = np.exp(x)
- x /= x.sum(1, keepdims=True)
- return x
-
- def ref(x, y):
- return np.mean([-np.log(x[i, y[i]]) for i in range(len(y))])
-
- x = (np.random.rand(5, 10) - 0.5) * 4
- y = np.random.randint(10, size=(5,))
- for i in range(len(x)):
- x[i, y[i]] += np.random.rand() * 2
- x = softmax(x)
- l_ref = ref(x, y)
- l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False)
- np.testing.assert_allclose(l.numpy(), l_ref)
-
-
- def test_cross_entropy_reduction():
- logits = np.random.randn(16, 10)
- label = np.random.randint(10, size=[16])
- logits = tensor(logits, dtype="float32")
- label = tensor(label, dtype="int32")
-
- perm = np.random.permutation(16)
- logits_perm = tensor(logits[perm], dtype="float32")
- label_perm = tensor(label[perm], dtype="int32")
-
- loss = F.nn.cross_entropy(logits, label, reduction="none")
- loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none")
- np.testing.assert_allclose(loss.numpy()[perm], loss_perm.numpy())
-
- loss_sum = F.nn.cross_entropy(logits, label, reduction="sum")
- np.testing.assert_allclose(loss.numpy().sum(), loss_sum.numpy(), rtol=2e-7)
-
- loss_mean = F.nn.cross_entropy(logits, label, reduction="mean")
- np.testing.assert_allclose(loss_mean.numpy(), loss_sum.numpy() / 16)
-
- loss_ls = F.nn.cross_entropy(logits, label, reduction="mean", label_smooth=0.1)
- loss_ls_none_reduce = F.nn.cross_entropy(
- logits, label, reduction="none", label_smooth=0.1
- )
- np.testing.assert_allclose(
- loss_ls.numpy(), loss_ls_none_reduce.numpy().mean(), rtol=2e-7
- )
-
- with pytest.raises(ValueError):
- F.nn.cross_entropy(logits, label, reduction="MEAN")
-
- with pytest.raises(ValueError):
- F.nn.cross_entropy(logits, label, reduction="max")
|