Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh 3 years ago
parent
commit
7a7d75d100
10 changed files with 118 additions and 57 deletions
  1. +1
    -0
      fastNLP/core/dataset/dataset.py
  2. +7
    -2
      fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py
  3. +64
    -16
      fastNLP/core/drivers/jittor_driver/mpi.py
  4. +6
    -3
      fastNLP/core/drivers/jittor_driver/single_device.py
  5. +2
    -2
      fastNLP/modules/mix_modules/utils.py
  6. +1
    -1
      fastNLP/transformers/torch/tokenization_utils_base.py
  7. +1
    -1
      tests/core/controllers/_test_trainer_jittor.py
  8. +6
    -2
      tests/core/controllers/test_trainer_jittor.py
  9. +0
    -1
      tests/core/dataloaders/torch_dataloader/test_fdl.py
  10. +30
    -29
      tests/core/dataloaders/torch_dataloader/test_mixdataloader.py

+ 1
- 0
fastNLP/core/dataset/dataset.py View File

@@ -156,6 +156,7 @@ import _pickle as pickle
from copy import deepcopy from copy import deepcopy
from typing import Optional, List, Callable, Union, Dict, Any, Mapping from typing import Optional, List, Callable, Union, Dict, Any, Mapping
from types import LambdaType from types import LambdaType
from subprocess import DEVNULL
import sys import sys
import time import time




+ 7
- 2
fastNLP/core/drivers/jittor_driver/initialize_jittor_driver.py View File

@@ -1,5 +1,6 @@
from typing import Union, List from typing import Union, List


from fastNLP.core.drivers.jittor_driver.mpi import JittorMPIDriver
from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
@@ -29,7 +30,11 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo
raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].")


# TODO 实现更详细的判断 # TODO 实现更详细的判断
if driver == "jittor":
if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]:
return JittorSingleDriver(model, device, **kwargs) return JittorSingleDriver(model, device, **kwargs)
elif type(device) is int:
return JittorMPIDriver(model, device, **kwargs)
elif type(device) is list:
return JittorMPIDriver(model, device, **kwargs)
else: else:
raise NotImplementedError
raise NotImplementedError(f"Device={device}")

+ 64
- 16
fastNLP/core/drivers/jittor_driver/mpi.py View File

@@ -2,11 +2,14 @@ import os
from typing import Optional, Union, Callable, Dict, Tuple from typing import Optional, Union, Callable, Dict, Tuple


from .jittor_driver import JittorDriver from .jittor_driver import JittorDriver
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger


if _NEED_IMPORT_JITTOR: if _NEED_IMPORT_JITTOR:
import jittor
import jittor as jt


__all__ = [ __all__ = [
"JittorMPIDriver", "JittorMPIDriver",
@@ -42,7 +45,31 @@ class JittorMPIDriver(JittorDriver):
self.outside_mpi = False self.outside_mpi = False


def setup(self): def setup(self):
pass
self.__fork_with_mpi__()

def __fork_with_mpi__(self):
import sys
if jt.in_mpi:
# you can mult other process output
if jt.rank != 0:
sys.stdout = open("/dev/null", "w")
return
else:
if self.parallel_device == -1: # device 为 -1,那么默认使用全部的显卡
raise NotImplementedError(f"Device={self.parallel_device}")
elif type(self.parallel_device) is int: # device 为 *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练
num_procs = 1
devices = self.parallel_device
elif type(self.parallel_device) is list: # device 为 *list(int)*: 多于 1 个device,应当通过该种方式进行设定
num_procs = len(self.parallel_device)
devices = str(self.parallel_device)[1:-1]
else:
raise NotImplementedError(f"Device={self.parallel_device}")
print(sys.argv)
cmd = " ".join(["CUDA_VISIBLE_DEVICES='%s'" % devices, "mpirun", "-np", str(num_procs), sys.executable] + sys.argv)
print("[RUN CMD]:", cmd)
os.system(cmd)
exit(0)


def configure_mpi(self): def configure_mpi(self):
pass pass
@@ -71,25 +98,46 @@ class JittorMPIDriver(JittorDriver):
def data_device(self): def data_device(self):
if self.outside_mpi: if self.outside_mpi:
return self._data_device return self._data_device
return self.model_device
return self.parallel_device

def step(self):
# for optimizer in self.optimizers:
# self.grad_scaler.step(optimizer)
# self.grad_scaler.update()
for optimizer in self.optimizers:
optimizer.step()

def backward(self, loss):
# self.grad_scaler.scale(loss).backward()
for optimizer in self.optimizers:
optimizer.backward(loss)

def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()


def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
pass
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return fn(batch)


def get_model_call_fn(self, fn: str) -> Tuple: def get_model_call_fn(self, fn: str) -> Tuple:
pass
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...')
return self.model, self.model.execute
else:
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")


def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None): reproducible: bool = False, sampler_or_batch_sampler=None):
pass

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()
return dataloader


