You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_loss.py 8.7 kB

6 years ago
6 years ago
6 years ago

  1. import math
  2. import unittest
  3. import torch
  4. import torch as tc
  5. import torch.nn.functional as F
  6. import fastNLP.core.losses as loss
  7. from fastNLP.core.losses import LossFunc
  8. class TestLoss(unittest.TestCase):
  9. def test_case_1(self):
  10. loss_func = loss.LossFunc(F.nll_loss)
  11. nll_loss = loss.NLLLoss()
  12. y = tc.Tensor(
  13. [
  14. [.3, .4, .3],
  15. [.5, .3, .2],
  16. [.3, .6, .1],
  17. ]
  18. )
  19. gy = tc.LongTensor(
  20. [
  21. 0,
  22. 1,
  23. 2,
  24. ]
  25. )
  26. y = tc.log(y)
  27. los = loss_func({'input': y}, {'target': gy})
  28. losses = nll_loss({'input': y}, {'target': gy})
  29. r = -math.log(.3) - math.log(.3) - math.log(.1)
  30. r /= 3
  31. print("loss = %f" % (los))
  32. print("r = %f" % (r))
  33. print("nll_loss = %f" % (losses))
  34. self.assertEqual(int(los * 1000), int(r * 1000))
  35. def test_case_2(self):
  36. # 验证squash()的正确性
  37. log = math.log
  38. loss_func = loss.Loss("nll")
  39. y = tc.Tensor(
  40. [
  41. [[.3, .4, .3], [.3, .4, .3], ],
  42. [[.5, .3, .2], [.1, .2, .7], ],
  43. [[.3, .6, .1], [.2, .1, .7], ],
  44. ]
  45. )
  46. gy = tc.LongTensor(
  47. [
  48. [0, 2],
  49. [1, 2],
  50. [2, 1],
  51. ]
  52. )
  53. y = tc.log(y)
  54. # los = loss_func({'input': y}, {'target': gy})
  55. los = loss_func(y, gy)
  56. r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
  57. r /= 6
  58. self.assertEqual(int(los * 1000), int(r * 1000))
  59. def test_case_3(self):
  60. # 验证pack_padded_sequence()的正确性
  61. log = math.log
  62. loss_func = loss.NLLLoss()
  63. y = tc.Tensor(
  64. [
  65. [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], ],
  66. [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], ],
  67. [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], ],
  68. ]
  69. )
  70. gy = tc.LongTensor(
  71. [
  72. [0, 2, 1, ],
  73. [1, 2, 0, ],
  74. [2, 0, 0, ],
  75. ]
  76. )
  77. lens = [3, 2, 1]
  78. # pdb.set_trace()
  79. y = tc.log(y)
  80. yy = tc.nn.utils.rnn.pack_padded_sequence(y, lens, batch_first=True).data
  81. gyy = tc.nn.utils.rnn.pack_padded_sequence(gy, lens, batch_first=True).data
  82. los = loss_func({'input': yy}, {'target': gyy})
  83. r = -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
  84. r /= 6
  85. self.assertEqual(int(los * 1000), int(r * 1000))
  86. def test_case_4(self):
  87. # 验证unpad()的正确性
  88. log = math.log
  89. y = tc.Tensor(
  90. [
  91. [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ],
  92. [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ],
  93. [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ],
  94. ]
  95. )
  96. gy = tc.LongTensor(
  97. [
  98. [0, 2, 1, 2, ],
  99. [1, 2, 0, 0, ],
  100. [2, 0, 0, 0, ],
  101. ]
  102. )
  103. lens = [4, 2, 1]
  104. y = tc.log(y)
  105. loss_func = loss.Loss("nll", pre_pro=["unpad"])
  106. los = loss_func(y, gy, lens=lens)
  107. r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
  108. r /= 7
  109. self.assertEqual(int(los * 1000), int(r * 1000))
  110. def test_case_5(self):
  111. # 验证mask()和make_mask()的正确性
  112. log = math.log
  113. y = tc.Tensor(
  114. [
  115. [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ],
  116. [[.5, .4, .1], [.3, .2, .5], [.4, .5, .1, ], [.6, .1, .3, ], ],
  117. [[.3, .6, .1], [.3, .2, .5], [.0, .0, .0, ], [.0, .0, .0, ], ],
  118. ]
  119. )
  120. gy = tc.LongTensor(
  121. [
  122. [1, 2, 0, 0, ],
  123. [0, 2, 1, 2, ],
  124. [2, 1, 0, 0, ],
  125. ]
  126. )
  127. mask = tc.ByteTensor(
  128. [
  129. [1, 1, 0, 0, ],
  130. [1, 1, 1, 1, ],
  131. [1, 1, 0, 0, ],
  132. ]
  133. )
  134. y = tc.log(y)
  135. lens = [2, 4, 2]
  136. loss_func = loss.Loss("nll", pre_pro=["mask"])
  137. los = loss_func(y, gy, mask=mask)
  138. los2 = loss_func(y, gy, mask=loss.make_mask(lens, gy.size()[-1]))
  139. r = -log(.3) - log(.7) - log(.5) - log(.5) - log(.5) - log(.3) - log(.1) - log(.2)
  140. r /= 8
  141. self.assertEqual(int(los * 1000), int(r * 1000))
  142. self.assertEqual(int(los2 * 1000), int(r * 1000))
  143. def test_case_6(self):
  144. # 验证unpad_mask()的正确性
  145. log = math.log
  146. y = tc.Tensor(
  147. [
  148. [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ],
  149. [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ],
  150. [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ],
  151. ]
  152. )
  153. gy = tc.LongTensor(
  154. [
  155. [0, 2, 1, 2, ],
  156. [1, 2, 0, 0, ],
  157. [2, 0, 0, 0, ],
  158. ]
  159. )
  160. lens = [4, 2, 1]
  161. # pdb.set_trace()
  162. y = tc.log(y)
  163. loss_func = loss.Loss("nll", pre_pro=["unpad_mask"])
  164. los = loss_func(y, gy, lens=lens)
  165. r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
  166. r /= 7
  167. self.assertEqual(int(los * 1000), int(r * 1000))
  168. def test_case_7(self):
  169. # 验证一些其他东西
  170. log = math.log
  171. y = tc.Tensor(
  172. [
  173. [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ],
  174. [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ],
  175. [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ],
  176. ]
  177. )
  178. gy = tc.LongTensor(
  179. [
  180. [0, 2, 1, 2, ],
  181. [1, 2, 0, 0, ],
  182. [2, 0, 0, 0, ],
  183. ]
  184. )
  185. lens = [4, 2, 1]
  186. y = tc.log(y)
  187. loss_func = loss.Loss("nll", pre_pro=[], weight=tc.Tensor([1, 1, 0]))
  188. loss_func.add_pre_pro("unpad_mask")
  189. los = loss_func(y, gy, lens=lens)
  190. r = - log(.3) - log(.5) - log(.3)
  191. r /= 3
  192. self.assertEqual(int(los * 1000), int(r * 1000))
  193. def test_case_8(self):
  194. def func(a, b):
  195. return F.cross_entropy(a, b)
  196. def func2(a, truth):
  197. return func(a, truth)
  198. def func3(predict, truth):
  199. return func(predict, truth)
  200. def func4(a, b, c=2):
  201. return (a + b) * c
  202. def func6(a, b, **kwargs):
  203. c = kwargs['c']
  204. return (a + b) * c
  205. get_loss = LossFunc(func, {'a': 'predict', 'b': 'truth'})
  206. predict = torch.randn(5, 3)
  207. truth = torch.LongTensor([1, 0, 1, 2, 1])
  208. loss1 = get_loss({'predict': predict}, {'truth': truth})
  209. get_loss_2 = LossFunc(func2, {'a': 'predict'})
  210. loss2 = get_loss_2({'predict': predict}, {'truth': truth})
  211. get_loss_3 = LossFunc(func3)
  212. loss3 = get_loss_3({'predict': predict}, {'truth': truth})
  213. assert loss1 == loss2 and loss1 == loss3
  214. """
  215. get_loss_4 = LossFunc(func4)
  216. loss4 = get_loss_4({'a': 1, 'b': 3}, {})
  217. print(loss4)
  218. assert loss4 == (1 + 3) * 2
  219. get_loss_5 = LossFunc(func4)
  220. loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4})
  221. print(loss5)
  222. assert loss5 == (1 + 3) * 4
  223. get_loss_6 = LossFunc(func6)
  224. loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4})
  225. print(loss6)
  226. assert loss6 == (1 + 3) * 4
  227. get_loss_7 = LossFunc(func6, c='cc')
  228. loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4})
  229. print(loss7)
  230. assert loss7 == (1 + 3) * 4
  231. """
  232. class TestLoss_v2(unittest.TestCase):
  233. def test_CrossEntropyLoss(self):
  234. ce = loss.CrossEntropyLoss(input="my_predict", target="my_truth")
  235. a = torch.randn(3, 5, requires_grad=False)
  236. b = torch.empty(3, dtype=torch.long).random_(5)
  237. ans = ce({"my_predict": a}, {"my_truth": b})
  238. self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b))
  239. def test_BCELoss(self):
  240. bce = loss.BCELoss(input="my_predict", target="my_truth")
  241. a = torch.sigmoid(torch.randn((3, 5), requires_grad=False))
  242. b = torch.randn((3, 5), requires_grad=False)
  243. ans = bce({"my_predict": a}, {"my_truth": b})
  244. self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b))