From 30e565e5b8f1a3b23eb7bdd28afa093e3e2b427f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 1 Dec 2021 14:26:05 +0800 Subject: [PATCH] fix(traced_module): fix error message GitOrigin-RevId: 3046225e30757d26d1e3423e2e57f46db725f958 --- imperative/python/megengine/traced_module/utils.py | 33 +++++++++++++++------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/imperative/python/megengine/traced_module/utils.py b/imperative/python/megengine/traced_module/utils.py index d93b658f..8fe9dd87 100644 --- a/imperative/python/megengine/traced_module/utils.py +++ b/imperative/python/megengine/traced_module/utils.py @@ -6,6 +6,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import copy +import contextlib import inspect from collections.abc import MutableMapping, MutableSequence from inspect import FullArgSpec @@ -65,8 +66,12 @@ def _convert_kwargs_to_args( arg_specs = ( inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs ) + func_name = argspecs.__qualname__ if isinstance(argspecs, Callable) else "function" assert isinstance(arg_specs, FullArgSpec) arg_specs_args = arg_specs.args + arg_specs_defaults = arg_specs.defaults if arg_specs.defaults else [] + arg_specs_kwonlyargs = arg_specs.kwonlyargs + arg_specs_kwonlydefaults = arg_specs.kwonlydefaults if arg_specs.kwonlydefaults else dict() if is_bounded: arg_specs_args = arg_specs.args[1:] new_args = [] @@ -76,31 +81,39 @@ def _convert_kwargs_to_args( repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()) raise TypeError( "{} got multiple values for argument {}".format( - func.__qualname__, ", ".join(repeated_arg_name) + func_name, ", ".join(repeated_arg_name) ) ) - if len(new_args) < len(arg_specs.args): + if len(new_args) < len(arg_specs_args): for ind in range(len(new_args), len(arg_specs_args)): arg_name = arg_specs_args[ind] if arg_name in kwargs: new_args.append(kwargs[arg_name]) else: - index = ind - len(arg_specs_args) + len(arg_specs.defaults) - assert index < len(arg_specs.defaults) and index >= 0 - new_args.append(arg_specs.defaults[index]) + index = ind - len(arg_specs_args) + len(arg_specs_defaults) + if index >= len(arg_specs_defaults) or index < 0: + raise TypeError( + "{} missing required positional arguments: {}".format( + func_name, arg_name + ) + ) + new_args.append(arg_specs_defaults[index]) - for kwarg_name in arg_specs.kwonlyargs: + for kwarg_name in arg_specs_kwonlyargs: if kwarg_name in kwargs: new_kwargs[kwarg_name] = kwargs[kwarg_name] else: - assert kwarg_name in arg_specs.kwonlydefaults - new_kwargs[kwarg_name] = arg_specs.kwonlydefaults[kwarg_name] + if kwarg_name not in arg_specs_kwonlydefaults: + raise TypeError("{} missing required keyword-only argument: {}".format( + func_name, kwarg_name + )) + new_kwargs[kwarg_name] = arg_specs_kwonlydefaults[kwarg_name] for k, v in kwargs.items(): - if k not in arg_specs.args and k not in arg_specs.kwonlyargs: + if k not in arg_specs_args and k not in arg_specs_kwonlyargs: if arg_specs.varkw is None: raise TypeError( "{} got an unexpected keyword argument {}".format( - func.__qualname__, k + func_name, k ) ) new_kwargs[k] = v