def is_global_zero(self): def is_global_zero(self):
return self.global_rank == 0 return self.global_rank == 0
@@ -107,4 +155,4 @@ class JittorMPIDriver(JittorDriver):
pass pass


def is_distributed(self): def is_distributed(self):
return True
return True

+ 6
- 3
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -8,7 +8,7 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger from fastNLP.core.log import logger


if _NEED_IMPORT_JITTOR: if _NEED_IMPORT_JITTOR:
import jittor
import jittor as jt


__all__ = [ __all__ = [
"JittorSingleDriver", "JittorSingleDriver",
@@ -105,6 +105,9 @@ class JittorSingleDriver(JittorDriver):


def setup(self): def setup(self):
""" """
使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作
支持 cpu 和 gpu 的切换
""" """
pass
if self.model_device in ["cpu", None]:
jt.flags.use_cuda = 0 # 使用 cpu
else:
jt.flags.use_cuda = 1 # 使用 cuda

+ 2
- 2
fastNLP/modules/mix_modules/utils.py View File

@@ -86,12 +86,12 @@ def _torch2paddle(torch_tensor: 'torch.Tensor', device: str = None, no_gradient:
if not no_gradient: if not no_gradient:
# 保持梯度并保持反向传播 # 保持梯度并保持反向传播
# paddle的stop_gradient和torch的requires_grad表现是相反的 # paddle的stop_gradient和torch的requires_grad表现是相反的
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=False)
paddle_tensor = paddle.to_tensor(torch_tensor.detach().cpu().numpy(), stop_gradient=False)
hook = paddle_tensor.register_hook( hook = paddle_tensor.register_hook(
lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy())) lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy()))
) )
else: else:
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=True)
paddle_tensor = paddle.to_tensor(torch_tensor.detach().cpu().numpy(), stop_gradient=True)


paddle_tensor = paddle_to(paddle_tensor, device) paddle_tensor = paddle_to(paddle_tensor, device)




+ 1
- 1
fastNLP/transformers/torch/tokenization_utils_base.py View File

@@ -2179,7 +2179,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
if padding is True: if padding is True:
if verbose: if verbose:
if max_length is not None and (truncation is False or truncation == "do_not_truncate"): if max_length is not None and (truncation is False or truncation == "do_not_truncate"):
logger.warn(
logger.warning_once(
"`max_length` is ignored when `padding`=`True` and there is no truncation strategy. " "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. "
"To pad to max length, use `padding='max_length'`." "To pad to max length, use `padding='max_length'`."
) )


+ 1
- 1
tests/core/controllers/_test_trainer_jittor.py View File

@@ -225,7 +225,7 @@ if __name__ == "__main__":
device=[0,1,2,3,4], device=[0,1,2,3,4],
optimizers=optimizer, optimizers=optimizer,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
validate_dataloaders=val_dataloader,
evaluate_dataloaders=val_dataloader,
validate_every=-1, validate_every=-1,
input_mapping=None, input_mapping=None,
output_mapping=None, output_mapping=None,


