import unittest import numpy as np import torch from fastNLP.core.metrics import AccuracyMetric from fastNLP.core.metrics import pred_topk, accuracy_topk class TestAccuracyMetric(unittest.TestCase): def test_AccuracyMetric1(self): # (1) only input, targets passed pred_dict = {"pred": torch.zeros(4, 3)} target_dict = {'target': torch.zeros(4)} metric = AccuracyMetric() metric(pred_dict=pred_dict, target_dict=target_dict, ) print(metric.get_metric()) def test_AccuracyMetric2(self): # (2) with corrupted size try: pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4)} metric = AccuracyMetric() metric(pred_dict=pred_dict, target_dict=target_dict, ) print(metric.get_metric()) except Exception as e: print(e) return self.assertTrue(True, False), "No exception catches." def test_AccuracyMetric3(self): # (3) the second batch is corrupted size try: metric = AccuracyMetric() pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4)} metric(pred_dict=pred_dict, target_dict=target_dict) print(metric.get_metric()) except Exception as e: print(e) return self.assertTrue(True, False), "No exception catches." def test_AccuaryMetric4(self): # (5) check reset metric = AccuracyMetric() pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) self.assertDictEqual(metric.get_metric(), {'acc': 1}) pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4, 3) + 1} metric(pred_dict=pred_dict, target_dict=target_dict) self.assertDictEqual(metric.get_metric(), {'acc': 0}) def test_AccuaryMetric5(self): # (5) check reset metric = AccuracyMetric() pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4, 3) + 1} metric(pred_dict=pred_dict, target_dict=target_dict) self.assertDictEqual(metric.get_metric(), {'acc': 0.5}) def test_AccuaryMetric6(self): # (6) check numpy array is not acceptable try: metric = AccuracyMetric() pred_dict = {"pred": np.zeros((4, 3, 2))} target_dict = {'target': np.zeros((4, 3))} metric(pred_dict=pred_dict, target_dict=target_dict) except Exception as e: print(e) return self.assertTrue(True, False), "No exception catches." def test_AccuaryMetric7(self): # (7) check map, match metric = AccuracyMetric(pred='predictions', target='targets') pred_dict = {"predictions": torch.zeros(4, 3, 2)} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) self.assertDictEqual(metric.get_metric(), {'acc': 1}) def test_AccuaryMetric8(self): # (8) check map, does not match. use stop_fast_param to stop fast param map try: metric = AccuracyMetric(pred='predictions', target='targets') pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict, ) self.assertDictEqual(metric.get_metric(), {'acc': 1}) except Exception as e: print(e) return self.assertTrue(True, False), "No exception catches." def test_AccuaryMetric9(self): # (9) check map, include unused try: metric = AccuracyMetric(pred='prediction', target='targets') pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) self.assertDictEqual(metric.get_metric(), {'acc': 1}) except Exception as e: print(e) return self.assertTrue(True, False), "No exception catches." def test_AccuaryMetric10(self): # (10) check _fast_metric try: metric = AccuracyMetric() pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) self.assertDictEqual(metric.get_metric(), {'acc': 1}) except Exception as e: print(e) return self.assertTrue(True, False), "No exception catches." class TestUsefulFunctions(unittest.TestCase): # 测试metrics.py中一些看上去挺有用的函数 def test_case_1(self): # multi-class _ = accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) _ = pred_topk(np.random.randint(0, 3, size=(10, 1))) # 跑通即可