You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import megengine._internal as mgb
  10. from ..core.tensor import Tensor
  11. from .elemwise import abs, equal, log, maximum, power
  12. from .nn import assert_equal, indexing_one_hot
  13. from .utils import zero_grad
  14. def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
  15. r"""
  16. Calculates the mean absolute error (MAE) between
  17. each element in the pred :math:`x` and label :math:`y`.
  18. The mean absolute error can be described as:
  19. .. math::
  20. \ell(x,y) = mean\left(L \right)
  21. where
  22. .. math::
  23. L = \{l_1,\dots,l_N\}, \quad
  24. l_n = \left| x_n - y_n \right|,
  25. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  26. of :math:`N` elements each. :math:`N` is the batch size.
  27. :param pred: The predicted result from model.
  28. :param label: The ground truth to compare.
  29. Shape:
  30. - pred: :math:`(N, *)` where :math:`*` means any number of additional
  31. dimensions
  32. - label: :math:`(N, *)`. Same shape as ``pred``
  33. Examples:
  34. .. testcode::
  35. import numpy as np
  36. import megengine as mge
  37. import megengine.functional as F
  38. ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
  39. tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
  40. loss = F.l1_loss(ipt,tgt)
  41. print(loss.numpy())
  42. Outputs:
  43. .. testoutput::
  44. [2.75]
  45. """
  46. diff = pred - label
  47. return abs(diff).mean()
  48. def square_loss(pred: Tensor, label: Tensor) -> Tensor:
  49. r"""
  50. Calculates the mean squared error (squared L2 norm) between
  51. each element in the pred :math:`x` and label :math:`y`.
  52. The mean squared error can be described as:
  53. .. math::
  54. \ell(x, y) = mean\left( L \right)
  55. where
  56. .. math::
  57. L = \{l_1,\dots,l_N\}, \quad
  58. l_n = \left( x_n - y_n \right)^2,
  59. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  60. of :math:`N` elements each. :math:`N` is the batch size.
  61. :param pred: The predicted result from model.
  62. :param label: The ground truth to compare.
  63. Shape:
  64. - pred: :math:`(N, *)` where :math:`*` means any number of additional
  65. dimensions
  66. - label: :math:`(N, *)`. Same shape as ``pred``
  67. """
  68. diff = pred - label
  69. return (diff ** 2).mean()
  70. def cross_entropy(
  71. inp: Tensor, target: Tensor, axis: int = 1, ignore_index: int = -1
  72. ) -> Tensor:
  73. r"""Returns the cross entropy loss in a classification problem.
  74. .. math::
  75. \textrm{CrossEntropy}(x, y) = - \sum_{i} y_i\log(x_i)
  76. :param inp: The input tensor representing the predicted probability.
  77. :param label: The input tensor representing the classification label.
  78. :param axis: An axis along which cross_entropy will be applied. Default: 1
  79. :param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. Default: -1
  80. Examples:
  81. .. testcode::
  82. import numpy as np
  83. from megengine import tensor
  84. import megengine.functional as F
  85. data_shape = (1, 2)
  86. label_shape = (1, )
  87. pred = tensor(
  88. np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape)
  89. )
  90. label = tensor(
  91. np.ones(label_shape, dtype=np.int32)
  92. )
  93. loss = F.cross_entropy(pred, label)
  94. """
  95. n0 = inp.ndim
  96. n1 = target.ndim
  97. assert n0 == n1 + 1, (
  98. "target ndim must be one less than input ndim; input_ndim={} "
  99. "target_ndim={}".format(n0, n1)
  100. )
  101. if ignore_index != -1:
  102. mask = 1 - equal(target, ignore_index)
  103. target = target * mask
  104. loss = -log(indexing_one_hot(inp, target, axis)) * mask
  105. return loss.sum() / maximum(mask.sum(), 1.0)
  106. else:
  107. return -log(indexing_one_hot(inp, target, axis)).mean()
  108. def cross_entropy_with_softmax(
  109. pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0
  110. ) -> Tensor:
  111. r"""
  112. Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`.
  113. It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`.
  114. When using label smoothing, the label distribution is as follows:
  115. .. math::
  116. y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
  117. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
  118. k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes.
  119. :param pred: The input tensor representing the predicted probability.
  120. :param label: The input tensor representing the classification label.
  121. :param axis: An axis along which softmax will be applied. Default: 1.
  122. :param label_smooth: A label smoothing of parameter that can re-distribute target distribution. Default: 0.
  123. """
  124. n0 = pred.ndim
  125. n1 = label.ndim
  126. assert n0 == n1 + 1, (
  127. "target ndim must be one less than input ndim; input_ndim={} "
  128. "target_ndim={}".format(n0, n1)
  129. )
  130. num_classes = pred.shapeof(axis)
  131. # Denominator of the softmax
  132. offset = zero_grad(pred.max(axis=axis, keepdims=True))
  133. pred = pred - offset
  134. down = mgb.opr.elem.exp(pred).sum(axis=axis, keepdims=True)
  135. up = indexing_one_hot(pred, label, axis)
  136. if label_smooth != 0:
  137. factor = label_smooth / num_classes
  138. up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor
  139. return (log(down) - up).mean()
  140. def triplet_margin_loss(
  141. anchor: Tensor, positive: Tensor, negative: Tensor, margin: float = 1.0, p: int = 2
  142. ) -> Tensor:
  143. r"""
  144. Creates a criterion that measures the triplet loss given an input tensors.
  145. .. math::
  146. L(a, p, n) = max\left\{d\left(a_{i},p_{i}\right)-d\left(a_{i}, n_{i}\right)+margin, 0\right\},\
  147. d\left(x_{i},y_{i}\right)=\left\|x_{i}-y_{i}\right\|_{p}
  148. :param anchor: The input tensor representing the anchor samples.
  149. :param positive: The input tensor representing the positive samples.
  150. :param negative: The input tensor representing the negative samples.
  151. :param margin: Default: 1.0
  152. :param p: The norm degree for pairwise distance. Default: 2.0
  153. """
  154. s0 = anchor.shapeof()
  155. s1 = positive.shapeof()
  156. s2 = negative.shapeof()
  157. assert_equal(s0, s1)
  158. assert_equal(s1, s2)
  159. n0 = anchor.ndim
  160. n1 = positive.ndim
  161. n2 = negative.ndim
  162. assert n0 == 2 and n1 == 2 and n2 == 2, (
  163. "anchor ndim, positive ndim, and negative ndim must be 2; "
  164. "anchor_ndim={} positive_ndim={} negative_ndim={}".format(n0, n1, n2)
  165. )
  166. assert p > 0, "a margin with a value greater than 0; p={}".format(p)
  167. diff0 = abs(anchor - positive)
  168. diff1 = abs(anchor - negative)
  169. d1 = power(power(diff0, p).sum(axis=1, keepdims=True), 1 / p)
  170. d2 = power(power(diff1, p).sum(axis=1, keepdims=True), 1 / p)
  171. loss = maximum(d1 - d2 + margin, 0)
  172. return loss.mean()
  173. def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
  174. r"""Function that measures the Binary Cross Entropy between the target and the prediction.
  175. :param pred: (N,*) where * means, any number of additional dimensions.
  176. :param label: (N,*), same shape as the input.
  177. """
  178. s0 = pred.shapeof()
  179. s1 = label.shapeof()
  180. assert_equal(s0, s1)
  181. return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean()
  182. def nll_loss(
  183. pred: Tensor, label: Tensor, axis: int = 1, ignore_index: int = -1
  184. ) -> Tensor:
  185. r"""
  186. The negative log likelihood loss.
  187. Shape:
  188. - pred: :math:`(N, *)` where :math:`*` means any number of additional
  189. dimensions
  190. - label: :math:`(N, *)`. Same shape as ``pred``
  191. Examples:
  192. .. testcode::
  193. import numpy as np
  194. from megengine import tensor
  195. import megengine.functional as F
  196. from megengine.test.utils import assertTensorClose
  197. data_shape = (2, 2)
  198. label_shape = (2, )
  199. data = tensor(
  200. np.array([[1, 0.5], [0.3, 1.2]], dtype=np.float32).reshape(data_shape),
  201. )
  202. label = tensor(
  203. np.ones(label_shape, dtype=np.int32)
  204. )
  205. pred = F.log(F.softmax(data))
  206. loss1 = F.nll_loss(pred, label)
  207. loss2 = F.cross_entropy_with_softmax(data, label)
  208. assertTensorClose(loss1.numpy(), loss2.numpy(), max_err=5e-6)
  209. """
  210. n0 = pred.ndim
  211. n1 = label.ndim
  212. assert n0 == n1 + 1, (
  213. "target ndim must be one less than input ndim; input_ndim={} "
  214. "target_ndim={}".format(n0, n1)
  215. )
  216. mask = 1.0 - equal(label, ignore_index)
  217. label = label * mask
  218. loss = indexing_one_hot(pred, label, axis) * mask
  219. return -1.0 * loss.sum() / maximum(mask.sum(), 1.0)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)