From 5425095cac8cc1c370182202d3eb00a13be1878a Mon Sep 17 00:00:00 2001 From: Letian Li <73881739+LetianLee@users.noreply.github.com> Date: Wed, 25 May 2022 07:07:37 +0100 Subject: [PATCH] =?UTF-8?q?[bugfix]=20=E4=BF=AE=E5=A4=8D=20tests/core/cont?= =?UTF-8?q?rollers/=5Ftest=5Ftrainer=5Fjittor.py=EF=BC=8C=E4=BD=BF?= =?UTF-8?q?=E5=85=B6=E5=8F=AF=E4=BB=A5=E6=AD=A3=E5=B8=B8=E8=BF=90=E8=A1=8C?= =?UTF-8?q?=20(#415)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复 tests/core/controllers/_test_trainer_jittor.py,使其可以正常运行 Trainer 并不接收 validate_dataloaders 参数,改为 evaluate_dataloaders 即可。 * jittor single driver 支持 cpu 和 gpu 的切换 --- fastNLP/core/drivers/jittor_driver/single_device.py | 9 ++++++--- tests/core/controllers/_test_trainer_jittor.py | 2 +- tests/core/controllers/test_trainer_jittor.py | 4 +++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 7529aec9..19c4b4c2 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -8,7 +8,7 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: - import jittor + import jittor as jt __all__ = [ "JittorSingleDriver", @@ -105,6 +105,9 @@ class JittorSingleDriver(JittorDriver): def setup(self): """ - 使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作 + 支持 cpu 和 gpu 的切换 """ - pass + if self.model_device in ["cpu", None]: + jt.flags.use_cuda = 0 # 使用 cpu + else: + jt.flags.use_cuda = 1 # 使用 cuda diff --git a/tests/core/controllers/_test_trainer_jittor.py b/tests/core/controllers/_test_trainer_jittor.py index bc4b05f0..13ab2e8b 100644 --- a/tests/core/controllers/_test_trainer_jittor.py +++ b/tests/core/controllers/_test_trainer_jittor.py @@ -225,7 +225,7 @@ if __name__ == "__main__": device=[0,1,2,3,4], optimizers=optimizer, train_dataloader=train_dataloader, - validate_dataloaders=val_dataloader, + evaluate_dataloaders=val_dataloader, validate_every=-1, input_mapping=None, output_mapping=None, diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py index c01fd6e6..c84c24f1 100644 --- a/tests/core/controllers/test_trainer_jittor.py +++ b/tests/core/controllers/test_trainer_jittor.py @@ -69,7 +69,8 @@ class TrainJittorConfig: shuffle: bool = True -@pytest.mark.parametrize("driver,device", [("jittor", None)]) +@pytest.mark.parametrize("driver", ["jittor"]) +@pytest.mark.parametrize("device", ["cpu", 1]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) @pytest.mark.jittor def test_trainer_jittor( @@ -134,4 +135,5 @@ def test_trainer_jittor( if __name__ == "__main__": # test_trainer_jittor("jittor", None, [RichCallback(100)]) + # test_trainer_jittor("jittor", 1, [RichCallback(100)]) pytest.main(['test_trainer_jittor.py']) # 只运行此模块