@@ -47,6 +47,7 @@ from fastNLP.core.collators.collator import Collator | |||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||
from fastNLP.core.dataset import DataSet as FDataSet | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler | |||
from ..utils import _match_param | |||
class _PaddleDataset(Dataset): | |||
@@ -154,14 +155,17 @@ class PaddleDataLoader(DataLoader): | |||
else: | |||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||
return_list=return_list, batch_sampler=batch_sampler, | |||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||
collate_fn=collate_fn, num_workers=num_workers, | |||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
dl_kwargs = _match_param(PaddleDataLoader.__init__, DataLoader.__init__, DataLoader.__name__) | |||
if dl_kwargs is None: | |||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||
return_list=return_list, batch_sampler=batch_sampler, | |||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||
collate_fn=collate_fn, num_workers=num_workers, | |||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
else: | |||
super().__init__(**dl_kwargs) | |||
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | |||
# if collate_fn is not None: | |||
# _collate_fn.add_collator(collate_fn) | |||
@@ -11,6 +11,7 @@ from fastNLP.core.collators import Collator | |||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | |||
from ..utils import _match_param | |||
if _NEED_IMPORT_TORCH: | |||
from torch.utils.data import DataLoader, Sampler | |||
@@ -96,12 +97,16 @@ class TorchDataLoader(DataLoader): | |||
else: | |||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, | |||
persistent_workers=persistent_workers) | |||
dl_kwargs = _match_param(TorchDataLoader.__init__, DataLoader.__init__, fn_name=DataLoader.__name__) | |||
if dl_kwargs is None: | |||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, | |||
persistent_workers=persistent_workers) | |||
else: | |||
super().__init__(**dl_kwargs) | |||
self.cur_batch_indices = None | |||
@@ -1,4 +1,9 @@ | |||
from typing import Callable | |||
import inspect | |||
import ast | |||
from ..log import logger | |||
from ..utils.cache_results import get_func_calls, truncate_start_blanks | |||
__all__ = [ | |||
"indice_collate_wrapper" | |||
] | |||
@@ -25,6 +30,72 @@ def indice_collate_wrapper(func:Callable): | |||
return _indice_collate_wrapper | |||
def _match_param(fun, call_fn:Callable, fn_name:str=None): | |||
""" | |||
在调用 _match_param 的函数(就是 fun )中会调用 call_fn 这个函数。由于 fun 中支持的函数比 call_fn 更多,例如低版本的 | |||
:class:`~.fastNLP.TorchDataLoader` 中支持的参数,在torch 1.6 版本的 DataLoader 就不支持,但在高版本的 torch 中是支持的 | |||
因此,这里需要根据当前版本的 DataLoader 判定出适合传入 DataLoader 进行初始化的参数,并且在不支持但又被设置的参数上进行 | |||
warning 。 | |||
:param fun: 调用函数本身 | |||
:param call_fn: | |||
:param fn_name: 方便报错的用的函数 | |||
:return: | |||
""" | |||
try: | |||
if fn_name is None: | |||
try: | |||
fn_name = call_fn.__name__ | |||
except: | |||
fn_name = str(call_fn) | |||
last_frame = inspect.currentframe().f_back | |||
# 调用 _match_param 的函数名称,获取默认的参数值 | |||
fun_default_params = {} | |||
fun_parameters = inspect.signature(fun) | |||
for name, fun_param in fun_parameters.parameters.items(): | |||
if fun_param.default is not fun_param.empty: | |||
fun_default_params[name] = fun_param.default | |||
# 获取实际传入的参数值 | |||
param_names, args_name, kwargs_name, values = inspect.getargvalues(last_frame) | |||
if args_name is not None: | |||
raise RuntimeError("Function does not support positional arguments, such as: fun(*args).") | |||
kwargs = values.get(kwargs_name, {}) | |||
for param in param_names: | |||
if param not in values: | |||
value = fun_default_params.get(param) | |||
else: | |||
value = values[param] | |||
kwargs[param] = value | |||
# 根据需要实际需要调用的 call_fn 的参数进行匹配 | |||
call_fn_parameters = inspect.signature(call_fn) | |||
call_fn_kwargs = {} | |||
has_kwargs = False | |||
for name, param in call_fn_parameters.parameters.items(): | |||
if name == 'self': | |||
continue | |||
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY): # 最前面的 args | |||
call_fn_kwargs[name] = param.default | |||
if param.kind == param.VAR_KEYWORD: | |||
has_kwargs = True | |||
# 组装得到最终的参数 | |||
call_kwargs = {} | |||
for name, value in kwargs.items(): | |||
if name in call_fn_kwargs or has_kwargs: # 如果存在在里面,或者包含了 kwargs 就直接运行 | |||
call_kwargs[name] = value | |||
# 如果不在需要调用的函数里面,同时又是非默认值 | |||
elif name not in call_fn_kwargs and name in fun_default_params and fun_default_params[name]!=value: | |||
logger.rank_zero_warning(f"Parameter:{name} is not supported for {fn_name}.") | |||
return call_kwargs | |||
except BaseException as e: | |||
logger.debug(f"Exception happens when match parameters for {fn_name}: {e}") | |||
return None | |||
if __name__ == '__main__': | |||
def demo(*args, **kwargs): | |||
pass | |||
@@ -0,0 +1,39 @@ | |||
import pytest | |||
from fastNLP.core.dataloaders.utils import _match_param | |||
from fastNLP import logger | |||
from tests.helpers.utils import recover_logger, Capturing | |||
def demo(): | |||
pass | |||
def test_no_args(): | |||
def f(*args, a, b, **kwarg): | |||
c = 100 | |||
call_kwargs = _match_param(f, demo) | |||
with pytest.raises(RuntimeError): | |||
f(a=1, b=2) | |||
def f(a, *args, b, **kwarg): | |||
c = 100 | |||
call_kwargs = _match_param(f, demo) | |||
with pytest.raises(RuntimeError): | |||
f(a=1, b=2) | |||
@recover_logger | |||
def test_warning(): | |||
logger.set_stdout('raw') | |||
def f1(a, b): | |||
return 1 | |||
def f2(a, b, c=2): | |||
kwargs = _match_param(f2, f1) | |||
return f1(*kwargs) | |||
with Capturing() as out: | |||
f2(a=1, b=2, c=3) | |||
assert 'Parameter:c' in out[0] # 传入了需要 warning | |||
assert f2(1, 2) == 1 |
@@ -5,6 +5,9 @@ from fastNLP.core.dataset import DataSet | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core import Trainer | |||
from pkg_resources import parse_version | |||
from tests.helpers.utils import Capturing, recover_logger | |||
from fastNLP import logger | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
@@ -128,3 +131,33 @@ class TestFdl: | |||
dl = DataLoader(MyDatset(), collate_fn=collate_batch) | |||
for batch in dl: | |||
print(batch) | |||
@recover_logger | |||
def test_version_16(self): | |||
if parse_version(torch.__version__) >= parse_version('1.7'): | |||
pytest.skip("Torch version larger than 1.7") | |||
logger.set_stdout() | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
with Capturing() as out: | |||
dl = TorchDataLoader(ds, prefetch_factor=3, shuffle=False) | |||
for idx, batch in enumerate(dl): | |||
assert len(batch['x'])==1 | |||
assert batch['x'][0].tolist() == ds[idx]['x'] | |||
assert 'Parameter:prefetch_factor' in out[0] | |||
@recover_logger | |||
def test_version_111(self): | |||
if parse_version(torch.__version__) <= parse_version('1.7'): | |||
pytest.skip("Torch version smaller than 1.7") | |||
logger.set_stdout() | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
with Capturing() as out: | |||
dl = TorchDataLoader(ds, num_workers=2, prefetch_factor=3, shuffle=False) | |||
for idx, batch in enumerate(dl): | |||
assert len(batch['x'])==1 | |||
assert batch['x'][0].tolist() == ds[idx]['x'] | |||
assert 'Parameter:prefetch_factor' not in out[0] | |||