BREAKING CHANGE:
GitOrigin-RevId: ced3da3a12
release-1.5
@@ -9,6 +9,7 @@ | |||
import argparse | |||
import logging | |||
import re | |||
from collections import namedtuple | |||
import numpy as np | |||
@@ -16,12 +17,17 @@ from megengine.core.tensor.dtype import is_quantize | |||
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | |||
from megengine.utils.module_stats import ( | |||
enable_receptive_field, | |||
get_activation_stats, | |||
get_op_stats, | |||
get_param_stats, | |||
print_activations_stats, | |||
print_op_stats, | |||
print_param_stats, | |||
print_summary, | |||
sizeof_fmt, | |||
sum_activations_stats, | |||
sum_op_stats, | |||
sum_param_stats, | |||
) | |||
from megengine.utils.network import Network | |||
@@ -34,6 +40,7 @@ def visualize( | |||
bar_length_max: int = 20, | |||
log_params: bool = True, | |||
log_flops: bool = True, | |||
log_activations: bool = True, | |||
): | |||
r""" | |||
Load megengine dumped model and visualize graph structure with tensorboard log files. | |||
@@ -44,6 +51,7 @@ def visualize( | |||
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
:param log_params: whether print and record params size. | |||
:param log_flops: whether print and record op flops. | |||
:param log_activations: whether print and record op activations. | |||
""" | |||
if log_path: | |||
try: | |||
@@ -83,6 +91,10 @@ def visualize( | |||
node_list = [] | |||
flops_list = [] | |||
params_list = [] | |||
activations_list = [] | |||
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | |||
for node in graph.all_oprs: | |||
if hasattr(node, "output_idx"): | |||
node_oup = node.outputs[node.output_idx] | |||
@@ -124,6 +136,11 @@ def visualize( | |||
flops_stats["class_name"] = node.type | |||
flops_list.append(flops_stats) | |||
acts = get_activation_stats(node_oup.numpy()) | |||
acts["name"] = node.name | |||
acts["class_name"] = node.type | |||
activations_list.append(acts) | |||
if node.type == "ImmutableTensor": | |||
param_stats = get_param_stats(node.numpy()) | |||
# add tensor size attr | |||
@@ -149,20 +166,36 @@ def visualize( | |||
"#params": len(params_list), | |||
} | |||
total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||
( | |||
total_flops, | |||
total_param_dims, | |||
total_param_size, | |||
total_act_dims, | |||
total_param_size, | |||
) = (0, 0, 0, 0, 0) | |||
total_param_dims, total_param_size, params = sum_param_stats( | |||
params_list, bar_length_max | |||
) | |||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
if log_params: | |||
total_param_dims, total_param_size = print_param_stats( | |||
params_list, bar_length_max | |||
) | |||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
print_param_stats(params) | |||
total_flops, flops = sum_op_stats(flops_list, bar_length_max) | |||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
if log_flops: | |||
total_flops = print_op_stats(flops_list, bar_length_max) | |||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
if log_params and log_flops: | |||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||
total_flops / total_param_size | |||
) | |||
print_op_stats(flops) | |||
total_act_dims, total_act_size, activations = sum_activations_stats( | |||
activations_list, bar_length_max | |||
) | |||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
if log_activations: | |||
print_activations_stats(activations) | |||
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||
if log_path: | |||
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||
@@ -179,7 +212,12 @@ def visualize( | |||
# FIXME: remove this after resolving "span dist too large" warning | |||
_imperative_rt_logger.set_log_level(old_level) | |||
return total_param_size, total_flops | |||
return ( | |||
total_stats( | |||
param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||
), | |||
stats_details(params=params, flops=flops, activations=activations), | |||
) | |||
def main(): | |||
@@ -5,7 +5,7 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import contextlib | |||
from collections import namedtuple | |||
from functools import partial | |||
import numpy as np | |||
@@ -18,6 +18,8 @@ import megengine.module.quantized as qm | |||
from megengine.core.tensor.dtype import get_dtype_bit | |||
from megengine.functional.tensor import zeros | |||
from .module_utils import set_module_mode_safe | |||
try: | |||
mge.logger.MegEngineLogFormatter.max_lines = float("inf") | |||
except AttributeError as e: | |||
@@ -98,6 +100,27 @@ def flops_convNd(module: m.Conv2d, inputs, outputs): | |||
) | |||
@register_flops( | |||
m.batchnorm._BatchNorm, m.SyncBatchNorm, m.GroupNorm, m.LayerNorm, m.InstanceNorm, | |||
) | |||
def flops_norm(module: m.Linear, inputs, outputs): | |||
return np.prod(inputs[0].shape) * 7 | |||
@register_flops(m.AvgPool2d, m.MaxPool2d) | |||
def flops_pool(module: m.AvgPool2d, inputs, outputs): | |||
return np.prod(outputs[0].shape) * (module.kernel_size ** 2) | |||
@register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d) | |||
def flops_adaptivePool(module: m.AdaptiveAvgPool2d, inputs, outputs): | |||
stride_h = np.floor(inputs[0].shape[2] / (inputs[0].shape[2] - 1)) | |||
kernel_h = inputs[0].shape[2] - (inputs[0].shape[2] - 1) * stride_h | |||
stride_w = np.floor(inputs[0].shape[3] / (inputs[0].shape[3] - 1)) | |||
kernel_w = inputs[0].shape[3] - (inputs[0].shape[3] - 1) * stride_w | |||
return np.prod(outputs[0].shape) * kernel_h * kernel_w | |||
@register_flops(m.Linear) | |||
def flops_linear(module: m.Linear, inputs, outputs): | |||
bias = module.out_features if module.bias is not None else 0 | |||
@@ -120,6 +143,12 @@ hook_modules = ( | |||
m.conv._ConvNd, | |||
m.Linear, | |||
m.BatchMatMulActivation, | |||
m.batchnorm._BatchNorm, | |||
m.LayerNorm, | |||
m.GroupNorm, | |||
m.InstanceNorm, | |||
m.pooling._PoolNd, | |||
m.adaptive_pooling._AdaptivePoolNd, | |||
) | |||
@@ -137,12 +166,16 @@ def dict2table(list_of_dict, header): | |||
def sizeof_fmt(num, suffix="B"): | |||
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: | |||
if abs(num) < 1024.0: | |||
if suffix == "B": | |||
scale = 1024.0 | |||
units = ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"] | |||
else: | |||
scale = 1000.0 | |||
units = ["", "K", "M", "G", "T", "P", "E", "Z", "Y"] | |||
for unit in units: | |||
if abs(num) < scale or unit == units[-1]: | |||
return "{:3.3f} {}{}".format(num, unit, suffix) | |||
num /= 1024.0 | |||
sign_str = "-" if num < 0 else "" | |||
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | |||
num /= scale | |||
def preprocess_receptive_field(module, inputs, outputs): | |||
@@ -159,6 +192,8 @@ def preprocess_receptive_field(module, inputs, outputs): | |||
def get_op_stats(module, inputs, outputs): | |||
if not isinstance(outputs, tuple) and not isinstance(outputs, list): | |||
outputs = (outputs,) | |||
rst = { | |||
"input_shapes": [i.shape for i in inputs], | |||
"output_shapes": [o.shape for o in outputs], | |||
@@ -189,7 +224,7 @@ def get_op_stats(module, inputs, outputs): | |||
return | |||
def print_op_stats(flops, bar_length_max=20): | |||
def sum_op_stats(flops, bar_length_max=20): | |||
max_flops_num = max([i["flops_num"] for i in flops] + [0]) | |||
total_flops_num = 0 | |||
for d in flops: | |||
@@ -203,6 +238,18 @@ def print_op_stats(flops, bar_length_max=20): | |||
d["bar"] = "#" * bar_length | |||
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") | |||
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||
total_var_size = sum( | |||
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops | |||
) | |||
flops.append( | |||
dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||
) | |||
return total_flops_num, flops | |||
def print_op_stats(flops): | |||
header = [ | |||
"name", | |||
"class_name", | |||
@@ -216,19 +263,8 @@ def print_op_stats(flops, bar_length_max=20): | |||
if _receptive_field_enabled: | |||
header.insert(4, "receptive_field") | |||
header.insert(5, "stride") | |||
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||
total_var_size = sum( | |||
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops | |||
) | |||
flops.append( | |||
dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||
) | |||
logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) | |||
return total_flops_num | |||
def get_param_stats(param: np.ndarray): | |||
nbits = get_dtype_bit(param.dtype.name) | |||
@@ -246,7 +282,7 @@ def get_param_stats(param: np.ndarray): | |||
} | |||
def print_param_stats(params, bar_length_max=20): | |||
def sum_param_stats(params, bar_length_max=20): | |||
max_size = max([d["size"] for d in params] + [0]) | |||
total_param_dims, total_param_size = 0, 0 | |||
for d in params: | |||
@@ -265,6 +301,10 @@ def print_param_stats(params, bar_length_max=20): | |||
param_size = sizeof_fmt(total_param_size) | |||
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | |||
return total_param_dims, total_param_size, params | |||
def print_param_stats(params): | |||
header = [ | |||
"name", | |||
"dtype", | |||
@@ -272,18 +312,74 @@ def print_param_stats(params, bar_length_max=20): | |||
"mean", | |||
"std", | |||
"param_dim", | |||
"bits", | |||
"nbits", | |||
"size", | |||
"size_cum", | |||
"percentage", | |||
"size_bar", | |||
] | |||
logger.info( | |||
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | |||
) | |||
return total_param_dims, total_param_size | |||
def get_activation_stats(output: np.ndarray): | |||
out_shape = output.shape | |||
activations_dtype = output.dtype | |||
nbits = get_dtype_bit(activations_dtype.name) | |||
act_dim = np.prod(out_shape) | |||
act_size = act_dim * nbits // 8 | |||
return { | |||
"dtype": activations_dtype, | |||
"shape": out_shape, | |||
"act_dim": act_dim, | |||
"mean": "{:.3g}".format(output.mean()), | |||
"std": "{:.3g}".format(output.std()), | |||
"nbits": nbits, | |||
"size": act_size, | |||
} | |||
def sum_activations_stats(activations, bar_length_max=20): | |||
max_act_size = max([i["size"] for i in activations] + [0]) | |||
total_act_dims, total_act_size = 0, 0 | |||
for d in activations: | |||
total_act_size += int(d["size"]) | |||
total_act_dims += int(d["act_dim"]) | |||
d["size_cum"] = sizeof_fmt(total_act_size) | |||
for d in activations: | |||
ratio = d["ratio"] = d["size"] / total_act_size | |||
d["percentage"] = "{:.2f}%".format(ratio * 100) | |||
bar_length = int(d["size"] / max_act_size * bar_length_max) | |||
d["size_bar"] = "#" * bar_length | |||
d["size"] = sizeof_fmt(d["size"]) | |||
act_size = sizeof_fmt(total_act_size) | |||
activations.append(dict(name="total", act_dim=total_act_dims, size=act_size,)) | |||
return total_act_dims, total_act_size, activations | |||
def print_activations_stats(activations): | |||
header = [ | |||
"name", | |||
"class_name", | |||
"dtype", | |||
"shape", | |||
"mean", | |||
"std", | |||
"nbits", | |||
"act_dim", | |||
"size", | |||
"size_cum", | |||
"percentage", | |||
"size_bar", | |||
] | |||
logger.info( | |||
"activations stats: \n" | |||
+ tabulate.tabulate(dict2table(activations, header=header)) | |||
) | |||
def print_summary(**kwargs): | |||
@@ -294,25 +390,26 @@ def print_summary(**kwargs): | |||
def module_stats( | |||
model: m.Module, | |||
input_size: int, | |||
input_shapes: list, | |||
bar_length_max: int = 20, | |||
log_params: bool = True, | |||
log_flops: bool = True, | |||
log_activations: bool = True, | |||
): | |||
r""" | |||
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | |||
:param model: model that need to get stats info. | |||
:param input_size: size of input for running model and calculating stats. | |||
:param input_shapes: shapes of inputs for running model and calculating stats. | |||
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
:param log_params: whether print and record params size. | |||
:param log_flops: whether print and record op flops. | |||
:param log_activations: whether print and record op activations. | |||
""" | |||
disable_receptive_field() | |||
def module_stats_hook(module, inputs, outputs, name=""): | |||
class_name = str(module.__class__).split(".")[-1].split("'")[0] | |||
flops_stats = get_op_stats(module, inputs, outputs) | |||
if flops_stats is not None: | |||
flops_stats["name"] = name | |||
@@ -331,38 +428,25 @@ def module_stats( | |||
param_stats["name"] = name + "-b" | |||
params.append(param_stats) | |||
@contextlib.contextmanager | |||
def adjust_stats(module, training=False): | |||
"""Adjust module to training/eval mode temporarily. | |||
Args: | |||
module (M.Module): used module. | |||
training (bool): training mode. True for train mode, False fro eval mode. | |||
""" | |||
def recursive_backup_stats(module, mode): | |||
for m in module.modules(): | |||
# save prev status to _prev_training | |||
m._prev_training = m.training | |||
m.train(mode, recursive=False) | |||
def recursive_recover_stats(module): | |||
for m in module.modules(): | |||
# recover prev status and delete attribute | |||
m.training = m._prev_training | |||
delattr(m, "_prev_training") | |||
recursive_backup_stats(module, mode=training) | |||
yield module | |||
recursive_recover_stats(module) | |||
if not isinstance(outputs, tuple) or not isinstance(outputs, list): | |||
output = outputs.numpy() | |||
else: | |||
output = outputs[0].numpy() | |||
activation_stats = get_activation_stats(output) | |||
activation_stats["name"] = name | |||
activation_stats["class_name"] = class_name | |||
activations.append(activation_stats) | |||
# multiple inputs to the network | |||
if not isinstance(input_size[0], tuple): | |||
input_size = [input_size] | |||
if not isinstance(input_shapes[0], tuple): | |||
input_shapes = [input_shapes] | |||
params = [] | |||
flops = [] | |||
hooks = [] | |||
activations = [] | |||
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | |||
for (name, module) in model.named_modules(): | |||
if isinstance(module, hook_modules): | |||
@@ -370,8 +454,8 @@ def module_stats( | |||
module.register_forward_hook(partial(module_stats_hook, name=name)) | |||
) | |||
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] | |||
with adjust_stats(model, training=False) as model: | |||
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] | |||
with set_module_mode_safe(model, training=False) as model: | |||
model(*inputs) | |||
for h in hooks: | |||
@@ -380,19 +464,40 @@ def module_stats( | |||
extra_info = { | |||
"#params": len(params), | |||
} | |||
total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||
( | |||
total_flops, | |||
total_param_dims, | |||
total_param_size, | |||
total_act_dims, | |||
total_param_size, | |||
) = (0, 0, 0, 0, 0) | |||
total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max) | |||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
if log_params: | |||
total_param_dims, total_param_size = print_param_stats(params, bar_length_max) | |||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
print_param_stats(params) | |||
total_flops, flops = sum_op_stats(flops, bar_length_max) | |||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
if log_flops: | |||
total_flops = print_op_stats(flops, bar_length_max) | |||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
if log_params and log_flops: | |||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||
total_flops / total_param_size | |||
) | |||
print_op_stats(flops) | |||
total_act_dims, total_act_size, activations = sum_activations_stats( | |||
activations, bar_length_max | |||
) | |||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
if log_activations: | |||
print_activations_stats(activations) | |||
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||
print_summary(**extra_info) | |||
return total_param_size, total_flops | |||
return ( | |||
total_stats( | |||
param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||
), | |||
stats_details(params=params, flops=flops, activations=activations), | |||
) |
@@ -5,6 +5,7 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import contextlib | |||
from collections import Iterable | |||
from ..module import Sequential | |||
@@ -41,3 +42,28 @@ def set_expand_structure(obj: Module, key: str, value): | |||
parent[key] = value | |||
_access_structure(obj, key, callback=f) | |||
@contextlib.contextmanager | |||
def set_module_mode_safe( | |||
module: Module, training: bool = False, | |||
): | |||
"""Adjust module to training/eval mode temporarily. | |||
:param module: used module. | |||
:param training: training (bool): training mode. True for train mode, False fro eval mode. | |||
""" | |||
backup_stats = {} | |||
def recursive_backup_stats(module, mode): | |||
for m in module.modules(): | |||
backup_stats[m] = m.training | |||
m.train(mode, recursive=False) | |||
def recursive_recover_stats(module): | |||
for m in module.modules(): | |||
m.training = backup_stats.pop(m) | |||
recursive_backup_stats(module, mode=training) | |||
yield module | |||
recursive_recover_stats(module) |
@@ -0,0 +1,377 @@ | |||
import math | |||
from copy import deepcopy | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.functional as F | |||
import megengine.hub as hub | |||
import megengine.module as M | |||
from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.utils.module_stats import module_stats | |||
@pytest.mark.skipif( | |||
use_symbolic_shape(), reason="This test do not support symbolic shape.", | |||
) | |||
def test_module_stats(): | |||
net = ResNet(BasicBlock, [2, 2, 2, 2]) | |||
input_shape = (1, 3, 224, 224) | |||
total_stats, stats_details = module_stats(net, input_shape) | |||
x1 = mge.tensor(np.zeros((1, 3, 224, 224))) | |||
gt_flops, gt_acts = net.get_stats(x1) | |||
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||
gt_flops, | |||
gt_acts, | |||
) | |||
class BasicBlock(M.Module): | |||
expansion = 1 | |||
def __init__( | |||
self, | |||
in_channels, | |||
channels, | |||
stride=1, | |||
groups=1, | |||
base_width=64, | |||
dilation=1, | |||
norm=M.BatchNorm2d, | |||
): | |||
super().__init__() | |||
self.tmp_in_channels = in_channels | |||
self.tmp_channels = channels | |||
self.stride = stride | |||
if groups != 1 or base_width != 64: | |||
raise ValueError("BasicBlock only supports groups=1 and base_width=64") | |||
if dilation > 1: | |||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |||
self.conv1 = M.Conv2d( | |||
in_channels, channels, 3, stride, padding=dilation, bias=False | |||
) | |||
self.bn1 = norm(channels) | |||
self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False) | |||
self.bn2 = norm(channels) | |||
self.downsample_id = M.Identity() | |||
self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False) | |||
self.downsample_norm = norm(channels) | |||
def forward(self, x): | |||
identity = x | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = F.relu(x) | |||
x = self.conv2(x) | |||
x = self.bn2(x) | |||
if self.tmp_in_channels == self.tmp_channels and self.stride == 1: | |||
identity = self.downsample_id(identity) | |||
else: | |||
identity = self.downsample_conv(identity) | |||
identity = self.downsample_norm(identity) | |||
x += identity | |||
x = F.relu(x) | |||
return x | |||
def get_stats(self, x): | |||
activations, flops = 0, 0 | |||
identity = x | |||
in_x = deepcopy(x) | |||
x = self.conv1(x) | |||
tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
in_x = deepcopy(x) | |||
x = self.bn1(x) | |||
tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x = F.relu(x) | |||
in_x = deepcopy(x) | |||
x = self.conv2(x) | |||
tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
in_x = deepcopy(x) | |||
x = self.bn2(x) | |||
tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
if self.tmp_in_channels == self.tmp_channels and self.stride == 1: | |||
identity = self.downsample_id(identity) | |||
else: | |||
in_x = deepcopy(identity) | |||
identity = self.downsample_conv(identity) | |||
tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
in_x = deepcopy(identity) | |||
identity = self.downsample_norm(identity) | |||
tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x += identity | |||
x = F.relu(x) | |||
return x, flops, activations | |||
class ResNet(M.Module): | |||
def __init__( | |||
self, | |||
block, | |||
layers=[2, 2, 2, 2], | |||
num_classes=1000, | |||
zero_init_residual=False, | |||
groups=1, | |||
width_per_group=64, | |||
replace_stride_with_dilation=None, | |||
norm=M.BatchNorm2d, | |||
): | |||
super().__init__() | |||
self.in_channels = 64 | |||
self.dilation = 1 | |||
if replace_stride_with_dilation is None: | |||
# each element in the tuple indicates if we should replace | |||
# the 2x2 stride with a dilated convolution instead | |||
replace_stride_with_dilation = [False, False, False] | |||
if len(replace_stride_with_dilation) != 3: | |||
raise ValueError( | |||
"replace_stride_with_dilation should be None " | |||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation) | |||
) | |||
self.groups = groups | |||
self.base_width = width_per_group | |||
self.conv1 = M.Conv2d( | |||
3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False | |||
) | |||
self.bn1 = norm(self.in_channels) | |||
self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
self.layer1_0 = BasicBlock( | |||
self.in_channels, | |||
64, | |||
stride=1, | |||
groups=self.groups, | |||
base_width=self.base_width, | |||
dilation=self.dilation, | |||
norm=M.BatchNorm2d, | |||
) | |||
self.layer1_1 = BasicBlock( | |||
self.in_channels, | |||
64, | |||
stride=1, | |||
groups=self.groups, | |||
base_width=self.base_width, | |||
dilation=self.dilation, | |||
norm=M.BatchNorm2d, | |||
) | |||
self.layer2_0 = BasicBlock(64, 128, stride=2) | |||
self.layer2_1 = BasicBlock(128, 128) | |||
self.layer3_0 = BasicBlock(128, 256, stride=2) | |||
self.layer3_1 = BasicBlock(256, 256) | |||
self.layer4_0 = BasicBlock(256, 512, stride=2) | |||
self.layer4_1 = BasicBlock(512, 512) | |||
self.layer1 = self._make_layer(block, 64, layers[0], norm=norm) | |||
self.layer2 = self._make_layer( | |||
block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm | |||
) | |||
self.layer3 = self._make_layer( | |||
block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm | |||
) | |||
self.layer4 = self._make_layer( | |||
block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm | |||
) | |||
self.fc = M.Linear(512, num_classes) | |||
for m in self.modules(): | |||
if isinstance(m, M.Conv2d): | |||
M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |||
if m.bias is not None: | |||
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) | |||
bound = 1 / math.sqrt(fan_in) | |||
M.init.uniform_(m.bias, -bound, bound) | |||
elif isinstance(m, M.BatchNorm2d): | |||
M.init.ones_(m.weight) | |||
M.init.zeros_(m.bias) | |||
elif isinstance(m, M.Linear): | |||
M.init.msra_uniform_(m.weight, a=math.sqrt(5)) | |||
if m.bias is not None: | |||
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) | |||
bound = 1 / math.sqrt(fan_in) | |||
M.init.uniform_(m.bias, -bound, bound) | |||
if zero_init_residual: | |||
for m in self.modules(): | |||
M.init.zeros_(m.bn2.weight) | |||
def _make_layer( | |||
self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d | |||
): | |||
previous_dilation = self.dilation | |||
if dilate: | |||
self.dilation *= stride | |||
stride = 1 | |||
layers = [] | |||
layers.append( | |||
block( | |||
self.in_channels, | |||
channels, | |||
stride, | |||
groups=self.groups, | |||
base_width=self.base_width, | |||
dilation=previous_dilation, | |||
norm=norm, | |||
) | |||
) | |||
self.in_channels = channels * block.expansion | |||
for _ in range(1, blocks): | |||
layers.append( | |||
block( | |||
self.in_channels, | |||
channels, | |||
groups=self.groups, | |||
base_width=self.base_width, | |||
dilation=self.dilation, | |||
norm=norm, | |||
) | |||
) | |||
return M.Sequential(*layers) | |||
def extract_features(self, x): | |||
outputs = {} | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = F.relu(x) | |||
x = self.maxpool(x) | |||
outputs["stem"] = x | |||
x = self.layer1(x) | |||
outputs["res2"] = x | |||
x = self.layer2(x) | |||
outputs["res3"] = x | |||
x = self.layer3(x) | |||
outputs["res4"] = x | |||
x = self.layer4(x) | |||
outputs["res5"] = x | |||
return outputs | |||
def forward(self, x): | |||
x = self.extract_features(x)["res5"] | |||
x = F.avg_pool2d(x, 7) | |||
x = F.flatten(x, 1) | |||
x = self.fc(x) | |||
return x | |||
def get_stats(self, x): | |||
flops, activations = 0, 0 | |||
in_x = deepcopy(x) | |||
x = self.conv1(x) | |||
tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
in_x = deepcopy(x) | |||
x = self.bn1(x) | |||
tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x = F.relu(x) | |||
in_x = deepcopy(x) | |||
x = self.maxpool(x) | |||
tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
x = F.avg_pool2d(x, 7) | |||
x = F.flatten(x, 1) | |||
in_x = deepcopy(x) | |||
x = self.fc(x) | |||
tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x) | |||
activations += tmp_acts | |||
flops += tmp_flops | |||
return flops, activations | |||
def cal_conv_stats(module, input, output): | |||
bias = 1 if module.bias is not None else 0 | |||
flops = np.prod(output[0].shape) * ( | |||
module.in_channels // module.groups * np.prod(module.kernel_size) + bias | |||
) | |||
acts = np.prod(output[0].shape) | |||
return flops, acts | |||
def cal_norm_stats(module, input, output): | |||
return np.prod(input[0].shape) * 7, np.prod(output[0].shape) | |||
def cal_linear_stats(module, inputs, outputs): | |||
bias = module.out_features if module.bias is not None else 0 | |||
return ( | |||
np.prod(outputs[0].shape) * module.in_features + bias, | |||
np.prod(outputs[0].shape), | |||
) | |||
def cal_pool_stats(module, inputs, outputs): | |||
return ( | |||
np.prod(outputs[0].shape) * (module.kernel_size ** 2), | |||
np.prod(outputs[0].shape), | |||
) |