Browse Source

fix(mge): update format

GitOrigin-RevId: 618164e2cd
release-1.11.1
Megvii Engine Team 2 years ago
parent
commit
d27a4456f9
1 changed files with 7 additions and 6 deletions
  1. +7
    -6
      imperative/python/megengine/utils/module_stats.py

+ 7
- 6
imperative/python/megengine/utils/module_stats.py View File

@@ -83,9 +83,7 @@ def disable_receptive_field():
_receptive_field_enabled = False _receptive_field_enabled = False




@register_flops(
M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d
)
@register_flops(M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d)
def flops_convNd(module: M.Conv2d, inputs, outputs): def flops_convNd(module: M.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0 bias = 1 if module.bias is not None else 0
# N x Cout x H x W x (Cin x Kw x Kh + bias) # N x Cout x H x W x (Cin x Kw x Kh + bias)
@@ -93,13 +91,16 @@ def flops_convNd(module: M.Conv2d, inputs, outputs):
float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias
) )



@register_flops(M.ConvTranspose2d) @register_flops(M.ConvTranspose2d)
def flops_convNdTranspose(module: M.Conv2d, inputs, outputs): def flops_convNdTranspose(module: M.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0 bias = 1 if module.bias is not None else 0
# N x Cout x H x W x (Cin x Kw x Kh + bias) # N x Cout x H x W x (Cin x Kw x Kh + bias)
return np.prod(inputs[0].shape) * (
module.out_channels // module.groups * np.prod(module.kernel_size)
) + np.prod(outputs[0].shape) * bias
return (
np.prod(inputs[0].shape)
* (module.out_channels // module.groups * np.prod(module.kernel_size))
+ np.prod(outputs[0].shape) * bias
)




@register_flops( @register_flops(


Loading…
Cancel
Save