Browse Source

让DataLoader兼容1.6版本

tags/v1.0.0alpha
yh 3 years ago
parent
commit
0a5a4fde51
5 changed files with 166 additions and 14 deletions
  1. +12
    -8
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  2. +11
    -6
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  3. +71
    -0
      fastNLP/core/dataloaders/utils.py
  4. +39
    -0
      tests/core/dataloaders/test_utils.py
  5. +33
    -0
      tests/core/dataloaders/torch_dataloader/test_fdl.py

+ 12
- 8
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -47,6 +47,7 @@ from fastNLP.core.collators.collator import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler
from ..utils import _match_param


class _PaddleDataset(Dataset):
@@ -154,14 +155,17 @@ class PaddleDataLoader(DataLoader):
else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")

super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
return_list=return_list, batch_sampler=batch_sampler,
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
collate_fn=collate_fn, num_workers=num_workers,
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory,
timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers)

dl_kwargs = _match_param(PaddleDataLoader.__init__, DataLoader.__init__, DataLoader.__name__)
if dl_kwargs is None:
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
return_list=return_list, batch_sampler=batch_sampler,
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
collate_fn=collate_fn, num_workers=num_workers,
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory,
timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers)
else:
super().__init__(**dl_kwargs)
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True))
# if collate_fn is not None:
# _collate_fn.add_collator(collate_fn)


+ 11
- 6
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -11,6 +11,7 @@ from fastNLP.core.collators import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler
from ..utils import _match_param

if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler
@@ -96,12 +97,16 @@ class TorchDataLoader(DataLoader):
else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")

super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)
dl_kwargs = _match_param(TorchDataLoader.__init__, DataLoader.__init__, fn_name=DataLoader.__name__)
if dl_kwargs is None:
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)
else:
super().__init__(**dl_kwargs)

self.cur_batch_indices = None



+ 71
- 0
fastNLP/core/dataloaders/utils.py View File

@@ -1,4 +1,9 @@
from typing import Callable
import inspect
import ast

from ..log import logger
from ..utils.cache_results import get_func_calls, truncate_start_blanks
__all__ = [
"indice_collate_wrapper"
]
@@ -25,6 +30,72 @@ def indice_collate_wrapper(func:Callable):
return _indice_collate_wrapper


def _match_param(fun, call_fn:Callable, fn_name:str=None):
"""
在调用 _match_param 的函数(就是 fun )中会调用 call_fn 这个函数。由于 fun 中支持的函数比 call_fn 更多,例如低版本的
:class:`~.fastNLP.TorchDataLoader` 中支持的参数,在torch 1.6 版本的 DataLoader 就不支持,但在高版本的 torch 中是支持的
因此,这里需要根据当前版本的 DataLoader 判定出适合传入 DataLoader 进行初始化的参数,并且在不支持但又被设置的参数上进行
warning 。

:param fun: 调用函数本身
:param call_fn:
:param fn_name: 方便报错的用的函数
:return:
"""
try:
if fn_name is None:
try:
fn_name = call_fn.__name__
except:
fn_name = str(call_fn)

last_frame = inspect.currentframe().f_back

# 调用 _match_param 的函数名称,获取默认的参数值
fun_default_params = {}
fun_parameters = inspect.signature(fun)
for name, fun_param in fun_parameters.parameters.items():
if fun_param.default is not fun_param.empty:
fun_default_params[name] = fun_param.default

# 获取实际传入的参数值
param_names, args_name, kwargs_name, values = inspect.getargvalues(last_frame)
if args_name is not None:
raise RuntimeError("Function does not support positional arguments, such as: fun(*args).")
kwargs = values.get(kwargs_name, {})
for param in param_names:
if param not in values:
value = fun_default_params.get(param)
else:
value = values[param]
kwargs[param] = value

# 根据需要实际需要调用的 call_fn 的参数进行匹配
call_fn_parameters = inspect.signature(call_fn)
call_fn_kwargs = {}
has_kwargs = False
for name, param in call_fn_parameters.parameters.items():
if name == 'self':
continue
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY): # 最前面的 args
call_fn_kwargs[name] = param.default
if param.kind == param.VAR_KEYWORD:
has_kwargs = True

# 组装得到最终的参数
call_kwargs = {}
for name, value in kwargs.items():
if name in call_fn_kwargs or has_kwargs: # 如果存在在里面,或者包含了 kwargs 就直接运行
call_kwargs[name] = value
# 如果不在需要调用的函数里面,同时又是非默认值
elif name not in call_fn_kwargs and name in fun_default_params and fun_default_params[name]!=value:
logger.rank_zero_warning(f"Parameter:{name} is not supported for {fn_name}.")

return call_kwargs
except BaseException as e:
logger.debug(f"Exception happens when match parameters for {fn_name}: {e}")
return None

if __name__ == '__main__':
def demo(*args, **kwargs):
pass


+ 39
- 0
tests/core/dataloaders/test_utils.py View File

@@ -0,0 +1,39 @@
import pytest
from fastNLP.core.dataloaders.utils import _match_param
from fastNLP import logger
from tests.helpers.utils import recover_logger, Capturing


def demo():
pass


def test_no_args():
def f(*args, a, b, **kwarg):
c = 100
call_kwargs = _match_param(f, demo)
with pytest.raises(RuntimeError):
f(a=1, b=2)

def f(a, *args, b, **kwarg):
c = 100
call_kwargs = _match_param(f, demo)
with pytest.raises(RuntimeError):
f(a=1, b=2)


@recover_logger
def test_warning():
logger.set_stdout('raw')
def f1(a, b):
return 1

def f2(a, b, c=2):
kwargs = _match_param(f2, f1)
return f1(*kwargs)

with Capturing() as out:
f2(a=1, b=2, c=3)
assert 'Parameter:c' in out[0] # 传入了需要 warning

assert f2(1, 2) == 1

+ 33
- 0
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -5,6 +5,9 @@ from fastNLP.core.dataset import DataSet
from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core import Trainer
from pkg_resources import parse_version
from tests.helpers.utils import Capturing, recover_logger
from fastNLP import logger

if _NEED_IMPORT_TORCH:
import torch
@@ -128,3 +131,33 @@ class TestFdl:
dl = DataLoader(MyDatset(), collate_fn=collate_batch)
for batch in dl:
print(batch)

@recover_logger
def test_version_16(self):
if parse_version(torch.__version__) >= parse_version('1.7'):
pytest.skip("Torch version larger than 1.7")
logger.set_stdout()
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
with Capturing() as out:
dl = TorchDataLoader(ds, prefetch_factor=3, shuffle=False)
for idx, batch in enumerate(dl):
assert len(batch['x'])==1
assert batch['x'][0].tolist() == ds[idx]['x']

assert 'Parameter:prefetch_factor' in out[0]

@recover_logger
def test_version_111(self):
if parse_version(torch.__version__) <= parse_version('1.7'):
pytest.skip("Torch version smaller than 1.7")
logger.set_stdout()
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
with Capturing() as out:
dl = TorchDataLoader(ds, num_workers=2, prefetch_factor=3, shuffle=False)
for idx, batch in enumerate(dl):
assert len(batch['x'])==1
assert batch['x'][0].tolist() == ds[idx]['x']

assert 'Parameter:prefetch_factor' not in out[0]



Loading…
Cancel
Save