|
|
@@ -83,9 +83,7 @@ def disable_receptive_field(): |
|
|
|
_receptive_field_enabled = False |
|
|
|
|
|
|
|
|
|
|
|
@register_flops( |
|
|
|
M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d |
|
|
|
) |
|
|
|
@register_flops(M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d) |
|
|
|
def flops_convNd(module: M.Conv2d, inputs, outputs): |
|
|
|
bias = 1 if module.bias is not None else 0 |
|
|
|
# N x Cout x H x W x (Cin x Kw x Kh + bias) |
|
|
@@ -93,13 +91,16 @@ def flops_convNd(module: M.Conv2d, inputs, outputs): |
|
|
|
float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@register_flops(M.ConvTranspose2d) |
|
|
|
def flops_convNdTranspose(module: M.Conv2d, inputs, outputs): |
|
|
|
bias = 1 if module.bias is not None else 0 |
|
|
|
# N x Cout x H x W x (Cin x Kw x Kh + bias) |
|
|
|
return np.prod(inputs[0].shape) * ( |
|
|
|
module.out_channels // module.groups * np.prod(module.kernel_size) |
|
|
|
) + np.prod(outputs[0].shape) * bias |
|
|
|
return ( |
|
|
|
np.prod(inputs[0].shape) |
|
|
|
* (module.out_channels // module.groups * np.prod(module.kernel_size)) |
|
|
|
+ np.prod(outputs[0].shape) * bias |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@register_flops( |
|
|
|