You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_star_transformer.py 563 B

123456789101112131415161718
  1. import pytest
  2. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  3. if _NEED_IMPORT_TORCH:
  4. import torch
  5. from fastNLP.modules.torch.encoder.star_transformer import StarTransformer
  6. @pytest.mark.torch
  7. class TestStarTransformer:
  8. def test_1(self):
  9. model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100)
  10. x = torch.rand(16, 45, 100)
  11. mask = torch.ones(16, 45).byte()
  12. y, yn = model(x, mask)
  13. assert (tuple(y.size()) == (16, 45, 100))
  14. assert (tuple(yn.size()) == (16, 100))