You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. from ..tensor import Tensor
  11. from .elemwise import abs, log
  12. from .nn import indexing_one_hot, logsigmoid, logsumexp, relu
  13. from .tensor import where
  14. __all__ = [
  15. "l1_loss",
  16. "square_loss",
  17. "cross_entropy",
  18. "binary_cross_entropy",
  19. "hinge_loss",
  20. ]
  21. def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
  22. r"""
  23. Calculates the mean absolute error (MAE) between
  24. each element in the pred :math:`x` and label :math:`y`.
  25. The mean absolute error can be described as:
  26. .. math:: \ell(x,y) = mean\left(L \right)
  27. where
  28. .. math::
  29. L = \{l_1,\dots,l_N\}, \quad
  30. l_n = \left| x_n - y_n \right|,
  31. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  32. of :math:`N` elements each. :math:`N` is the batch size.
  33. :param pred: predicted result from model.
  34. :param label: ground truth to compare.
  35. :return: loss value.
  36. Examples:
  37. .. testcode::
  38. import numpy as np
  39. import megengine as mge
  40. import megengine.functional as F
  41. ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
  42. tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
  43. loss = F.nn.l1_loss(ipt, tgt)
  44. print(loss.numpy())
  45. Outputs:
  46. .. testoutput::
  47. 2.75
  48. """
  49. diff = pred - label
  50. return abs(diff).mean()
  51. def square_loss(pred: Tensor, label: Tensor) -> Tensor:
  52. r"""
  53. Calculates the mean squared error (squared L2 norm) between
  54. each element in the pred :math:`x` and label :math:`y`.
  55. The mean squared error can be described as:
  56. .. math:: \ell(x, y) = mean\left( L \right)
  57. where
  58. .. math::
  59. L = \{l_1,\dots,l_N\}, \quad
  60. l_n = \left( x_n - y_n \right)^2,
  61. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  62. of :math:`N` elements each. :math:`N` is the batch size.
  63. :param pred: predicted result from model.
  64. :param label: ground truth to compare.
  65. :return: loss value.
  66. Shape:
  67. - pred: :math:`(N, *)` where :math:`*` means any number of additional
  68. dimensions.
  69. - label: :math:`(N, *)`. Same shape as ``pred``.
  70. Examples:
  71. .. testcode::
  72. import numpy as np
  73. import megengine as mge
  74. import megengine.functional as F
  75. ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
  76. tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
  77. loss = F.nn.square_loss(ipt, tgt)
  78. print(loss.numpy())
  79. Outputs:
  80. .. testoutput::
  81. 9.75
  82. """
  83. diff = pred - label
  84. return (diff ** 2).mean()
  85. def cross_entropy(
  86. pred: Tensor,
  87. label: Tensor,
  88. axis: int = 1,
  89. with_logits: bool = True,
  90. label_smooth: float = 0,
  91. ) -> Tensor:
  92. r"""
  93. Computes the multi-class cross entropy loss (using logits by default).
  94. By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
  95. class probabilities are given by softmax.
  96. It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`.
  97. When using label smoothing, the label distribution is as follows:
  98. .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
  99. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
  100. k is the index of label distribution. :math:`\alpha` is ``label_smooth`` and :math:`K` is the number of classes.
  101. :param pred: input tensor representing the predicted probability.
  102. :param label: input tensor representing the classification label.
  103. :param axis: an axis along which softmax will be applied. Default: 1
  104. :param with_logits: whether to apply softmax first. Default: True
  105. :param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0
  106. :return: loss value.
  107. Examples:
  108. .. testcode::
  109. import numpy as np
  110. from megengine import tensor
  111. import megengine.functional as F
  112. data_shape = (1, 2)
  113. label_shape = (1, )
  114. pred = tensor(np.array([0, 0], dtype=np.float32).reshape(data_shape))
  115. label = tensor(np.ones(label_shape, dtype=np.int32))
  116. loss = F.nn.cross_entropy(pred, label)
  117. print(loss.numpy().round(decimals=4))
  118. Outputs:
  119. .. testoutput::
  120. 0.6931
  121. """
  122. n0 = pred.ndim
  123. n1 = label.ndim
  124. assert n0 == n1 + 1, (
  125. "target ndim must be one less than input ndim; input_ndim={} "
  126. "target_ndim={}".format(n0, n1)
  127. )
  128. ls = label_smooth
  129. if with_logits:
  130. logZ = logsumexp(pred, axis).mean()
  131. primary_term = indexing_one_hot(pred, label, axis).mean()
  132. else:
  133. logZ = 0
  134. primary_term = log(indexing_one_hot(pred, label, axis)).mean()
  135. if ls is None or type(ls) in (int, float) and ls == 0:
  136. return logZ - primary_term
  137. if not with_logits:
  138. pred = log(pred)
  139. return logZ - ls * pred.mean() - (1 - ls) * primary_term
  140. def binary_cross_entropy(
  141. pred: Tensor, label: Tensor, with_logits: bool = True
  142. ) -> Tensor:
  143. r"""
  144. Computes the binary cross entropy loss (using logits by default).
  145. By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
  146. class probabilities are given by sigmoid.
  147. :param pred: `(N, *)`, where `*` means any number of additional dimensions.
  148. :param label: `(N, *)`, same shape as the input.
  149. :param with_logits: bool, whether to apply sigmoid first. Default: True
  150. :return: loss value.
  151. Examples:
  152. .. testcode::
  153. import numpy as np
  154. from megengine import tensor
  155. import megengine.functional as F
  156. pred = tensor(np.array([0, 0], dtype=np.float32).reshape(1, 2))
  157. label = tensor(np.ones((1, 2), dtype=np.float32))
  158. loss = F.nn.binary_cross_entropy(pred, label)
  159. print(loss.numpy().round(decimals=4))
  160. Outputs:
  161. .. testoutput::
  162. 0.6931
  163. """
  164. if not with_logits:
  165. return -(label * log(pred) + (1 - label) * log(1 - pred)).mean()
  166. # logsigmoid(pred) and logsigmoid(-pred) has common sub-expression
  167. # hopefully the backend would optimize this
  168. return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred)).mean()
  169. def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
  170. r"""
  171. Caculates the hinge loss which is often used in SVM.
  172. The hinge loss can be described as:
  173. .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij}))
  174. :param pred: input tensor representing the predicted probability, shape is `(N, C)`.
  175. :param label: input tensor representing the binary classification label, shape is `(N, C)`.
  176. :param norm: specify the norm to caculate the loss, should be "L1" or "L2".
  177. :return: loss value.
  178. Examples:
  179. .. testcode::
  180. from megengine import tensor
  181. import megengine.functional as F
  182. pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32")
  183. label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32")
  184. loss = F.nn.hinge_loss(pred, label)
  185. print(loss.numpy())
  186. Outputs:
  187. .. testoutput::
  188. 1.5
  189. """
  190. norm = norm.upper()
  191. assert norm in ["L1", "L2"], "norm must be L1 or L2"
  192. # Converts binary labels to -1/1 labels.
  193. loss = relu(1.0 - pred * label)
  194. if norm == "L1":
  195. return loss.sum(axis=1).mean()
  196. else:
  197. return (loss ** 2).sum(axis=1).mean()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台