You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_model_io.py 782 B

12345678910111213141516171819202122232425
  1. import os
  2. import unittest
  3. from fastNLP.io import ModelSaver, ModelLoader
  4. from fastNLP.models import CNNText
  5. class TestModelIO(unittest.TestCase):
  6. def test_save_and_load(self):
  7. model = CNNText((10, 10), 2)
  8. saver = ModelSaver('tmp')
  9. loader = ModelLoader()
  10. saver.save_pytorch(model)
  11. new_cnn = CNNText((10, 10), 2)
  12. loader.load_pytorch(new_cnn, 'tmp')
  13. new_model = loader.load_pytorch_model('tmp')
  14. for i in range(10):
  15. for j in range(10):
  16. self.assertEqual(model.embed.embed.weight[i, j], new_cnn.embed.embed.weight[i, j])
  17. self.assertEqual(model.embed.embed.weight[i, j], new_model["embed.embed.weight"][i, j])
  18. os.system('rm tmp')