From 0a5a4fde51fc69128fb8bd6ca58605ff7f882d01 Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 18 May 2022 17:13:45 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AE=A9DataLoader=E5=85=BC=E5=AE=B91.6?= =?UTF-8?q?=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/paddle_dataloader/fdl.py | 20 +++--- fastNLP/core/dataloaders/torch_dataloader/fdl.py | 17 ++++-- fastNLP/core/dataloaders/utils.py | 71 ++++++++++++++++++++++ tests/core/dataloaders/test_utils.py | 39 ++++++++++++ .../core/dataloaders/torch_dataloader/test_fdl.py | 33 ++++++++++ 5 files changed, 166 insertions(+), 14 deletions(-) create mode 100644 tests/core/dataloaders/test_utils.py diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 6082fb1a..36f6588b 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -47,6 +47,7 @@ from fastNLP.core.collators.collator import Collator from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataset import DataSet as FDataSet from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler +from ..utils import _match_param class _PaddleDataset(Dataset): @@ -154,14 +155,17 @@ class PaddleDataLoader(DataLoader): else: raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") - super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, - return_list=return_list, batch_sampler=batch_sampler, - batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, - collate_fn=collate_fn, num_workers=num_workers, - use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, - timeout=timeout, worker_init_fn=worker_init_fn, - persistent_workers=persistent_workers) - + dl_kwargs = _match_param(PaddleDataLoader.__init__, DataLoader.__init__, DataLoader.__name__) + if dl_kwargs is None: + super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, + return_list=return_list, batch_sampler=batch_sampler, + batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + collate_fn=collate_fn, num_workers=num_workers, + use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, + timeout=timeout, worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers) + else: + super().__init__(**dl_kwargs) # _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) # if collate_fn is not None: # _collate_fn.add_collator(collate_fn) diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 065504a6..f5e4af97 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -11,6 +11,7 @@ from fastNLP.core.collators import Collator from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler +from ..utils import _match_param if _NEED_IMPORT_TORCH: from torch.utils.data import DataLoader, Sampler @@ -96,12 +97,16 @@ class TorchDataLoader(DataLoader): else: raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") - super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, - batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, - pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, - multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers) + dl_kwargs = _match_param(TorchDataLoader.__init__, DataLoader.__init__, fn_name=DataLoader.__name__) + if dl_kwargs is None: + super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, + pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers) + else: + super().__init__(**dl_kwargs) self.cur_batch_indices = None diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index 495fb6d3..6c6118d9 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -1,4 +1,9 @@ from typing import Callable +import inspect +import ast + +from ..log import logger +from ..utils.cache_results import get_func_calls, truncate_start_blanks __all__ = [ "indice_collate_wrapper" ] @@ -25,6 +30,72 @@ def indice_collate_wrapper(func:Callable): return _indice_collate_wrapper +def _match_param(fun, call_fn:Callable, fn_name:str=None): + """ + 在调用 _match_param 的函数(就是 fun )中会调用 call_fn 这个函数。由于 fun 中支持的函数比 call_fn 更多,例如低版本的 + :class:`~.fastNLP.TorchDataLoader` 中支持的参数,在torch 1.6 版本的 DataLoader 就不支持,但在高版本的 torch 中是支持的 + 因此,这里需要根据当前版本的 DataLoader 判定出适合传入 DataLoader 进行初始化的参数,并且在不支持但又被设置的参数上进行 + warning 。 + + :param fun: 调用函数本身 + :param call_fn: + :param fn_name: 方便报错的用的函数 + :return: + """ + try: + if fn_name is None: + try: + fn_name = call_fn.__name__ + except: + fn_name = str(call_fn) + + last_frame = inspect.currentframe().f_back + + # 调用 _match_param 的函数名称,获取默认的参数值 + fun_default_params = {} + fun_parameters = inspect.signature(fun) + for name, fun_param in fun_parameters.parameters.items(): + if fun_param.default is not fun_param.empty: + fun_default_params[name] = fun_param.default + + # 获取实际传入的参数值 + param_names, args_name, kwargs_name, values = inspect.getargvalues(last_frame) + if args_name is not None: + raise RuntimeError("Function does not support positional arguments, such as: fun(*args).") + kwargs = values.get(kwargs_name, {}) + for param in param_names: + if param not in values: + value = fun_default_params.get(param) + else: + value = values[param] + kwargs[param] = value + + # 根据需要实际需要调用的 call_fn 的参数进行匹配 + call_fn_parameters = inspect.signature(call_fn) + call_fn_kwargs = {} + has_kwargs = False + for name, param in call_fn_parameters.parameters.items(): + if name == 'self': + continue + if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY): # 最前面的 args + call_fn_kwargs[name] = param.default + if param.kind == param.VAR_KEYWORD: + has_kwargs = True + + # 组装得到最终的参数 + call_kwargs = {} + for name, value in kwargs.items(): + if name in call_fn_kwargs or has_kwargs: # 如果存在在里面,或者包含了 kwargs 就直接运行 + call_kwargs[name] = value + # 如果不在需要调用的函数里面,同时又是非默认值 + elif name not in call_fn_kwargs and name in fun_default_params and fun_default_params[name]!=value: + logger.rank_zero_warning(f"Parameter:{name} is not supported for {fn_name}.") + + return call_kwargs + except BaseException as e: + logger.debug(f"Exception happens when match parameters for {fn_name}: {e}") + return None + if __name__ == '__main__': def demo(*args, **kwargs): pass diff --git a/tests/core/dataloaders/test_utils.py b/tests/core/dataloaders/test_utils.py new file mode 100644 index 00000000..e5a7cc9e --- /dev/null +++ b/tests/core/dataloaders/test_utils.py @@ -0,0 +1,39 @@ +import pytest +from fastNLP.core.dataloaders.utils import _match_param +from fastNLP import logger +from tests.helpers.utils import recover_logger, Capturing + + +def demo(): + pass + + +def test_no_args(): + def f(*args, a, b, **kwarg): + c = 100 + call_kwargs = _match_param(f, demo) + with pytest.raises(RuntimeError): + f(a=1, b=2) + + def f(a, *args, b, **kwarg): + c = 100 + call_kwargs = _match_param(f, demo) + with pytest.raises(RuntimeError): + f(a=1, b=2) + + +@recover_logger +def test_warning(): + logger.set_stdout('raw') + def f1(a, b): + return 1 + + def f2(a, b, c=2): + kwargs = _match_param(f2, f1) + return f1(*kwargs) + + with Capturing() as out: + f2(a=1, b=2, c=3) + assert 'Parameter:c' in out[0] # 传入了需要 warning + + assert f2(1, 2) == 1 diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index ff38b614..6d20754a 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -5,6 +5,9 @@ from fastNLP.core.dataset import DataSet from fastNLP.io.data_bundle import DataBundle from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.core import Trainer +from pkg_resources import parse_version +from tests.helpers.utils import Capturing, recover_logger +from fastNLP import logger if _NEED_IMPORT_TORCH: import torch @@ -128,3 +131,33 @@ class TestFdl: dl = DataLoader(MyDatset(), collate_fn=collate_batch) for batch in dl: print(batch) + + @recover_logger + def test_version_16(self): + if parse_version(torch.__version__) >= parse_version('1.7'): + pytest.skip("Torch version larger than 1.7") + logger.set_stdout() + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + with Capturing() as out: + dl = TorchDataLoader(ds, prefetch_factor=3, shuffle=False) + for idx, batch in enumerate(dl): + assert len(batch['x'])==1 + assert batch['x'][0].tolist() == ds[idx]['x'] + + assert 'Parameter:prefetch_factor' in out[0] + + @recover_logger + def test_version_111(self): + if parse_version(torch.__version__) <= parse_version('1.7'): + pytest.skip("Torch version smaller than 1.7") + logger.set_stdout() + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + with Capturing() as out: + dl = TorchDataLoader(ds, num_workers=2, prefetch_factor=3, shuffle=False) + for idx, batch in enumerate(dl): + assert len(batch['x'])==1 + assert batch['x'][0].tolist() == ds[idx]['x'] + + assert 'Parameter:prefetch_factor' not in out[0] + +