Browse Source

fix flops count bug for ConvTranspose2d

release-1.11.1
fanhqme2 GitHub 3 years ago
parent
commit
6ef1e12cfd
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 1 deletions
  1. +10
    -1
      imperative/python/megengine/utils/module_stats.py

+ 10
- 1
imperative/python/megengine/utils/module_stats.py View File

@@ -84,7 +84,7 @@ def disable_receptive_field():




@register_flops( @register_flops(
M.Conv1d, M.Conv2d, M.Conv3d, M.ConvTranspose2d, M.LocalConv2d, M.DeformableConv2d
M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d
) )
def flops_convNd(module: M.Conv2d, inputs, outputs): def flops_convNd(module: M.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0 bias = 1 if module.bias is not None else 0
@@ -93,6 +93,15 @@ def flops_convNd(module: M.Conv2d, inputs, outputs):
module.in_channels // module.groups * np.prod(module.kernel_size) + bias module.in_channels // module.groups * np.prod(module.kernel_size) + bias
) )


@register_flops(
M.ConvTranspose2d
)
def flops_convNdTranspose(module: M.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return np.prod(inputs[0].shape) * (
module.out_channels // module.groups * np.prod(module.kernel_size)
) + np.prod(outputs[0].shape) * bias


@register_flops( @register_flops(
M.batchnorm._BatchNorm, M.SyncBatchNorm, M.GroupNorm, M.LayerNorm, M.InstanceNorm, M.batchnorm._BatchNorm, M.SyncBatchNorm, M.GroupNorm, M.LayerNorm, M.InstanceNorm,


Loading…
Cancel
Save