diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index cb05f82d..d7284dba 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -248,3 +248,52 @@ class Tester(object): _str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()]) _str += '\n' return _str[:-1] + + def flp_topredict(self): + r"""开始进行预测,并返回预测结果。 + + :return 本次的预测结果,为一个字典,其中只有{predict}一个key,而key的值类型为tensor。 + """ + # turn on the testing mode; clean up the history + self._model_device = _get_model_device(self._model) + network = self._model + self._mode(network, is_test=True) + data_iterator = self.data_iterator + eval_results = [] + try: + with torch.no_grad(): + if not self.use_tqdm: + from .utils import _pseudo_tqdm as inner_tqdm + else: + inner_tqdm = tqdm + with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar: + pbar.set_description_str(desc="Pred") + + start_time = time.time() + + for batch_x, batch_y in data_iterator: + _move_dict_value_to_device(batch_x, batch_y, device=self._model_device, + non_blocking=self.pin_memory) + with self.auto_cast(): + pred_dict = self._data_forward(self._predict_func, batch_x) + + eval_results.extend(pred_dict['predict'].detach().cpu().numpy()) + + if self.use_tqdm: + pbar.update() + + pbar.close() + end_time = time.time() + test_str = f'Predict data in {round(end_time - start_time, 2)} seconds!' + if self.verbose >= 0: + self.logger.info(test_str) + except _CheckError as e: + prev_func_signature = _get_func_signature(self._predict_func) + _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, + check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, + dataset=self.data, check_level=0) + finally: + self._mode(network, is_test=False) + print(f'预测完成') + + return eval_results \ No newline at end of file