From ea4dcde9882b5d018bc77cb5565abdbf319d2da7 Mon Sep 17 00:00:00 2001 From: lhenry15 Date: Thu, 6 May 2021 17:03:15 -0500 Subject: [PATCH] add more metrics --- tods/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tods/utils.py b/tods/utils.py index d0f2ccb..a829a74 100644 --- a/tods/utils.py +++ b/tods/utils.py @@ -46,6 +46,12 @@ def generate_problem(dataset, metric): performance_metrics = [{'metric': PerformanceMetric.F1, 'params': {'pos_label': '1'}}] elif metric == 'F1_MACRO': performance_metrics = [{'metric': PerformanceMetric.F1_MACRO, 'params': {}}] + elif metric == 'RECALL': + performance_metrics = [{'metric': PerformanceMetric.RECALL, 'params': {'pos_label': '1'}}] + elif metric == 'PRECISION': + performance_metrics = [{'metric': PerformanceMetric.PRECISION, 'params': {'pos_label': '1'}}] + elif metric == 'ALL': + performance_metrics = [{'metric': PerformanceMetric.PRECISION, 'params': {'pos_label': '1'}}, {'metric': PerformanceMetric.RECALL, 'params': {'pos_label': '1'}}, {'metric': PerformanceMetric.F1_MACRO, 'params': {}}, {'metric': PerformanceMetric.F1, 'params': {'pos_label': '1'}}] else: raise ValueError('The metric {} not supported.'.format(metric))