You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import megengine._internal as mgb
  10. from ..core.tensor import Tensor
  11. from .elemwise import abs, equal, log, maximum, power, relu
  12. from .nn import assert_equal, indexing_one_hot
  13. from .tensor import where
  14. from .utils import zero_grad
  15. def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
  16. r"""
  17. Calculates the mean absolute error (MAE) between
  18. each element in the pred :math:`x` and label :math:`y`.
  19. The mean absolute error can be described as:
  20. .. math:: \ell(x,y) = mean\left(L \right)
  21. where
  22. .. math::
  23. L = \{l_1,\dots,l_N\}, \quad
  24. l_n = \left| x_n - y_n \right|,
  25. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  26. of :math:`N` elements each. :math:`N` is the batch size.
  27. :param pred: The predicted result from model.
  28. :param label: The ground truth to compare.
  29. Examples:
  30. .. testcode::
  31. import numpy as np
  32. import megengine as mge
  33. import megengine.functional as F
  34. ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
  35. tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
  36. loss = F.l1_loss(ipt,tgt)
  37. print(loss.numpy())
  38. Outputs:
  39. .. testoutput::
  40. [2.75]
  41. """
  42. diff = pred - label
  43. return abs(diff).mean()
  44. def square_loss(pred: Tensor, label: Tensor) -> Tensor:
  45. r"""
  46. Calculates the mean squared error (squared L2 norm) between
  47. each element in the pred :math:`x` and label :math:`y`.
  48. The mean squared error can be described as:
  49. .. math:: \ell(x, y) = mean\left( L \right)
  50. where
  51. .. math::
  52. L = \{l_1,\dots,l_N\}, \quad
  53. l_n = \left( x_n - y_n \right)^2,
  54. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  55. of :math:`N` elements each. :math:`N` is the batch size.
  56. :param pred: The predicted result from model.
  57. :param label: The ground truth to compare.
  58. Shape:
  59. - pred: :math:`(N, *)` where :math:`*` means any number of additional
  60. dimensions
  61. - label: :math:`(N, *)`. Same shape as ``pred``
  62. """
  63. diff = pred - label
  64. return (diff ** 2).mean()
  65. def cross_entropy(
  66. inp: Tensor, target: Tensor, axis: int = 1, ignore_index: int = -1
  67. ) -> Tensor:
  68. r"""
  69. Returns the cross entropy loss in a classification problem.
  70. .. math:: \textrm{CrossEntropy}(x, y) = - \sum_{i} y_i\log(x_i)
  71. :param inp: The input tensor representing the predicted probability.
  72. :param label: The input tensor representing the classification label.
  73. :param axis: An axis along which cross_entropy will be applied. Default: 1
  74. :param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. Default: -1
  75. Examples:
  76. .. testcode::
  77. import numpy as np
  78. from megengine import tensor
  79. import megengine.functional as F
  80. data_shape = (1, 2)
  81. label_shape = (1, )
  82. pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape))
  83. label = tensor(np.ones(label_shape, dtype=np.int32))
  84. loss = F.cross_entropy(pred, label)
  85. print(loss.numpy())
  86. Outputs:
  87. .. testoutput::
  88. [0.69]
  89. """
  90. n0 = inp.ndim
  91. n1 = target.ndim
  92. assert n0 == n1 + 1, (
  93. "target ndim must be one less than input ndim; input_ndim={} "
  94. "target_ndim={}".format(n0, n1)
  95. )
  96. if ignore_index != -1:
  97. mask = 1 - equal(target, ignore_index)
  98. target = target * mask
  99. loss = -log(indexing_one_hot(inp, target, axis)) * mask
  100. return loss.sum() / maximum(mask.sum(), 1.0)
  101. else:
  102. return -log(indexing_one_hot(inp, target, axis)).mean()
  103. def cross_entropy_with_softmax(
  104. pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0
  105. ) -> Tensor:
  106. r"""
  107. Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`.
  108. It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`.
  109. When using label smoothing, the label distribution is as follows:
  110. .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
  111. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
  112. k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes.
  113. :param pred: The input tensor representing the predicted probability.
  114. :param label: The input tensor representing the classification label.
  115. :param axis: An axis along which softmax will be applied. Default: 1.
  116. :param label_smooth: A label smoothing of parameter that can re-distribute target distribution. Default: 0.
  117. """
  118. n0 = pred.ndim
  119. n1 = label.ndim
  120. assert n0 == n1 + 1, (
  121. "target ndim must be one less than input ndim; input_ndim={} "
  122. "target_ndim={}".format(n0, n1)
  123. )
  124. num_classes = pred.shapeof(axis)
  125. # Denominator of the softmax
  126. offset = zero_grad(pred.max(axis=axis, keepdims=True))
  127. pred = pred - offset
  128. down = mgb.opr.elem.exp(pred).sum(axis=axis, keepdims=True)
  129. up = indexing_one_hot(pred, label, axis)
  130. if label_smooth != 0:
  131. factor = label_smooth / num_classes
  132. up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor
  133. return (log(down) - up).mean()
  134. def triplet_margin_loss(
  135. anchor: Tensor, positive: Tensor, negative: Tensor, margin: float = 1.0, p: int = 2
  136. ) -> Tensor:
  137. r"""
  138. Creates a criterion that measures the triplet loss given an input tensors.
  139. .. math::
  140. L(a, p, n) = max\left\{d\left(a_{i},p_{i}\right)-d\left(a_{i}, n_{i}\right)+margin, 0\right\},\
  141. d\left(x_{i},y_{i}\right)=\left\|x_{i}-y_{i}\right\|_{p}
  142. :param anchor: The input tensor representing the anchor samples.
  143. :param positive: The input tensor representing the positive samples.
  144. :param negative: The input tensor representing the negative samples.
  145. :param margin: Default: 1.0
  146. :param p: The norm degree for pairwise distance. Default: 2.0
  147. """
  148. s0 = anchor.shapeof()
  149. s1 = positive.shapeof()
  150. s2 = negative.shapeof()
  151. assert_equal(s0, s1)
  152. assert_equal(s1, s2)
  153. n0 = anchor.ndim
  154. n1 = positive.ndim
  155. n2 = negative.ndim
  156. assert n0 == 2 and n1 == 2 and n2 == 2, (
  157. "anchor ndim, positive ndim, and negative ndim must be 2; "
  158. "anchor_ndim={} positive_ndim={} negative_ndim={}".format(n0, n1, n2)
  159. )
  160. assert p > 0, "a margin with a value greater than 0; p={}".format(p)
  161. diff0 = abs(anchor - positive)
  162. diff1 = abs(anchor - negative)
  163. d1 = power(power(diff0, p).sum(axis=1, keepdims=True), 1 / p)
  164. d2 = power(power(diff1, p).sum(axis=1, keepdims=True), 1 / p)
  165. loss = maximum(d1 - d2 + margin, 0)
  166. return loss.mean()
  167. def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
  168. r"""Function that measures the Binary Cross Entropy between the target and the prediction.
  169. :param pred: (N,*) where * means, any number of additional dimensions.
  170. :param label: (N,*), same shape as the input.
  171. """
  172. s0 = pred.shapeof()
  173. s1 = label.shapeof()
  174. assert_equal(s0, s1)
  175. return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean()
  176. def nll_loss(
  177. pred: Tensor, label: Tensor, axis: int = 1, ignore_index: int = -1
  178. ) -> Tensor:
  179. r"""
  180. The negative log likelihood loss.
  181. :param pred: The predicted result from model.
  182. :param label: The ground truth to compare.
  183. Examples:
  184. .. testcode::
  185. import numpy as np
  186. from megengine import tensor
  187. import megengine.functional as F
  188. data_shape = (2, 2)
  189. label_shape = (2, )
  190. data = tensor(
  191. np.array([[1, 0.5], [0.3, 1.2]], dtype=np.float32).reshape(data_shape),
  192. )
  193. label = tensor(
  194. np.ones(label_shape, dtype=np.int32)
  195. )
  196. pred = F.log(F.softmax(data))
  197. loss1 = F.nll_loss(pred, label)
  198. loss2 = F.cross_entropy_with_softmax(data, label)
  199. print(loss1.numpy(), loss2.numpy())
  200. Outputs:
  201. .. testoutput::
  202. [0.6576154] [0.6576154]
  203. """
  204. n0 = pred.ndim
  205. n1 = label.ndim
  206. assert n0 == n1 + 1, (
  207. "target ndim must be one less than input ndim; input_ndim={} "
  208. "target_ndim={}".format(n0, n1)
  209. )
  210. mask = 1.0 - equal(label, ignore_index)
  211. label = label * mask
  212. loss = indexing_one_hot(pred, label, axis) * mask
  213. return -1.0 * loss.sum() / maximum(mask.sum(), 1.0)
  214. def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
  215. r"""
  216. Caculate the hinge loss which is often used in SVMs.
  217. The hinge loss can be described as:
  218. .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij}))
  219. :param pred: The input tensor representing the predicted probability, shape is (N, C).
  220. :param label: The input tensor representing the binary classification label, shape is (N, C).
  221. :param norm: Specify the norm to caculate the loss, should be "L1" or "L2".
  222. Examples:
  223. .. testcode::
  224. from megengine import tensor
  225. import megengine.functional as F
  226. pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
  227. label = tensor([[1, -1, -1], [-1, 1, 1]])
  228. loss = F.hinge_loss(pred, label)
  229. print(loss.numpy())
  230. Outputs:
  231. .. testoutput::
  232. [1.5]
  233. """
  234. assert norm in ["L1", "L2"], "norm must be L1 or L2"
  235. # Converts binary labels to -1/1 labels.
  236. loss = relu(1.0 - pred * label)
  237. if norm == "L1":
  238. return loss.sum(axis=1).mean()
  239. else:
  240. return (loss ** 2).sum(axis=1).mean()
  241. def smooth_l1_loss(pred: Tensor, label: Tensor) -> Tensor:
  242. r"""
  243. Caculate the smooth l1 loss proposed in `Fast R-CNN paper by Ross Girshick`.
  244. The smooth l1 loss can be described as:
  245. .. math::
  246. \text{loss}(x, y) = \frac{1}{n} \sum_{i} l_{i}
  247. where :math:`l_{i}` is given by:
  248. .. math::
  249. l_{i} =
  250. \begin{cases}
  251. 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\
  252. |x_i - y_i| - 0.5, & \text{otherwise }
  253. \end{cases}
  254. :param pred: The predicted result from model.
  255. :param label: The ground truth to compare.
  256. Examples:
  257. .. testcode::
  258. from megengine import tensor
  259. import megengine.functional as F
  260. pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
  261. label = tensor([[0.4, 1.5, 1.2], [0., 0.1, 2.2]])
  262. loss = F.smooth_l1_loss(pred, label)
  263. print(loss.numpy())
  264. Outputs:
  265. .. testoutput::
  266. [0.5608334]
  267. """
  268. diff = abs(pred - label)
  269. l2_loss = 0.5 * (diff ** 2)
  270. l1_loss = diff - 0.5
  271. mask = diff < 1
  272. loss = where(mask, l2_loss, l1_loss)
  273. return loss.mean()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台