You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import numpy as np
  11. from ..core.tensor.array_method import _reduce
  12. from ..tensor import Tensor
  13. from .elemwise import abs, equal, log, logaddexp, maximum
  14. from .nn import indexing_one_hot, logsigmoid, logsumexp, relu
  15. from .tensor import broadcast_to, cumsum, linspace, ones, where, zeros
  16. __all__ = [
  17. "l1_loss",
  18. "square_loss",
  19. "cross_entropy",
  20. "binary_cross_entropy",
  21. "hinge_loss",
  22. "ctc_loss",
  23. ]
  24. def _reduce_output(loss_fn):
  25. r"""Wrapper to apply canonical reductions to loss outputs."""
  26. @functools.wraps(loss_fn)
  27. def reduced_loss_fn(*args, reduction="mean", **kwargs):
  28. loss = loss_fn(*args, **kwargs)
  29. if reduction == "none":
  30. return loss
  31. elif reduction in ("mean", "sum"):
  32. return _reduce(reduction)(loss)
  33. else:
  34. raise ValueError("{} is not a valid value for reduction".format(reduction))
  35. return reduced_loss_fn
  36. @_reduce_output
  37. def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
  38. r"""Calculates the mean absolute error (MAE) between
  39. each element in the pred :math:`x` and label :math:`y`.
  40. The mean absolute error can be described as:
  41. .. math::
  42. \ell(x,y) = mean\left(L \right)
  43. where
  44. .. math::
  45. L = \{l_1,\dots,l_N\}, \quad
  46. l_n = \left| x_n - y_n \right|,
  47. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  48. of :math:`N` elements each. :math:`N` is the batch size.
  49. Args:
  50. pred: predicted result from model.
  51. label: ground truth to compare.
  52. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  53. Returns:
  54. loss value.
  55. Examples:
  56. .. testcode::
  57. import numpy as np
  58. import megengine as mge
  59. import megengine.functional as F
  60. ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
  61. tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
  62. loss = F.nn.l1_loss(ipt, tgt)
  63. print(loss.numpy())
  64. Outputs:
  65. .. testoutput::
  66. 2.75
  67. """
  68. diff = pred - label
  69. return abs(diff)
  70. @_reduce_output
  71. def square_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
  72. r"""Calculates the mean squared error (squared L2 norm) between
  73. each element in the pred :math:`x` and label :math:`y`.
  74. The mean squared error can be described as:
  75. .. math::
  76. \ell(x, y) = mean\left( L \right)
  77. where
  78. .. math::
  79. L = \{l_1,\dots,l_N\}, \quad
  80. l_n = \left( x_n - y_n \right)^2,
  81. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  82. of :math:`N` elements each. :math:`N` is the batch size.
  83. Args:
  84. pred: predicted result from model.
  85. label: ground truth to compare.
  86. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  87. Returns:
  88. loss value.
  89. Shape:
  90. * pred: :math:`(N, *)` where :math:`*` means any number of additional
  91. dimensions.
  92. * label: :math:`(N, *)`. Same shape as ``pred``.
  93. Examples:
  94. .. testcode::
  95. import numpy as np
  96. import megengine as mge
  97. import megengine.functional as F
  98. ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
  99. tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
  100. loss = F.nn.square_loss(ipt, tgt)
  101. print(loss.numpy())
  102. Outputs:
  103. .. testoutput::
  104. 9.75
  105. """
  106. diff = pred - label
  107. return diff ** 2
  108. @_reduce_output
  109. def cross_entropy(
  110. pred: Tensor,
  111. label: Tensor,
  112. axis: int = 1,
  113. with_logits: bool = True,
  114. label_smooth: float = 0,
  115. reduction: str = "mean",
  116. ) -> Tensor:
  117. r"""Computes the multi-class cross entropy loss (using logits by default).
  118. By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
  119. class probabilities are given by softmax.
  120. It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`.
  121. When using label smoothing, the label distribution is as follows:
  122. .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
  123. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
  124. k is the index of label distribution. :math:`\alpha` is ``label_smooth`` and :math:`K` is the number of classes.
  125. Args:
  126. pred: input tensor representing the predicted probability.
  127. label: input tensor representing the classification label.
  128. axis: an axis along which softmax will be applied. Default: 1
  129. with_logits: whether to apply softmax first. Default: True
  130. label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0
  131. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  132. Returns:
  133. loss value.
  134. Examples:
  135. .. testcode::
  136. import numpy as np
  137. from megengine import tensor
  138. import megengine.functional as F
  139. data_shape = (1, 2)
  140. label_shape = (1, )
  141. pred = tensor(np.array([0, 0], dtype=np.float32).reshape(data_shape))
  142. label = tensor(np.ones(label_shape, dtype=np.int32))
  143. loss = F.nn.cross_entropy(pred, label)
  144. print(loss.numpy().round(decimals=4))
  145. Outputs:
  146. .. testoutput::
  147. 0.6931
  148. """
  149. n0 = pred.ndim
  150. n1 = label.ndim
  151. assert n0 == n1 + 1, (
  152. "target ndim must be one less than input ndim; input_ndim={} "
  153. "target_ndim={}".format(n0, n1)
  154. )
  155. ls = label_smooth
  156. if with_logits:
  157. logZ = logsumexp(pred, axis)
  158. primary_term = indexing_one_hot(pred, label, axis)
  159. else:
  160. logZ = 0
  161. primary_term = log(indexing_one_hot(pred, label, axis))
  162. if ls is None or type(ls) in (int, float) and ls == 0:
  163. return logZ - primary_term
  164. if not with_logits:
  165. pred = log(pred)
  166. return logZ - ls * pred.mean(axis) - (1 - ls) * primary_term
  167. @_reduce_output
  168. def binary_cross_entropy(
  169. pred: Tensor, label: Tensor, with_logits: bool = True, reduction: str = "mean",
  170. ) -> Tensor:
  171. r"""Computes the binary cross entropy loss (using logits by default).
  172. By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
  173. class probabilities are given by sigmoid.
  174. Args:
  175. pred: `(N, *)`, where `*` means any number of additional dimensions.
  176. label: `(N, *)`, same shape as the input.
  177. with_logits: bool, whether to apply sigmoid first. Default: True
  178. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  179. Returns:
  180. loss value.
  181. Examples:
  182. .. testcode::
  183. import numpy as np
  184. from megengine import tensor
  185. import megengine.functional as F
  186. pred = tensor(np.array([0, 0], dtype=np.float32).reshape(1, 2))
  187. label = tensor(np.ones((1, 2), dtype=np.float32))
  188. loss = F.nn.binary_cross_entropy(pred, label)
  189. print(loss.numpy().round(decimals=4))
  190. Outputs:
  191. .. testoutput::
  192. 0.6931
  193. """
  194. if not with_logits:
  195. return -(label * log(pred) + (1 - label) * log(1 - pred))
  196. # logsigmoid(pred) and logsigmoid(-pred) has common sub-expression
  197. # hopefully the backend would optimize this
  198. return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred))
  199. @_reduce_output
  200. def hinge_loss(
  201. pred: Tensor, label: Tensor, norm: str = "L1", reduction: str = "mean"
  202. ) -> Tensor:
  203. r"""Caculates the hinge loss which is often used in SVM.
  204. The hinge loss can be described as:
  205. .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij}))
  206. Args:
  207. pred: input tensor representing the predicted probability, shape is `(N, C)`.
  208. label: input tensor representing the binary classification label, shape is `(N, C)`.
  209. norm: specify the norm to caculate the loss, should be "L1" or "L2".
  210. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  211. Returns:
  212. loss value.
  213. Examples:
  214. .. testcode::
  215. from megengine import tensor
  216. import megengine.functional as F
  217. pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32")
  218. label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32")
  219. loss = F.nn.hinge_loss(pred, label)
  220. print(loss.numpy())
  221. Outputs:
  222. .. testoutput::
  223. 1.5
  224. """
  225. norm = norm.upper()
  226. assert norm in ["L1", "L2"], "norm must be L1 or L2"
  227. # Converts binary labels to -1/1 labels.
  228. loss = relu(1.0 - pred * label)
  229. if norm == "L1":
  230. return loss.sum(axis=1)
  231. else:
  232. return (loss ** 2).sum(axis=1)
  233. def _gen_repeat_idx(inp: Tensor):
  234. idx = cumsum(inp, axis=0)
  235. ret = zeros(inp.sum(), dtype="int32")
  236. ret[idx[:-1]] = 1
  237. return cumsum(ret, axis=0)
  238. def _gen_tile_idx(inp: Tensor):
  239. idx = cumsum(inp, axis=0)
  240. ret = ones(inp.sum(), dtype="int32")
  241. ret[idx[:-1]] = -(inp - 1)[:-1]
  242. return cumsum(ret, axis=0) - 1
  243. def _expand_label(label: Tensor, label_lengths: Tensor, blank: int) -> Tensor:
  244. N = label_lengths.shape[0]
  245. if len(label.shape) == 1:
  246. L = label_lengths.max()
  247. unpack_label = zeros((N, L), dtype="int32") + blank
  248. idx_0 = _gen_repeat_idx(label_lengths)
  249. idx_1 = _gen_tile_idx(label_lengths)
  250. unpack_label[idx_0, idx_1] = label
  251. label = unpack_label
  252. L = label.shape[1]
  253. ex_label = zeros((N, L * 2 + 1), dtype="int32") + blank
  254. ex_label[:, 1::2] = label
  255. return ex_label
  256. def _safelog(x: Tensor) -> Tensor:
  257. eps = np.finfo(x.dtype).tiny
  258. return log(maximum(x, eps))
  259. def ctc_loss(
  260. pred: Tensor,
  261. pred_lengths: Tensor,
  262. label: Tensor,
  263. label_lengths: Tensor,
  264. blank: int = 0,
  265. reduction: str = "mean",
  266. ) -> Tensor:
  267. r"""The Connectionist Temporal Classification loss.
  268. Args:
  269. pred: The probabilities of the output, shape is (T, N, C) ,
  270. where T=input length, N=batch size, and C=number of classes (including blank).
  271. pred_lengths: number of time steps for each sequence in ``pred``, shape is (N, )
  272. label: groundtruth labels, containing the indices of groundtruth
  273. symbols for each sequence at each output time step, and the blank
  274. symbol should not be included. shape is (N, S) or (sum(label_lengths)).
  275. label_lengths: number of time steps for each sequence in the groundtruth, shape is (N, )
  276. blank: the blank symbol number, default 0
  277. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
  278. Returns:
  279. loss value.
  280. Examples:
  281. .. testcode::
  282. from megengine import tensor
  283. import megengine.functional as F
  284. pred = tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]])
  285. pred_length = tensor([2,2])
  286. label = tensor([1,1])
  287. label_lengths = tensor([1,1])
  288. loss = F.nn.ctc_loss(pred, pred_length, label, label_lengths)
  289. print(loss.numpy())
  290. Outputs:
  291. .. testoutput::
  292. 0.1504417
  293. """
  294. T, N, C = pred.shape
  295. assert (
  296. pred_lengths.size == N
  297. ), "pred_lengths must be equal to batch_size {}, but got {}".format(
  298. N, pred_lengths.size
  299. )
  300. assert (
  301. label_lengths.size == N
  302. ), "label_lengths must be euqal to batch_size {}, but got {}".format(
  303. N, label_lengths.size
  304. )
  305. assert (
  306. blank >= 0 and blank < C
  307. ), "blank must be in label range [0, {}), but got {}".format(C, blank)
  308. assert (
  309. pred_lengths.min() > 0 and pred_lengths.max() <= T
  310. ), "pred_lengths must be in range ({}, {}], bug got min {}, max {}".format(
  311. 0, T, pred_lengths.min(), pred_lengths.max()
  312. )
  313. if label.ndim == 1: # concatenated label
  314. assert label_lengths.min() > 0, "label lengths muse be positive"
  315. assert (
  316. label.size == label_lengths.sum()
  317. ), "label size must be equal to sum(label_lengths)"
  318. else:
  319. N, S = label.shape
  320. assert (
  321. label_lengths.min() > 0 and label_lengths.max() <= S
  322. ), "label_lengths must be in range ({}, {}], bug got min {}, max {}".format(
  323. 0, S, label_lengths.min(), label_lengths.max()
  324. )
  325. label = _expand_label(label, label_lengths, blank)
  326. label_mask = label[:, 2:] != label[:, :-2]
  327. L = label.shape[1]
  328. pred = pred.transpose(1, 0, 2) # (T, N, C) -> (N, T, C)
  329. batch_idx = linspace(0, N - 1, N).astype("int32").reshape(-1)
  330. batch_idx_NL = broadcast_to(batch_idx.reshape(N, 1), (N, L)).reshape(-1)
  331. match_pred = pred[batch_idx_NL, :, label.reshape(-1)].reshape(
  332. N, L, -1
  333. ) # (N, T, C) -> (N, L, T)
  334. log_alpha = zeros((N, L), dtype="float32")
  335. log_alpha[:, :2] = match_pred[:, :2, 0]
  336. log_alpha = _safelog(log_alpha)
  337. ret = -logaddexp(
  338. log_alpha[batch_idx, label_lengths * 2],
  339. log_alpha[batch_idx, label_lengths * 2 - 1],
  340. ) * equal(pred_lengths - 1, 0)
  341. for t in range(1, T):
  342. la2 = log_alpha[:, :-2]
  343. log_alpha[:, 1:] = logaddexp(log_alpha[:, 1:], log_alpha[:, :-1])
  344. log_alpha[:, 2:] = (
  345. log_alpha[:, 2:] * (1 - label_mask)
  346. + logaddexp(log_alpha[:, 2:], la2) * label_mask
  347. )
  348. log_alpha += _safelog(match_pred[:, :, t])
  349. ret_t = -logaddexp(
  350. log_alpha[batch_idx, label_lengths * 2],
  351. log_alpha[batch_idx, label_lengths * 2 - 1],
  352. )
  353. ret += ret_t * equal(pred_lengths - 1, t)
  354. if reduction == "mean":
  355. return (ret / label_lengths).mean()
  356. elif reduction == "sum":
  357. return ret.sum()
  358. elif reduction == "none":
  359. return ret
  360. else:
  361. raise ValueError("{} is not a valid value for reduction".format(reduction))