You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert.py 718 B

12345678910111213141516171819202122
  1. import unittest
  2. import torch
  3. from fastNLP.models.bert import BertModel
  4. class TestBert(unittest.TestCase):
  5. def test_bert_1(self):
  6. from fastNLP.modules.encoder._bert import BertConfig
  7. config = BertConfig(32000)
  8. model = BertModel(config)
  9. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  10. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  11. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  12. all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
  13. for layer in all_encoder_layers:
  14. self.assertEqual(tuple(layer.shape), (2, 3, 768))
  15. self.assertEqual(tuple(pooled_output.shape), (2, 768))