From acdebfccbc48769ca1dacc6314817385a5e3c2fb Mon Sep 17 00:00:00 2001 From: liuxiaoxiong <17966083+Xiaoxiong-Liu@users.noreply.github.com> Date: Sat, 29 Aug 2020 16:16:57 +0800 Subject: [PATCH] =?UTF-8?q?[bugfix]=E4=BF=AE=E5=A4=8D=20coreference=20reso?= =?UTF-8?q?lution=E5=A4=8D=E7=8E=B0=E4=BB=A3=E7=A0=81=E4=B8=AD=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E5=90=8D=E5=AD=97=E4=B8=8D=E5=AF=B9=E5=BA=94=E7=9A=84?= =?UTF-8?q?bug=20(#323)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * pipeline * 修复找不到对应参数的bug * 增加requirement文件 --- .../coreference_resolution/model/metric.py | 3 ++- .../coreference_resolution/model/model_re.py | 29 +++++++++++++++++----- .../coreference_resolution/requirements.txt | 5 ++++ reproduction/coreference_resolution/train.py | 2 ++ 4 files changed, 32 insertions(+), 7 deletions(-) create mode 100644 reproduction/coreference_resolution/requirements.txt diff --git a/reproduction/coreference_resolution/model/metric.py b/reproduction/coreference_resolution/model/metric.py index 2c924660..7687e685 100644 --- a/reproduction/coreference_resolution/model/metric.py +++ b/reproduction/coreference_resolution/model/metric.py @@ -17,7 +17,8 @@ class CRMetric(MetricBase): self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] # TODO 改名为evaluate,输入也 - def evaluate(self, predicted, mention_to_predicted,clusters): + def evaluate(self, predicted, mention_to_predicted,target): + clusters = target for e in self.evaluators: e.update(predicted,mention_to_predicted, clusters) diff --git a/reproduction/coreference_resolution/model/model_re.py b/reproduction/coreference_resolution/model/model_re.py index eaa2941b..92f9bc03 100644 --- a/reproduction/coreference_resolution/model/model_re.py +++ b/reproduction/coreference_resolution/model/model_re.py @@ -17,7 +17,6 @@ torch.cuda.manual_seed(0) # gpu np.random.seed(0) # numpy random.seed(0) - class ffnn(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(ffnn, self).__init__() @@ -565,19 +564,37 @@ class Model(BaseModel): return ans - def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len): + def predict(self, words1 , words2, words3, words4, chars, seq_len): + """ + 实际输入都是tensor + :param sentences: 句子,被fastNLP转化成了numpy, + :param doc_np: 被fastNLP转化成了Tensor + :param speaker_ids_np: 被fastNLP转化成了Tensor + :param genre: 被fastNLP转化成了Tensor + :param char_index: 被fastNLP转化成了Tensor + :param seq_len: 被fastNLP转化成了Tensor + :return: + """ + + sentences = words1 + doc_np = words2 + speaker_ids_np = words3 + genre = words4 + char_index = chars + + # def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len): ans = self(sentences, doc_np, speaker_ids_np, genre, char_index, seq_len) - - predicted_antecedents = self.get_predicted_antecedents(ans["antecedents"], ans["antecedent_scores"]) - predicted_clusters, mention_to_predicted = self.get_predicted_clusters(ans["mention_start_tensor"], - ans["mention_end_tensor"], + predicted_antecedents = self.get_predicted_antecedents(ans["antecedents"], ans["antecedent_scores"].cpu()) + predicted_clusters, mention_to_predicted = self.get_predicted_clusters(ans["mention_start_tensor"].cpu(), + ans["mention_end_tensor"].cpu(), predicted_antecedents) + return {'predicted':predicted_clusters,"mention_to_predicted":mention_to_predicted} diff --git a/reproduction/coreference_resolution/requirements.txt b/reproduction/coreference_resolution/requirements.txt new file mode 100644 index 00000000..a8f04f04 --- /dev/null +++ b/reproduction/coreference_resolution/requirements.txt @@ -0,0 +1,5 @@ +prettytable==0.7.2 +allennlp==0.8.2 +scikit-learn==0.22.2 +pyhocon==0.3.50 +torch==1.1 diff --git a/reproduction/coreference_resolution/train.py b/reproduction/coreference_resolution/train.py index d5445cd5..bf3ea624 100644 --- a/reproduction/coreference_resolution/train.py +++ b/reproduction/coreference_resolution/train.py @@ -1,3 +1,5 @@ +import sys +sys.path.append('../..') import torch from torch.optim import Adam