You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. from ..core.tensor.utils import make_shape_tuple
  11. from ..tensor import Tensor
  12. from .elemwise import abs, eq, exp, log, maximum, pow, relu
  13. from .nn import assert_equal, indexing_one_hot
  14. from .tensor import where
  15. from .utils import zero_grad
  16. __all__ = [
  17. "l1_loss",
  18. "square_loss",
  19. "cross_entropy_with_softmax",
  20. "binary_cross_entropy",
  21. "hinge_loss",
  22. ]
  23. def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
  24. r"""
  25. Calculates the mean absolute error (MAE) between
  26. each element in the pred :math:`x` and label :math:`y`.
  27. The mean absolute error can be described as:
  28. .. math:: \ell(x,y) = mean\left(L \right)
  29. where
  30. .. math::
  31. L = \{l_1,\dots,l_N\}, \quad
  32. l_n = \left| x_n - y_n \right|,
  33. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  34. of :math:`N` elements each. :math:`N` is the batch size.
  35. :param pred: The predicted result from model.
  36. :param label: The ground truth to compare.
  37. Examples:
  38. .. testcode::
  39. import numpy as np
  40. import megengine as mge
  41. import megengine.functional as F
  42. ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
  43. tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
  44. loss = F.l1_loss(ipt,tgt)
  45. print(loss.numpy())
  46. Outputs:
  47. .. testoutput::
  48. [2.75]
  49. """
  50. diff = pred - label
  51. return abs(diff).mean()
  52. def square_loss(pred: Tensor, label: Tensor) -> Tensor:
  53. r"""
  54. Calculates the mean squared error (squared L2 norm) between
  55. each element in the pred :math:`x` and label :math:`y`.
  56. The mean squared error can be described as:
  57. .. math:: \ell(x, y) = mean\left( L \right)
  58. where
  59. .. math::
  60. L = \{l_1,\dots,l_N\}, \quad
  61. l_n = \left( x_n - y_n \right)^2,
  62. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  63. of :math:`N` elements each. :math:`N` is the batch size.
  64. :param pred: The predicted result from model.
  65. :param label: The ground truth to compare.
  66. Shape:
  67. - pred: :math:`(N, *)` where :math:`*` means any number of additional
  68. dimensions
  69. - label: :math:`(N, *)`. Same shape as ``pred``
  70. """
  71. diff = pred - label
  72. return (diff ** 2).mean()
  73. def cross_entropy_with_softmax(
  74. pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0
  75. ) -> Tensor:
  76. r"""
  77. Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`.
  78. It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`.
  79. When using label smoothing, the label distribution is as follows:
  80. .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
  81. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
  82. k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes.
  83. :param pred: The input tensor representing the predicted probability.
  84. :param label: The input tensor representing the classification label.
  85. :param axis: An axis along which softmax will be applied. Default: 1.
  86. :param label_smooth: A label smoothing of parameter that can re-distribute target distribution. Default: 0.
  87. """
  88. n0 = pred.ndim
  89. n1 = label.ndim
  90. assert n0 == n1 + 1, (
  91. "target ndim must be one less than input ndim; input_ndim={} "
  92. "target_ndim={}".format(n0, n1)
  93. )
  94. num_classes = pred.shape[axis]
  95. # Denominator of the softmax
  96. offset = pred.max(axis=axis, keepdims=True).detach()
  97. pred = pred - offset
  98. down = exp(pred).sum(axis=axis, keepdims=True)
  99. up = indexing_one_hot(pred, label, axis)
  100. if label_smooth != 0:
  101. factor = label_smooth / num_classes
  102. up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor
  103. return (log(down) - up).mean()
  104. def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
  105. r"""Function that measures the Binary Cross Entropy between the target and the prediction.
  106. :param pred: (N,*) where * means, any number of additional dimensions.
  107. :param label: (N,*), same shape as the input.
  108. """
  109. assert make_shape_tuple(pred.shape) == make_shape_tuple(label.shape)
  110. return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean()
  111. def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
  112. r"""
  113. Caculate the hinge loss which is often used in SVMs.
  114. The hinge loss can be described as:
  115. .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_i_j*y_i_j))
  116. :param pred: The input tensor representing the predicted probability, shape is (N, C).
  117. :param label: The input tensor representing the binary classification label, shape is (N, C).
  118. :param norm: Specify the norm to caculate the loss, should be "L1" or "L2".
  119. Examples:
  120. .. testcode::
  121. from megengine import tensor
  122. import megengine.functional as F
  123. pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32")
  124. label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32")
  125. loss = F.hinge_loss(pred, label)
  126. print(loss.numpy())
  127. Outputs:
  128. .. testoutput::
  129. [1.5]
  130. """
  131. assert norm in ["L1", "L2"], "norm must be L1 or L2"
  132. # Converts binary labels to -1/1 labels.
  133. loss = relu(1.0 - pred * label)
  134. if norm == "L1":
  135. return loss.sum(axis=1).mean()
  136. else:
  137. return (loss ** 2).sum(axis=1).mean()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台