Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh 3 years ago
parent
commit
7a7d75d100
10 changed files with 118 additions and 57 deletions
  1. +1
    -0
      fastNLP/core/dataset/dataset.py
  2. +7
    -2
      fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py
  3. +64
    -16
      fastNLP/core/drivers/jittor_driver/mpi.py
  4. +6
    -3
      fastNLP/core/drivers/jittor_driver/single_device.py
  5. +2
    -2
      fastNLP/modules/mix_modules/utils.py
  6. +1
    -1
      fastNLP/transformers/torch/tokenization_utils_base.py
  7. +1
    -1
      tests/core/controllers/_test_trainer_jittor.py
  8. +6
    -2
      tests/core/controllers/test_trainer_jittor.py
  9. +0
    -1
      tests/core/dataloaders/torch_dataloader/test_fdl.py
  10. +30
    -29
      tests/core/dataloaders/torch_dataloader/test_mixdataloader.py

+ 1
- 0
fastNLP/core/dataset/dataset.py View File

@@ -156,6 +156,7 @@ import _pickle as pickle
from copy import deepcopy
from typing import Optional, List, Callable, Union, Dict, Any, Mapping
from types import LambdaType
from subprocess import DEVNULL
import sys
import time



+ 7
- 2
fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py View File

@@ -1,5 +1,6 @@
from typing import Union, List

from fastNLP.core.drivers.jittor_driver.mpi import JittorMPIDriver
from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
@@ -29,7 +30,11 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo
raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].")

# TODO 实现更详细的判断
if driver == "jittor":
if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]:
return JittorSingleDriver(model, device, **kwargs)
elif type(device) is int:
return JittorMPIDriver(model, device, **kwargs)
elif type(device) is list:
return JittorMPIDriver(model, device, **kwargs)
else:
raise NotImplementedError
raise NotImplementedError(f"Device={device}")

+ 64
- 16
fastNLP/core/drivers/jittor_driver/mpi.py View File

@@ -2,11 +2,14 @@ import os
from typing import Optional, Union, Callable, Dict, Tuple

from .jittor_driver import JittorDriver
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger

if _NEED_IMPORT_JITTOR:
import jittor
import jittor as jt

__all__ = [
"JittorMPIDriver",
@@ -42,7 +45,31 @@ class JittorMPIDriver(JittorDriver):
self.outside_mpi = False

def setup(self):
pass
self.__fork_with_mpi__()

def __fork_with_mpi__(self):
import sys
if jt.in_mpi:
# you can mult other process output
if jt.rank != 0:
sys.stdout = open("/dev/null", "w")
return
else:
if self.parallel_device == -1: # device 为 -1,那么默认使用全部的显卡
raise NotImplementedError(f"Device={self.parallel_device}")
elif type(self.parallel_device) is int: # device 为 *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练
num_procs = 1
devices = self.parallel_device
elif type(self.parallel_device) is list: # device 为 *list(int)*: 多于 1 个device,应当通过该种方式进行设定
num_procs = len(self.parallel_device)
devices = str(self.parallel_device)[1:-1]
else:
raise NotImplementedError(f"Device={self.parallel_device}")
print(sys.argv)
cmd = " ".join(["CUDA_VISIBLE_DEVICES='%s'" % devices, "mpirun", "-np", str(num_procs), sys.executable] + sys.argv)
print("[RUN CMD]:", cmd)
os.system(cmd)
exit(0)

def configure_mpi(self):
pass
@@ -71,25 +98,46 @@ class JittorMPIDriver(JittorDriver):
def data_device(self):
if self.outside_mpi:
return self._data_device
return self.model_device
return self.parallel_device

def step(self):
# for optimizer in self.optimizers:
# self.grad_scaler.step(optimizer)
# self.grad_scaler.update()
for optimizer in self.optimizers:
optimizer.step()

def backward(self, loss):
# self.grad_scaler.scale(loss).backward()
for optimizer in self.optimizers:
optimizer.backward(loss)

def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()

def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
pass
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
pass
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...')
return self.model, self.model.execute
else:
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):
pass

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()
return dataloader

def is_global_zero(self):
return self.global_rank == 0
@@ -107,4 +155,4 @@ class JittorMPIDriver(JittorDriver):
pass

def is_distributed(self):
return True
return True

+ 6
- 3
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -8,7 +8,7 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger

