Browse Source

update fastNLP/core/tester.py.

开发者好,我在在0.7.0使用过程中发现了一些自己写的网络,在框架中是没有找到预测函数的,只单独返回预测结果。因此在tester分支中加入了一个只返回预测结果的函数,flp_topredict,并考虑到了不同divice转换。
pull/13/head
jackeyGG Gitee 2 years ago
parent
commit
f048946f2d
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 49 additions and 0 deletions
  1. +49
    -0
      fastNLP/core/tester.py

+ 49
- 0
fastNLP/core/tester.py View File

@@ -248,3 +248,52 @@ class Tester(object):
_str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()])
_str += '\n'
return _str[:-1]

def flp_topredict(self):
r"""开始进行预测,并返回预测结果。

:return 本次的预测结果,为一个字典,其中只有{predict}一个key,而key的值类型为tensor。
"""
# turn on the testing mode; clean up the history
self._model_device = _get_model_device(self._model)
network = self._model
self._mode(network, is_test=True)
data_iterator = self.data_iterator
eval_results = []
try:
with torch.no_grad():
if not self.use_tqdm:
from .utils import _pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm
with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar:
pbar.set_description_str(desc="Pred")

start_time = time.time()

for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device,
non_blocking=self.pin_memory)
with self.auto_cast():
pred_dict = self._data_forward(self._predict_func, batch_x)

eval_results.extend(pred_dict['predict'].detach().cpu().numpy())

if self.use_tqdm:
pbar.update()

pbar.close()
end_time = time.time()
test_str = f'Predict data in {round(end_time - start_time, 2)} seconds!'
if self.verbose >= 0:
self.logger.info(test_str)
except _CheckError as e:
prev_func_signature = _get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=self.data, check_level=0)
finally:
self._mode(network, is_test=False)
print(f'预测完成')

return eval_results

Loading…
Cancel
Save