|
|
@@ -46,6 +46,12 @@ def generate_problem(dataset, metric): |
|
|
|
performance_metrics = [{'metric': PerformanceMetric.F1, 'params': {'pos_label': '1'}}] |
|
|
|
elif metric == 'F1_MACRO': |
|
|
|
performance_metrics = [{'metric': PerformanceMetric.F1_MACRO, 'params': {}}] |
|
|
|
elif metric == 'RECALL': |
|
|
|
performance_metrics = [{'metric': PerformanceMetric.RECALL, 'params': {'pos_label': '1'}}] |
|
|
|
elif metric == 'PRECISION': |
|
|
|
performance_metrics = [{'metric': PerformanceMetric.PRECISION, 'params': {'pos_label': '1'}}] |
|
|
|
elif metric == 'ALL': |
|
|
|
performance_metrics = [{'metric': PerformanceMetric.PRECISION, 'params': {'pos_label': '1'}}, {'metric': PerformanceMetric.RECALL, 'params': {'pos_label': '1'}}, {'metric': PerformanceMetric.F1_MACRO, 'params': {}}, {'metric': PerformanceMetric.F1, 'params': {'pos_label': '1'}}] |
|
|
|
else: |
|
|
|
raise ValueError('The metric {} not supported.'.format(metric)) |
|
|
|
|
|
|
|