You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import numpy as np
  11. from ..core.tensor.array_method import _reduce
  12. from ..tensor import Tensor
  13. from .elemwise import abs, log
  14. from .nn import indexing_one_hot, logsigmoid, logsumexp, relu
  15. from .tensor import where
  16. __all__ = [
  17. "l1_loss",
  18. "square_loss",
  19. "cross_entropy",
  20. "binary_cross_entropy",
  21. "hinge_loss",
  22. ]
  23. def _reduce_output(loss_fn):
  24. r"""
  25. Wrapper to apply canonical reductions to loss outputs.
  26. """
  27. @functools.wraps(loss_fn)
  28. def reduced_loss_fn(*args, reduction="mean", **kwargs):
  29. loss = loss_fn(*args, **kwargs)
  30. if reduction == "none":
  31. return loss
  32. elif reduction in ("mean", "sum"):
  33. return _reduce(reduction)(loss)
  34. else:
  35. raise ValueError("{} is not a valid value for reduction".format(reduction))
  36. return reduced_loss_fn
  37. @_reduce_output
  38. def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
  39. r"""
  40. Calculates the mean absolute error (MAE) between
  41. each element in the pred :math:`x` and label :math:`y`.
  42. The mean absolute error can be described as:
  43. .. math:: \ell(x,y) = mean\left(L \right)
  44. where
  45. .. math::
  46. L = \{l_1,\dots,l_N\}, \quad
  47. l_n = \left| x_n - y_n \right|,
  48. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  49. of :math:`N` elements each. :math:`N` is the batch size.
  50. :param pred: predicted result from model.
  51. :param label: ground truth to compare.
  52. :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  53. :return: loss value.
  54. Examples:
  55. .. testcode::
  56. import numpy as np
  57. import megengine as mge
  58. import megengine.functional as F
  59. ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
  60. tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
  61. loss = F.nn.l1_loss(ipt, tgt)
  62. print(loss.numpy())
  63. Outputs:
  64. .. testoutput::
  65. 2.75
  66. """
  67. diff = pred - label
  68. return abs(diff)
  69. @_reduce_output
  70. def square_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
  71. r"""
  72. Calculates the mean squared error (squared L2 norm) between
  73. each element in the pred :math:`x` and label :math:`y`.
  74. The mean squared error can be described as:
  75. .. math:: \ell(x, y) = mean\left( L \right)
  76. where
  77. .. math::
  78. L = \{l_1,\dots,l_N\}, \quad
  79. l_n = \left( x_n - y_n \right)^2,
  80. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  81. of :math:`N` elements each. :math:`N` is the batch size.
  82. :param pred: predicted result from model.
  83. :param label: ground truth to compare.
  84. :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  85. :return: loss value.
  86. Shape:
  87. - pred: :math:`(N, *)` where :math:`*` means any number of additional
  88. dimensions.
  89. - label: :math:`(N, *)`. Same shape as ``pred``.
  90. Examples:
  91. .. testcode::
  92. import numpy as np
  93. import megengine as mge
  94. import megengine.functional as F
  95. ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
  96. tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
  97. loss = F.nn.square_loss(ipt, tgt)
  98. print(loss.numpy())
  99. Outputs:
  100. .. testoutput::
  101. 9.75
  102. """
  103. diff = pred - label
  104. return diff ** 2
  105. @_reduce_output
  106. def cross_entropy(
  107. pred: Tensor,
  108. label: Tensor,
  109. axis: int = 1,
  110. with_logits: bool = True,
  111. label_smooth: float = 0,
  112. reduction: str = "mean",
  113. ) -> Tensor:
  114. r"""
  115. Computes the multi-class cross entropy loss (using logits by default).
  116. By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
  117. class probabilities are given by softmax.
  118. It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`.
  119. When using label smoothing, the label distribution is as follows:
  120. .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
  121. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
  122. k is the index of label distribution. :math:`\alpha` is ``label_smooth`` and :math:`K` is the number of classes.
  123. :param pred: input tensor representing the predicted probability.
  124. :param label: input tensor representing the classification label.
  125. :param axis: an axis along which softmax will be applied. Default: 1
  126. :param with_logits: whether to apply softmax first. Default: True
  127. :param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0
  128. :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  129. :return: loss value.
  130. Examples:
  131. .. testcode::
  132. import numpy as np
  133. from megengine import tensor
  134. import megengine.functional as F
  135. data_shape = (1, 2)
  136. label_shape = (1, )
  137. pred = tensor(np.array([0, 0], dtype=np.float32).reshape(data_shape))
  138. label = tensor(np.ones(label_shape, dtype=np.int32))
  139. loss = F.nn.cross_entropy(pred, label)
  140. print(loss.numpy().round(decimals=4))
  141. Outputs:
  142. .. testoutput::
  143. 0.6931
  144. """
  145. n0 = pred.ndim
  146. n1 = label.ndim
  147. assert n0 == n1 + 1, (
  148. "target ndim must be one less than input ndim; input_ndim={} "
  149. "target_ndim={}".format(n0, n1)
  150. )
  151. ls = label_smooth
  152. if with_logits:
  153. logZ = logsumexp(pred, axis)
  154. primary_term = indexing_one_hot(pred, label, axis)
  155. else:
  156. logZ = 0
  157. primary_term = log(indexing_one_hot(pred, label, axis))
  158. if ls is None or type(ls) in (int, float) and ls == 0:
  159. return logZ - primary_term
  160. if not with_logits:
  161. pred = log(pred)
  162. return logZ - ls * pred.mean(axis) - (1 - ls) * primary_term
  163. @_reduce_output
  164. def binary_cross_entropy(
  165. pred: Tensor, label: Tensor, with_logits: bool = True, reduction: str = "mean",
  166. ) -> Tensor:
  167. r"""
  168. Computes the binary cross entropy loss (using logits by default).
  169. By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
  170. class probabilities are given by sigmoid.
  171. :param pred: `(N, *)`, where `*` means any number of additional dimensions.
  172. :param label: `(N, *)`, same shape as the input.
  173. :param with_logits: bool, whether to apply sigmoid first. Default: True
  174. :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  175. :return: loss value.
  176. Examples:
  177. .. testcode::
  178. import numpy as np
  179. from megengine import tensor
  180. import megengine.functional as F
  181. pred = tensor(np.array([0, 0], dtype=np.float32).reshape(1, 2))
  182. label = tensor(np.ones((1, 2), dtype=np.float32))
  183. loss = F.nn.binary_cross_entropy(pred, label)
  184. print(loss.numpy().round(decimals=4))
  185. Outputs:
  186. .. testoutput::
  187. 0.6931
  188. """
  189. if not with_logits:
  190. return -(label * log(pred) + (1 - label) * log(1 - pred))
  191. # logsigmoid(pred) and logsigmoid(-pred) has common sub-expression
  192. # hopefully the backend would optimize this
  193. return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred))
  194. @_reduce_output
  195. def hinge_loss(
  196. pred: Tensor, label: Tensor, norm: str = "L1", reduction: str = "mean"
  197. ) -> Tensor:
  198. r"""
  199. Caculates the hinge loss which is often used in SVM.
  200. The hinge loss can be described as:
  201. .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij}))
  202. :param pred: input tensor representing the predicted probability, shape is `(N, C)`.
  203. :param label: input tensor representing the binary classification label, shape is `(N, C)`.
  204. :param norm: specify the norm to caculate the loss, should be "L1" or "L2".
  205. :param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  206. :return: loss value.
  207. Examples:
  208. .. testcode::
  209. from megengine import tensor
  210. import megengine.functional as F
  211. pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32")
  212. label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32")
  213. loss = F.nn.hinge_loss(pred, label)
  214. print(loss.numpy())
  215. Outputs:
  216. .. testoutput::
  217. 1.5
  218. """
  219. norm = norm.upper()
  220. assert norm in ["L1", "L2"], "norm must be L1 or L2"
  221. # Converts binary labels to -1/1 labels.
  222. loss = relu(1.0 - pred * label)
  223. if norm == "L1":
  224. return loss.sum(axis=1)
  225. else:
  226. return (loss ** 2).sum(axis=1)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台