From dd0e02b244b12a65966d832053e49434c3789232 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 22 May 2022 16:07:58 +0000 Subject: [PATCH 1/6] =?UTF-8?q?1.tokenization=5Futils=5Fbase.py=E4=B8=AD?= =?UTF-8?q?=E4=B8=80=E5=A4=84logger.warn=E6=94=B9=E4=B8=BAwarning=5Fonce?= =?UTF-8?q?=202.=20mix=5Fmodules/utils.py=20torch=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E6=97=B6=E6=B7=BB=E5=8A=A0.cpu()=203.=20dataset=E4=B8=AD?= =?UTF-8?q?=E5=B0=86=E8=BE=93=E5=87=BA=E9=87=8D=E5=AE=9A=E5=90=91=E5=88=B0?= =?UTF-8?q?DEVNULL=E8=80=8C=E4=B8=8D=E6=98=AFNull?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset/dataset.py | 3 ++- fastNLP/modules/mix_modules/utils.py | 4 ++-- fastNLP/transformers/torch/tokenization_utils_base.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 025d33e5..cd3cae59 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -156,6 +156,7 @@ import _pickle as pickle from copy import deepcopy from typing import Optional, List, Callable, Union, Dict, Any, Mapping from types import LambdaType +from subprocess import DEVNULL import sys import time @@ -231,7 +232,7 @@ def _multi_proc(ds, _apply_field, func, counter, queue): """ idx = -1 import contextlib - with contextlib.redirect_stdout(None): # 避免打印触发 rich 的锁 + with contextlib.redirect_stdout(DEVNULL): # 避免打印触发 rich 的锁 logger.set_stdout(stdout='raw') results = [] try: diff --git a/fastNLP/modules/mix_modules/utils.py b/fastNLP/modules/mix_modules/utils.py index 142644f9..e709b0ac 100644 --- a/fastNLP/modules/mix_modules/utils.py +++ b/fastNLP/modules/mix_modules/utils.py @@ -86,12 +86,12 @@ def _torch2paddle(torch_tensor: 'torch.Tensor', device: str = None, no_gradient: if not no_gradient: # 保持梯度并保持反向传播 # paddle的stop_gradient和torch的requires_grad表现是相反的 - paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=False) + paddle_tensor = paddle.to_tensor(torch_tensor.detach().cpu().numpy(), stop_gradient=False) hook = paddle_tensor.register_hook( lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy())) ) else: - paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=True) + paddle_tensor = paddle.to_tensor(torch_tensor.detach().cpu().numpy(), stop_gradient=True) paddle_tensor = paddle_to(paddle_tensor, device) diff --git a/fastNLP/transformers/torch/tokenization_utils_base.py b/fastNLP/transformers/torch/tokenization_utils_base.py index 8ed5a2e2..3a033c96 100644 --- a/fastNLP/transformers/torch/tokenization_utils_base.py +++ b/fastNLP/transformers/torch/tokenization_utils_base.py @@ -2179,7 +2179,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): if padding is True: if verbose: if max_length is not None and (truncation is False or truncation == "do_not_truncate"): - logger.warn( + logger.warning_once( "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. " "To pad to max length, use `padding='max_length'`." ) From f33309014934fb8af76ff6166f62ce1df14298a4 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 23 May 2022 13:06:35 +0000 Subject: [PATCH 2/6] =?UTF-8?q?1.=E5=88=A0=E6=8E=89=E4=B8=8D=E5=BF=85?= =?UTF-8?q?=E8=A6=81=E7=9A=84temp=E6=A0=87=E7=AD=BE=202.=E4=B8=BAtest=5Fmi?= =?UTF-8?q?xdataloader=E6=B7=BB=E5=8A=A0torch=E6=A0=87=E7=AD=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/dataloaders/{ => torch_dataloader}/test_mixdataloader.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/core/dataloaders/{ => torch_dataloader}/test_mixdataloader.py (100%) diff --git a/tests/core/dataloaders/test_mixdataloader.py b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py similarity index 100% rename from tests/core/dataloaders/test_mixdataloader.py rename to tests/core/dataloaders/torch_dataloader/test_mixdataloader.py From 756bd09d1adb6e87ba8c44973b9c49a7df4a3fbb Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 23 May 2022 13:17:45 +0000 Subject: [PATCH 3/6] =?UTF-8?q?1.=E5=88=A0=E9=99=A4=E4=B8=8D=E5=BF=85?= =?UTF-8?q?=E8=A6=81=E7=9A=84=E6=B5=8B=E8=AF=95=E6=A0=87=E7=AD=BE=202.?= =?UTF-8?q?=E4=B8=BAtest=5Fmixdataloader=E6=B7=BB=E5=8A=A0torch=E6=A0=87?= =?UTF-8?q?=E7=AD=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/dataloaders/torch_dataloader/test_fdl.py | 1 - tests/core/dataloaders/torch_dataloader/test_mixdataloader.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index d977e52f..b53790bb 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -147,7 +147,6 @@ class TestFdl: assert 'Parameter:prefetch_factor' in out[0] @recover_logger - @pytest.mark.temp def test_version_111(self): if parse_version(torch.__version__) <= parse_version('1.7'): pytest.skip("Torch version smaller than 1.7") diff --git a/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py index 35872b39..17c151ea 100644 --- a/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py +++ b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py @@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch - from torch.utils.data import default_collate, SequentialSampler, RandomSampler + from torch.utils.data import SequentialSampler, RandomSampler d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) @@ -28,6 +28,7 @@ def test_pad_val(tensor, val=0): return True +@pytest.mark.torch class TestMixDataLoader: def test_sequential_init(self): From da0b747b305b1fa496a94827fb97df156d1d3467 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 23 May 2022 14:38:49 +0000 Subject: [PATCH 4/6] =?UTF-8?q?=E4=BF=AE=E6=94=B9test=5Fmixdataloader?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E5=B7=A5=E5=85=B7=E5=87=BD=E6=95=B0=E5=90=8D?= =?UTF-8?q?=EF=BC=8C=E9=98=B2=E6=AD=A2=E8=A2=ABpytest=E6=89=A7=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../torch_dataloader/test_mixdataloader.py | 56 +++++++++++----------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py index 17c151ea..bf9e3d9e 100644 --- a/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py +++ b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py @@ -17,7 +17,7 @@ d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100}) -def test_pad_val(tensor, val=0): +def _test_pad_val(tensor, val=0): if isinstance(tensor, torch.Tensor): tensor = tensor.tolist() for item in tensor: @@ -45,7 +45,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) # collate_fn = Callable def collate_batch(batch): @@ -74,13 +74,13 @@ class TestMixDataLoader: dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True) for idx, batch in enumerate(dl2): if idx == 0: - assert test_pad_val(batch['x'], val=-1) + assert _test_pad_val(batch['x'], val=-1) assert batch['x'].shape == torch.Size([16, 4]) if idx == 1: - assert test_pad_val(batch['x'], val=-2) + assert _test_pad_val(batch['x'], val=-2) assert batch['x'].shape == torch.Size([16, 3]) if idx > 1: - assert test_pad_val(batch['x'], val=-3) + assert _test_pad_val(batch['x'], val=-3) assert batch['x'].shape == torch.Size([16, 4]) # sampler 为 str @@ -101,7 +101,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) for idx, batch in enumerate(dl4): if idx == 0: @@ -118,7 +118,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) # sampler 为 Dict samplers = {'d1': SequentialSampler(d1), @@ -137,7 +137,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) # ds_ratio 为 'truncate_to_least' dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True) @@ -154,7 +154,7 @@ class TestMixDataLoader: # d3 assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx > 2: raise ValueError(f"ds_ratio: 'truncate_to_least' error") @@ -170,7 +170,7 @@ class TestMixDataLoader: if 36 <= idx < 54: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 54: raise ValueError(f"ds_ratio: 'pad_to_most' error") @@ -187,7 +187,7 @@ class TestMixDataLoader: if 4 <= idx < 41: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 41: raise ValueError(f"ds_ratio: 'pad_to_most' error") @@ -201,7 +201,7 @@ class TestMixDataLoader: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 19: raise ValueError(f"ds_ratio: 'pad_to_most' error") @@ -209,7 +209,7 @@ class TestMixDataLoader: datasets = {'d1': d1, 'd2': d2, 'd3': d3} dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True) for idx, batch in enumerate(dl): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") @@ -224,7 +224,7 @@ class TestMixDataLoader: dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True) for idx, batch in enumerate(dl1): assert isinstance(batch['x'], list) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") @@ -237,12 +237,12 @@ class TestMixDataLoader: # sampler 为 str dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True) for idx, batch in enumerate(dl3): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True) for idx, batch in enumerate(dl4): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") # sampler 为 Dict @@ -251,7 +251,7 @@ class TestMixDataLoader: 'd3': RandomSampler(d3)} dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True) for idx, batch in enumerate(dl5): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") # ds_ratio 为 'truncate_to_least' @@ -333,7 +333,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # collate_fn = Callable def collate_batch(batch): @@ -361,16 +361,16 @@ class TestMixDataLoader: dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18) for idx, batch in enumerate(dl1): if idx == 0 or idx == 3: - assert test_pad_val(batch['x'], val=-1) + assert _test_pad_val(batch['x'], val=-1) assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]] assert batch['x'].shape[1] == 4 elif idx == 1 or idx == 4: # d2 - assert test_pad_val(batch['x'], val=-2) + assert _test_pad_val(batch['x'], val=-2) assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]] assert batch['x'].shape[1] == 3 elif idx == 2 or 4 < idx <= 20: - assert test_pad_val(batch['x'], val=-3) + assert _test_pad_val(batch['x'], val=-3) assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]] assert batch['x'].shape[1] == 4 if idx > 20: @@ -392,7 +392,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) for idx, batch in enumerate(dl3): if idx == 0 or idx == 3: assert batch['x'].shape[1] == 4 @@ -403,7 +403,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # sampler 为 Dict samplers = {'d1': SequentialSampler(d1), 'd2': SequentialSampler(d2), @@ -421,7 +421,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # ds_ratio 为 'truncate_to_least' dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18) @@ -438,7 +438,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 5: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # ds_ratio 为 'pad_to_most' dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18) @@ -457,7 +457,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx >= 51: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # ds_ratio 为 Dict[str, float] ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} @@ -475,7 +475,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 39: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) @@ -493,4 +493,4 @@ class TestMixDataLoader: if idx > 18: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) \ No newline at end of file + _test_pad_val(batch['x'], val=0) \ No newline at end of file From 5425095cac8cc1c370182202d3eb00a13be1878a Mon Sep 17 00:00:00 2001 From: Letian Li <73881739+LetianLee@users.noreply.github.com> Date: Wed, 25 May 2022 07:07:37 +0100 Subject: [PATCH 5/6] =?UTF-8?q?[bugfix]=20=E4=BF=AE=E5=A4=8D=20tests/core/?= =?UTF-8?q?controllers/=5Ftest=5Ftrainer=5Fjittor.py=EF=BC=8C=E4=BD=BF?= =?UTF-8?q?=E5=85=B6=E5=8F=AF=E4=BB=A5=E6=AD=A3=E5=B8=B8=E8=BF=90=E8=A1=8C?= =?UTF-8?q?=20(#415)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复 tests/core/controllers/_test_trainer_jittor.py,使其可以正常运行 Trainer 并不接收 validate_dataloaders 参数,改为 evaluate_dataloaders 即可。 * jittor single driver 支持 cpu 和 gpu 的切换 --- fastNLP/core/drivers/jittor_driver/single_device.py | 9 ++++++--- tests/core/controllers/_test_trainer_jittor.py | 2 +- tests/core/controllers/test_trainer_jittor.py | 4 +++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 7529aec9..19c4b4c2 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -8,7 +8,7 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: - import jittor + import jittor as jt __all__ = [ "JittorSingleDriver", @@ -105,6 +105,9 @@ class JittorSingleDriver(JittorDriver): def setup(self): """ - 使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作 + 支持 cpu 和 gpu 的切换 """ - pass + if self.model_device in ["cpu", None]: + jt.flags.use_cuda = 0 # 使用 cpu + else: + jt.flags.use_cuda = 1 # 使用 cuda diff --git a/tests/core/controllers/_test_trainer_jittor.py b/tests/core/controllers/_test_trainer_jittor.py index bc4b05f0..13ab2e8b 100644 --- a/tests/core/controllers/_test_trainer_jittor.py +++ b/tests/core/controllers/_test_trainer_jittor.py @@ -225,7 +225,7 @@ if __name__ == "__main__": device=[0,1,2,3,4], optimizers=optimizer, train_dataloader=train_dataloader, - validate_dataloaders=val_dataloader, + evaluate_dataloaders=val_dataloader, validate_every=-1, input_mapping=None, output_mapping=None, diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py index c01fd6e6..c84c24f1 100644 --- a/tests/core/controllers/test_trainer_jittor.py +++ b/tests/core/controllers/test_trainer_jittor.py @@ -69,7 +69,8 @@ class TrainJittorConfig: shuffle: bool = True -@pytest.mark.parametrize("driver,device", [("jittor", None)]) +@pytest.mark.parametrize("driver", ["jittor"]) +@pytest.mark.parametrize("device", ["cpu", 1]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) @pytest.mark.jittor def test_trainer_jittor( @@ -134,4 +135,5 @@ def test_trainer_jittor( if __name__ == "__main__": # test_trainer_jittor("jittor", None, [RichCallback(100)]) + # test_trainer_jittor("jittor", 1, [RichCallback(100)]) pytest.main(['test_trainer_jittor.py']) # 只运行此模块 From 005b0e055e184fe2a964410279f9e92750ea4ba0 Mon Sep 17 00:00:00 2001 From: Letian Li <73881739+LetianLee@users.noreply.github.com> Date: Wed, 25 May 2022 12:09:09 +0100 Subject: [PATCH 6/6] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20jittor=20driver=20?= =?UTF-8?q?=E5=A4=9A=E5=8D=A1=E8=AE=AD=E7=BB=83=20(#418)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../jittor_driver/initialize_jittor_driver.py | 9 ++- fastNLP/core/drivers/jittor_driver/mpi.py | 80 +++++++++++++++++----- tests/core/controllers/test_trainer_jittor.py | 8 ++- 3 files changed, 76 insertions(+), 21 deletions(-) diff --git a/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py index 4b1fcba7..eff8fcfe 100644 --- a/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py @@ -1,5 +1,6 @@ from typing import Union, List +from fastNLP.core.drivers.jittor_driver.mpi import JittorMPIDriver from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver from fastNLP.envs.imports import _NEED_IMPORT_JITTOR @@ -29,7 +30,11 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") # TODO 实现更详细的判断 - if driver == "jittor": + if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]: return JittorSingleDriver(model, device, **kwargs) + elif type(device) is int: + return JittorMPIDriver(model, device, **kwargs) + elif type(device) is list: + return JittorMPIDriver(model, device, **kwargs) else: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError(f"Device={device}") diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index 4ade3fd1..ee2514e9 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -2,11 +2,14 @@ import os from typing import Optional, Union, Callable, Dict, Tuple from .jittor_driver import JittorDriver +from fastNLP.core.utils import auto_param_call +from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.envs.imports import _NEED_IMPORT_JITTOR -from fastNLP.core.samplers import ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: - import jittor + import jittor as jt __all__ = [ "JittorMPIDriver", @@ -42,7 +45,31 @@ class JittorMPIDriver(JittorDriver): self.outside_mpi = False def setup(self): - pass + self.__fork_with_mpi__() + + def __fork_with_mpi__(self): + import sys + if jt.in_mpi: + # you can mult other process output + if jt.rank != 0: + sys.stdout = open("/dev/null", "w") + return + else: + if self.parallel_device == -1: # device 为 -1,那么默认使用全部的显卡 + raise NotImplementedError(f"Device={self.parallel_device}") + elif type(self.parallel_device) is int: # device 为 *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练 + num_procs = 1 + devices = self.parallel_device + elif type(self.parallel_device) is list: # device 为 *list(int)*: 多于 1 个device,应当通过该种方式进行设定 + num_procs = len(self.parallel_device) + devices = str(self.parallel_device)[1:-1] + else: + raise NotImplementedError(f"Device={self.parallel_device}") + print(sys.argv) + cmd = " ".join(["CUDA_VISIBLE_DEVICES='%s'" % devices, "mpirun", "-np", str(num_procs), sys.executable] + sys.argv) + print("[RUN CMD]:", cmd) + os.system(cmd) + exit(0) def configure_mpi(self): pass @@ -71,25 +98,46 @@ class JittorMPIDriver(JittorDriver): def data_device(self): if self.outside_mpi: return self._data_device - return self.model_device + return self.parallel_device + + def step(self): + # for optimizer in self.optimizers: + # self.grad_scaler.step(optimizer) + # self.grad_scaler.update() + for optimizer in self.optimizers: + optimizer.step() + + def backward(self, loss): + # self.grad_scaler.scale(loss).backward() + for optimizer in self.optimizers: + optimizer.backward(loss) + + def zero_grad(self): + for optimizer in self.optimizers: + optimizer.zero_grad() def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: - pass + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) + else: + return fn(batch) def get_model_call_fn(self, fn: str) -> Tuple: - pass + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') + return fn, None + elif fn in {"train_step", "evaluate_step"}: + logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...') + return self.model, self.model.execute + else: + raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], reproducible: bool = False, sampler_or_batch_sampler=None): - pass - - def backward(self, loss): - self.grad_scaler.scale(loss).backward() - - def step(self): - for optimizer in self.optimizers: - self.grad_scaler.step(optimizer) - self.grad_scaler.update() + return dataloader def is_global_zero(self): return self.global_rank == 0 @@ -107,4 +155,4 @@ class JittorMPIDriver(JittorDriver): pass def is_distributed(self): - return True \ No newline at end of file + return True diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py index c84c24f1..b6cefdf3 100644 --- a/tests/core/controllers/test_trainer_jittor.py +++ b/tests/core/controllers/test_trainer_jittor.py @@ -70,7 +70,7 @@ class TrainJittorConfig: @pytest.mark.parametrize("driver", ["jittor"]) -@pytest.mark.parametrize("device", ["cpu", 1]) +@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) @pytest.mark.jittor def test_trainer_jittor( @@ -134,6 +134,8 @@ def test_trainer_jittor( if __name__ == "__main__": - # test_trainer_jittor("jittor", None, [RichCallback(100)]) - # test_trainer_jittor("jittor", 1, [RichCallback(100)]) + # test_trainer_jittor("jittor", "cpu", [RichCallback(100)]) # 测试 CPU + # test_trainer_jittor("jittor", "cuda:0", [RichCallback(100)]) # 测试 单卡 GPU + # test_trainer_jittor("jittor", 1, [RichCallback(100)]) # 测试 指定 GPU + # test_trainer_jittor("jittor", [0, 1], [RichCallback(100)]) # 测试 多卡 GPU pytest.main(['test_trainer_jittor.py']) # 只运行此模块