import pytest from .model_runner import * from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: from fastNLP.models.torch.cnn_text_classification import CNNText @pytest.mark.torch class TestCNNText: def init_model(self, kernel_sizes, kernel_nums=(1,3,5)): model = CNNText((VOCAB_SIZE, 30), NUM_CLS, kernel_nums=kernel_nums, kernel_sizes=kernel_sizes) return model def test_case1(self): # 测试能否正常运行CNN model = self.init_model((1,3,5)) RUNNER.run_model_with_task(TEXT_CLS, model) def test_init_model(self): with pytest.raises(Exception): self.init_model(2, 4) with pytest.raises(Exception): self.init_model((2,)) def test_output(self): model = self.init_model((3,), (1,)) global MAX_LEN MAX_LEN = 2 RUNNER.run_model_with_task(TEXT_CLS, model)