Browse Source

Merge pull request #459 from Qsingle:fix_overflow_of_flops_calculate

GitOrigin-RevId: c1333e2089
HuaHua404-patch-1
Megvii Engine Team 2 years ago
parent
commit
6c4c4ca6e6
2 changed files with 2 additions and 2 deletions
  1. +1
    -1
      imperative/python/megengine/utils/module_stats.py
  2. +1
    -1
      imperative/python/megengine/utils/network_node.py

+ 1
- 1
imperative/python/megengine/utils/module_stats.py View File

@@ -90,7 +90,7 @@ def flops_convNd(module: M.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return np.prod(outputs[0].shape) * (
module.in_channels // module.groups * np.prod(module.kernel_size) + bias
float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias
)




+ 1
- 1
imperative/python/megengine/utils/network_node.py View File

@@ -487,7 +487,7 @@ def flops_conv(opnode: ConvolutionForward, inputs, outputs):
NCHW = np.prod(outputs[0].shape)
bias = 1 if isinstance(opnode, ConvBiasForward) else 0
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh + bias)
return NCHW * (float(num_input * kw * kh) + bias)


@register_receptive_field(ConvolutionForward, ConvBiasForward)


Loading…
Cancel
Save