You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import numpy as np
  11. from ..core.tensor.array_method import _reduce
  12. from ..tensor import Tensor
  13. from .elemwise import abs, equal, log, logaddexp, maximum
  14. from .nn import indexing_one_hot, logsigmoid, logsumexp, relu
  15. from .tensor import broadcast_to, cumsum, linspace, ones, where, zeros
  16. __all__ = [
  17. "l1_loss",
  18. "square_loss",
  19. "cross_entropy",
  20. "binary_cross_entropy",
  21. "hinge_loss",
  22. "ctc_loss",
  23. ]
  24. def _reduce_output(loss_fn):
  25. r"""Wrapper to apply canonical reductions to loss outputs."""
  26. @functools.wraps(loss_fn)
  27. def reduced_loss_fn(*args, reduction="mean", **kwargs):
  28. loss = loss_fn(*args, **kwargs)
  29. if reduction == "none":
  30. return loss
  31. elif reduction in ("mean", "sum"):
  32. return _reduce(reduction)(loss)
  33. else:
  34. raise ValueError("{} is not a valid value for reduction".format(reduction))
  35. return reduced_loss_fn
  36. @_reduce_output
  37. def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
  38. r"""Calculates the mean absolute error (MAE) between
  39. each element in the pred :math:`x` and label :math:`y`.
  40. The mean absolute error can be described as:
  41. .. math::
  42. \ell(x,y) = mean\left(L \right)
  43. where
  44. .. math::
  45. L = \{l_1,\dots,l_N\}, \quad
  46. l_n = \left| x_n - y_n \right|,
  47. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  48. of :math:`N` elements each. :math:`N` is the batch size.
  49. Args:
  50. pred: predicted result from model.
  51. label: ground truth to compare.
  52. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
  53. Returns:
  54. loss value.
  55. Shape:
  56. * ``pred``: :math:`(N, *)` where :math:`*` means any number of additional
  57. dimensions.
  58. * ``label``: :math:`(N, *)`. Same shape as ``pred``.
  59. Examples:
  60. >>> pred = Tensor([3, 3, 3, 3])
  61. >>> label = Tensor([2, 8, 6, 1])
  62. >>> F.nn.l1_loss(pred, label)
  63. Tensor(2.75, device=xpux:0)
  64. >>> F.nn.l1_loss(pred, label, reduction="none")
  65. Tensor([1 5 3 2], dtype=int32, device=xpux:0)
  66. >>> F.nn.l1_loss(pred, label, reduction="sum")
  67. Tensor(11, dtype=int32, device=xpux:0)
  68. """
  69. diff = pred - label
  70. return abs(diff)
  71. @_reduce_output
  72. def square_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
  73. r"""Calculates the mean squared error (squared L2 norm) between
  74. each element in the pred :math:`x` and label :math:`y`.
  75. The mean squared error can be described as:
  76. .. math::
  77. \ell(x, y) = mean\left( L \right)
  78. where
  79. .. math::
  80. L = \{l_1,\dots,l_N\}, \quad
  81. l_n = \left( x_n - y_n \right)^2,
  82. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  83. of :math:`N` elements each. :math:`N` is the batch size.
  84. Args:
  85. pred: predicted result from model.
  86. label: ground truth to compare.
  87. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
  88. Returns:
  89. loss value.
  90. Shape:
  91. * ``pred``: :math:`(N, *)` where :math:`*` means any number of additional
  92. dimensions.
  93. * ``label``: :math:`(N, *)`. Same shape as ``pred``.
  94. Examples:
  95. >>> pred = Tensor([3, 3, 3, 3])
  96. >>> label = Tensor([2, 8, 6, 1])
  97. >>> F.nn.square_loss(pred, label)
  98. Tensor(9.75, device=xpux:0)
  99. >>> F.nn.square_loss(pred, label, reduction="none")
  100. Tensor([ 1. 25. 9. 4.], device=xpux:0)
  101. >>> F.nn.square_loss(pred, label, reduction="sum")
  102. Tensor(39.0, device=xpux:0)
  103. """
  104. diff = pred - label
  105. return diff ** 2
  106. @_reduce_output
  107. def cross_entropy(
  108. pred: Tensor,
  109. label: Tensor,
  110. axis: int = 1,
  111. with_logits: bool = True,
  112. label_smooth: float = 0,
  113. reduction: str = "mean",
  114. ) -> Tensor:
  115. r"""Computes the multi-class cross entropy loss (using logits by default).
  116. When using label smoothing, the label distribution is as follows:
  117. .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
  118. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
  119. k is the index of label distribution. :math:`\alpha` is ``label_smooth`` and :math:`K` is the number of classes.
  120. Args:
  121. pred: input tensor representing the predicted value.
  122. label: input tensor representing the classification label.
  123. axis: an axis along which softmax will be applied. Default: 1
  124. with_logits: whether to apply softmax first. Default: True
  125. label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0
  126. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
  127. Returns:
  128. loss value.
  129. Examples:
  130. By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
  131. class probabilities are given by softmax.
  132. It has better numerical stability compared with sequential calls to
  133. :func:`~.softmax` and :func:`~.cross_entropy`.
  134. >>> pred = Tensor([[0., 1.], [0.3, 0.7], [0.7, 0.3]])
  135. >>> label = Tensor([1., 1., 1.])
  136. >>> F.nn.cross_entropy(pred, label) # doctest: +SKIP
  137. Tensor(0.57976407, device=xpux:0)
  138. >>> F.nn.cross_entropy(pred, label, reduction="none")
  139. Tensor([0.3133 0.513 0.913 ], device=xpux:0)
  140. If the ``pred`` value has been probabilities, set ``with_logits`` to False:
  141. >>> pred = Tensor([[0., 1.], [0.3, 0.7], [0.7, 0.3]])
  142. >>> label = Tensor([1., 1., 1.])
  143. >>> F.nn.cross_entropy(pred, label, with_logits=False) # doctest: +SKIP
  144. Tensor(0.5202159, device=xpux:0)
  145. >>> F.nn.cross_entropy(pred, label, with_logits=False, reduction="none")
  146. Tensor([0. 0.3567 1.204 ], device=xpux:0)
  147. """
  148. n0 = pred.ndim
  149. n1 = label.ndim
  150. assert n0 == n1 + 1, (
  151. "target ndim must be one less than input ndim; input_ndim={} "
  152. "target_ndim={}".format(n0, n1)
  153. )
  154. ls = label_smooth
  155. if with_logits:
  156. logZ = logsumexp(pred, axis)
  157. primary_term = indexing_one_hot(pred, label, axis)
  158. else:
  159. logZ = 0
  160. primary_term = log(indexing_one_hot(pred, label, axis))
  161. if ls is None or type(ls) in (int, float) and ls == 0:
  162. return logZ - primary_term
  163. if not with_logits:
  164. pred = log(pred)
  165. return logZ - ls * pred.mean(axis) - (1 - ls) * primary_term
  166. @_reduce_output
  167. def binary_cross_entropy(
  168. pred: Tensor, label: Tensor, with_logits: bool = True, reduction: str = "mean",
  169. ) -> Tensor:
  170. r"""Computes the binary cross entropy loss (using logits by default).
  171. Args:
  172. pred: `(N, *)`, where `*` means any number of additional dimensions.
  173. label: `(N, *)`, same shape as the input.
  174. with_logits: bool, whether to apply sigmoid first. Default: True
  175. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
  176. Returns:
  177. loss value.
  178. Examples:
  179. By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
  180. class probabilities are given by softmax.
  181. It has better numerical stability compared with sequential calls to
  182. :func:`~.sigmoid` and :func:`~.binary_cross_entropy`.
  183. >>> pred = Tensor([0.9, 0.7, 0.3])
  184. >>> label = Tensor([1., 1., 1.])
  185. >>> F.nn.binary_cross_entropy(pred, label)
  186. Tensor(0.4328984, device=xpux:0)
  187. >>> F.nn.binary_cross_entropy(pred, label, reduction="none")
  188. Tensor([0.3412 0.4032 0.5544], device=xpux:0)
  189. If the ``pred`` value has been probabilities, set ``with_logits`` to False:
  190. >>> pred = Tensor([0.9, 0.7, 0.3])
  191. >>> label = Tensor([1., 1., 1.])
  192. >>> F.nn.binary_cross_entropy(pred, label, with_logits=False)
  193. Tensor(0.5553361, device=xpux:0)
  194. >>> F.nn.binary_cross_entropy(pred, label, with_logits=False, reduction="none")
  195. Tensor([0.1054 0.3567 1.204 ], device=xpux:0)
  196. """
  197. if not with_logits:
  198. return -(label * log(pred) + (1 - label) * log(1 - pred))
  199. # logsigmoid(pred) and logsigmoid(-pred) has common sub-expression
  200. # hopefully the backend would optimize this
  201. return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred))
  202. @_reduce_output
  203. def hinge_loss(
  204. pred: Tensor, label: Tensor, norm: str = "L1", reduction: str = "mean"
  205. ) -> Tensor:
  206. r"""Caculates the hinge loss which is often used in SVM.
  207. The hinge loss can be described as:
  208. .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij}))
  209. Args:
  210. pred: input tensor representing the predicted probability, shape is `(N, C)`.
  211. label: input tensor representing the binary classification label, shape is `(N, C)`.
  212. norm: specify the norm to caculate the loss, should be "L1" or "L2".
  213. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  214. Returns:
  215. loss value.
  216. Examples:
  217. >>> pred = Tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
  218. >>> label = Tensor([[1, -1, -1], [-1, 1, 1]])
  219. >>> F.nn.hinge_loss(pred, label)
  220. Tensor(1.5, device=xpux:0)
  221. >>> F.nn.hinge_loss(pred, label, reduction="none")
  222. Tensor([2.1 0.9], device=xpux:0)
  223. >>> F.nn.hinge_loss(pred, label, reduction="sum")
  224. Tensor(3.0, device=xpux:0)
  225. """
  226. norm = norm.upper()
  227. assert norm in ["L1", "L2"], "norm must be L1 or L2"
  228. # Converts binary labels to -1/1 labels.
  229. loss = relu(1.0 - pred * label)
  230. if norm == "L1":
  231. return loss.sum(axis=1)
  232. else:
  233. return (loss ** 2).sum(axis=1)
  234. def _gen_repeat_idx(inp: Tensor):
  235. idx = cumsum(inp, axis=0)
  236. ret = zeros(inp.sum(), dtype="int32")
  237. ret[idx[:-1]] = 1
  238. return cumsum(ret, axis=0)
  239. def _gen_tile_idx(inp: Tensor):
  240. idx = cumsum(inp, axis=0)
  241. ret = ones(inp.sum(), dtype="int32")
  242. ret[idx[:-1]] = -(inp - 1)[:-1]
  243. return cumsum(ret, axis=0) - 1
  244. def _expand_label(label: Tensor, label_lengths: Tensor, blank: int) -> Tensor:
  245. N = label_lengths.shape[0]
  246. if len(label.shape) == 1:
  247. L = label_lengths.max()
  248. unpack_label = zeros((N, L), dtype="int32") + blank
  249. idx_0 = _gen_repeat_idx(label_lengths)
  250. idx_1 = _gen_tile_idx(label_lengths)
  251. unpack_label[idx_0, idx_1] = label
  252. label = unpack_label
  253. L = label.shape[1]
  254. ex_label = zeros((N, L * 2 + 1), dtype="int32") + blank
  255. ex_label[:, 1::2] = label
  256. return ex_label
  257. def _safelog(x: Tensor) -> Tensor:
  258. eps = np.finfo(x.dtype).tiny
  259. return log(maximum(x, eps))
  260. def ctc_loss(
  261. pred: Tensor,
  262. pred_lengths: Tensor,
  263. label: Tensor,
  264. label_lengths: Tensor,
  265. blank: int = 0,
  266. reduction: str = "mean",
  267. ) -> Tensor:
  268. r"""The Connectionist Temporal Classification loss.
  269. Args:
  270. pred: The probabilities of the output, shape is (T, N, C) ,
  271. where T=input length, N=batch size, and C=number of classes (including blank).
  272. pred_lengths: number of time steps for each sequence in ``pred``, shape is (N, )
  273. label: groundtruth labels, containing the indices of groundtruth
  274. symbols for each sequence at each output time step, and the blank
  275. symbol should not be included. shape is (N, S) or (sum(label_lengths)).
  276. label_lengths: number of time steps for each sequence in the groundtruth, shape is (N, )
  277. blank: the blank symbol number, default 0
  278. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  279. Returns:
  280. loss value.
  281. Examples:
  282. >>> pred = Tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]])
  283. >>> pred_lengths = Tensor([2, 2])
  284. >>> label = Tensor([1, 1])
  285. >>> label_lengths = Tensor([1, 1])
  286. >>> F.nn.ctc_loss(pred, pred_lengths, label, label_lengths)
  287. Tensor(0.1504417, device=xpux:0)
  288. """
  289. T, N, C = pred.shape
  290. assert (
  291. pred_lengths.size == N
  292. ), "pred_lengths must be equal to batch_size {}, but got {}".format(
  293. N, pred_lengths.size
  294. )
  295. assert (
  296. label_lengths.size == N
  297. ), "label_lengths must be euqal to batch_size {}, but got {}".format(
  298. N, label_lengths.size
  299. )
  300. assert (
  301. blank >= 0 and blank < C
  302. ), "blank must be in label range [0, {}), but got {}".format(C, blank)
  303. assert (
  304. pred_lengths.min() > 0 and pred_lengths.max() <= T
  305. ), "pred_lengths must be in range ({}, {}], bug got min {}, max {}".format(
  306. 0, T, pred_lengths.min(), pred_lengths.max()
  307. )
  308. if label.ndim == 1: # concatenated label
  309. assert label_lengths.min() > 0, "label lengths muse be positive"
  310. assert (
  311. label.size == label_lengths.sum()
  312. ), "label size must be equal to sum(label_lengths)"
  313. else:
  314. N, S = label.shape
  315. assert (
  316. label_lengths.min() > 0 and label_lengths.max() <= S
  317. ), "label_lengths must be in range ({}, {}], bug got min {}, max {}".format(
  318. 0, S, label_lengths.min(), label_lengths.max()
  319. )
  320. label = _expand_label(label, label_lengths, blank)
  321. label_mask = label[:, 2:] != label[:, :-2]
  322. L = label.shape[1]
  323. pred = pred.transpose(1, 0, 2) # (T, N, C) -> (N, T, C)
  324. batch_idx = linspace(0, N - 1, N).astype("int32").reshape(-1)
  325. batch_idx_NL = broadcast_to(batch_idx.reshape(N, 1), (N, L)).reshape(-1)
  326. match_pred = pred[batch_idx_NL, :, label.reshape(-1)].reshape(
  327. N, L, -1
  328. ) # (N, T, C) -> (N, L, T)
  329. log_alpha = zeros((N, L), dtype="float32")
  330. log_alpha[:, :2] = match_pred[:, :2, 0]
  331. log_alpha = _safelog(log_alpha)
  332. ret = -logaddexp(
  333. log_alpha[batch_idx, label_lengths * 2],
  334. log_alpha[batch_idx, label_lengths * 2 - 1],
  335. ) * equal(pred_lengths - 1, 0)
  336. for t in range(1, T):
  337. la2 = log_alpha[:, :-2]
  338. log_alpha[:, 1:] = logaddexp(log_alpha[:, 1:], log_alpha[:, :-1])
  339. log_alpha[:, 2:] = (
  340. log_alpha[:, 2:] * (1 - label_mask)
  341. + logaddexp(log_alpha[:, 2:], la2) * label_mask
  342. )
  343. log_alpha += _safelog(match_pred[:, :, t])
  344. ret_t = -logaddexp(
  345. log_alpha[batch_idx, label_lengths * 2],
  346. log_alpha[batch_idx, label_lengths * 2 - 1],
  347. )
  348. ret += ret_t * equal(pred_lengths - 1, t)
  349. if reduction == "mean":
  350. return (ret / label_lengths).mean()
  351. elif reduction == "sum":
  352. return ret.sum()
  353. elif reduction == "none":
  354. return ret
  355. else:
  356. raise ValueError("{} is not a valid value for reduction".format(reduction))