From f75261d76f74277ca2487c24189ff0e18ab88a91 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 20 Dec 2021 19:03:55 +0800 Subject: [PATCH] fix(mge/utils): fix module_status error GitOrigin-RevId: 9e004d98a17e408fd63b162f4f8fe868aaad57bd --- imperative/python/megengine/utils/module_stats.py | 2 +- .../python/test/unit/utils/test_module_stats.py | 26 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index b3a2589d..b298153f 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -443,7 +443,7 @@ def module_stats( if isinstance(x, np.ndarray): return Tensor(x) elif isinstance(x, collections.abc.Mapping): - return {k: load_tensor(x) for k, v in x.items()} + return {k: load_tensor(v) for k, v in x.items()} elif isinstance(x, tuple) and hasattr(x, "_fields"): # nametuple return type(x)(*(load_tensor(value) for value in x)) elif isinstance(x, collections.abc.Sequence): diff --git a/imperative/python/test/unit/utils/test_module_stats.py b/imperative/python/test/unit/utils/test_module_stats.py index d4848000..f7d748a5 100644 --- a/imperative/python/test/unit/utils/test_module_stats.py +++ b/imperative/python/test/unit/utils/test_module_stats.py @@ -1,3 +1,4 @@ +import collections import math from copy import deepcopy @@ -27,6 +28,31 @@ def test_module_stats(): assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,) +@pytest.mark.skipif( + use_symbolic_shape(), reason="This test do not support symbolic shape.", +) +def test_other_input_module_state(): + a = [1, 2] + b = {"1": 1, "2": 2} + nt = collections.namedtuple("nt", ["n", "t"]) + _nt = nt(n=1, t=2) + net = FakeNet() + net(a) + net(b) + net(_nt) + + +class FakeNet(M.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + assert isinstance( + x, + (np.ndarray, collections.abc.Mapping, collections.abc.Sequence, mge.Tensor), + ) or (isinstance(x, tuple) and hasattr(x, "_fields")) + + class BasicBlock(M.Module): expansion = 1