GitOrigin-RevId: 8e31a00c7e
release-1.7
@@ -9,6 +9,7 @@ | |||||
import collections | import collections | ||||
from collections import OrderedDict, defaultdict | from collections import OrderedDict, defaultdict | ||||
from functools import partial | from functools import partial | ||||
from inspect import FullArgSpec | |||||
from typing import Callable, NamedTuple | from typing import Callable, NamedTuple | ||||
import numpy as np | import numpy as np | ||||
@@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = { | |||||
QuantMode, | QuantMode, | ||||
ArgsIndex, | ArgsIndex, | ||||
Group, | Group, | ||||
FullArgSpec, | |||||
} | } | ||||
USER_REGISTERED_LEAF_TYPE = [] | USER_REGISTERED_LEAF_TYPE = [] | ||||
@@ -1928,8 +1928,11 @@ class TracedModule(Module): | |||||
self.watch_node_value = {} | self.watch_node_value = {} | ||||
self.end_points = [] | self.end_points = [] | ||||
self.is_qat = is_qat | self.is_qat = is_qat | ||||
self.argspec = None | |||||
def forward(self, *args, **kwargs): | def forward(self, *args, **kwargs): | ||||
if hasattr(self, "argspec") and self.argspec is not None: | |||||
args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True) | |||||
inputs, treedef = tree_flatten(((self, *args), kwargs)) | inputs, treedef = tree_flatten(((self, *args), kwargs)) | ||||
assert treedef in self.argdef_graph_map | assert treedef in self.argdef_graph_map | ||||
inputs = filter( | inputs = filter( | ||||
@@ -2422,8 +2425,12 @@ def trace_module( | |||||
NodeMixin.wrap_safe( | NodeMixin.wrap_safe( | ||||
builder, Input.make(name="top", type=ModuleNode, qualname=net_name) | builder, Input.make(name="top", type=ModuleNode, qualname=net_name) | ||||
) | ) | ||||
args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True) | |||||
forward_argspec = ( | |||||
mod.argspec | |||||
if hasattr(mod, "argspec") | |||||
else inspect.getfullargspec(mod.forward) | |||||
) | |||||
args, kwargs = _convert_kwargs_to_args(forward_argspec, args, kwargs, True) | |||||
inputs, _ = tree_flatten((args, kwargs)) | inputs, _ = tree_flatten((args, kwargs)) | ||||
for _, i in enumerate(inputs): | for _, i in enumerate(inputs): | ||||
# assert isinstance(i, Tensor), "not support " | # assert isinstance(i, Tensor), "not support " | ||||
@@ -2439,6 +2446,7 @@ def trace_module( | |||||
builder(*args, **kwargs) | builder(*args, **kwargs) | ||||
active_module_tracer().pop_scope() | active_module_tracer().pop_scope() | ||||
traced_mod = builder.build() | traced_mod = builder.build() | ||||
traced_mod.argspec = forward_argspec | |||||
traced_mod.graph._reset_ids() | traced_mod.graph._reset_ids() | ||||
return traced_mod | return traced_mod | ||||
finally: | finally: | ||||
@@ -9,7 +9,8 @@ import collections | |||||
import copy | import copy | ||||
import inspect | import inspect | ||||
from collections.abc import MutableMapping, MutableSequence | from collections.abc import MutableMapping, MutableSequence | ||||
from typing import Dict, Iterable, List, Optional, Sequence, Type | |||||
from inspect import FullArgSpec | |||||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union | |||||
from .. import get_logger | from .. import get_logger | ||||
from ..module import Module | from ..module import Module | ||||
@@ -57,9 +58,14 @@ def replace_container_with_module_container(container): | |||||
return has_module, module_container | return has_module, module_container | ||||
def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False): | |||||
def _convert_kwargs_to_args( | |||||
argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False | |||||
): | |||||
# is_bounded = True when func is a method and provided args don't include 'self' | # is_bounded = True when func is a method and provided args don't include 'self' | ||||
arg_specs = inspect.getfullargspec(func) | |||||
arg_specs = ( | |||||
inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs | |||||
) | |||||
assert isinstance(arg_specs, FullArgSpec) | |||||
arg_specs_args = arg_specs.args | arg_specs_args = arg_specs.args | ||||
if is_bounded: | if is_bounded: | ||||
arg_specs_args = arg_specs.args[1:] | arg_specs_args = arg_specs.args[1:] | ||||
@@ -5,6 +5,7 @@ import numpy as np | |||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.module as M | import megengine.module as M | ||||
from megengine import Tensor | from megengine import Tensor | ||||
from megengine.module.module import Module | |||||
from megengine.traced_module import TracedModule, trace_module | from megengine.traced_module import TracedModule, trace_module | ||||
from megengine.traced_module.expr import CallFunction | from megengine.traced_module.expr import CallFunction | ||||
@@ -89,5 +90,46 @@ def test_trace_module(): | |||||
m4 = MyModule4() | m4 = MyModule4() | ||||
tm4 = trace_module(m4, a, b) | tm4 = trace_module(m4, a, b) | ||||
np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||||
np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||||
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||||
tm4 = trace_module(m4, a, y=b) | |||||
np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||||
np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||||
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||||
tm4 = trace_module(m4, x=a, y=b) | |||||
np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||||
np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||||
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||||
tm5 = trace_module(tm4, a, b) | |||||
np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||||
np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||||
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||||
tm5 = trace_module(tm4, a, y=b) | |||||
np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||||
np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||||
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||||
tm5 = trace_module(tm4, x=a, y=b) | |||||
np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||||
np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||||
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||||
assert len(tm4.graph._exprs) == 1 | assert len(tm4.graph._exprs) == 1 | ||||
assert isinstance(tm4.graph._exprs[0], CallFunction) | assert isinstance(tm4.graph._exprs[0], CallFunction) | ||||
class MyModule5(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.m1 = tm4 | |||||
def forward(self, x, y): | |||||
return self.m1(x, y) | |||||
tm6 = trace_module(MyModule5(), a, b) | |||||
assert tm6.m1.argspec is None | |||||
assert tm6.m1._is_top is False |