Browse Source

fix(mge/tools): fix module_stats for duplicated module

GitOrigin-RevId: a15f17d616
HuaHua404-patch-4
Megvii Engine Team 2 years ago
parent
commit
19af26887e
2 changed files with 72 additions and 3 deletions
  1. +13
    -3
      imperative/python/megengine/utils/module_stats.py
  2. +59
    -0
      imperative/python/test/unit/utils/test_module_stats.py

+ 13
- 3
imperative/python/megengine/utils/module_stats.py View File

@@ -457,6 +457,7 @@ def module_stats(
log_activations = False log_activations = False


disable_receptive_field() disable_receptive_field()
recorded_parameters = set()


def module_stats_hook(module, inputs, outputs, name=""): def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0] class_name = str(module.__class__).split(".")[-1].split("'")[0]
@@ -468,17 +469,27 @@ def module_stats(
flops.append(flops_stats) flops.append(flops_stats)


if cal_params: if cal_params:
if hasattr(module, "weight") and module.weight is not None:
if (
hasattr(module, "weight")
and (module.weight is not None)
and module.weight not in recorded_parameters
):
w = module.weight w = module.weight
param_stats = get_param_stats(w) param_stats = get_param_stats(w)
param_stats["name"] = name + "-w" param_stats["name"] = name + "-w"
params.append(param_stats) params.append(param_stats)
recorded_parameters.add(w)


if hasattr(module, "bias") and module.bias is not None:
if (
hasattr(module, "bias")
and module.bias is not None
and module.bias not in recorded_parameters
):
b = module.bias b = module.bias
param_stats = get_param_stats(b) param_stats = get_param_stats(b)
param_stats["name"] = name + "-b" param_stats["name"] = name + "-b"
params.append(param_stats) params.append(param_stats)
recorded_parameters.add(b)


if cal_activations: if cal_activations:
if not isinstance(outputs, (tuple, list)): if not isinstance(outputs, (tuple, list)):
@@ -504,7 +515,6 @@ def module_stats(
hooks.append( hooks.append(
module.register_forward_hook(partial(module_stats_hook, name=name)) module.register_forward_hook(partial(module_stats_hook, name=name))
) )

with set_module_mode_safe(model, training=False) as model: with set_module_mode_safe(model, training=False) as model:
model(*inputs) model(*inputs)




+ 59
- 0
imperative/python/test/unit/utils/test_module_stats.py View File

@@ -42,6 +42,65 @@ def test_other_input_module_state():
net(_nt) net(_nt)




@pytest.mark.skipif(
use_symbolic_shape(), reason="This test do not support symbolic shape.",
)
def test_duplicated_module():
input_shape = (1, 3, 224, 224)

net0 = TestNet0()
net0_stats, _ = module_stats(net0, input_shapes=input_shape)

net1 = TestNet1()
net1_stats, _ = module_stats(net1, input_shapes=input_shape)

net2 = TestNet2()
net2_stats, _ = module_stats(net2, input_shapes=input_shape)

assert net0_stats.param_dims == net1_stats.param_dims
assert net0_stats.param_size == net1_stats.param_size

assert net0_stats.param_dims == net2_stats.param_dims
assert net0_stats.param_size == net2_stats.param_size


class TestNet0(M.Module):
def __init__(self):
super().__init__()
self.conv = M.Conv2d(3, 3, 3, padding=(1, 1))
self.conv.bias = mge.Parameter(
np.random.random(self.conv.bias.shape).astype(np.float32)
)

def forward(self, x):
x = self.conv(x)
return x


class TestNet1(TestNet0):
def __init__(self):
super().__init__()
self.conv1 = self.conv

def forward(self, x):
x = self.conv(x)
x = self.conv1(x)
return x


class TestNet2(TestNet0):
def __init__(self):
super().__init__()
self.conv1 = M.Conv2d(3, 3, 3, padding=(1, 1))
self.conv1.weight = self.conv.weight
self.conv1.bias = self.conv.bias

def forward(self, x):
x = self.conv(x)
x = self.conv1(x)
return x


class FakeNet(M.Module): class FakeNet(M.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()


Loading…
Cancel
Save