+ 6
- 2
tests/core/controllers/test_trainer_jittor.py View File

@@ -69,7 +69,8 @@ class TrainJittorConfig:
shuffle: bool = True shuffle: bool = True




@pytest.mark.parametrize("driver,device", [("jittor", None)])
@pytest.mark.parametrize("driver", ["jittor"])
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"])
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
@pytest.mark.jittor @pytest.mark.jittor
def test_trainer_jittor( def test_trainer_jittor(
@@ -133,5 +134,8 @@ def test_trainer_jittor(




if __name__ == "__main__": if __name__ == "__main__":
# test_trainer_jittor("jittor", None, [RichCallback(100)])
# test_trainer_jittor("jittor", "cpu", [RichCallback(100)]) # 测试 CPU
# test_trainer_jittor("jittor", "cuda:0", [RichCallback(100)]) # 测试 单卡 GPU
# test_trainer_jittor("jittor", 1, [RichCallback(100)]) # 测试 指定 GPU
# test_trainer_jittor("jittor", [0, 1], [RichCallback(100)]) # 测试 多卡 GPU
pytest.main(['test_trainer_jittor.py']) # 只运行此模块 pytest.main(['test_trainer_jittor.py']) # 只运行此模块

+ 0
- 1
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -147,7 +147,6 @@ class TestFdl:
assert 'Parameter:prefetch_factor' in out[0] assert 'Parameter:prefetch_factor' in out[0]


@recover_logger @recover_logger
@pytest.mark.temp
def test_version_111(self): def test_version_111(self):
if parse_version(torch.__version__) <= parse_version('1.7'): if parse_version(torch.__version__) <= parse_version('1.7'):
pytest.skip("Torch version smaller than 1.7") pytest.skip("Torch version smaller than 1.7")


tests/core/dataloaders/test_mixdataloader.py → tests/core/dataloaders/torch_dataloader/test_mixdataloader.py View File

@@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH


if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch
from torch.utils.data import default_collate, SequentialSampler, RandomSampler
from torch.utils.data import SequentialSampler, RandomSampler


d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})


@@ -17,7 +17,7 @@ d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10]
d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100}) d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100})




def test_pad_val(tensor, val=0):
def _test_pad_val(tensor, val=0):
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor):
tensor = tensor.tolist() tensor = tensor.tolist()
for item in tensor: for item in tensor:
@@ -28,6 +28,7 @@ def test_pad_val(tensor, val=0):
return True return True




@pytest.mark.torch
class TestMixDataLoader: class TestMixDataLoader:


def test_sequential_init(self): def test_sequential_init(self):
@@ -44,7 +45,7 @@ class TestMixDataLoader:
if idx > 1: if idx > 1:
# d3 # d3
assert batch['x'].shape == torch.Size([16, 4]) assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)


# collate_fn = Callable # collate_fn = Callable
def collate_batch(batch): def collate_batch(batch):
@@ -73,13 +74,13 @@ class TestMixDataLoader:
dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True) dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True)
for idx, batch in enumerate(dl2): for idx, batch in enumerate(dl2):
if idx == 0: if idx == 0:
assert test_pad_val(batch['x'], val=-1)
assert _test_pad_val(batch['x'], val=-1)
assert batch['x'].shape == torch.Size([16, 4]) assert batch['x'].shape == torch.Size([16, 4])
if idx == 1: if idx == 1:
assert test_pad_val(batch['x'], val=-2)
assert _test_pad_val(batch['x'], val=-2)
assert batch['x'].shape == torch.Size([16, 3]) assert batch['x'].shape == torch.Size([16, 3])
if idx > 1: if idx > 1:
assert test_pad_val(batch['x'], val=-3)
assert _test_pad_val(batch['x'], val=-3)
assert batch['x'].shape == torch.Size([16, 4]) assert batch['x'].shape == torch.Size([16, 4])


