Browse Source

[fix] star-transformer position embedding

tags/v0.5.5
yunfan 5 years ago
parent
commit
0b401e2c8a
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      fastNLP/modules/encoder/star_transformer.py

+ 1
- 1
fastNLP/modules/encoder/star_transformer.py View File

@@ -69,7 +69,7 @@ class StarTransformer(nn.Module):
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)

embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1
if self.pos_emb and False:
if self.pos_emb:
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1
embs = embs + P


Loading…
Cancel
Save