if _NEED_IMPORT_JITTOR:
import jittor
import jittor as jt

__all__ = [
"JittorSingleDriver",
@@ -105,6 +105,9 @@ class JittorSingleDriver(JittorDriver):

def setup(self):
"""
使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作
支持 cpu 和 gpu 的切换
"""
pass
if self.model_device in ["cpu", None]:
jt.flags.use_cuda = 0 # 使用 cpu
else:
jt.flags.use_cuda = 1 # 使用 cuda

+ 2
- 2
fastNLP/modules/mix_modules/utils.py View File

@@ -86,12 +86,12 @@ def _torch2paddle(torch_tensor: 'torch.Tensor', device: str = None, no_gradient:
if not no_gradient:
# 保持梯度并保持反向传播
# paddle的stop_gradient和torch的requires_grad表现是相反的
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=False)
paddle_tensor = paddle.to_tensor(torch_tensor.detach().cpu().numpy(), stop_gradient=False)
hook = paddle_tensor.register_hook(
lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy()))
)
else:
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=True)
paddle_tensor = paddle.to_tensor(torch_tensor.detach().cpu().numpy(), stop_gradient=True)

paddle_tensor = paddle_to(paddle_tensor, device)



+ 1
- 1
fastNLP/transformers/torch/tokenization_utils_base.py View File

@@ -2179,7 +2179,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
if padding is True:
if verbose:
if max_length is not None and (truncation is False or truncation == "do_not_truncate"):
logger.warn(
logger.warning_once(
"`max_length` is ignored when `padding`=`True` and there is no truncation strategy. "
"To pad to max length, use `padding='max_length'`."
)


+ 1
- 1
tests/core/controllers/_test_trainer_jittor.py View File

@@ -225,7 +225,7 @@ if __name__ == "__main__":
device=[0,1,2,3,4],
optimizers=optimizer,
train_dataloader=train_dataloader,
validate_dataloaders=val_dataloader,
evaluate_dataloaders=val_dataloader,
validate_every=-1,
input_mapping=None,
output_mapping=None,


+ 6
- 2
tests/core/controllers/test_trainer_jittor.py View File

@@ -69,7 +69,8 @@ class TrainJittorConfig:
shuffle: bool = True


@pytest.mark.parametrize("driver,device", [("jittor", None)])
@pytest.mark.parametrize("driver", ["jittor"])
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"])
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
@pytest.mark.jittor
def test_trainer_jittor(
@@ -133,5 +134,8 @@ def test_trainer_jittor(


if __name__ == "__main__":
# test_trainer_jittor("jittor", None, [RichCallback(100)])
# test_trainer_jittor("jittor", "cpu", [RichCallback(100)]) # 测试 CPU
# test_trainer_jittor("jittor", "cuda:0", [RichCallback(100)]) # 测试 单卡 GPU
# test_trainer_jittor("jittor", 1, [RichCallback(100)]) # 测试 指定 GPU
# test_trainer_jittor("jittor", [0, 1], [RichCallback(100)]) # 测试 多卡 GPU
pytest.main(['test_trainer_jittor.py']) # 只运行此模块

+ 0
- 1
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -147,7 +147,6 @@ class TestFdl:
assert 'Parameter:prefetch_factor' in out[0]

@recover_logger
@pytest.mark.temp
def test_version_111(self):
if parse_version(torch.__version__) <= parse_version('1.7'):
pytest.skip("Torch version smaller than 1.7")


tests/core/dataloaders/test_mixdataloader.py → tests/core/dataloaders/torch_dataloader/test_mixdataloader.py View File

@@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch
from torch.utils.data import default_collate, SequentialSampler, RandomSampler
from torch.utils.data import SequentialSampler, RandomSampler

d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})

@@ -17,7 +17,7 @@ d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10]
d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100})


def test_pad_val(tensor, val=0):
def _test_pad_val(tensor, val=0):
if isinstance(tensor, torch.Tensor):
tensor = tensor.tolist()
for item in tensor:
@@ -28,6 +28,7 @@ def test_pad_val(tensor, val=0):
return True


@pytest.mark.torch
class TestMixDataLoader:

def test_sequential_init(self):
@@ -44,7 +45,7 @@ class TestMixDataLoader:
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)

