|
|
@@ -122,7 +122,7 @@ class Evaluator: |
|
|
|
_evaluate_batch_loop: Loop |
|
|
|
|
|
|
|
def __init__(self, model, dataloaders, metrics: Optional[Dict] = None, |
|
|
|
driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, |
|
|
|
driver: Union[str, Driver] = 'auto', device: Optional[Union[int, List[int], str]] = None, |
|
|
|
evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, |
|
|
|
input_mapping: Optional[Union[Callable, Dict]] = None, |
|
|
|
output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, |
|
|
@@ -279,8 +279,9 @@ class Evaluator: |
|
|
|
raise e |
|
|
|
finally: |
|
|
|
self.finally_progress_bar() |
|
|
|
metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) |
|
|
|
if len(metric_results) > 0: # 如果 metric 不为 None 需要 print 。 |
|
|
|
metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) |
|
|
|
# metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) |
|
|
|
if self.verbose: |
|
|
|
if self.progress_bar == 'rich': |
|
|
|
f_rich_progress.print(metric_results) |
|
|
|