diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index a874bf3b..13281cfe 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -189,11 +189,18 @@ def replace_sampler(dataloader: "DataLoader", sampler): # 中寻找; has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) if has_variadic_kwargs: + # 这里之所以这样写是因为用户自己定制的 Dataloader 中名字一样的参数所设置的默认值可能不同;因此不能直接使用 update 覆盖掉了; for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): if key not in init_params and key != 'self': init_params[key] = value - reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} + # 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; + non_default_params = {name for name, p in init_params.items() if + name in instance_attrs and p.default != instance_attrs[name]} + # add `dataset` as it might have been replaced with `*args` + non_default_params.add("dataset") + + reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) required_args = {