|
|
@@ -13,24 +13,20 @@ from typing import Iterable |
|
|
|
import numpy as np |
|
|
|
import tabulate |
|
|
|
|
|
|
|
import megengine as mge |
|
|
|
import megengine.module as m |
|
|
|
import megengine.module.qat as qatm |
|
|
|
import megengine.module.quantized as qm |
|
|
|
from megengine import Tensor |
|
|
|
from megengine import functional as F |
|
|
|
from megengine.core.tensor.dtype import get_dtype_bit |
|
|
|
from megengine.functional.tensor import zeros |
|
|
|
from megengine.tensor import Tensor |
|
|
|
|
|
|
|
from .. import Tensor |
|
|
|
from .. import functional as F |
|
|
|
from .. import get_logger |
|
|
|
from .. import module as M |
|
|
|
from ..core.tensor.dtype import get_dtype_bit |
|
|
|
from ..logger import MegEngineLogFormatter |
|
|
|
from .module_utils import set_module_mode_safe |
|
|
|
|
|
|
|
try: |
|
|
|
mge.logger.MegEngineLogFormatter.max_lines = float("inf") |
|
|
|
MegEngineLogFormatter.max_lines = float("inf") |
|
|
|
except AttributeError as e: |
|
|
|
raise ValueError("set logger max lines failed") |
|
|
|
|
|
|
|
logger = mge.get_logger(__name__) |
|
|
|
logger = get_logger(__name__) |
|
|
|
logger.setLevel("INFO") |
|
|
|
|
|
|
|
|
|
|
@@ -95,9 +91,9 @@ def disable_receptive_field(): |
|
|
|
|
|
|
|
|
|
|
|
@register_flops( |
|
|
|
m.Conv1d, m.Conv2d, m.Conv3d, m.ConvTranspose2d, m.LocalConv2d, m.DeformableConv2d |
|
|
|
M.Conv1d, M.Conv2d, M.Conv3d, M.ConvTranspose2d, M.LocalConv2d, M.DeformableConv2d |
|
|
|
) |
|
|
|
def flops_convNd(module: m.Conv2d, inputs, outputs): |
|
|
|
def flops_convNd(module: M.Conv2d, inputs, outputs): |
|
|
|
bias = 1 if module.bias is not None else 0 |
|
|
|
# N x Cout x H x W x (Cin x Kw x Kh + bias) |
|
|
|
return np.prod(outputs[0].shape) * ( |
|
|
@@ -106,14 +102,14 @@ def flops_convNd(module: m.Conv2d, inputs, outputs): |
|
|
|
|
|
|
|
|
|
|
|
@register_flops( |
|
|
|
m.batchnorm._BatchNorm, m.SyncBatchNorm, m.GroupNorm, m.LayerNorm, m.InstanceNorm, |
|
|
|
M.batchnorm._BatchNorm, M.SyncBatchNorm, M.GroupNorm, M.LayerNorm, M.InstanceNorm, |
|
|
|
) |
|
|
|
def flops_norm(module: m.Linear, inputs, outputs): |
|
|
|
def flops_norm(module: M.Linear, inputs, outputs): |
|
|
|
return np.prod(inputs[0].shape) * 7 |
|
|
|
|
|
|
|
|
|
|
|
@register_flops(m.AvgPool2d, m.MaxPool2d) |
|
|
|
def flops_pool(module: m.AvgPool2d, inputs, outputs): |
|
|
|
@register_flops(M.AvgPool2d, M.MaxPool2d) |
|
|
|
def flops_pool(module: M.AvgPool2d, inputs, outputs): |
|
|
|
kernel_sum = 0 |
|
|
|
if isinstance(module.kernel_size, tuple) and len(module.kernel_size) == 2: |
|
|
|
kernel_sum = np.prod(module.kernel_size) |
|
|
@@ -122,8 +118,8 @@ def flops_pool(module: m.AvgPool2d, inputs, outputs): |
|
|
|
return np.prod(outputs[0].shape) * kernel_sum |
|
|
|
|
|
|
|
|
|
|
|
@register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d) |
|
|
|
def flops_adaptivePool(module: m.AdaptiveAvgPool2d, inputs, outputs): |
|
|
|
@register_flops(M.AdaptiveAvgPool2d, M.AdaptiveMaxPool2d) |
|
|
|
def flops_adaptivePool(module: M.AdaptiveAvgPool2d, inputs, outputs): |
|
|
|
stride_h = np.floor(inputs[0].shape[2] / (inputs[0].shape[2] - 1)) |
|
|
|
kernel_h = inputs[0].shape[2] - (inputs[0].shape[2] - 1) * stride_h |
|
|
|
stride_w = np.floor(inputs[0].shape[3] / (inputs[0].shape[3] - 1)) |
|
|
@@ -131,14 +127,14 @@ def flops_adaptivePool(module: m.AdaptiveAvgPool2d, inputs, outputs): |
|
|
|
return np.prod(outputs[0].shape) * kernel_h * kernel_w |
|
|
|
|
|
|
|
|
|
|
|
@register_flops(m.Linear) |
|
|
|
def flops_linear(module: m.Linear, inputs, outputs): |
|
|
|
@register_flops(M.Linear) |
|
|
|
def flops_linear(module: M.Linear, inputs, outputs): |
|
|
|
bias = module.out_features if module.bias is not None else 0 |
|
|
|
return np.prod(outputs[0].shape) * module.in_features + bias |
|
|
|
|
|
|
|
|
|
|
|
@register_flops(m.BatchMatMulActivation) |
|
|
|
def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs): |
|
|
|
@register_flops(M.BatchMatMulActivation) |
|
|
|
def flops_batchmatmul(module: M.BatchMatMulActivation, inputs, outputs): |
|
|
|
bias = 1 if module.bias is not None else 0 |
|
|
|
x = inputs[0] |
|
|
|
w = module.weight |
|
|
@@ -150,25 +146,25 @@ def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs): |
|
|
|
|
|
|
|
# does not need import qat and quantized module since they inherit from float module. |
|
|
|
hook_modules = ( |
|
|
|
m.conv._ConvNd, |
|
|
|
m.Linear, |
|
|
|
m.BatchMatMulActivation, |
|
|
|
m.batchnorm._BatchNorm, |
|
|
|
m.LayerNorm, |
|
|
|
m.GroupNorm, |
|
|
|
m.InstanceNorm, |
|
|
|
m.pooling._PoolNd, |
|
|
|
m.adaptive_pooling._AdaptivePoolNd, |
|
|
|
M.conv._ConvNd, |
|
|
|
M.Linear, |
|
|
|
M.BatchMatMulActivation, |
|
|
|
M.batchnorm._BatchNorm, |
|
|
|
M.LayerNorm, |
|
|
|
M.GroupNorm, |
|
|
|
M.InstanceNorm, |
|
|
|
M.pooling._PoolNd, |
|
|
|
M.adaptive_pooling._AdaptivePoolNd, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _mean(inp): |
|
|
|
inp = mge.tensor(inp).astype(np.float32) |
|
|
|
inp = Tensor(inp).astype(np.float32) |
|
|
|
return F.mean(inp).numpy() |
|
|
|
|
|
|
|
|
|
|
|
def _std(inp): |
|
|
|
inp = mge.tensor(inp).astype(np.float32) |
|
|
|
inp = Tensor(inp).astype(np.float32) |
|
|
|
return F.std(inp).numpy() |
|
|
|
|
|
|
|
|
|
|
@@ -412,7 +408,7 @@ def print_summary(**kwargs): |
|
|
|
|
|
|
|
|
|
|
|
def module_stats( |
|
|
|
model: m.Module, |
|
|
|
model: M.Module, |
|
|
|
inputs: Iterable[np.ndarray] = None, |
|
|
|
input_shapes: list = None, |
|
|
|
cal_params: bool = True, |
|
|
@@ -457,7 +453,7 @@ def module_stats( |
|
|
|
if input_shapes: |
|
|
|
if not isinstance(input_shapes[0], tuple): |
|
|
|
input_shapes = [input_shapes] |
|
|
|
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] |
|
|
|
inputs = [F.zeros(in_size, dtype=np.float32) for in_size in input_shapes] |
|
|
|
else: |
|
|
|
logger.error( |
|
|
|
"Inputs or input_shapes is required for running model and calculating stats.", |
|
|
|