You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. from ..core.tensor.utils import make_shape_tuple
  11. from ..tensor import Tensor
  12. from .elemwise import abs, eq, exp, log, maximum, pow, relu
  13. from .nn import assert_equal, indexing_one_hot
  14. from .tensor import where
  15. from .utils import zero_grad
  16. def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
  17. r"""
  18. Calculates the mean absolute error (MAE) between
  19. each element in the pred :math:`x` and label :math:`y`.
  20. The mean absolute error can be described as:
  21. .. math:: \ell(x,y) = mean\left(L \right)
  22. where
  23. .. math::
  24. L = \{l_1,\dots,l_N\}, \quad
  25. l_n = \left| x_n - y_n \right|,
  26. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  27. of :math:`N` elements each. :math:`N` is the batch size.
  28. :param pred: The predicted result from model.
  29. :param label: The ground truth to compare.
  30. Examples:
  31. .. testcode::
  32. import numpy as np
  33. import megengine as mge
  34. import megengine.functional as F
  35. ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
  36. tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
  37. loss = F.l1_loss(ipt,tgt)
  38. print(loss.numpy())
  39. Outputs:
  40. .. testoutput::
  41. [2.75]
  42. """
  43. diff = pred - label
  44. return abs(diff).mean()
  45. def square_loss(pred: Tensor, label: Tensor) -> Tensor:
  46. r"""
  47. Calculates the mean squared error (squared L2 norm) between
  48. each element in the pred :math:`x` and label :math:`y`.
  49. The mean squared error can be described as:
  50. .. math:: \ell(x, y) = mean\left( L \right)
  51. where
  52. .. math::
  53. L = \{l_1,\dots,l_N\}, \quad
  54. l_n = \left( x_n - y_n \right)^2,
  55. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  56. of :math:`N` elements each. :math:`N` is the batch size.
  57. :param pred: The predicted result from model.
  58. :param label: The ground truth to compare.
  59. Shape:
  60. - pred: :math:`(N, *)` where :math:`*` means any number of additional
  61. dimensions
  62. - label: :math:`(N, *)`. Same shape as ``pred``
  63. """
  64. diff = pred - label
  65. return (diff ** 2).mean()
  66. def cross_entropy(
  67. inp: Tensor, target: Tensor, axis: int = 1, ignore_index: int = -1
  68. ) -> Tensor:
  69. r"""
  70. Returns the cross entropy loss in a classification problem.
  71. .. math:: \textrm{CrossEntropy}(x, y) = - \sum_{i} y_i\log(x_i)
  72. :param inp: The input tensor representing the predicted probability.
  73. :param label: The input tensor representing the classification label.
  74. :param axis: An axis along which cross_entropy will be applied. Default: 1
  75. :param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. Default: -1
  76. Examples:
  77. .. testcode::
  78. import numpy as np
  79. from megengine import tensor
  80. import megengine.functional as F
  81. data_shape = (1, 2)
  82. label_shape = (1, )
  83. pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape))
  84. label = tensor(np.ones(label_shape, dtype=np.int32))
  85. loss = F.cross_entropy(pred, label)
  86. print(loss.numpy())
  87. Outputs:
  88. .. testoutput::
  89. [0.69]
  90. """
  91. raise NotImplementedError
  92. # n0 = inp.ndim
  93. # n1 = target.ndim
  94. # assert n0 == n1 + 1, (
  95. # "target ndim must be one less than input ndim; input_ndim={} "
  96. # "target_ndim={}".format(n0, n1)
  97. # )
  98. # if ignore_index != -1:
  99. # mask = 1 - equal(target, ignore_index)
  100. # target = target * mask
  101. # loss = -log(indexing_one_hot(inp, target, axis)) * mask
  102. # return loss.sum() / maximum(mask.sum(), 1.0)
  103. # else:
  104. # return -log(indexing_one_hot(inp, target, axis)).mean()
  105. def cross_entropy_with_softmax(
  106. pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0
  107. ) -> Tensor:
  108. r"""
  109. Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`.
  110. It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`.
  111. When using label smoothing, the label distribution is as follows:
  112. .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
  113. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
  114. k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes.
  115. :param pred: The input tensor representing the predicted probability.
  116. :param label: The input tensor representing the classification label.
  117. :param axis: An axis along which softmax will be applied. Default: 1.
  118. :param label_smooth: A label smoothing of parameter that can re-distribute target distribution. Default: 0.
  119. """
  120. n0 = pred.ndim
  121. n1 = label.ndim
  122. assert n0 == n1 + 1, (
  123. "target ndim must be one less than input ndim; input_ndim={} "
  124. "target_ndim={}".format(n0, n1)
  125. )
  126. num_classes = pred.shape[axis]
  127. # Denominator of the softmax
  128. offset = pred.max(axis=axis).detach()
  129. pred = pred - offset
  130. down = exp(pred).sum(axis=axis)
  131. up = indexing_one_hot(pred, label, axis)
  132. if label_smooth != 0:
  133. factor = label_smooth / num_classes
  134. up = up * (1 - label_smooth) + pred.sum(axis=axis) * factor
  135. return (log(down) - up).mean()
  136. def triplet_margin_loss(
  137. anchor: Tensor, positive: Tensor, negative: Tensor, margin: float = 1.0, p: int = 2
  138. ) -> Tensor:
  139. r"""
  140. Creates a criterion that measures the triplet loss given an input tensors.
  141. .. math::
  142. L(a, p, n) = max\left\{d\left(a_{i},p_{i}\right)-d\left(a_{i}, n_{i}\right)+margin, 0\right\},\
  143. d\left(x_{i},y_{i}\right)=\left\|x_{i}-y_{i}\right\|_{p}
  144. :param anchor: The input tensor representing the anchor samples.
  145. :param positive: The input tensor representing the positive samples.
  146. :param negative: The input tensor representing the negative samples.
  147. :param margin: Default: 1.0
  148. :param p: The norm degree for pairwise distance. Default: 2.0
  149. """
  150. s0 = anchor.shapeof()
  151. s1 = positive.shapeof()
  152. s2 = negative.shapeof()
  153. assert_equal(s0, s1)
  154. assert_equal(s1, s2)
  155. n0 = anchor.ndim
  156. n1 = positive.ndim
  157. n2 = negative.ndim
  158. assert n0 == 2 and n1 == 2 and n2 == 2, (
  159. "anchor ndim, positive ndim, and negative ndim must be 2; "
  160. "anchor_ndim={} positive_ndim={} negative_ndim={}".format(n0, n1, n2)
  161. )
  162. assert p > 0, "a margin with a value greater than 0; p={}".format(p)
  163. diff0 = abs(anchor - positive)
  164. diff1 = abs(anchor - negative)
  165. d1 = power(power(diff0, p).sum(axis=1, keepdims=True), 1 / p)
  166. d2 = power(power(diff1, p).sum(axis=1, keepdims=True), 1 / p)
  167. loss = maximum(d1 - d2 + margin, 0)
  168. return loss.mean()
  169. def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
  170. r"""Function that measures the Binary Cross Entropy between the target and the prediction.
  171. :param pred: (N,*) where * means, any number of additional dimensions.
  172. :param label: (N,*), same shape as the input.
  173. """
  174. assert make_shape_tuple(pred.shape) == make_shape_tuple(label.shape)
  175. return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean()
  176. def nll_loss(
  177. pred: Tensor, label: Tensor, axis: int = 1, ignore_index: int = -1
  178. ) -> Tensor:
  179. r"""
  180. The negative log likelihood loss.
  181. :param pred: The predicted result from model.
  182. :param label: The ground truth to compare.
  183. Examples:
  184. .. testcode::
  185. import numpy as np
  186. from megengine import tensor
  187. import megengine.functional as F
  188. data_shape = (2, 2)
  189. label_shape = (2, )
  190. data = tensor(
  191. np.array([[1, 0.5], [0.3, 1.2]], dtype=np.float32).reshape(data_shape),
  192. )
  193. label = tensor(
  194. np.ones(label_shape, dtype=np.int32)
  195. )
  196. pred = F.log(F.softmax(data))
  197. loss1 = F.nll_loss(pred, label)
  198. loss2 = F.cross_entropy_with_softmax(data, label)
  199. print(loss1.numpy(), loss2.numpy())
  200. Outputs:
  201. .. testoutput::
  202. [0.6576154] [0.6576154]
  203. """
  204. raise NotImplementedError
  205. # n0 = pred.ndim
  206. # n1 = label.ndim
  207. # assert n0 == n1 + 1, (
  208. # "target ndim must be one less than input ndim; input_ndim={} "
  209. # "target_ndim={}".format(n0, n1)
  210. # )
  211. # mask = 1.0 - equal(label, ignore_index)
  212. # label = label * mask
  213. # loss = indexing_one_hot(pred, label, axis) * mask
  214. # return -1.0 * loss.sum() / maximum(mask.sum(), 1.0)
  215. def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
  216. r"""
  217. Caculate the hinge loss which is often used in SVMs.
  218. The hinge loss can be described as:
  219. .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_i_j*y_i_j))
  220. :param pred: The input tensor representing the predicted probability, shape is (N, C).
  221. :param label: The input tensor representing the binary classification label, shape is (N, C).
  222. :param norm: Specify the norm to caculate the loss, should be "L1" or "L2".
  223. Examples:
  224. .. testcode::
  225. from megengine import tensor
  226. import megengine.functional as F
  227. pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32")
  228. label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32")
  229. loss = F.hinge_loss(pred, label)
  230. print(loss.numpy())
  231. Outputs:
  232. .. testoutput::
  233. [1.5]
  234. """
  235. assert norm in ["L1", "L2"], "norm must be L1 or L2"
  236. # Converts binary labels to -1/1 labels.
  237. loss = relu(1.0 - pred * label)
  238. if norm == "L1":
  239. return loss.sum(axis=1).mean()
  240. else:
  241. return (loss ** 2).sum(axis=1).mean()
  242. def smooth_l1_loss(pred: Tensor, label: Tensor) -> Tensor:
  243. r"""
  244. Caculate the smooth l1 loss proposed in `Fast R-CNN paper by Ross Girshick`.
  245. The smooth l1 loss can be described as:
  246. .. math::
  247. \text{loss}(x, y) = \frac{1}{n} \sum_{i} l_{i}
  248. where :math:`l_{i}` is given by:
  249. .. math::
  250. l_{i} =
  251. \begin{cases}
  252. 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\
  253. |x_i - y_i| - 0.5, & \text{otherwise }
  254. \end{cases}
  255. :param pred: The predicted result from model.
  256. :param label: The ground truth to compare.
  257. Examples:
  258. .. testcode::
  259. from megengine import tensor
  260. import megengine.functional as F
  261. pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
  262. label = tensor([[0.4, 1.5, 1.2], [0., 0.1, 2.2]])
  263. loss = F.smooth_l1_loss(pred, label)
  264. print(loss.numpy())
  265. Outputs:
  266. .. testoutput::
  267. [0.5608334]
  268. """
  269. raise NotImplementedError
  270. # diff = abs(pred - label)
  271. # l2_loss = 0.5 * (diff ** 2)
  272. # l1_loss = diff - 0.5
  273. # mask = diff < 1
  274. # loss = where(mask, l2_loss, l1_loss)
  275. # return loss.mean()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台