Browse Source

fix(mge/utils): fix module stats calculate flops bug for group conv and remove model status change

GitOrigin-RevId: 647dc6eb66
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
da167cbc05
1 changed files with 35 additions and 18 deletions
  1. +35
    -18
      imperative/python/megengine/utils/module_stats.py

+ 35
- 18
imperative/python/megengine/utils/module_stats.py View File

@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import contextlib
from functools import partial from functools import partial


import numpy as np import numpy as np
@@ -87,30 +88,20 @@ def disable_receptive_field():




@register_flops( @register_flops(
m.Conv1d, m.Conv2d, m.Conv3d,
m.Conv1d, m.Conv2d, m.Conv3d, m.ConvTranspose2d, m.LocalConv2d, m.DeformableConv2d
) )
def flops_convNd(module: m.Conv2d, inputs, outputs): def flops_convNd(module: m.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0 bias = 1 if module.bias is not None else 0
group = module.groups
ic = inputs[0].shape[1]
oc = outputs[0].shape[1]
goc = oc // group
gic = ic // group
N = outputs[0].shape[0]
HW = np.prod(outputs[0].shape[2:])
# N x Cout x H x W x (Cin x Kw x Kh + bias) # N x Cout x H x W x (Cin x Kw x Kh + bias)
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias)


@register_flops(m.ConvTranspose2d)
def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs):
return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size)
return np.prod(outputs[0].shape) * (
module.in_channels // module.groups * np.prod(module.kernel_size) + bias
)




@register_flops(m.Linear) @register_flops(m.Linear)
def flops_linear(module: m.Linear, inputs, outputs): def flops_linear(module: m.Linear, inputs, outputs):
bias = 1 if module.bias is not None else 0
return np.prod(outputs[0].shape) * module.in_features
bias = module.out_features if module.bias is not None else 0
return np.prod(outputs[0].shape) * module.in_features + bias




@register_flops(m.BatchMatMulActivation) @register_flops(m.BatchMatMulActivation)
@@ -340,6 +331,31 @@ def module_stats(
param_stats["name"] = name + "-b" param_stats["name"] = name + "-b"
params.append(param_stats) params.append(param_stats)


@contextlib.contextmanager
def adjust_stats(module, training=False):
"""Adjust module to training/eval mode temporarily.

Args:
module (M.Module): used module.
training (bool): training mode. True for train mode, False fro eval mode.
"""

def recursive_backup_stats(module, mode):
for m in module.modules():
# save prev status to _prev_training
m._prev_training = m.training
m.train(mode, recursive=False)

def recursive_recover_stats(module):
for m in module.modules():
# recover prev status and delete attribute
m.training = m._prev_training
delattr(m, "_prev_training")

recursive_backup_stats(module, mode=training)
yield module
recursive_recover_stats(module)

# multiple inputs to the network # multiple inputs to the network
if not isinstance(input_size[0], tuple): if not isinstance(input_size[0], tuple):
input_size = [input_size] input_size = [input_size]
@@ -355,8 +371,9 @@ def module_stats(
) )


inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size]
model.eval()
model(*inputs)
with adjust_stats(model, training=False) as model:
model(*inputs)

for h in hooks: for h in hooks:
h.remove() h.remove()




Loading…
Cancel
Save