Browse Source

feat(mge/traced_module): add argspec for top TracedModule

GitOrigin-RevId: 8e31a00c7e
release-1.7
Megvii Engine Team 3 years ago
parent
commit
0185c1a9b9
4 changed files with 63 additions and 5 deletions
  1. +2
    -0
      imperative/python/megengine/traced_module/pytree.py
  2. +10
    -2
      imperative/python/megengine/traced_module/traced_module.py
  3. +9
    -3
      imperative/python/megengine/traced_module/utils.py
  4. +42
    -0
      imperative/python/test/unit/traced_module/test_trace_module.py

+ 2
- 0
imperative/python/megengine/traced_module/pytree.py View File

@@ -9,6 +9,7 @@
import collections import collections
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from functools import partial from functools import partial
from inspect import FullArgSpec
from typing import Callable, NamedTuple from typing import Callable, NamedTuple


import numpy as np import numpy as np
@@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = {
QuantMode, QuantMode,
ArgsIndex, ArgsIndex,
Group, Group,
FullArgSpec,
} }


USER_REGISTERED_LEAF_TYPE = [] USER_REGISTERED_LEAF_TYPE = []


+ 10
- 2
imperative/python/megengine/traced_module/traced_module.py View File

@@ -1928,8 +1928,11 @@ class TracedModule(Module):
self.watch_node_value = {} self.watch_node_value = {}
self.end_points = [] self.end_points = []
self.is_qat = is_qat self.is_qat = is_qat
self.argspec = None


def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if hasattr(self, "argspec") and self.argspec is not None:
args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True)
inputs, treedef = tree_flatten(((self, *args), kwargs)) inputs, treedef = tree_flatten(((self, *args), kwargs))
assert treedef in self.argdef_graph_map assert treedef in self.argdef_graph_map
inputs = filter( inputs = filter(
@@ -2422,8 +2425,12 @@ def trace_module(
NodeMixin.wrap_safe( NodeMixin.wrap_safe(
builder, Input.make(name="top", type=ModuleNode, qualname=net_name) builder, Input.make(name="top", type=ModuleNode, qualname=net_name)
) )
args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True)

forward_argspec = (
mod.argspec
if hasattr(mod, "argspec")
else inspect.getfullargspec(mod.forward)
)
args, kwargs = _convert_kwargs_to_args(forward_argspec, args, kwargs, True)
inputs, _ = tree_flatten((args, kwargs)) inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
# assert isinstance(i, Tensor), "not support " # assert isinstance(i, Tensor), "not support "
@@ -2439,6 +2446,7 @@ def trace_module(
builder(*args, **kwargs) builder(*args, **kwargs)
active_module_tracer().pop_scope() active_module_tracer().pop_scope()
traced_mod = builder.build() traced_mod = builder.build()
traced_mod.argspec = forward_argspec
traced_mod.graph._reset_ids() traced_mod.graph._reset_ids()
return traced_mod return traced_mod
finally: finally:


+ 9
- 3
imperative/python/megengine/traced_module/utils.py View File

@@ -9,7 +9,8 @@ import collections
import copy import copy
import inspect import inspect
from collections.abc import MutableMapping, MutableSequence from collections.abc import MutableMapping, MutableSequence
from typing import Dict, Iterable, List, Optional, Sequence, Type
from inspect import FullArgSpec
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union


from .. import get_logger from .. import get_logger
from ..module import Module from ..module import Module
@@ -57,9 +58,14 @@ def replace_container_with_module_container(container):
return has_module, module_container return has_module, module_container




def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False):
def _convert_kwargs_to_args(
argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False
):
# is_bounded = True when func is a method and provided args don't include 'self' # is_bounded = True when func is a method and provided args don't include 'self'
arg_specs = inspect.getfullargspec(func)
arg_specs = (
inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs
)
assert isinstance(arg_specs, FullArgSpec)
arg_specs_args = arg_specs.args arg_specs_args = arg_specs.args
if is_bounded: if is_bounded:
arg_specs_args = arg_specs.args[1:] arg_specs_args = arg_specs.args[1:]


+ 42
- 0
imperative/python/test/unit/traced_module/test_trace_module.py View File

@@ -5,6 +5,7 @@ import numpy as np
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import Tensor from megengine import Tensor
from megengine.module.module import Module
from megengine.traced_module import TracedModule, trace_module from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module.expr import CallFunction from megengine.traced_module.expr import CallFunction


@@ -89,5 +90,46 @@ def test_trace_module():


m4 = MyModule4() m4 = MyModule4()
tm4 = trace_module(m4, a, b) tm4 = trace_module(m4, a, b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

tm4 = trace_module(m4, a, y=b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

tm4 = trace_module(m4, x=a, y=b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

tm5 = trace_module(tm4, a, b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

tm5 = trace_module(tm4, a, y=b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

tm5 = trace_module(tm4, x=a, y=b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

assert len(tm4.graph._exprs) == 1 assert len(tm4.graph._exprs) == 1
assert isinstance(tm4.graph._exprs[0], CallFunction) assert isinstance(tm4.graph._exprs[0], CallFunction)

class MyModule5(Module):
def __init__(self):
super().__init__()
self.m1 = tm4

def forward(self, x, y):
return self.m1(x, y)

tm6 = trace_module(MyModule5(), a, b)
assert tm6.m1.argspec is None
assert tm6.m1._is_top is False

Loading…
Cancel
Save