From 0b401e2c8a695ce8e66bdc4eaf3d3f03bbad3d38 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 11 Oct 2019 15:48:53 +0800 Subject: [PATCH] [fix] star-transformer position embedding --- fastNLP/modules/encoder/star_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/modules/encoder/star_transformer.py b/fastNLP/modules/encoder/star_transformer.py index d4cc66f7..85b1ac4d 100644 --- a/fastNLP/modules/encoder/star_transformer.py +++ b/fastNLP/modules/encoder/star_transformer.py @@ -69,7 +69,7 @@ class StarTransformer(nn.Module): smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 - if self.pos_emb and False: + if self.pos_emb: P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ .view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 embs = embs + P