diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index edcde4ac..9da65112 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -156,6 +156,7 @@ import _pickle as pickle from copy import deepcopy from typing import Optional, List, Callable, Union, Dict, Any, Mapping from types import LambdaType +from subprocess import DEVNULL import sys import time diff --git a/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py index 4b1fcba7..eff8fcfe 100644 --- a/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py @@ -1,5 +1,6 @@ from typing import Union, List +from fastNLP.core.drivers.jittor_driver.mpi import JittorMPIDriver from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver from fastNLP.envs.imports import _NEED_IMPORT_JITTOR @@ -29,7 +30,11 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") # TODO 实现更详细的判断 - if driver == "jittor": + if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]: return JittorSingleDriver(model, device, **kwargs) + elif type(device) is int: + return JittorMPIDriver(model, device, **kwargs) + elif type(device) is list: + return JittorMPIDriver(model, device, **kwargs) else: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError(f"Device={device}") diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index 4ade3fd1..ee2514e9 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -2,11 +2,14 @@ import os from typing import Optional, Union, Callable, Dict, Tuple from .jittor_driver import JittorDriver +from fastNLP.core.utils import auto_param_call +from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.envs.imports import _NEED_IMPORT_JITTOR -from fastNLP.core.samplers import ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: - import jittor + import jittor as jt __all__ = [ "JittorMPIDriver", @@ -42,7 +45,31 @@ class JittorMPIDriver(JittorDriver): self.outside_mpi = False def setup(self): - pass + self.__fork_with_mpi__() + + def __fork_with_mpi__(self): + import sys + if jt.in_mpi: + # you can mult other process output + if jt.rank != 0: + sys.stdout = open("/dev/null", "w") + return + else: + if self.parallel_device == -1: # device 为 -1,那么默认使用全部的显卡 + raise NotImplementedError(f"Device={self.parallel_device}") + elif type(self.parallel_device) is int: # device 为 *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练 + num_procs = 1 + devices = self.parallel_device + elif type(self.parallel_device) is list: # device 为 *list(int)*: 多于 1 个device,应当通过该种方式进行设定 + num_procs = len(self.parallel_device) + devices = str(self.parallel_device)[1:-1] + else: + raise NotImplementedError(f"Device={self.parallel_device}") + print(sys.argv) + cmd = " ".join(["CUDA_VISIBLE_DEVICES='%s'" % devices, "mpirun", "-np", str(num_procs), sys.executable] + sys.argv) + print("[RUN CMD]:", cmd) + os.system(cmd) + exit(0) def configure_mpi(self): pass @@ -71,25 +98,46 @@ class JittorMPIDriver(JittorDriver): def data_device(self): if self.outside_mpi: return self._data_device - return self.model_device + return self.parallel_device + + def step(self): + # for optimizer in self.optimizers: + # self.grad_scaler.step(optimizer) + # self.grad_scaler.update() + for optimizer in self.optimizers: + optimizer.step() + + def backward(self, loss): + # self.grad_scaler.scale(loss).backward() + for optimizer in self.optimizers: + optimizer.backward(loss) + + def zero_grad(self): + for optimizer in self.optimizers: + optimizer.zero_grad() def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: - pass + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) + else: + return fn(batch) def get_model_call_fn(self, fn: str) -> Tuple: - pass + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') + return fn, None + elif fn in {"train_step", "evaluate_step"}: + logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...') + return self.model, self.model.execute + else: + raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], reproducible: bool = False, sampler_or_batch_sampler=None): - pass - - def backward(self, loss): - self.grad_scaler.scale(loss).backward() - - def step(self): - for optimizer in self.optimizers: - self.grad_scaler.step(optimizer) - self.grad_scaler.update() + return dataloader def is_global_zero(self): return self.global_rank == 0 @@ -107,4 +155,4 @@ class JittorMPIDriver(JittorDriver): pass def is_distributed(self): - return True \ No newline at end of file + return True diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 7529aec9..19c4b4c2 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -8,7 +8,7 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: - import jittor + import jittor as jt __all__ = [ "JittorSingleDriver", @@ -105,6 +105,9 @@ class JittorSingleDriver(JittorDriver): def setup(self): """ - 使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作 + 支持 cpu 和 gpu 的切换 """ - pass + if self.model_device in ["cpu", None]: + jt.flags.use_cuda = 0 # 使用 cpu + else: + jt.flags.use_cuda = 1 # 使用 cuda diff --git a/fastNLP/modules/mix_modules/utils.py b/fastNLP/modules/mix_modules/utils.py index 142644f9..e709b0ac 100644 --- a/fastNLP/modules/mix_modules/utils.py +++ b/fastNLP/modules/mix_modules/utils.py @@ -86,12 +86,12 @@ def _torch2paddle(torch_tensor: 'torch.Tensor', device: str = None, no_gradient: if not no_gradient: # 保持梯度并保持反向传播 # paddle的stop_gradient和torch的requires_grad表现是相反的 - paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=False) + paddle_tensor = paddle.to_tensor(torch_tensor.detach().cpu().numpy(), stop_gradient=False) hook = paddle_tensor.register_hook( lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy())) ) else: - paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=True) + paddle_tensor = paddle.to_tensor(torch_tensor.detach().cpu().numpy(), stop_gradient=True) paddle_tensor = paddle_to(paddle_tensor, device) diff --git a/fastNLP/transformers/torch/tokenization_utils_base.py b/fastNLP/transformers/torch/tokenization_utils_base.py index 8ed5a2e2..3a033c96 100644 --- a/fastNLP/transformers/torch/tokenization_utils_base.py +++ b/fastNLP/transformers/torch/tokenization_utils_base.py @@ -2179,7 +2179,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): if padding is True: if verbose: if max_length is not None and (truncation is False or truncation == "do_not_truncate"): - logger.warn( + logger.warning_once( "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. " "To pad to max length, use `padding='max_length'`." ) diff --git a/tests/core/controllers/_test_trainer_jittor.py b/tests/core/controllers/_test_trainer_jittor.py index bc4b05f0..13ab2e8b 100644 --- a/tests/core/controllers/_test_trainer_jittor.py +++ b/tests/core/controllers/_test_trainer_jittor.py @@ -225,7 +225,7 @@ if __name__ == "__main__": device=[0,1,2,3,4], optimizers=optimizer, train_dataloader=train_dataloader, - validate_dataloaders=val_dataloader, + evaluate_dataloaders=val_dataloader, validate_every=-1, input_mapping=None, output_mapping=None, diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py index c01fd6e6..b6cefdf3 100644 --- a/tests/core/controllers/test_trainer_jittor.py +++ b/tests/core/controllers/test_trainer_jittor.py @@ -69,7 +69,8 @@ class TrainJittorConfig: shuffle: bool = True -@pytest.mark.parametrize("driver,device", [("jittor", None)]) +@pytest.mark.parametrize("driver", ["jittor"]) +@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) @pytest.mark.jittor def test_trainer_jittor( @@ -133,5 +134,8 @@ def test_trainer_jittor( if __name__ == "__main__": - # test_trainer_jittor("jittor", None, [RichCallback(100)]) + # test_trainer_jittor("jittor", "cpu", [RichCallback(100)]) # 测试 CPU + # test_trainer_jittor("jittor", "cuda:0", [RichCallback(100)]) # 测试 单卡 GPU + # test_trainer_jittor("jittor", 1, [RichCallback(100)]) # 测试 指定 GPU + # test_trainer_jittor("jittor", [0, 1], [RichCallback(100)]) # 测试 多卡 GPU pytest.main(['test_trainer_jittor.py']) # 只运行此模块 diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index d977e52f..b53790bb 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -147,7 +147,6 @@ class TestFdl: assert 'Parameter:prefetch_factor' in out[0] @recover_logger - @pytest.mark.temp def test_version_111(self): if parse_version(torch.__version__) <= parse_version('1.7'): pytest.skip("Torch version smaller than 1.7") diff --git a/tests/core/dataloaders/test_mixdataloader.py b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py similarity index 93% rename from tests/core/dataloaders/test_mixdataloader.py rename to tests/core/dataloaders/torch_dataloader/test_mixdataloader.py index 35872b39..bf9e3d9e 100644 --- a/tests/core/dataloaders/test_mixdataloader.py +++ b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py @@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch - from torch.utils.data import default_collate, SequentialSampler, RandomSampler + from torch.utils.data import SequentialSampler, RandomSampler d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) @@ -17,7 +17,7 @@ d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100}) -def test_pad_val(tensor, val=0): +def _test_pad_val(tensor, val=0): if isinstance(tensor, torch.Tensor): tensor = tensor.tolist() for item in tensor: @@ -28,6 +28,7 @@ def test_pad_val(tensor, val=0): return True +@pytest.mark.torch class TestMixDataLoader: def test_sequential_init(self): @@ -44,7 +45,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) # collate_fn = Callable def collate_batch(batch): @@ -73,13 +74,13 @@ class TestMixDataLoader: dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True) for idx, batch in enumerate(dl2): if idx == 0: - assert test_pad_val(batch['x'], val=-1) + assert _test_pad_val(batch['x'], val=-1) assert batch['x'].shape == torch.Size([16, 4]) if idx == 1: - assert test_pad_val(batch['x'], val=-2) + assert _test_pad_val(batch['x'], val=-2) assert batch['x'].shape == torch.Size([16, 3]) if idx > 1: - assert test_pad_val(batch['x'], val=-3) + assert _test_pad_val(batch['x'], val=-3) assert batch['x'].shape == torch.Size([16, 4]) # sampler 为 str @@ -100,7 +101,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) for idx, batch in enumerate(dl4): if idx == 0: @@ -117,7 +118,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) # sampler 为 Dict samplers = {'d1': SequentialSampler(d1), @@ -136,7 +137,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) # ds_ratio 为 'truncate_to_least' dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True) @@ -153,7 +154,7 @@ class TestMixDataLoader: # d3 assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx > 2: raise ValueError(f"ds_ratio: 'truncate_to_least' error") @@ -169,7 +170,7 @@ class TestMixDataLoader: if 36 <= idx < 54: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 54: raise ValueError(f"ds_ratio: 'pad_to_most' error") @@ -186,7 +187,7 @@ class TestMixDataLoader: if 4 <= idx < 41: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 41: raise ValueError(f"ds_ratio: 'pad_to_most' error") @@ -200,7 +201,7 @@ class TestMixDataLoader: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 19: raise ValueError(f"ds_ratio: 'pad_to_most' error") @@ -208,7 +209,7 @@ class TestMixDataLoader: datasets = {'d1': d1, 'd2': d2, 'd3': d3} dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True) for idx, batch in enumerate(dl): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") @@ -223,7 +224,7 @@ class TestMixDataLoader: dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True) for idx, batch in enumerate(dl1): assert isinstance(batch['x'], list) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") @@ -236,12 +237,12 @@ class TestMixDataLoader: # sampler 为 str dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True) for idx, batch in enumerate(dl3): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True) for idx, batch in enumerate(dl4): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") # sampler 为 Dict @@ -250,7 +251,7 @@ class TestMixDataLoader: 'd3': RandomSampler(d3)} dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True) for idx, batch in enumerate(dl5): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") # ds_ratio 为 'truncate_to_least' @@ -332,7 +333,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # collate_fn = Callable def collate_batch(batch): @@ -360,16 +361,16 @@ class TestMixDataLoader: dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18) for idx, batch in enumerate(dl1): if idx == 0 or idx == 3: - assert test_pad_val(batch['x'], val=-1) + assert _test_pad_val(batch['x'], val=-1) assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]] assert batch['x'].shape[1] == 4 elif idx == 1 or idx == 4: # d2 - assert test_pad_val(batch['x'], val=-2) + assert _test_pad_val(batch['x'], val=-2) assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]] assert batch['x'].shape[1] == 3 elif idx == 2 or 4 < idx <= 20: - assert test_pad_val(batch['x'], val=-3) + assert _test_pad_val(batch['x'], val=-3) assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]] assert batch['x'].shape[1] == 4 if idx > 20: @@ -391,7 +392,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) for idx, batch in enumerate(dl3): if idx == 0 or idx == 3: assert batch['x'].shape[1] == 4 @@ -402,7 +403,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # sampler 为 Dict samplers = {'d1': SequentialSampler(d1), 'd2': SequentialSampler(d2), @@ -420,7 +421,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # ds_ratio 为 'truncate_to_least' dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18) @@ -437,7 +438,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 5: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # ds_ratio 为 'pad_to_most' dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18) @@ -456,7 +457,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx >= 51: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # ds_ratio 为 Dict[str, float] ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} @@ -474,7 +475,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 39: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) @@ -492,4 +493,4 @@ class TestMixDataLoader: if idx > 18: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) \ No newline at end of file + _test_pad_val(batch['x'], val=0) \ No newline at end of file