# sampler 为 str # sampler 为 str
@@ -100,7 +101,7 @@ class TestMixDataLoader:
if idx > 1: if idx > 1:
# d3 # d3
assert batch['x'].shape == torch.Size([16, 4]) assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)


for idx, batch in enumerate(dl4): for idx, batch in enumerate(dl4):
if idx == 0: if idx == 0:
@@ -117,7 +118,7 @@ class TestMixDataLoader:
if idx > 1: if idx > 1:
# d3 # d3
assert batch['x'].shape == torch.Size([16, 4]) assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)


# sampler 为 Dict # sampler 为 Dict
samplers = {'d1': SequentialSampler(d1), samplers = {'d1': SequentialSampler(d1),
@@ -136,7 +137,7 @@ class TestMixDataLoader:
if idx > 1: if idx > 1:
# d3 # d3
assert batch['x'].shape == torch.Size([16, 4]) assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)


# ds_ratio 为 'truncate_to_least' # ds_ratio 为 'truncate_to_least'
dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True) dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True)
@@ -153,7 +154,7 @@ class TestMixDataLoader:
# d3 # d3
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
assert batch['x'].shape == torch.Size([16, 4]) assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx > 2: if idx > 2:
raise ValueError(f"ds_ratio: 'truncate_to_least' error") raise ValueError(f"ds_ratio: 'truncate_to_least' error")


@@ -169,7 +170,7 @@ class TestMixDataLoader:
if 36 <= idx < 54: if 36 <= idx < 54:
# d3 # d3
assert batch['x'].shape == torch.Size([16, 4]) assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 54: if idx >= 54:
raise ValueError(f"ds_ratio: 'pad_to_most' error") raise ValueError(f"ds_ratio: 'pad_to_most' error")


@@ -186,7 +187,7 @@ class TestMixDataLoader:
if 4 <= idx < 41: if 4 <= idx < 41:
# d3 # d3
assert batch['x'].shape == torch.Size([16, 4]) assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 41: if idx >= 41:
raise ValueError(f"ds_ratio: 'pad_to_most' error") raise ValueError(f"ds_ratio: 'pad_to_most' error")


@@ -200,7 +201,7 @@ class TestMixDataLoader:
# d3 # d3
assert batch['x'].shape == torch.Size([16, 4]) assert batch['x'].shape == torch.Size([16, 4])


assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 19: if idx >= 19:
raise ValueError(f"ds_ratio: 'pad_to_most' error") raise ValueError(f"ds_ratio: 'pad_to_most' error")


@@ -208,7 +209,7 @@ class TestMixDataLoader:
datasets = {'d1': d1, 'd2': d2, 'd3': d3} datasets = {'d1': d1, 'd2': d2, 'd3': d3}
dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True) dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True)
for idx, batch in enumerate(dl): for idx, batch in enumerate(dl):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22: if idx >= 22:
raise ValueError(f"out of range") raise ValueError(f"out of range")


@@ -223,7 +224,7 @@ class TestMixDataLoader:
dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True) dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True)
for idx, batch in enumerate(dl1): for idx, batch in enumerate(dl1):
assert isinstance(batch['x'], list) assert isinstance(batch['x'], list)
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22: if idx >= 22:
raise ValueError(f"out of range") raise ValueError(f"out of range")


