You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # -*- coding: utf-8 -*-
  2. import functools
  3. import numpy as np
  4. from ..core.tensor.array_method import _reduce
  5. from ..tensor import Tensor
  6. from .elemwise import abs, equal, log, logaddexp, maximum
  7. from .nn import indexing_one_hot, logsigmoid, logsumexp, relu
  8. from .tensor import broadcast_to, cumsum, linspace, ones, where, zeros
  9. __all__ = [
  10. "l1_loss",
  11. "square_loss",
  12. "cross_entropy",
  13. "binary_cross_entropy",
  14. "hinge_loss",
  15. "ctc_loss",
  16. ]
  17. def _reduce_output(loss_fn):
  18. r"""Wrapper to apply canonical reductions to loss outputs."""
  19. @functools.wraps(loss_fn)
  20. def reduced_loss_fn(*args, reduction="mean", **kwargs):
  21. loss = loss_fn(*args, **kwargs)
  22. if reduction == "none":
  23. return loss
  24. elif reduction in ("mean", "sum"):
  25. return _reduce(reduction)(loss)
  26. else:
  27. raise ValueError("{} is not a valid value for reduction".format(reduction))
  28. return reduced_loss_fn
  29. @_reduce_output
  30. def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
  31. r"""Calculates the mean absolute error (MAE) between
  32. each element in the pred :math:`x` and label :math:`y`.
  33. The mean absolute error can be described as:
  34. .. math::
  35. \ell(x,y) = mean\left(L \right)
  36. where
  37. .. math::
  38. L = \{l_1,\dots,l_N\}, \quad
  39. l_n = \left| x_n - y_n \right|,
  40. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  41. of :math:`N` elements each. :math:`N` is the batch size.
  42. Args:
  43. pred: predicted result from model.
  44. label: ground truth to compare.
  45. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
  46. Returns:
  47. loss value.
  48. Shape:
  49. * ``pred``: :math:`(N, *)` where :math:`*` means any number of additional
  50. dimensions.
  51. * ``label``: :math:`(N, *)`. Same shape as ``pred``.
  52. Examples:
  53. >>> pred = Tensor([3, 3, 3, 3])
  54. >>> label = Tensor([2, 8, 6, 1])
  55. >>> F.nn.l1_loss(pred, label)
  56. Tensor(2.75, device=xpux:0)
  57. >>> F.nn.l1_loss(pred, label, reduction="none")
  58. Tensor([1 5 3 2], dtype=int32, device=xpux:0)
  59. >>> F.nn.l1_loss(pred, label, reduction="sum")
  60. Tensor(11, dtype=int32, device=xpux:0)
  61. """
  62. diff = pred - label
  63. return abs(diff)
  64. @_reduce_output
  65. def square_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
  66. r"""Calculates the mean squared error (squared L2 norm) between
  67. each element in the pred :math:`x` and label :math:`y`.
  68. The mean squared error can be described as:
  69. .. math::
  70. \ell(x, y) = mean\left( L \right)
  71. where
  72. .. math::
  73. L = \{l_1,\dots,l_N\}, \quad
  74. l_n = \left( x_n - y_n \right)^2,
  75. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  76. of :math:`N` elements each. :math:`N` is the batch size.
  77. Args:
  78. pred: predicted result from model.
  79. label: ground truth to compare.
  80. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
  81. Returns:
  82. loss value.
  83. Shape:
  84. * ``pred``: :math:`(N, *)` where :math:`*` means any number of additional
  85. dimensions.
  86. * ``label``: :math:`(N, *)`. Same shape as ``pred``.
  87. Examples:
  88. >>> pred = Tensor([3, 3, 3, 3])
  89. >>> label = Tensor([2, 8, 6, 1])
  90. >>> F.nn.square_loss(pred, label)
  91. Tensor(9.75, device=xpux:0)
  92. >>> F.nn.square_loss(pred, label, reduction="none")
  93. Tensor([ 1. 25. 9. 4.], device=xpux:0)
  94. >>> F.nn.square_loss(pred, label, reduction="sum")
  95. Tensor(39.0, device=xpux:0)
  96. """
  97. diff = pred - label
  98. return diff ** 2
  99. @_reduce_output
  100. def cross_entropy(
  101. pred: Tensor,
  102. label: Tensor,
  103. axis: int = 1,
  104. with_logits: bool = True,
  105. label_smooth: float = 0,
  106. reduction: str = "mean",
  107. ) -> Tensor:
  108. r"""Computes the multi-class cross entropy loss (using logits by default).
  109. When using label smoothing, the label distribution is as follows:
  110. .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
  111. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
  112. k is the index of label distribution. :math:`\alpha` is ``label_smooth`` and :math:`K` is the number of classes.
  113. Args:
  114. pred: input tensor representing the predicted value.
  115. label: input tensor representing the classification label.
  116. axis: an axis along which softmax will be applied. Default: 1
  117. with_logits: whether to apply softmax first. Default: True
  118. label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0
  119. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
  120. Returns:
  121. loss value.
  122. Examples:
  123. By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
  124. class probabilities are given by softmax.
  125. It has better numerical stability compared with sequential calls to
  126. :func:`~.softmax` and :func:`~.cross_entropy`.
  127. >>> pred = Tensor([[0., 1.], [0.3, 0.7], [0.7, 0.3]])
  128. >>> label = Tensor([1., 1., 1.])
  129. >>> F.nn.cross_entropy(pred, label) # doctest: +SKIP
  130. Tensor(0.57976407, device=xpux:0)
  131. >>> F.nn.cross_entropy(pred, label, reduction="none")
  132. Tensor([0.3133 0.513 0.913 ], device=xpux:0)
  133. If the ``pred`` value has been probabilities, set ``with_logits`` to False:
  134. >>> pred = Tensor([[0., 1.], [0.3, 0.7], [0.7, 0.3]])
  135. >>> label = Tensor([1., 1., 1.])
  136. >>> F.nn.cross_entropy(pred, label, with_logits=False) # doctest: +SKIP
  137. Tensor(0.5202159, device=xpux:0)
  138. >>> F.nn.cross_entropy(pred, label, with_logits=False, reduction="none")
  139. Tensor([0. 0.3567 1.204 ], device=xpux:0)
  140. """
  141. n0 = pred.ndim
  142. n1 = label.ndim
  143. assert n0 == n1 + 1, (
  144. "target ndim must be one less than input ndim; input_ndim={} "
  145. "target_ndim={}".format(n0, n1)
  146. )
  147. ls = label_smooth
  148. if with_logits:
  149. logZ = logsumexp(pred, axis)
  150. primary_term = indexing_one_hot(pred, label, axis)
  151. else:
  152. logZ = 0
  153. primary_term = log(indexing_one_hot(pred, label, axis))
  154. if ls is None or type(ls) in (int, float) and ls == 0:
  155. return logZ - primary_term
  156. if not with_logits:
  157. pred = log(pred)
  158. return logZ - ls * pred.mean(axis) - (1 - ls) * primary_term
  159. @_reduce_output
  160. def binary_cross_entropy(
  161. pred: Tensor, label: Tensor, with_logits: bool = True, reduction: str = "mean",
  162. ) -> Tensor:
  163. r"""Computes the binary cross entropy loss (using logits by default).
  164. Args:
  165. pred: `(N, *)`, where `*` means any number of additional dimensions.
  166. label: `(N, *)`, same shape as the input.
  167. with_logits: bool, whether to apply sigmoid first. Default: True
  168. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'.
  169. Returns:
  170. loss value.
  171. Examples:
  172. By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
  173. class probabilities are given by softmax.
  174. It has better numerical stability compared with sequential calls to
  175. :func:`~.sigmoid` and :func:`~.binary_cross_entropy`.
  176. >>> pred = Tensor([0.9, 0.7, 0.3])
  177. >>> label = Tensor([1., 1., 1.])
  178. >>> F.nn.binary_cross_entropy(pred, label)
  179. Tensor(0.4328984, device=xpux:0)
  180. >>> F.nn.binary_cross_entropy(pred, label, reduction="none")
  181. Tensor([0.3412 0.4032 0.5544], device=xpux:0)
  182. If the ``pred`` value has been probabilities, set ``with_logits`` to False:
  183. >>> pred = Tensor([0.9, 0.7, 0.3])
  184. >>> label = Tensor([1., 1., 1.])
  185. >>> F.nn.binary_cross_entropy(pred, label, with_logits=False)
  186. Tensor(0.5553361, device=xpux:0)
  187. >>> F.nn.binary_cross_entropy(pred, label, with_logits=False, reduction="none")
  188. Tensor([0.1054 0.3567 1.204 ], device=xpux:0)
  189. """
  190. if not with_logits:
  191. return -(label * log(pred) + (1 - label) * log(1 - pred))
  192. # logsigmoid(pred) and logsigmoid(-pred) has common sub-expression
  193. # hopefully the backend would optimize this
  194. return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred))
  195. @_reduce_output
  196. def hinge_loss(
  197. pred: Tensor, label: Tensor, norm: str = "L1", reduction: str = "mean"
  198. ) -> Tensor:
  199. r"""Caculates the hinge loss which is often used in SVM.
  200. The hinge loss can be described as:
  201. .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij}))
  202. Args:
  203. pred: input tensor representing the predicted probability, shape is `(N, C)`.
  204. label: input tensor representing the binary classification label, shape is `(N, C)`.
  205. norm: specify the norm to caculate the loss, should be "L1" or "L2".
  206. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  207. Returns:
  208. loss value.
  209. Examples:
  210. >>> pred = Tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
  211. >>> label = Tensor([[1, -1, -1], [-1, 1, 1]])
  212. >>> F.nn.hinge_loss(pred, label)
  213. Tensor(1.5, device=xpux:0)
  214. >>> F.nn.hinge_loss(pred, label, reduction="none")
  215. Tensor([2.1 0.9], device=xpux:0)
  216. >>> F.nn.hinge_loss(pred, label, reduction="sum")
  217. Tensor(3.0, device=xpux:0)
  218. """
  219. norm = norm.upper()
  220. assert norm in ["L1", "L2"], "norm must be L1 or L2"
  221. # Converts binary labels to -1/1 labels.
  222. loss = relu(1.0 - pred * label)
  223. if norm == "L1":
  224. return loss.sum(axis=1)
  225. else:
  226. return (loss ** 2).sum(axis=1)
  227. def _gen_repeat_idx(inp: Tensor):
  228. idx = cumsum(inp, axis=0)
  229. ret = zeros(inp.sum(), dtype="int32")
  230. ret[idx[:-1]] = 1
  231. return cumsum(ret, axis=0)
  232. def _gen_tile_idx(inp: Tensor):
  233. idx = cumsum(inp, axis=0)
  234. ret = ones(inp.sum(), dtype="int32")
  235. ret[idx[:-1]] = -(inp - 1)[:-1]
  236. return cumsum(ret, axis=0) - 1
  237. def _expand_label(label: Tensor, label_lengths: Tensor, blank: int) -> Tensor:
  238. N = label_lengths.shape[0]
  239. if len(label.shape) == 1:
  240. L = label_lengths.max()
  241. unpack_label = zeros((N, L), dtype="int32") + blank
  242. idx_0 = _gen_repeat_idx(label_lengths)
  243. idx_1 = _gen_tile_idx(label_lengths)
  244. unpack_label[idx_0, idx_1] = label
  245. label = unpack_label
  246. L = label.shape[1]
  247. ex_label = zeros((N, L * 2 + 1), dtype="int32") + blank
  248. ex_label[:, 1::2] = label
  249. return ex_label
  250. def _safelog(x: Tensor) -> Tensor:
  251. eps = np.finfo(x.dtype).tiny
  252. return log(maximum(x, eps))
  253. def ctc_loss(
  254. pred: Tensor,
  255. pred_lengths: Tensor,
  256. label: Tensor,
  257. label_lengths: Tensor,
  258. blank: int = 0,
  259. reduction: str = "mean",
  260. ) -> Tensor:
  261. r"""The Connectionist Temporal Classification loss.
  262. Args:
  263. pred: The probabilities of the output, shape is (T, N, C) ,
  264. where T=input length, N=batch size, and C=number of classes (including blank).
  265. pred_lengths: number of time steps for each sequence in ``pred``, shape is (N, )
  266. label: groundtruth labels, containing the indices of groundtruth
  267. symbols for each sequence at each output time step, and the blank
  268. symbol should not be included. shape is (N, S) or (sum(label_lengths)).
  269. label_lengths: number of time steps for each sequence in the groundtruth, shape is (N, )
  270. blank: the blank symbol number, default 0
  271. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  272. Returns:
  273. loss value.
  274. Examples:
  275. >>> pred = Tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]])
  276. >>> pred_lengths = Tensor([2, 2])
  277. >>> label = Tensor([1, 1])
  278. >>> label_lengths = Tensor([1, 1])
  279. >>> F.nn.ctc_loss(pred, pred_lengths, label, label_lengths)
  280. Tensor(0.1504417, device=xpux:0)
  281. """
  282. T, N, C = pred.shape
  283. assert (
  284. pred_lengths.size == N
  285. ), "pred_lengths must be equal to batch_size {}, but got {}".format(
  286. N, pred_lengths.size
  287. )
  288. assert (
  289. label_lengths.size == N
  290. ), "label_lengths must be euqal to batch_size {}, but got {}".format(
  291. N, label_lengths.size
  292. )
  293. assert (
  294. blank >= 0 and blank < C
  295. ), "blank must be in label range [0, {}), but got {}".format(C, blank)
  296. assert (
  297. pred_lengths.min() > 0 and pred_lengths.max() <= T
  298. ), "pred_lengths must be in range ({}, {}], bug got min {}, max {}".format(
  299. 0, T, pred_lengths.min(), pred_lengths.max()
  300. )
  301. if label.ndim == 1: # concatenated label
  302. assert label_lengths.min() > 0, "label lengths muse be positive"
  303. assert (
  304. label.size == label_lengths.sum()
  305. ), "label size must be equal to sum(label_lengths)"
  306. else:
  307. N, S = label.shape
  308. assert (
  309. label_lengths.min() > 0 and label_lengths.max() <= S
  310. ), "label_lengths must be in range ({}, {}], bug got min {}, max {}".format(
  311. 0, S, label_lengths.min(), label_lengths.max()
  312. )
  313. label = _expand_label(label, label_lengths, blank)
  314. label_mask = label[:, 2:] != label[:, :-2]
  315. L = label.shape[1]
  316. pred = pred.transpose(1, 0, 2) # (T, N, C) -> (N, T, C)
  317. batch_idx = linspace(0, N - 1, N).astype("int32").reshape(-1)
  318. batch_idx_NL = broadcast_to(batch_idx.reshape(N, 1), (N, L)).reshape(-1)
  319. match_pred = pred[batch_idx_NL, :, label.reshape(-1)].reshape(
  320. N, L, -1
  321. ) # (N, T, C) -> (N, L, T)
  322. log_alpha = zeros((N, L), dtype="float32")
  323. log_alpha[:, :2] = match_pred[:, :2, 0]
  324. log_alpha = _safelog(log_alpha)
  325. ret = -logaddexp(
  326. log_alpha[batch_idx, label_lengths * 2],
  327. log_alpha[batch_idx, label_lengths * 2 - 1],
  328. ) * equal(pred_lengths - 1, 0)
  329. for t in range(1, T):
  330. la2 = log_alpha[:, :-2]
  331. log_alpha[:, 1:] = logaddexp(log_alpha[:, 1:], log_alpha[:, :-1])
  332. log_alpha[:, 2:] = (
  333. log_alpha[:, 2:] * (1 - label_mask)
  334. + logaddexp(log_alpha[:, 2:], la2) * label_mask
  335. )
  336. log_alpha += _safelog(match_pred[:, :, t])
  337. ret_t = -logaddexp(
  338. log_alpha[batch_idx, label_lengths * 2],
  339. log_alpha[batch_idx, label_lengths * 2 - 1],
  340. )
  341. ret += ret_t * equal(pred_lengths - 1, t)
  342. if reduction == "mean":
  343. return (ret / label_lengths).mean()
  344. elif reduction == "sum":
  345. return ret.sum()
  346. elif reduction == "none":
  347. return ret
  348. else:
  349. raise ValueError("{} is not a valid value for reduction".format(reduction))