|
|
@@ -31,7 +31,7 @@ def _test(local_rank: int, world_size: int, device: "torch.device", |
|
|
|
|
|
|
|
my_result = metric.get_metric() |
|
|
|
for keys in ['f', 'pre', 'rec']: |
|
|
|
np.allclose(my_result[keys], metric_result[keys], atol=0.000001) |
|
|
|
assert np.allclose(my_result[keys], metric_result[keys], atol=0.000001) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
@@ -69,7 +69,6 @@ class TestClassfiyFPreRecMetric: |
|
|
|
[-0.8088, -0.6648, -0.5018, -0.0230, -0.8207], |
|
|
|
[-0.7753, -0.3508, 1.6163, 0.7158, 1.5207], |
|
|
|
[0.8692, 0.7718, -0.6734, 0.6515, 0.0641]]) |
|
|
|
arg_max_pred = torch.argmax(pred, dim=-1) |
|
|
|
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, |
|
|
|
0, 3, 0, 0, 0, 1, 3, 1]) |
|
|
|
|
|
|
@@ -79,10 +78,9 @@ class TestClassfiyFPreRecMetric: |
|
|
|
f1_score = 0.1882051282051282 |
|
|
|
recall = 0.1619047619047619 |
|
|
|
pre = 0.23928571428571427 |
|
|
|
|
|
|
|
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} |
|
|
|
for keys in ['f', 'pre', 'rec']: |
|
|
|
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) |
|
|
|
assert np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) |
|
|
|
|
|
|
|
metric = ClassifyFPreRecMetric(f_type='micro') |
|
|
|
metric.update(pred, target) |
|
|
@@ -93,7 +91,7 @@ class TestClassfiyFPreRecMetric: |
|
|
|
|
|
|
|
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} |
|
|
|
for keys in ['f', 'pre', 'rec']: |
|
|
|
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) |
|
|
|
assert np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) |
|
|
|
|
|
|
|
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro') |
|
|
|
metric.update(pred, target) |
|
|
@@ -103,19 +101,35 @@ class TestClassfiyFPreRecMetric: |
|
|
|
'1': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 5}, |
|
|
|
'2': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 2}, |
|
|
|
'3': {'f1-score': 0.30769230769230765, 'precision': 0.5, 'recall': 0.2222222222222222, 'support': 9}, |
|
|
|
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9}, |
|
|
|
'macro avg': {'f1-score': 0.1882051282051282, 'precision': 0.23928571428571427, |
|
|
|
'recall': 0.1619047619047619, 'support': 32}, |
|
|
|
'micro avg': {'f1-score': 0.21875, 'precision': 0.21875, 'recall': 0.21875, 'support': 32}, |
|
|
|
'weighted avg': {'f1-score': 0.2563301282051282, 'precision': 0.3286830357142857, 'recall': 0.21875, |
|
|
|
'support': 32}} |
|
|
|
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9}} |
|
|
|
for keys in result_dict.keys(): |
|
|
|
if keys == "f" or "pre" or "rec": |
|
|
|
continue |
|
|
|
gl = str(keys[-1]) |
|
|
|
tmp_d = {"p": "precision", "r": "recall", "f": "f1-score"} |
|
|
|
gk = tmp_d[keys[0]] |
|
|
|
np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001) |
|
|
|
assert np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001) |
|
|
|
|
|
|
|
def test_seq_len(self): |
|
|
|
pred = torch.tensor([[[0.3, 0.7, 0.1], [0.4, 0.1, 0.1], [0.3, 0.1, 0.7]], |
|
|
|
[[0.7, 0.1, 0.1], [0.5, 0.9, 0.1], [0.3, 0.1, 0.7]]]) |
|
|
|
seq_len = torch.LongTensor([3, 2]) |
|
|
|
target = torch.LongTensor([[1, 0, 2], [0, 1, 0]]) |
|
|
|
|
|
|
|
# 不考虑长度 |
|
|
|
metric = ClassifyFPreRecMetric(only_gross=True, f_type='macro') |
|
|
|
metric.update(pred, target) |
|
|
|
result_dict = metric.get_metric() |
|
|
|
for keys in ['f', 'pre', 'rec']: |
|
|
|
assert result_dict[keys] != 1 |
|
|
|
|
|
|
|
# 考虑长度 |
|
|
|
metric = ClassifyFPreRecMetric(only_gross=True, f_type='macro') |
|
|
|
metric.update(pred, target, seq_len=seq_len) |
|
|
|
result_dict = metric.get_metric() |
|
|
|
for keys in ['f', 'pre', 'rec']: |
|
|
|
assert result_dict[keys] == 1 |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("f_type, f1_score,recall,pre", |
|
|
|
[('macro', 0.1882051282051282, 0.1619047619047619, 0.23928571428571427), |
|
|
@@ -180,3 +194,4 @@ class TestClassfiyFPreRecMetric: |
|
|
|
[(rank, NUM_PROCESSES, torch.device(f'cuda:{rank}')) for rank in range(NUM_PROCESSES)]) |
|
|
|
pool.close() |
|
|
|
pool.join() |
|
|
|
|