diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index 11d4d483..4c03533b 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -84,7 +84,7 @@ def disable_receptive_field(): @register_flops( - M.Conv1d, M.Conv2d, M.Conv3d, M.ConvTranspose2d, M.LocalConv2d, M.DeformableConv2d + M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d ) def flops_convNd(module: M.Conv2d, inputs, outputs): bias = 1 if module.bias is not None else 0 @@ -93,6 +93,15 @@ def flops_convNd(module: M.Conv2d, inputs, outputs): module.in_channels // module.groups * np.prod(module.kernel_size) + bias ) +@register_flops( + M.ConvTranspose2d +) +def flops_convNdTranspose(module: M.Conv2d, inputs, outputs): + bias = 1 if module.bias is not None else 0 + # N x Cout x H x W x (Cin x Kw x Kh + bias) + return np.prod(inputs[0].shape) * ( + module.out_channels // module.groups * np.prod(module.kernel_size) + ) + np.prod(outputs[0].shape) * bias @register_flops( M.batchnorm._BatchNorm, M.SyncBatchNorm, M.GroupNorm, M.LayerNorm, M.InstanceNorm,