From 8d0602f5ded9b4c8c28b86c0e715ef8e5ea67c9f Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sun, 5 Jun 2022 20:07:36 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20torch=5Fdriver/utils/repla?= =?UTF-8?q?ce=5Fsampler=20=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/torch_driver/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index a874bf3b..13281cfe 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -189,11 +189,18 @@ def replace_sampler(dataloader: "DataLoader", sampler): # 中寻找; has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) if has_variadic_kwargs: + # 这里之所以这样写是因为用户自己定制的 Dataloader 中名字一样的参数所设置的默认值可能不同;因此不能直接使用 update 覆盖掉了; for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): if key not in init_params and key != 'self': init_params[key] = value - reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} + # 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; + non_default_params = {name for name, p in init_params.items() if + name in instance_attrs and p.default != instance_attrs[name]} + # add `dataset` as it might have been replaced with `*args` + non_default_params.add("dataset") + + reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) required_args = {