Browse Source

修复 torch_driver/utils/replace_sampler 的bug

tags/v1.0.0alpha
YWMditto 3 years ago
parent
commit
8d0602f5de
1 changed files with 8 additions and 1 deletions
  1. +8
    -1
      fastNLP/core/drivers/torch_driver/utils.py

+ 8
- 1
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -189,11 +189,18 @@ def replace_sampler(dataloader: "DataLoader", sampler):
# 中寻找;
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
if has_variadic_kwargs:
# 这里之所以这样写是因为用户自己定制的 Dataloader 中名字一样的参数所设置的默认值可能不同;因此不能直接使用 update 覆盖掉了;
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items():
if key not in init_params and key != 'self':
init_params[key] = value

reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params}
# 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置;
non_default_params = {name for name, p in init_params.items() if
name in instance_attrs and p.default != instance_attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_default_params.add("dataset")

reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler))

required_args = {


Loading…
Cancel
Save