Browse Source

fix(mge/utils): fix module_status error

GitOrigin-RevId: 9e004d98a1
tags/v1.7.2.m1
Megvii Engine Team 3 years ago
parent
commit
f75261d76f
2 changed files with 27 additions and 1 deletions
  1. +1
    -1
      imperative/python/megengine/utils/module_stats.py
  2. +26
    -0
      imperative/python/test/unit/utils/test_module_stats.py

+ 1
- 1
imperative/python/megengine/utils/module_stats.py View File

@@ -443,7 +443,7 @@ def module_stats(
if isinstance(x, np.ndarray):
return Tensor(x)
elif isinstance(x, collections.abc.Mapping):
return {k: load_tensor(x) for k, v in x.items()}
return {k: load_tensor(v) for k, v in x.items()}
elif isinstance(x, tuple) and hasattr(x, "_fields"): # nametuple
return type(x)(*(load_tensor(value) for value in x))
elif isinstance(x, collections.abc.Sequence):


+ 26
- 0
imperative/python/test/unit/utils/test_module_stats.py View File

@@ -1,3 +1,4 @@
import collections
import math
from copy import deepcopy

@@ -27,6 +28,31 @@ def test_module_stats():
assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)


@pytest.mark.skipif(
use_symbolic_shape(), reason="This test do not support symbolic shape.",
)
def test_other_input_module_state():
a = [1, 2]
b = {"1": 1, "2": 2}
nt = collections.namedtuple("nt", ["n", "t"])
_nt = nt(n=1, t=2)
net = FakeNet()
net(a)
net(b)
net(_nt)


class FakeNet(M.Module):
def __init__(self):
super().__init__()

def forward(self, x):
assert isinstance(
x,
(np.ndarray, collections.abc.Mapping, collections.abc.Sequence, mge.Tensor),
) or (isinstance(x, tuple) and hasattr(x, "_fields"))


class BasicBlock(M.Module):
expansion = 1



Loading…
Cancel
Save