GitOrigin-RevId: 8e31a00c7e
release-1.7
@@ -9,6 +9,7 @@ | |||
import collections | |||
from collections import OrderedDict, defaultdict | |||
from functools import partial | |||
from inspect import FullArgSpec | |||
from typing import Callable, NamedTuple | |||
import numpy as np | |||
@@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = { | |||
QuantMode, | |||
ArgsIndex, | |||
Group, | |||
FullArgSpec, | |||
} | |||
USER_REGISTERED_LEAF_TYPE = [] | |||
@@ -1928,8 +1928,11 @@ class TracedModule(Module): | |||
self.watch_node_value = {} | |||
self.end_points = [] | |||
self.is_qat = is_qat | |||
self.argspec = None | |||
def forward(self, *args, **kwargs): | |||
if hasattr(self, "argspec") and self.argspec is not None: | |||
args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True) | |||
inputs, treedef = tree_flatten(((self, *args), kwargs)) | |||
assert treedef in self.argdef_graph_map | |||
inputs = filter( | |||
@@ -2422,8 +2425,12 @@ def trace_module( | |||
NodeMixin.wrap_safe( | |||
builder, Input.make(name="top", type=ModuleNode, qualname=net_name) | |||
) | |||
args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True) | |||
forward_argspec = ( | |||
mod.argspec | |||
if hasattr(mod, "argspec") | |||
else inspect.getfullargspec(mod.forward) | |||
) | |||
args, kwargs = _convert_kwargs_to_args(forward_argspec, args, kwargs, True) | |||
inputs, _ = tree_flatten((args, kwargs)) | |||
for _, i in enumerate(inputs): | |||
# assert isinstance(i, Tensor), "not support " | |||
@@ -2439,6 +2446,7 @@ def trace_module( | |||
builder(*args, **kwargs) | |||
active_module_tracer().pop_scope() | |||
traced_mod = builder.build() | |||
traced_mod.argspec = forward_argspec | |||
traced_mod.graph._reset_ids() | |||
return traced_mod | |||
finally: | |||
@@ -9,7 +9,8 @@ import collections | |||
import copy | |||
import inspect | |||
from collections.abc import MutableMapping, MutableSequence | |||
from typing import Dict, Iterable, List, Optional, Sequence, Type | |||
from inspect import FullArgSpec | |||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union | |||
from .. import get_logger | |||
from ..module import Module | |||
@@ -57,9 +58,14 @@ def replace_container_with_module_container(container): | |||
return has_module, module_container | |||
def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False): | |||
def _convert_kwargs_to_args( | |||
argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False | |||
): | |||
# is_bounded = True when func is a method and provided args don't include 'self' | |||
arg_specs = inspect.getfullargspec(func) | |||
arg_specs = ( | |||
inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs | |||
) | |||
assert isinstance(arg_specs, FullArgSpec) | |||
arg_specs_args = arg_specs.args | |||
if is_bounded: | |||
arg_specs_args = arg_specs.args[1:] | |||
@@ -5,6 +5,7 @@ import numpy as np | |||
import megengine.functional as F | |||
import megengine.module as M | |||
from megengine import Tensor | |||
from megengine.module.module import Module | |||
from megengine.traced_module import TracedModule, trace_module | |||
from megengine.traced_module.expr import CallFunction | |||
@@ -89,5 +90,46 @@ def test_trace_module(): | |||
m4 = MyModule4() | |||
tm4 = trace_module(m4, a, b) | |||
np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||
np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||
tm4 = trace_module(m4, a, y=b) | |||
np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||
np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||
tm4 = trace_module(m4, x=a, y=b) | |||
np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||
np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||
tm5 = trace_module(tm4, a, b) | |||
np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||
np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||
tm5 = trace_module(tm4, a, y=b) | |||
np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||
np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||
tm5 = trace_module(tm4, x=a, y=b) | |||
np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||
np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||
assert len(tm4.graph._exprs) == 1 | |||
assert isinstance(tm4.graph._exprs[0], CallFunction) | |||
class MyModule5(Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.m1 = tm4 | |||
def forward(self, x, y): | |||
return self.m1(x, y) | |||
tm6 = trace_module(MyModule5(), a, b) | |||
assert tm6.m1.argspec is None | |||
assert tm6.m1._is_top is False |