You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_loss.py 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine.functional as F
  12. from megengine import tensor
  13. def test_cross_entropy_with_logits():
  14. data = tensor([[0, 50], [0, -150]]).astype(np.float32)
  15. label = tensor([1, 0]).astype(np.int32)
  16. loss = F.nn.cross_entropy(data, label)
  17. np.testing.assert_allclose(loss.numpy(), 0.0)
  18. label = tensor([0, 1]).astype(np.int32)
  19. loss = F.nn.cross_entropy(data, label)
  20. np.testing.assert_allclose(loss.numpy(), 100)
  21. label = np.array([1, 0])
  22. loss = F.nn.cross_entropy(data, label)
  23. np.testing.assert_allclose(loss.numpy(), 0.0)
  24. def test_cross_entropy():
  25. def softmax(x):
  26. x = np.exp(x)
  27. x /= x.sum(1, keepdims=True)
  28. return x
  29. def ref(x, y):
  30. return np.mean([-np.log(x[i, y[i]]) for i in range(len(y))])
  31. x = (np.random.rand(5, 10) - 0.5) * 4
  32. y = np.random.randint(10, size=(5,))
  33. for i in range(len(x)):
  34. x[i, y[i]] += np.random.rand() * 2
  35. x = softmax(x)
  36. l_ref = ref(x, y)
  37. l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False)
  38. np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6)
  39. def test_cross_entropy_reduction():
  40. logits = np.random.randn(16, 10)
  41. label = np.random.randint(10, size=[16])
  42. logits = tensor(logits, dtype="float32")
  43. label = tensor(label, dtype="int32")
  44. perm = np.random.permutation(16)
  45. logits_perm = tensor(logits[perm], dtype="float32")
  46. label_perm = tensor(label[perm], dtype="int32")
  47. loss = F.nn.cross_entropy(logits, label, reduction="none")
  48. loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none")
  49. np.testing.assert_allclose(loss.numpy()[perm], loss_perm.numpy())
  50. loss_sum = F.nn.cross_entropy(logits, label, reduction="sum")
  51. np.testing.assert_allclose(loss.numpy().sum(), loss_sum.numpy(), rtol=2e-7)
  52. loss_mean = F.nn.cross_entropy(logits, label, reduction="mean")
  53. np.testing.assert_allclose(loss_mean.numpy(), loss_sum.numpy() / 16)
  54. loss_ls = F.nn.cross_entropy(logits, label, reduction="mean", label_smooth=0.1)
  55. loss_ls_none_reduce = F.nn.cross_entropy(
  56. logits, label, reduction="none", label_smooth=0.1
  57. )
  58. np.testing.assert_allclose(
  59. loss_ls.numpy(), loss_ls_none_reduce.numpy().mean(), rtol=2e-7
  60. )
  61. with pytest.raises(ValueError):
  62. F.nn.cross_entropy(logits, label, reduction="MEAN")
  63. with pytest.raises(ValueError):
  64. F.nn.cross_entropy(logits, label, reduction="max")
  65. def ctc_nll_naive_npy(
  66. pred,
  67. pred_lengths,
  68. label,
  69. label_lengths,
  70. blank=0,
  71. reduction="mean",
  72. time_major=False,
  73. ):
  74. """naive :func:`ctc_nll` using numpy arrays. Used for testing and helping
  75. our user to understand how CTC works. Only ``LABEL_COMPACT`` mode is
  76. supported."""
  77. pred = np.asarray(pred, dtype=np.float32)
  78. pred_lengths = np.asarray(pred_lengths, dtype=np.int8)
  79. label = np.asarray(label, dtype=np.int32)
  80. label_lengths = np.asarray(label_lengths, dtype=np.int32)
  81. if time_major:
  82. pred = np.transpose(pred, (1, 0, 2))
  83. # pred in (N, T, P) format
  84. batch_size, time_len, nr_class = pred.shape
  85. assert pred_lengths.shape == (batch_size,) and pred_lengths.max() <= pred.shape[1]
  86. assert label_lengths.shape == (batch_size,)
  87. assert label.shape == (label_lengths.sum(),) and label.max() < nr_class
  88. ret = np.empty((batch_size,), dtype=np.float32)
  89. label_start = 0
  90. for i in range(batch_size):
  91. label_end = label_start + label_lengths[i]
  92. ret[i] = _ctc_npy_single_seq(
  93. pred[i][: pred_lengths[i]], label[label_start:label_end], blank
  94. )
  95. label_start = label_end
  96. if reduction == "mean":
  97. return (ret / label_lengths).mean()
  98. elif reduction == "sum":
  99. return ret.sum()
  100. elif reduction == "none":
  101. return ret
  102. else:
  103. raise ValueError("{} is not a valid value for reduction".format(reduction))
  104. def _ctc_npy_single_seq(pred, label, blank):
  105. def safelog(x):
  106. eps = np.finfo(x.dtype).tiny
  107. return np.log(np.maximum(x, eps))
  108. def log_sum_exp(x, y):
  109. x, y = np.maximum(x, y), np.minimum(x, y)
  110. return x + np.log1p(np.exp(y - x))
  111. assert np.abs(pred.sum(axis=1) - 1).max() <= 1e-3
  112. len_pred, alphabet_size = pred.shape
  113. (len_label,) = label.shape
  114. len_ex_label = len_label * 2 + 1
  115. ex_label = (np.zeros(len_ex_label)).astype(np.int32) + blank
  116. ex_label[1::2] = label
  117. prob = np.zeros(len_ex_label, dtype=np.float32)
  118. prob[0] = pred[0][ex_label[0]]
  119. prob[1] = pred[0][ex_label[1]]
  120. prob = safelog(prob) # compute on log scale
  121. ex_label_pmask = ex_label[2:] != ex_label[:-2]
  122. for t in range(1, len_pred):
  123. # enter loop: prob[i] = log(p(pred[:t+1], label[:i+1]))
  124. new_prob = prob.copy()
  125. new_prob[1:] = log_sum_exp(new_prob[1:], prob[:-1])
  126. new_prob[2:] = (
  127. new_prob[2:] * (1 - ex_label_pmask)
  128. + log_sum_exp(new_prob[2:], prob[:-2]) * ex_label_pmask
  129. )
  130. new_prob += safelog(pred[t, ex_label])
  131. prob = new_prob
  132. return -log_sum_exp(prob[-1], prob[-2])
  133. def test_ctc_loss():
  134. def test_func(T, C, N):
  135. input = np.random.randn(T, N, C)
  136. input = F.softmax(tensor(input), axis=-1).numpy()
  137. input_lengths = np.ones(N, dtype=np.int32) * T
  138. target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32)
  139. target = np.random.randint(
  140. low=1, high=C, size=(sum(target_lengths)), dtype=np.int32
  141. )
  142. input_mge = tensor(input)
  143. input_lengths_mge = tensor(input_lengths)
  144. target_mge = tensor(target)
  145. target_lengths_mge = tensor(target_lengths)
  146. blank = np.random.randint(C)
  147. for method in ["mean", "sum", "none"]:
  148. np_out = ctc_nll_naive_npy(
  149. input,
  150. input_lengths,
  151. target,
  152. target_lengths,
  153. blank=blank,
  154. reduction=method,
  155. time_major=True,
  156. )
  157. mge_out = F.nn.ctc_loss(
  158. input_mge,
  159. input_lengths_mge,
  160. target_mge,
  161. target_lengths_mge,
  162. blank=blank,
  163. reduction=method,
  164. )
  165. np.testing.assert_allclose(mge_out.numpy(), np_out, rtol=2e-6)
  166. cases = [[1, 2, 1], [100, 50, 200], [100, 5, 1]]
  167. for case in cases:
  168. test_func(*case)