|
|
@@ -189,11 +189,18 @@ def replace_sampler(dataloader: "DataLoader", sampler): |
|
|
|
# 中寻找; |
|
|
|
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) |
|
|
|
if has_variadic_kwargs: |
|
|
|
# 这里之所以这样写是因为用户自己定制的 Dataloader 中名字一样的参数所设置的默认值可能不同;因此不能直接使用 update 覆盖掉了; |
|
|
|
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): |
|
|
|
if key not in init_params and key != 'self': |
|
|
|
init_params[key] = value |
|
|
|
|
|
|
|
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in init_params} |
|
|
|
# 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; |
|
|
|
non_default_params = {name for name, p in init_params.items() if |
|
|
|
name in instance_attrs and p.default != instance_attrs[name]} |
|
|
|
# add `dataset` as it might have been replaced with `*args` |
|
|
|
non_default_params.add("dataset") |
|
|
|
|
|
|
|
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} |
|
|
|
reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) |
|
|
|
|
|
|
|
required_args = { |
|
|
|