|
|
@@ -5,6 +5,7 @@ |
|
|
|
# Unless required by applicable law or agreed to in writing, |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import contextlib |
|
|
|
from functools import partial |
|
|
|
|
|
|
|
import numpy as np |
|
|
@@ -87,30 +88,20 @@ def disable_receptive_field(): |
|
|
|
|
|
|
|
|
|
|
|
@register_flops( |
|
|
|
m.Conv1d, m.Conv2d, m.Conv3d, |
|
|
|
m.Conv1d, m.Conv2d, m.Conv3d, m.ConvTranspose2d, m.LocalConv2d, m.DeformableConv2d |
|
|
|
) |
|
|
|
def flops_convNd(module: m.Conv2d, inputs, outputs): |
|
|
|
bias = 1 if module.bias is not None else 0 |
|
|
|
group = module.groups |
|
|
|
ic = inputs[0].shape[1] |
|
|
|
oc = outputs[0].shape[1] |
|
|
|
goc = oc // group |
|
|
|
gic = ic // group |
|
|
|
N = outputs[0].shape[0] |
|
|
|
HW = np.prod(outputs[0].shape[2:]) |
|
|
|
# N x Cout x H x W x (Cin x Kw x Kh + bias) |
|
|
|
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) |
|
|
|
|
|
|
|
|
|
|
|
@register_flops(m.ConvTranspose2d) |
|
|
|
def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs): |
|
|
|
return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size) |
|
|
|
return np.prod(outputs[0].shape) * ( |
|
|
|
module.in_channels // module.groups * np.prod(module.kernel_size) + bias |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@register_flops(m.Linear) |
|
|
|
def flops_linear(module: m.Linear, inputs, outputs): |
|
|
|
bias = 1 if module.bias is not None else 0 |
|
|
|
return np.prod(outputs[0].shape) * module.in_features |
|
|
|
bias = module.out_features if module.bias is not None else 0 |
|
|
|
return np.prod(outputs[0].shape) * module.in_features + bias |
|
|
|
|
|
|
|
|
|
|
|
@register_flops(m.BatchMatMulActivation) |
|
|
@@ -340,6 +331,31 @@ def module_stats( |
|
|
|
param_stats["name"] = name + "-b" |
|
|
|
params.append(param_stats) |
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
|
def adjust_stats(module, training=False): |
|
|
|
"""Adjust module to training/eval mode temporarily. |
|
|
|
|
|
|
|
Args: |
|
|
|
module (M.Module): used module. |
|
|
|
training (bool): training mode. True for train mode, False fro eval mode. |
|
|
|
""" |
|
|
|
|
|
|
|
def recursive_backup_stats(module, mode): |
|
|
|
for m in module.modules(): |
|
|
|
# save prev status to _prev_training |
|
|
|
m._prev_training = m.training |
|
|
|
m.train(mode, recursive=False) |
|
|
|
|
|
|
|
def recursive_recover_stats(module): |
|
|
|
for m in module.modules(): |
|
|
|
# recover prev status and delete attribute |
|
|
|
m.training = m._prev_training |
|
|
|
delattr(m, "_prev_training") |
|
|
|
|
|
|
|
recursive_backup_stats(module, mode=training) |
|
|
|
yield module |
|
|
|
recursive_recover_stats(module) |
|
|
|
|
|
|
|
# multiple inputs to the network |
|
|
|
if not isinstance(input_size[0], tuple): |
|
|
|
input_size = [input_size] |
|
|
@@ -355,8 +371,9 @@ def module_stats( |
|
|
|
) |
|
|
|
|
|
|
|
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] |
|
|
|
model.eval() |
|
|
|
model(*inputs) |
|
|
|
with adjust_stats(model, training=False) as model: |
|
|
|
model(*inputs) |
|
|
|
|
|
|
|
for h in hooks: |
|
|
|
h.remove() |
|
|
|
|
|
|
|