@@ -236,12 +237,12 @@ class TestMixDataLoader:
# sampler 为 str # sampler 为 str
dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True) dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True)
for idx, batch in enumerate(dl3): for idx, batch in enumerate(dl3):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22: if idx >= 22:
raise ValueError(f"out of range") raise ValueError(f"out of range")
dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True) dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True)
for idx, batch in enumerate(dl4): for idx, batch in enumerate(dl4):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22: if idx >= 22:
raise ValueError(f"out of range") raise ValueError(f"out of range")
# sampler 为 Dict # sampler 为 Dict
@@ -250,7 +251,7 @@ class TestMixDataLoader:
'd3': RandomSampler(d3)} 'd3': RandomSampler(d3)}
dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True) dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True)
for idx, batch in enumerate(dl5): for idx, batch in enumerate(dl5):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22: if idx >= 22:
raise ValueError(f"out of range") raise ValueError(f"out of range")
# ds_ratio 为 'truncate_to_least' # ds_ratio 为 'truncate_to_least'
@@ -332,7 +333,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4 assert batch['x'].shape[1] == 4
if idx > 20: if idx > 20:
raise ValueError(f"out of range") raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)


# collate_fn = Callable # collate_fn = Callable
def collate_batch(batch): def collate_batch(batch):
@@ -360,16 +361,16 @@ class TestMixDataLoader:
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18) dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18)
for idx, batch in enumerate(dl1): for idx, batch in enumerate(dl1):
if idx == 0 or idx == 3: if idx == 0 or idx == 3:
assert test_pad_val(batch['x'], val=-1)
assert _test_pad_val(batch['x'], val=-1)
assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]] assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]]
assert batch['x'].shape[1] == 4 assert batch['x'].shape[1] == 4
elif idx == 1 or idx == 4: elif idx == 1 or idx == 4:
# d2 # d2
assert test_pad_val(batch['x'], val=-2)
assert _test_pad_val(batch['x'], val=-2)
assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]] assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]]
assert batch['x'].shape[1] == 3 assert batch['x'].shape[1] == 3
elif idx == 2 or 4 < idx <= 20: elif idx == 2 or 4 < idx <= 20:
assert test_pad_val(batch['x'], val=-3)
assert _test_pad_val(batch['x'], val=-3)
assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]] assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]]
assert batch['x'].shape[1] == 4 assert batch['x'].shape[1] == 4
if idx > 20: if idx > 20:
@@ -391,7 +392,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4 assert batch['x'].shape[1] == 4
if idx > 20: if idx > 20:
raise ValueError(f"out of range") raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)
for idx, batch in enumerate(dl3): for idx, batch in enumerate(dl3):
if idx == 0 or idx == 3: if idx == 0 or idx == 3:
assert batch['x'].shape[1] == 4 assert batch['x'].shape[1] == 4
@@ -402,7 +403,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4 assert batch['x'].shape[1] == 4
if idx > 20: if idx > 20:
raise ValueError(f"out of range") raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)
# sampler 为 Dict # sampler 为 Dict
samplers = {'d1': SequentialSampler(d1), samplers = {'d1': SequentialSampler(d1),
'd2': SequentialSampler(d2), 'd2': SequentialSampler(d2),
@@ -420,7 +421,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4 assert batch['x'].shape[1] == 4
if idx > 20: if idx > 20:
raise ValueError(f"out of range") raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)


# ds_ratio 为 'truncate_to_least' # ds_ratio 为 'truncate_to_least'
dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18) dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18)
@@ -437,7 +438,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4 assert batch['x'].shape[1] == 4
if idx > 5: if idx > 5:
raise ValueError(f"out of range") raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)


# ds_ratio 为 'pad_to_most' # ds_ratio 为 'pad_to_most'
dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18) dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18)
@@ -456,7 +457,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4 assert batch['x'].shape[1] == 4
if idx >= 51: if idx >= 51:
raise ValueError(f"out of range") raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)


# ds_ratio 为 Dict[str, float] # ds_ratio 为 Dict[str, float]
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
@@ -474,7 +475,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4 assert batch['x'].shape[1] == 4
if idx > 39: if idx > 39:
raise ValueError(f"out of range") raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)


ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18)
@@ -492,4 +493,4 @@ class TestMixDataLoader:


if idx > 18: if idx > 18:
raise ValueError(f"out of range") raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

Loading…
Cancel
Save