diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 7529aec9..19c4b4c2 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -8,7 +8,7 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: - import jittor + import jittor as jt __all__ = [ "JittorSingleDriver", @@ -105,6 +105,9 @@ class JittorSingleDriver(JittorDriver): def setup(self): """ - 使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作 + 支持 cpu 和 gpu 的切换 """ - pass + if self.model_device in ["cpu", None]: + jt.flags.use_cuda = 0 # 使用 cpu + else: + jt.flags.use_cuda = 1 # 使用 cuda diff --git a/tests/core/controllers/_test_trainer_jittor.py b/tests/core/controllers/_test_trainer_jittor.py index bc4b05f0..13ab2e8b 100644 --- a/tests/core/controllers/_test_trainer_jittor.py +++ b/tests/core/controllers/_test_trainer_jittor.py @@ -225,7 +225,7 @@ if __name__ == "__main__": device=[0,1,2,3,4], optimizers=optimizer, train_dataloader=train_dataloader, - validate_dataloaders=val_dataloader, + evaluate_dataloaders=val_dataloader, validate_every=-1, input_mapping=None, output_mapping=None, diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py index c01fd6e6..c84c24f1 100644 --- a/tests/core/controllers/test_trainer_jittor.py +++ b/tests/core/controllers/test_trainer_jittor.py @@ -69,7 +69,8 @@ class TrainJittorConfig: shuffle: bool = True -@pytest.mark.parametrize("driver,device", [("jittor", None)]) +@pytest.mark.parametrize("driver", ["jittor"]) +@pytest.mark.parametrize("device", ["cpu", 1]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) @pytest.mark.jittor def test_trainer_jittor( @@ -134,4 +135,5 @@ def test_trainer_jittor( if __name__ == "__main__": # test_trainer_jittor("jittor", None, [RichCallback(100)]) + # test_trainer_jittor("jittor", 1, [RichCallback(100)]) pytest.main(['test_trainer_jittor.py']) # 只运行此模块