You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert.py 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import unittest
  2. import torch
  3. from fastNLP.models.bert import *
  4. class TestBert(unittest.TestCase):
  5. def test_bert_1(self):
  6. from fastNLP.core.const import Const
  7. from fastNLP.modules.encoder._bert import BertConfig
  8. model = BertForSequenceClassification(2, BertConfig(32000))
  9. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  10. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  11. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  12. pred = model(input_ids, token_type_ids, input_mask)
  13. self.assertTrue(isinstance(pred, dict))
  14. self.assertTrue(Const.OUTPUT in pred)
  15. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
  16. def test_bert_2(self):
  17. from fastNLP.core.const import Const
  18. from fastNLP.modules.encoder._bert import BertConfig
  19. model = BertForMultipleChoice(2, BertConfig(32000))
  20. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  21. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  22. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  23. pred = model(input_ids, token_type_ids, input_mask)
  24. self.assertTrue(isinstance(pred, dict))
  25. self.assertTrue(Const.OUTPUT in pred)
  26. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2))
  27. def test_bert_3(self):
  28. from fastNLP.core.const import Const
  29. from fastNLP.modules.encoder._bert import BertConfig
  30. model = BertForTokenClassification(7, BertConfig(32000))
  31. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  32. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  33. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  34. pred = model(input_ids, token_type_ids, input_mask)
  35. self.assertTrue(isinstance(pred, dict))
  36. self.assertTrue(Const.OUTPUT in pred)
  37. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7))
  38. def test_bert_4(self):
  39. from fastNLP.core.const import Const
  40. from fastNLP.modules.encoder._bert import BertConfig
  41. model = BertForQuestionAnswering(BertConfig(32000))
  42. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  43. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  44. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  45. pred = model(input_ids, token_type_ids, input_mask)
  46. self.assertTrue(isinstance(pred, dict))
  47. self.assertTrue(Const.OUTPUTS(0) in pred)
  48. self.assertTrue(Const.OUTPUTS(1) in pred)
  49. self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 3))
  50. self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 3))