GitOrigin-RevId: a15f17d616
HuaHua404-patch-4
@@ -457,6 +457,7 @@ def module_stats( | |||
log_activations = False | |||
disable_receptive_field() | |||
recorded_parameters = set() | |||
def module_stats_hook(module, inputs, outputs, name=""): | |||
class_name = str(module.__class__).split(".")[-1].split("'")[0] | |||
@@ -468,17 +469,27 @@ def module_stats( | |||
flops.append(flops_stats) | |||
if cal_params: | |||
if hasattr(module, "weight") and module.weight is not None: | |||
if ( | |||
hasattr(module, "weight") | |||
and (module.weight is not None) | |||
and module.weight not in recorded_parameters | |||
): | |||
w = module.weight | |||
param_stats = get_param_stats(w) | |||
param_stats["name"] = name + "-w" | |||
params.append(param_stats) | |||
recorded_parameters.add(w) | |||
if hasattr(module, "bias") and module.bias is not None: | |||
if ( | |||
hasattr(module, "bias") | |||
and module.bias is not None | |||
and module.bias not in recorded_parameters | |||
): | |||
b = module.bias | |||
param_stats = get_param_stats(b) | |||
param_stats["name"] = name + "-b" | |||
params.append(param_stats) | |||
recorded_parameters.add(b) | |||
if cal_activations: | |||
if not isinstance(outputs, (tuple, list)): | |||
@@ -504,7 +515,6 @@ def module_stats( | |||
hooks.append( | |||
module.register_forward_hook(partial(module_stats_hook, name=name)) | |||
) | |||
with set_module_mode_safe(model, training=False) as model: | |||
model(*inputs) | |||
@@ -42,6 +42,65 @@ def test_other_input_module_state(): | |||
net(_nt) | |||
@pytest.mark.skipif( | |||
use_symbolic_shape(), reason="This test do not support symbolic shape.", | |||
) | |||
def test_duplicated_module(): | |||
input_shape = (1, 3, 224, 224) | |||
net0 = TestNet0() | |||
net0_stats, _ = module_stats(net0, input_shapes=input_shape) | |||
net1 = TestNet1() | |||
net1_stats, _ = module_stats(net1, input_shapes=input_shape) | |||
net2 = TestNet2() | |||
net2_stats, _ = module_stats(net2, input_shapes=input_shape) | |||
assert net0_stats.param_dims == net1_stats.param_dims | |||
assert net0_stats.param_size == net1_stats.param_size | |||
assert net0_stats.param_dims == net2_stats.param_dims | |||
assert net0_stats.param_size == net2_stats.param_size | |||
class TestNet0(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv = M.Conv2d(3, 3, 3, padding=(1, 1)) | |||
self.conv.bias = mge.Parameter( | |||
np.random.random(self.conv.bias.shape).astype(np.float32) | |||
) | |||
def forward(self, x): | |||
x = self.conv(x) | |||
return x | |||
class TestNet1(TestNet0): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv1 = self.conv | |||
def forward(self, x): | |||
x = self.conv(x) | |||
x = self.conv1(x) | |||
return x | |||
class TestNet2(TestNet0): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv1 = M.Conv2d(3, 3, 3, padding=(1, 1)) | |||
self.conv1.weight = self.conv.weight | |||
self.conv1.bias = self.conv.bias | |||
def forward(self, x): | |||
x = self.conv(x) | |||
x = self.conv1(x) | |||
return x | |||
class FakeNet(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||