You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. from ..tensor import Tensor
  11. from .elemwise import abs, eq, exp, log, maximum, pow, relu
  12. from .nn import assert_equal, indexing_one_hot
  13. from .tensor import where
  14. from .utils import zero_grad
  15. def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
  16. r"""
  17. Calculates the mean absolute error (MAE) between
  18. each element in the pred :math:`x` and label :math:`y`.
  19. The mean absolute error can be described as:
  20. .. math:: \ell(x,y) = mean\left(L \right)
  21. where
  22. .. math::
  23. L = \{l_1,\dots,l_N\}, \quad
  24. l_n = \left| x_n - y_n \right|,
  25. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  26. of :math:`N` elements each. :math:`N` is the batch size.
  27. :param pred: The predicted result from model.
  28. :param label: The ground truth to compare.
  29. Examples:
  30. .. testcode::
  31. import numpy as np
  32. import megengine as mge
  33. import megengine.functional as F
  34. ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
  35. tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
  36. loss = F.l1_loss(ipt,tgt)
  37. print(loss.numpy())
  38. Outputs:
  39. .. testoutput::
  40. [2.75]
  41. """
  42. diff = pred - label
  43. return abs(diff).mean()
  44. def square_loss(pred: Tensor, label: Tensor) -> Tensor:
  45. r"""
  46. Calculates the mean squared error (squared L2 norm) between
  47. each element in the pred :math:`x` and label :math:`y`.
  48. The mean squared error can be described as:
  49. .. math:: \ell(x, y) = mean\left( L \right)
  50. where
  51. .. math::
  52. L = \{l_1,\dots,l_N\}, \quad
  53. l_n = \left( x_n - y_n \right)^2,
  54. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  55. of :math:`N` elements each. :math:`N` is the batch size.
  56. :param pred: The predicted result from model.
  57. :param label: The ground truth to compare.
  58. Shape:
  59. - pred: :math:`(N, *)` where :math:`*` means any number of additional
  60. dimensions
  61. - label: :math:`(N, *)`. Same shape as ``pred``
  62. """
  63. diff = pred - label
  64. return (diff ** 2).mean()
  65. def cross_entropy(
  66. inp: Tensor, target: Tensor, axis: int = 1, ignore_index: int = -1
  67. ) -> Tensor:
  68. r"""
  69. Returns the cross entropy loss in a classification problem.
  70. .. math:: \textrm{CrossEntropy}(x, y) = - \sum_{i} y_i\log(x_i)
  71. :param inp: The input tensor representing the predicted probability.
  72. :param label: The input tensor representing the classification label.
  73. :param axis: An axis along which cross_entropy will be applied. Default: 1
  74. :param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. Default: -1
  75. Examples:
  76. .. testcode::
  77. import numpy as np
  78. from megengine import tensor
  79. import megengine.functional as F
  80. data_shape = (1, 2)
  81. label_shape = (1, )
  82. pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape))
  83. label = tensor(np.ones(label_shape, dtype=np.int32))
  84. loss = F.cross_entropy(pred, label)
  85. print(loss.numpy())
  86. Outputs:
  87. .. testoutput::
  88. [0.69]
  89. """
  90. raise NotImplementedError
  91. # n0 = inp.ndim
  92. # n1 = target.ndim
  93. # assert n0 == n1 + 1, (
  94. # "target ndim must be one less than input ndim; input_ndim={} "
  95. # "target_ndim={}".format(n0, n1)
  96. # )
  97. # if ignore_index != -1:
  98. # mask = 1 - equal(target, ignore_index)
  99. # target = target * mask
  100. # loss = -log(indexing_one_hot(inp, target, axis)) * mask
  101. # return loss.sum() / maximum(mask.sum(), 1.0)
  102. # else:
  103. # return -log(indexing_one_hot(inp, target, axis)).mean()
  104. def cross_entropy_with_softmax(
  105. pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0
  106. ) -> Tensor:
  107. r"""
  108. Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`.
  109. It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`.
  110. When using label smoothing, the label distribution is as follows:
  111. .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
  112. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
  113. k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes.
  114. :param pred: The input tensor representing the predicted probability.
  115. :param label: The input tensor representing the classification label.
  116. :param axis: An axis along which softmax will be applied. Default: 1.
  117. :param label_smooth: A label smoothing of parameter that can re-distribute target distribution. Default: 0.
  118. """
  119. n0 = pred.ndim
  120. n1 = label.ndim
  121. assert n0 == n1 + 1, (
  122. "target ndim must be one less than input ndim; input_ndim={} "
  123. "target_ndim={}".format(n0, n1)
  124. )
  125. num_classes = pred.shape[axis]
  126. # Denominator of the softmax
  127. offset = pred.max(axis=axis).detach()
  128. pred = pred - offset
  129. down = exp(pred).sum(axis=axis)
  130. up = pred[np.arange(pred.shape[0]), label]
  131. if label_smooth != 0:
  132. factor = label_smooth / num_classes
  133. up = up * (1 - label_smooth) + pred.sum(axis=axis) * factor
  134. return (log(down) - up).mean()
  135. def triplet_margin_loss(
  136. anchor: Tensor, positive: Tensor, negative: Tensor, margin: float = 1.0, p: int = 2
  137. ) -> Tensor:
  138. r"""
  139. Creates a criterion that measures the triplet loss given an input tensors.
  140. .. math::
  141. L(a, p, n) = max\left\{d\left(a_{i},p_{i}\right)-d\left(a_{i}, n_{i}\right)+margin, 0\right\},\
  142. d\left(x_{i},y_{i}\right)=\left\|x_{i}-y_{i}\right\|_{p}
  143. :param anchor: The input tensor representing the anchor samples.
  144. :param positive: The input tensor representing the positive samples.
  145. :param negative: The input tensor representing the negative samples.
  146. :param margin: Default: 1.0
  147. :param p: The norm degree for pairwise distance. Default: 2.0
  148. """
  149. s0 = anchor.shapeof()
  150. s1 = positive.shapeof()
  151. s2 = negative.shapeof()
  152. assert_equal(s0, s1)
  153. assert_equal(s1, s2)
  154. n0 = anchor.ndim
  155. n1 = positive.ndim
  156. n2 = negative.ndim
  157. assert n0 == 2 and n1 == 2 and n2 == 2, (
  158. "anchor ndim, positive ndim, and negative ndim must be 2; "
  159. "anchor_ndim={} positive_ndim={} negative_ndim={}".format(n0, n1, n2)
  160. )
  161. assert p > 0, "a margin with a value greater than 0; p={}".format(p)
  162. diff0 = abs(anchor - positive)
  163. diff1 = abs(anchor - negative)
  164. d1 = power(power(diff0, p).sum(axis=1, keepdims=True), 1 / p)
  165. d2 = power(power(diff1, p).sum(axis=1, keepdims=True), 1 / p)
  166. loss = maximum(d1 - d2 + margin, 0)
  167. return loss.mean()
  168. def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
  169. r"""Function that measures the Binary Cross Entropy between the target and the prediction.
  170. :param pred: (N,*) where * means, any number of additional dimensions.
  171. :param label: (N,*), same shape as the input.
  172. """
  173. assert pred.shape == label.shape
  174. return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean()
  175. def nll_loss(
  176. pred: Tensor, label: Tensor, axis: int = 1, ignore_index: int = -1
  177. ) -> Tensor:
  178. r"""
  179. The negative log likelihood loss.
  180. :param pred: The predicted result from model.
  181. :param label: The ground truth to compare.
  182. Examples:
  183. .. testcode::
  184. import numpy as np
  185. from megengine import tensor
  186. import megengine.functional as F
  187. data_shape = (2, 2)
  188. label_shape = (2, )
  189. data = tensor(
  190. np.array([[1, 0.5], [0.3, 1.2]], dtype=np.float32).reshape(data_shape),
  191. )
  192. label = tensor(
  193. np.ones(label_shape, dtype=np.int32)
  194. )
  195. pred = F.log(F.softmax(data))
  196. loss1 = F.nll_loss(pred, label)
  197. loss2 = F.cross_entropy_with_softmax(data, label)
  198. print(loss1.numpy(), loss2.numpy())
  199. Outputs:
  200. .. testoutput::
  201. [0.6576154] [0.6576154]
  202. """
  203. raise NotImplementedError
  204. # n0 = pred.ndim
  205. # n1 = label.ndim
  206. # assert n0 == n1 + 1, (
  207. # "target ndim must be one less than input ndim; input_ndim={} "
  208. # "target_ndim={}".format(n0, n1)
  209. # )
  210. # mask = 1.0 - equal(label, ignore_index)
  211. # label = label * mask
  212. # loss = indexing_one_hot(pred, label, axis) * mask
  213. # return -1.0 * loss.sum() / maximum(mask.sum(), 1.0)
  214. def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
  215. r"""
  216. Caculate the hinge loss which is often used in SVMs.
  217. The hinge loss can be described as:
  218. .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_i_j*y_i_j))
  219. :param pred: The input tensor representing the predicted probability, shape is (N, C).
  220. :param label: The input tensor representing the binary classification label, shape is (N, C).
  221. :param norm: Specify the norm to caculate the loss, should be "L1" or "L2".
  222. Examples:
  223. .. testcode::
  224. from megengine import tensor
  225. import megengine.functional as F
  226. pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32")
  227. label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32")
  228. loss = F.hinge_loss(pred, label)
  229. print(loss.numpy())
  230. Outputs:
  231. .. testoutput::
  232. [1.5]
  233. """
  234. assert norm in ["L1", "L2"], "norm must be L1 or L2"
  235. # Converts binary labels to -1/1 labels.
  236. loss = relu(1.0 - pred * label)
  237. if norm == "L1":
  238. return loss.sum(axis=1).mean()
  239. else:
  240. return (loss ** 2).sum(axis=1).mean()
  241. def smooth_l1_loss(pred: Tensor, label: Tensor) -> Tensor:
  242. r"""
  243. Caculate the smooth l1 loss proposed in `Fast R-CNN paper by Ross Girshick`.
  244. The smooth l1 loss can be described as:
  245. .. math::
  246. \text{loss}(x, y) = \frac{1}{n} \sum_{i} l_{i}
  247. where :math:`l_{i}` is given by:
  248. .. math::
  249. l_{i} =
  250. \begin{cases}
  251. 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\
  252. |x_i - y_i| - 0.5, & \text{otherwise }
  253. \end{cases}
  254. :param pred: The predicted result from model.
  255. :param label: The ground truth to compare.
  256. Examples:
  257. .. testcode::
  258. from megengine import tensor
  259. import megengine.functional as F
  260. pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
  261. label = tensor([[0.4, 1.5, 1.2], [0., 0.1, 2.2]])
  262. loss = F.smooth_l1_loss(pred, label)
  263. print(loss.numpy())
  264. Outputs:
  265. .. testoutput::
  266. [0.5608334]
  267. """
  268. raise NotImplementedError
  269. # diff = abs(pred - label)
  270. # l2_loss = 0.5 * (diff ** 2)
  271. # l1_loss = diff - 0.5
  272. # mask = diff < 1
  273. # loss = where(mask, l2_loss, l1_loss)
  274. # return loss.mean()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台