Browse Source

fix(traced_module): fix error message

GitOrigin-RevId: 3046225e30
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
30e565e5b8
1 changed files with 23 additions and 10 deletions
  1. +23
    -10
      imperative/python/megengine/traced_module/utils.py

+ 23
- 10
imperative/python/megengine/traced_module/utils.py View File

@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy import copy
import contextlib
import inspect import inspect
from collections.abc import MutableMapping, MutableSequence from collections.abc import MutableMapping, MutableSequence
from inspect import FullArgSpec from inspect import FullArgSpec
@@ -65,8 +66,12 @@ def _convert_kwargs_to_args(
arg_specs = ( arg_specs = (
inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs
) )
func_name = argspecs.__qualname__ if isinstance(argspecs, Callable) else "function"
assert isinstance(arg_specs, FullArgSpec) assert isinstance(arg_specs, FullArgSpec)
arg_specs_args = arg_specs.args arg_specs_args = arg_specs.args
arg_specs_defaults = arg_specs.defaults if arg_specs.defaults else []
arg_specs_kwonlyargs = arg_specs.kwonlyargs
arg_specs_kwonlydefaults = arg_specs.kwonlydefaults if arg_specs.kwonlydefaults else dict()
if is_bounded: if is_bounded:
arg_specs_args = arg_specs.args[1:] arg_specs_args = arg_specs.args[1:]
new_args = [] new_args = []
@@ -76,31 +81,39 @@ def _convert_kwargs_to_args(
repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()) repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys())
raise TypeError( raise TypeError(
"{} got multiple values for argument {}".format( "{} got multiple values for argument {}".format(
func.__qualname__, ", ".join(repeated_arg_name)
func_name, ", ".join(repeated_arg_name)
) )
) )
if len(new_args) < len(arg_specs.args):
if len(new_args) < len(arg_specs_args):
for ind in range(len(new_args), len(arg_specs_args)): for ind in range(len(new_args), len(arg_specs_args)):
arg_name = arg_specs_args[ind] arg_name = arg_specs_args[ind]
if arg_name in kwargs: if arg_name in kwargs:
new_args.append(kwargs[arg_name]) new_args.append(kwargs[arg_name])
else: else:
index = ind - len(arg_specs_args) + len(arg_specs.defaults)
assert index < len(arg_specs.defaults) and index >= 0
new_args.append(arg_specs.defaults[index])
index = ind - len(arg_specs_args) + len(arg_specs_defaults)
if index >= len(arg_specs_defaults) or index < 0:
raise TypeError(
"{} missing required positional arguments: {}".format(
func_name, arg_name
)
)
new_args.append(arg_specs_defaults[index])


for kwarg_name in arg_specs.kwonlyargs:
for kwarg_name in arg_specs_kwonlyargs:
if kwarg_name in kwargs: if kwarg_name in kwargs:
new_kwargs[kwarg_name] = kwargs[kwarg_name] new_kwargs[kwarg_name] = kwargs[kwarg_name]
else: else:
assert kwarg_name in arg_specs.kwonlydefaults
new_kwargs[kwarg_name] = arg_specs.kwonlydefaults[kwarg_name]
if kwarg_name not in arg_specs_kwonlydefaults:
raise TypeError("{} missing required keyword-only argument: {}".format(
func_name, kwarg_name
))
new_kwargs[kwarg_name] = arg_specs_kwonlydefaults[kwarg_name]
for k, v in kwargs.items(): for k, v in kwargs.items():
if k not in arg_specs.args and k not in arg_specs.kwonlyargs:
if k not in arg_specs_args and k not in arg_specs_kwonlyargs:
if arg_specs.varkw is None: if arg_specs.varkw is None:
raise TypeError( raise TypeError(
"{} got an unexpected keyword argument {}".format( "{} got an unexpected keyword argument {}".format(
func.__qualname__, k
func_name, k
) )
) )
new_kwargs[k] = v new_kwargs[k] = v


Loading…
Cancel
Save