|
|
@@ -248,3 +248,52 @@ class Tester(object): |
|
|
|
_str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()]) |
|
|
|
_str += '\n' |
|
|
|
return _str[:-1] |
|
|
|
|
|
|
|
def flp_topredict(self): |
|
|
|
r"""开始进行预测,并返回预测结果。 |
|
|
|
|
|
|
|
:return 本次的预测结果,为一个字典,其中只有{predict}一个key,而key的值类型为tensor。 |
|
|
|
""" |
|
|
|
# turn on the testing mode; clean up the history |
|
|
|
self._model_device = _get_model_device(self._model) |
|
|
|
network = self._model |
|
|
|
self._mode(network, is_test=True) |
|
|
|
data_iterator = self.data_iterator |
|
|
|
eval_results = [] |
|
|
|
try: |
|
|
|
with torch.no_grad(): |
|
|
|
if not self.use_tqdm: |
|
|
|
from .utils import _pseudo_tqdm as inner_tqdm |
|
|
|
else: |
|
|
|
inner_tqdm = tqdm |
|
|
|
with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar: |
|
|
|
pbar.set_description_str(desc="Pred") |
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device, |
|
|
|
non_blocking=self.pin_memory) |
|
|
|
with self.auto_cast(): |
|
|
|
pred_dict = self._data_forward(self._predict_func, batch_x) |
|
|
|
|
|
|
|
eval_results.extend(pred_dict['predict'].detach().cpu().numpy()) |
|
|
|
|
|
|
|
if self.use_tqdm: |
|
|
|
pbar.update() |
|
|
|
|
|
|
|
pbar.close() |
|
|
|
end_time = time.time() |
|
|
|
test_str = f'Predict data in {round(end_time - start_time, 2)} seconds!' |
|
|
|
if self.verbose >= 0: |
|
|
|
self.logger.info(test_str) |
|
|
|
except _CheckError as e: |
|
|
|
prev_func_signature = _get_func_signature(self._predict_func) |
|
|
|
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, |
|
|
|
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, |
|
|
|
dataset=self.data, check_level=0) |
|
|
|
finally: |
|
|
|
self._mode(network, is_test=False) |
|
|
|
print(f'预测完成') |
|
|
|
|
|
|
|
return eval_results |