|
|
@@ -70,6 +70,7 @@ class BertForSequenceClassification(BaseModel): |
|
|
|
|
|
|
|
def forward(self, words): |
|
|
|
r""" |
|
|
|
输入为 [[w1, w2, w3, ...], [...]], BERTEmbedding会在开头和结尾额外加入[CLS]与[SEP] |
|
|
|
:param torch.LongTensor words: [batch_size, seq_len] |
|
|
|
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] |
|
|
|
""" |
|
|
@@ -115,6 +116,8 @@ class BertForSentenceMatching(BaseModel): |
|
|
|
|
|
|
|
def forward(self, words): |
|
|
|
r""" |
|
|
|
输入words的格式为 [sent1] + [SEP] + [sent2](BertEmbedding会在开头加入[CLS]和在结尾加入[SEP]),输出为batch_size x num_labels |
|
|
|
|
|
|
|
:param torch.LongTensor words: [batch_size, seq_len] |
|
|
|
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] |
|
|
|
""" |
|
|
@@ -247,6 +250,10 @@ class BertForQuestionAnswering(BaseModel): |
|
|
|
|
|
|
|
def forward(self, words): |
|
|
|
r""" |
|
|
|
输入words为question + [SEP] + [paragraph],BERTEmbedding在之后会额外加入开头的[CLS]和结尾的[SEP]. note: |
|
|
|
如果BERTEmbedding中include_cls_sep=True,则输出的start和end index相对输入words会增加一位;如果为BERTEmbedding中 |
|
|
|
include_cls_sep=False, 则输出start和end index的位置与输入words的顺序完全一致 |
|
|
|
|
|
|
|
:param torch.LongTensor words: [batch_size, seq_len] |
|
|
|
:return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len + 2] |
|
|
|
""" |
|
|
|