|
|
@@ -69,7 +69,7 @@ class StarTransformer(nn.Module): |
|
|
|
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) |
|
|
|
|
|
|
|
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 |
|
|
|
if self.pos_emb and False: |
|
|
|
if self.pos_emb: |
|
|
|
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ |
|
|
|
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 |
|
|
|
embs = embs + P |
|
|
|