Browse Source

feat(mge/traced_module): add argspec for top TracedModule

GitOrigin-RevId: 8e31a00c7e
release-1.7
Megvii Engine Team 3 years ago
parent
commit
0185c1a9b9
4 changed files with 63 additions and 5 deletions
  1. +2
    -0
      imperative/python/megengine/traced_module/pytree.py
  2. +10
    -2
      imperative/python/megengine/traced_module/traced_module.py
  3. +9
    -3
      imperative/python/megengine/traced_module/utils.py
  4. +42
    -0
      imperative/python/test/unit/traced_module/test_trace_module.py

+ 2
- 0
imperative/python/megengine/traced_module/pytree.py View File

@@ -9,6 +9,7 @@
import collections
from collections import OrderedDict, defaultdict
from functools import partial
from inspect import FullArgSpec
from typing import Callable, NamedTuple

import numpy as np
@@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = {
QuantMode,
ArgsIndex,
Group,
FullArgSpec,
}

USER_REGISTERED_LEAF_TYPE = []


+ 10
- 2
imperative/python/megengine/traced_module/traced_module.py View File

@@ -1928,8 +1928,11 @@ class TracedModule(Module):
self.watch_node_value = {}
self.end_points = []
self.is_qat = is_qat
self.argspec = None

def forward(self, *args, **kwargs):
if hasattr(self, "argspec") and self.argspec is not None:
args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True)
inputs, treedef = tree_flatten(((self, *args), kwargs))
assert treedef in self.argdef_graph_map
inputs = filter(
@@ -2422,8 +2425,12 @@ def trace_module(
NodeMixin.wrap_safe(
builder, Input.make(name="top", type=ModuleNode, qualname=net_name)
)
args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True)

forward_argspec = (
mod.argspec
if hasattr(mod, "argspec")
else inspect.getfullargspec(mod.forward)
)
args, kwargs = _convert_kwargs_to_args(forward_argspec, args, kwargs, True)
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs):
# assert isinstance(i, Tensor), "not support "
@@ -2439,6 +2446,7 @@ def trace_module(
builder(*args, **kwargs)
active_module_tracer().pop_scope()
traced_mod = builder.build()
traced_mod.argspec = forward_argspec
traced_mod.graph._reset_ids()
return traced_mod
finally:


+ 9
- 3
imperative/python/megengine/traced_module/utils.py View File

@@ -9,7 +9,8 @@ import collections
import copy
import inspect
from collections.abc import MutableMapping, MutableSequence
from typing import Dict, Iterable, List, Optional, Sequence, Type
from inspect import FullArgSpec
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union

from .. import get_logger
from ..module import Module
@@ -57,9 +58,14 @@ def replace_container_with_module_container(container):
return has_module, module_container


def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False):
def _convert_kwargs_to_args(
argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False
):
# is_bounded = True when func is a method and provided args don't include 'self'
arg_specs = inspect.getfullargspec(func)
arg_specs = (
inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs
)
assert isinstance(arg_specs, FullArgSpec)
arg_specs_args = arg_specs.args
if is_bounded:
arg_specs_args = arg_specs.args[1:]


+ 42
- 0
imperative/python/test/unit/traced_module/test_trace_module.py View File

@@ -5,6 +5,7 @@ import numpy as np
import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.module.module import Module
from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module.expr import CallFunction

@@ -89,5 +90,46 @@ def test_trace_module():

m4 = MyModule4()
tm4 = trace_module(m4, a, b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

tm4 = trace_module(m4, a, y=b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

tm4 = trace_module(m4, x=a, y=b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)

tm5 = trace_module(tm4, a, b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

tm5 = trace_module(tm4, a, y=b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

tm5 = trace_module(tm4, x=a, y=b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)

assert len(tm4.graph._exprs) == 1
assert isinstance(tm4.graph._exprs[0], CallFunction)

class MyModule5(Module):
def __init__(self):
super().__init__()
self.m1 = tm4

def forward(self, x, y):
return self.m1(x, y)

tm6 = trace_module(MyModule5(), a, b)
assert tm6.m1.argspec is None
assert tm6.m1._is_top is False

Loading…
Cancel
Save