You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_metrics.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import unittest
  2. import numpy as np
  3. import torch
  4. from fastNLP.core.metrics import AccuracyMetric
  5. from fastNLP.core.metrics import pred_topk, accuracy_topk
  6. class TestAccuracyMetric(unittest.TestCase):
  7. def test_AccuracyMetric1(self):
  8. # (1) only input, targets passed
  9. pred_dict = {"pred": torch.zeros(4, 3)}
  10. target_dict = {'target': torch.zeros(4)}
  11. metric = AccuracyMetric()
  12. metric(pred_dict=pred_dict, target_dict=target_dict, )
  13. print(metric.get_metric())
  14. def test_AccuracyMetric2(self):
  15. # (2) with corrupted size
  16. try:
  17. pred_dict = {"pred": torch.zeros(4, 3, 2)}
  18. target_dict = {'target': torch.zeros(4)}
  19. metric = AccuracyMetric()
  20. metric(pred_dict=pred_dict, target_dict=target_dict, )
  21. print(metric.get_metric())
  22. except Exception as e:
  23. print(e)
  24. return
  25. self.assertTrue(True, False), "No exception catches."
  26. def test_AccuracyMetric3(self):
  27. # (3) the second batch is corrupted size
  28. try:
  29. metric = AccuracyMetric()
  30. pred_dict = {"pred": torch.zeros(4, 3, 2)}
  31. target_dict = {'target': torch.zeros(4, 3)}
  32. metric(pred_dict=pred_dict, target_dict=target_dict)
  33. pred_dict = {"pred": torch.zeros(4, 3, 2)}
  34. target_dict = {'target': torch.zeros(4)}
  35. metric(pred_dict=pred_dict, target_dict=target_dict)
  36. print(metric.get_metric())
  37. except Exception as e:
  38. print(e)
  39. return
  40. self.assertTrue(True, False), "No exception catches."
  41. def test_AccuaryMetric4(self):
  42. # (5) check reset
  43. metric = AccuracyMetric()
  44. pred_dict = {"pred": torch.zeros(4, 3, 2)}
  45. target_dict = {'target': torch.zeros(4, 3)}
  46. metric(pred_dict=pred_dict, target_dict=target_dict)
  47. self.assertDictEqual(metric.get_metric(), {'acc': 1})
  48. pred_dict = {"pred": torch.zeros(4, 3, 2)}
  49. target_dict = {'target': torch.zeros(4, 3) + 1}
  50. metric(pred_dict=pred_dict, target_dict=target_dict)
  51. self.assertDictEqual(metric.get_metric(), {'acc': 0})
  52. def test_AccuaryMetric5(self):
  53. # (5) check reset
  54. metric = AccuracyMetric()
  55. pred_dict = {"pred": torch.zeros(4, 3, 2)}
  56. target_dict = {'target': torch.zeros(4, 3)}
  57. metric(pred_dict=pred_dict, target_dict=target_dict)
  58. self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1})
  59. pred_dict = {"pred": torch.zeros(4, 3, 2)}
  60. target_dict = {'target': torch.zeros(4, 3) + 1}
  61. metric(pred_dict=pred_dict, target_dict=target_dict)
  62. self.assertDictEqual(metric.get_metric(), {'acc': 0.5})
  63. def test_AccuaryMetric6(self):
  64. # (6) check numpy array is not acceptable
  65. try:
  66. metric = AccuracyMetric()
  67. pred_dict = {"pred": np.zeros((4, 3, 2))}
  68. target_dict = {'target': np.zeros((4, 3))}
  69. metric(pred_dict=pred_dict, target_dict=target_dict)
  70. except Exception as e:
  71. print(e)
  72. return
  73. self.assertTrue(True, False), "No exception catches."
  74. def test_AccuaryMetric7(self):
  75. # (7) check map, match
  76. metric = AccuracyMetric(pred='predictions', target='targets')
  77. pred_dict = {"predictions": torch.zeros(4, 3, 2)}
  78. target_dict = {'targets': torch.zeros(4, 3)}
  79. metric(pred_dict=pred_dict, target_dict=target_dict)
  80. self.assertDictEqual(metric.get_metric(), {'acc': 1})
  81. def test_AccuaryMetric8(self):
  82. # (8) check map, does not match. use stop_fast_param to stop fast param map
  83. try:
  84. metric = AccuracyMetric(pred='predictions', target='targets')
  85. pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1}
  86. target_dict = {'targets': torch.zeros(4, 3)}
  87. metric(pred_dict=pred_dict, target_dict=target_dict, )
  88. self.assertDictEqual(metric.get_metric(), {'acc': 1})
  89. except Exception as e:
  90. print(e)
  91. return
  92. self.assertTrue(True, False), "No exception catches."
  93. def test_AccuaryMetric9(self):
  94. # (9) check map, include unused
  95. try:
  96. metric = AccuracyMetric(pred='prediction', target='targets')
  97. pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1}
  98. target_dict = {'targets': torch.zeros(4, 3)}
  99. metric(pred_dict=pred_dict, target_dict=target_dict)
  100. self.assertDictEqual(metric.get_metric(), {'acc': 1})
  101. except Exception as e:
  102. print(e)
  103. return
  104. self.assertTrue(True, False), "No exception catches."
  105. def test_AccuaryMetric10(self):
  106. # (10) check _fast_metric
  107. try:
  108. metric = AccuracyMetric()
  109. pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)}
  110. target_dict = {'targets': torch.zeros(4, 3)}
  111. metric(pred_dict=pred_dict, target_dict=target_dict)
  112. self.assertDictEqual(metric.get_metric(), {'acc': 1})
  113. except Exception as e:
  114. print(e)
  115. return
  116. self.assertTrue(True, False), "No exception catches."
  117. class TestUsefulFunctions(unittest.TestCase):
  118. # 测试metrics.py中一些看上去挺有用的函数
  119. def test_case_1(self):
  120. # multi-class
  121. _ = accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3)
  122. _ = pred_topk(np.random.randint(0, 3, size=(10, 1)))
  123. # 跑通即可