# collate_fn = Callable
def collate_batch(batch):
@@ -73,13 +74,13 @@ class TestMixDataLoader:
dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True)
for idx, batch in enumerate(dl2):
if idx == 0:
assert test_pad_val(batch['x'], val=-1)
assert _test_pad_val(batch['x'], val=-1)
assert batch['x'].shape == torch.Size([16, 4])
if idx == 1:
assert test_pad_val(batch['x'], val=-2)
assert _test_pad_val(batch['x'], val=-2)
assert batch['x'].shape == torch.Size([16, 3])
if idx > 1:
assert test_pad_val(batch['x'], val=-3)
assert _test_pad_val(batch['x'], val=-3)
assert batch['x'].shape == torch.Size([16, 4])

# sampler 为 str
@@ -100,7 +101,7 @@ class TestMixDataLoader:
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)

for idx, batch in enumerate(dl4):
if idx == 0:
@@ -117,7 +118,7 @@ class TestMixDataLoader:
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)

# sampler 为 Dict
samplers = {'d1': SequentialSampler(d1),
@@ -136,7 +137,7 @@ class TestMixDataLoader:
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)

# ds_ratio 为 'truncate_to_least'
dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True)
@@ -153,7 +154,7 @@ class TestMixDataLoader:
# d3
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx > 2:
raise ValueError(f"ds_ratio: 'truncate_to_least' error")

@@ -169,7 +170,7 @@ class TestMixDataLoader:
if 36 <= idx < 54:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 54:
raise ValueError(f"ds_ratio: 'pad_to_most' error")

@@ -186,7 +187,7 @@ class TestMixDataLoader:
if 4 <= idx < 41:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 41:
raise ValueError(f"ds_ratio: 'pad_to_most' error")

@@ -200,7 +201,7 @@ class TestMixDataLoader:
# d3
assert batch['x'].shape == torch.Size([16, 4])

assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 19:
raise ValueError(f"ds_ratio: 'pad_to_most' error")

@@ -208,7 +209,7 @@ class TestMixDataLoader:
datasets = {'d1': d1, 'd2': d2, 'd3': d3}
dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True)
for idx, batch in enumerate(dl):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")

@@ -223,7 +224,7 @@ class TestMixDataLoader:
dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True)
for idx, batch in enumerate(dl1):
assert isinstance(batch['x'], list)
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")

@@ -236,12 +237,12 @@ class TestMixDataLoader:
# sampler 为 str
dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True)
for idx, batch in enumerate(dl3):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")
dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True)
for idx, batch in enumerate(dl4):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")
# sampler 为 Dict
@@ -250,7 +251,7 @@ class TestMixDataLoader:
'd3': RandomSampler(d3)}
dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True)
for idx, batch in enumerate(dl5):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")
# ds_ratio 为 'truncate_to_least'
@@ -332,7 +333,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

# collate_fn = Callable
def collate_batch(batch):
@@ -360,16 +361,16 @@ class TestMixDataLoader:
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18)
for idx, batch in enumerate(dl1):
if idx == 0 or idx == 3:
assert test_pad_val(batch['x'], val=-1)
assert _test_pad_val(batch['x'], val=-1)
assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]]
assert batch['x'].shape[1] == 4
elif idx == 1 or idx == 4:
# d2
assert test_pad_val(batch['x'], val=-2)
assert _test_pad_val(batch['x'], val=-2)
assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]]
assert batch['x'].shape[1] == 3
elif idx == 2 or 4 < idx <= 20:
assert test_pad_val(batch['x'], val=-3)
assert _test_pad_val(batch['x'], val=-3)
assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]]
assert batch['x'].shape[1] == 4
if idx > 20:
@@ -391,7 +392,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)
for idx, batch in enumerate(dl3):
if idx == 0 or idx == 3:
assert batch['x'].shape[1] == 4
@@ -402,7 +403,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)
# sampler 为 Dict
samplers = {'d1': SequentialSampler(d1),
'd2': SequentialSampler(d2),
@@ -420,7 +421,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

# ds_ratio 为 'truncate_to_least'
dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18)
@@ -437,7 +438,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 5:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

# ds_ratio 为 'pad_to_most'
dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18)
@@ -456,7 +457,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx >= 51:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

# ds_ratio 为 Dict[str, float]
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
@@ -474,7 +475,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 39:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18)
@@ -492,4 +493,4 @@ class TestMixDataLoader:

if idx > 18:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

Loading…
Cancel
Save