You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_loss.py 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine.functional as F
  12. from megengine import tensor
  13. def test_cross_entropy_with_logits():
  14. data = tensor([[0, 50], [0, -150]]).astype(np.float32)
  15. label = tensor([1, 0]).astype(np.int32)
  16. loss = F.nn.cross_entropy(data, label)
  17. np.testing.assert_allclose(loss.numpy(), 0.0)
  18. label = tensor([0, 1]).astype(np.int32)
  19. loss = F.nn.cross_entropy(data, label)
  20. np.testing.assert_allclose(loss.numpy(), 100)
  21. label = np.array([1, 0])
  22. loss = F.nn.cross_entropy(data, label)
  23. np.testing.assert_allclose(loss.numpy(), 0.0)
  24. def test_cross_entropy():
  25. def softmax(x):
  26. x = np.exp(x)
  27. x /= x.sum(1, keepdims=True)
  28. return x
  29. def ref(x, y):
  30. return np.mean([-np.log(x[i, y[i]]) for i in range(len(y))])
  31. x = (np.random.rand(5, 10) - 0.5) * 4
  32. y = np.random.randint(10, size=(5,))
  33. for i in range(len(x)):
  34. x[i, y[i]] += np.random.rand() * 2
  35. x = softmax(x)
  36. l_ref = ref(x, y)
  37. l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False)
  38. np.testing.assert_allclose(l.numpy(), l_ref)
  39. def test_cross_entropy_reduction():
  40. logits = np.random.randn(16, 10)
  41. label = np.random.randint(10, size=[16])
  42. logits = tensor(logits, dtype="float32")
  43. label = tensor(label, dtype="int32")
  44. perm = np.random.permutation(16)
  45. logits_perm = tensor(logits[perm], dtype="float32")
  46. label_perm = tensor(label[perm], dtype="int32")
  47. loss = F.nn.cross_entropy(logits, label, reduction="none")
  48. loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none")
  49. np.testing.assert_allclose(loss.numpy()[perm], loss_perm.numpy())
  50. loss_sum = F.nn.cross_entropy(logits, label, reduction="sum")
  51. np.testing.assert_allclose(loss.numpy().sum(), loss_sum.numpy(), rtol=2e-7)
  52. loss_mean = F.nn.cross_entropy(logits, label, reduction="mean")
  53. np.testing.assert_allclose(loss_mean.numpy(), loss_sum.numpy() / 16)
  54. loss_ls = F.nn.cross_entropy(logits, label, reduction="mean", label_smooth=0.1)
  55. loss_ls_none_reduce = F.nn.cross_entropy(
  56. logits, label, reduction="none", label_smooth=0.1
  57. )
  58. np.testing.assert_allclose(
  59. loss_ls.numpy(), loss_ls_none_reduce.numpy().mean(), rtol=2e-7
  60. )
  61. with pytest.raises(ValueError):
  62. F.nn.cross_entropy(logits, label, reduction="MEAN")
  63. with pytest.raises(ValueError):
  64. F.nn.cross_entropy(logits, label, reduction="max")

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台