BREAKING CHANGE:
GitOrigin-RevId: cd2a1acd11
release-1.4
@@ -12,6 +12,7 @@ import re | |||
from collections import namedtuple | |||
import numpy as np | |||
from tqdm import tqdm | |||
from megengine.core.tensor.dtype import is_quantize | |||
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | |||
@@ -37,10 +38,13 @@ logger = get_logger(__name__) | |||
def visualize( | |||
model_path: str, | |||
log_path: str, | |||
input: np.ndarray = None, | |||
inp_dict: dict = None, | |||
cal_params: bool = True, | |||
cal_flops: bool = True, | |||
cal_activations: bool = True, | |||
logging_to_stdout: bool = True, | |||
bar_length_max: int = 20, | |||
log_params: bool = True, | |||
log_flops: bool = True, | |||
log_activations: bool = True, | |||
): | |||
r""" | |||
Load megengine dumped model and visualize graph structure with tensorboard log files. | |||
@@ -48,10 +52,14 @@ def visualize( | |||
:param model_path: dir path for megengine dumped model. | |||
:param log_path: dir path for tensorboard graph log. | |||
:param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input. | |||
:param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used. | |||
:param cal_params: whether calculate and record params size. | |||
:param cal_flops: whether calculate and record op flops. | |||
:param cal_activations: whether calculate and record op activations. | |||
:param logging_to_stdout: whether print all calculated statistic details. | |||
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
:param log_params: whether print and record params size. | |||
:param log_flops: whether print and record op flops. | |||
:param log_activations: whether print and record op activations. | |||
""" | |||
if log_path: | |||
try: | |||
@@ -78,6 +86,27 @@ def visualize( | |||
enable_receptive_field() | |||
graph = Network.load(model_path) | |||
graph.reset_batch_size(1) | |||
has_input = False | |||
if input is not None or inp_dict is not None: | |||
has_input = True | |||
repl_dict = {} | |||
inp_vars = graph.input_vars | |||
if inp_dict is not None: | |||
assert len(inp_dict) == len( | |||
inp_vars | |||
), "Inputs are not sufficient for calculation." | |||
for v in inp_vars: | |||
new_input = graph.make_const(inp_dict[v.name], name=v.name) | |||
repl_dict[v] = new_input | |||
else: | |||
assert len(inp_vars) == 1, "The graph needs more than one input." | |||
inp_var = inp_vars[0] | |||
repl_dict[inp_var] = graph.make_const(input, name=inp_var.name) | |||
graph.replace_vars(repl_dict=repl_dict) | |||
graph._compile() | |||
def process_name(name): | |||
# nodes that start with point or contain float const will lead to display bug | |||
@@ -93,7 +122,7 @@ def visualize( | |||
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | |||
for node in graph.all_oprs: | |||
for node in tqdm(graph.all_oprs): | |||
if hasattr(node, "output_idx"): | |||
node_oup = node.outputs[node.output_idx] | |||
else: | |||
@@ -123,31 +152,35 @@ def visualize( | |||
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")), | |||
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), | |||
} | |||
flops_stats = get_op_stats(node, node.inputs, node.outputs) | |||
if flops_stats is not None: | |||
# add op flops attr | |||
if log_path and hasattr(flops_stats, "flops_num"): | |||
attr["flops"] = AttrValue( | |||
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||
) | |||
flops_stats["name"] = node.name | |||
flops_stats["class_name"] = node.type | |||
flops_list.append(flops_stats) | |||
acts = get_activation_stats(node_oup) | |||
if cal_flops: | |||
flops_stats = get_op_stats(node, node.inputs, node.outputs) | |||
if flops_stats is not None: | |||
# add op flops attr | |||
if log_path and hasattr(flops_stats, "flops_num"): | |||
attr["flops"] = AttrValue( | |||
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||
) | |||
flops_stats["name"] = node.name | |||
flops_stats["class_name"] = node.type | |||
flops_list.append(flops_stats) | |||
if cal_activations: | |||
acts = get_activation_stats(node_oup.numpy(), has_input=has_input) | |||
acts["name"] = node.name | |||
acts["class_name"] = node.type | |||
activations_list.append(acts) | |||
if node.type == "ImmutableTensor": | |||
param_stats = get_param_stats(node_oup) | |||
# add tensor size attr | |||
if log_path: | |||
attr["size"] = AttrValue( | |||
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") | |||
) | |||
param_stats["name"] = node.name | |||
params_list.append(param_stats) | |||
if cal_params: | |||
if node.type == "ImmutableTensor": | |||
param_stats = get_param_stats(node.numpy()) | |||
# add tensor size attr | |||
if log_path: | |||
attr["size"] = AttrValue( | |||
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") | |||
) | |||
param_stats["name"] = node.name | |||
params_list.append(param_stats) | |||
if log_path: | |||
node_list.append( | |||
@@ -169,31 +202,37 @@ def visualize( | |||
total_param_dims, | |||
total_param_size, | |||
total_act_dims, | |||
total_param_size, | |||
total_act_size, | |||
) = (0, 0, 0, 0, 0) | |||
total_param_dims, total_param_size, params = sum_param_stats( | |||
params_list, bar_length_max | |||
) | |||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
if log_params: | |||
print_param_stats(params) | |||
total_flops, flops = sum_op_stats(flops_list, bar_length_max) | |||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
if log_flops: | |||
print_op_stats(flops) | |||
total_act_dims, total_act_size, activations = sum_activations_stats( | |||
activations_list, bar_length_max | |||
) | |||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
if log_activations: | |||
print_activations_stats(activations) | |||
if cal_params: | |||
total_param_dims, total_param_size, params_list = sum_param_stats( | |||
params_list, bar_length_max | |||
) | |||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
if logging_to_stdout: | |||
print_param_stats(params_list) | |||
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||
if cal_flops: | |||
total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) | |||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
if logging_to_stdout: | |||
print_op_stats(flops_list) | |||
if cal_activations: | |||
total_act_dims, total_act_size, activations_list = sum_activations_stats( | |||
activations_list, bar_length_max | |||
) | |||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
if logging_to_stdout: | |||
print_activations_stats(activations_list, has_input=has_input) | |||
if cal_flops and cal_params: | |||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||
total_flops / total_param_size | |||
) | |||
if log_path: | |||
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||
@@ -211,7 +250,9 @@ def visualize( | |||
total_stats( | |||
param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||
), | |||
stats_details(params=params, flops=flops, activations=activations), | |||
stats_details( | |||
params=params_list, flops=flops_list, activations=activations_list | |||
), | |||
) | |||
@@ -229,12 +270,24 @@ def main(): | |||
help="size of bar indicating max flops or parameter size in net stats.", | |||
) | |||
parser.add_argument( | |||
"--log_params", | |||
"--cal_params", | |||
action="store_true", | |||
help="whether calculate and record params size.", | |||
) | |||
parser.add_argument( | |||
"--cal_flops", | |||
action="store_true", | |||
help="whether calculate and record op flops.", | |||
) | |||
parser.add_argument( | |||
"--cal_activations", | |||
action="store_true", | |||
help="whether print and record params size.", | |||
help="whether calculate and record op activations.", | |||
) | |||
parser.add_argument( | |||
"--log_flops", action="store_true", help="whether print and record op flops.", | |||
"--logging_to_stdout", | |||
action="store_true", | |||
help="whether print all calculated statistic details.", | |||
) | |||
parser.add_argument( | |||
"--all", | |||
@@ -243,8 +296,10 @@ def main(): | |||
) | |||
args = parser.parse_args() | |||
if args.all: | |||
args.log_params = True | |||
args.log_flops = True | |||
args.cal_params = True | |||
args.cal_flops = True | |||
args.cal_activations = True | |||
args.logging_to_stdout = True | |||
if not args.log_path: | |||
args.log_path = "./log" | |||
kwargs = vars(args) | |||
@@ -5,8 +5,9 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from collections import namedtuple | |||
from collections import Iterable, namedtuple | |||
from functools import partial | |||
from typing import Iterable | |||
import numpy as np | |||
import tabulate | |||
@@ -19,6 +20,7 @@ from megengine import Tensor | |||
from megengine import functional as F | |||
from megengine.core.tensor.dtype import get_dtype_bit | |||
from megengine.functional.tensor import zeros | |||
from megengine.tensor import Tensor | |||
from .module_utils import set_module_mode_safe | |||
@@ -335,21 +337,23 @@ def print_param_stats(params): | |||
) | |||
def get_activation_stats(output: Tensor): | |||
def get_activation_stats(output: np.ndarray, has_input=False): | |||
out_shape = output.shape | |||
activations_dtype = np.dtype(output.dtype) | |||
nbits = get_dtype_bit(activations_dtype.name) | |||
act_dim = np.prod(out_shape) | |||
act_size = act_dim * nbits // 8 | |||
return { | |||
activation_stats = { | |||
"dtype": activations_dtype, | |||
"shape": out_shape, | |||
"act_dim": act_dim, | |||
"mean": "{:.3g}".format(_mean(output)), | |||
"std": "{:.3g}".format(_std(output)), | |||
"nbits": nbits, | |||
"size": act_size, | |||
} | |||
if has_input: | |||
activation_stats["mean"] = "{:.3g}".format(output.mean()) | |||
activation_stats["std"] = "{:.3g}".format(output.std()) | |||
return activation_stats | |||
def sum_activations_stats(activations, bar_length_max=20): | |||
@@ -373,14 +377,12 @@ def sum_activations_stats(activations, bar_length_max=20): | |||
return total_act_dims, total_act_size, activations | |||
def print_activations_stats(activations): | |||
def print_activations_stats(activations, has_input=False): | |||
header = [ | |||
"name", | |||
"class_name", | |||
"dtype", | |||
"shape", | |||
"mean", | |||
"std", | |||
"nbits", | |||
"act_dim", | |||
"size", | |||
@@ -388,6 +390,9 @@ def print_activations_stats(activations): | |||
"percentage", | |||
"size_bar", | |||
] | |||
if has_input: | |||
header.insert(4, "mean") | |||
header.insert(5, "std") | |||
logger.info( | |||
"activations stats: \n" | |||
+ tabulate.tabulate(dict2table(activations, header=header)) | |||
@@ -402,56 +407,80 @@ def print_summary(**kwargs): | |||
def module_stats( | |||
model: m.Module, | |||
input_shapes: list, | |||
inputs: Iterable[np.ndarray] = None, | |||
input_shapes: list = None, | |||
cal_params: bool = True, | |||
cal_flops: bool = True, | |||
cal_activations: bool = True, | |||
logging_to_stdout: bool = True, | |||
bar_length_max: int = 20, | |||
log_params: bool = True, | |||
log_flops: bool = True, | |||
log_activations: bool = True, | |||
): | |||
r""" | |||
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | |||
:param model: model that need to get stats info. | |||
:param input_shapes: shapes of inputs for running model and calculating stats. | |||
:param inputs: user defined input data for running model and calculating stats, alternative with input_shapes. | |||
:param input_shapes: shapes to generate random inputs for running model and calculating stats, alternative with inputs. | |||
:param cal_params: whether calculate and record params size. | |||
:param cal_flops: whether calculate and record op flops. | |||
:param cal_activations: whether calculate and record op activations. | |||
:param logging_to_stdout: whether print all calculated statistic details. | |||
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
:param log_params: whether print and record params size. | |||
:param log_flops: whether print and record op flops. | |||
:param log_activations: whether print and record op activations. | |||
""" | |||
has_inputs = False | |||
if inputs is not None: | |||
has_inputs = True | |||
if not isinstance(inputs, (tuple, list)): | |||
inputs = [inputs] | |||
inputs = [Tensor(input, dtype=np.float32) for input in inputs] | |||
else: | |||
if input_shapes: | |||
if not isinstance(input_shapes[0], tuple): | |||
input_shapes = [input_shapes] | |||
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] | |||
else: | |||
logger.error( | |||
"Inputs or input_shapes is required for running model and calculating stats.", | |||
exc_info=True, | |||
) | |||
return | |||
if not cal_activations: | |||
log_activations = False | |||
disable_receptive_field() | |||
def module_stats_hook(module, inputs, outputs, name=""): | |||
class_name = str(module.__class__).split(".")[-1].split("'")[0] | |||
flops_stats = get_op_stats(module, inputs, outputs) | |||
if flops_stats is not None: | |||
flops_stats["name"] = name | |||
flops_stats["class_name"] = class_name | |||
flops.append(flops_stats) | |||
if hasattr(module, "weight") and module.weight is not None: | |||
w = module.weight | |||
param_stats = get_param_stats(w) | |||
param_stats["name"] = name + "-w" | |||
params.append(param_stats) | |||
if hasattr(module, "bias") and module.bias is not None: | |||
b = module.bias | |||
param_stats = get_param_stats(b) | |||
param_stats["name"] = name + "-b" | |||
params.append(param_stats) | |||
if not isinstance(outputs, tuple) or not isinstance(outputs, list): | |||
output = outputs | |||
else: | |||
output = outputs[0] | |||
activation_stats = get_activation_stats(output) | |||
activation_stats["name"] = name | |||
activation_stats["class_name"] = class_name | |||
activations.append(activation_stats) | |||
# multiple inputs to the network | |||
if not isinstance(input_shapes[0], tuple): | |||
input_shapes = [input_shapes] | |||
if cal_flops: | |||
flops_stats = get_op_stats(module, inputs, outputs) | |||
if flops_stats is not None: | |||
flops_stats["name"] = name | |||
flops_stats["class_name"] = class_name | |||
flops.append(flops_stats) | |||
if cal_params: | |||
if hasattr(module, "weight") and module.weight is not None: | |||
w = module.weight | |||
param_stats = get_param_stats(w.numpy()) | |||
param_stats["name"] = name + "-w" | |||
params.append(param_stats) | |||
if hasattr(module, "bias") and module.bias is not None: | |||
b = module.bias | |||
param_stats = get_param_stats(b.numpy()) | |||
param_stats["name"] = name + "-b" | |||
params.append(param_stats) | |||
if cal_activations: | |||
if not isinstance(outputs, (tuple, list)): | |||
output = outputs.numpy() | |||
else: | |||
output = outputs[0].numpy() | |||
activation_stats = get_activation_stats(output, has_inputs) | |||
activation_stats["name"] = name | |||
activation_stats["class_name"] = class_name | |||
activations.append(activation_stats) | |||
params = [] | |||
flops = [] | |||
@@ -466,7 +495,6 @@ def module_stats( | |||
module.register_forward_hook(partial(module_stats_hook, name=name)) | |||
) | |||
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] | |||
with set_module_mode_safe(model, training=False) as model: | |||
model(*inputs) | |||
@@ -481,29 +509,37 @@ def module_stats( | |||
total_param_dims, | |||
total_param_size, | |||
total_act_dims, | |||
total_param_size, | |||
total_act_size, | |||
) = (0, 0, 0, 0, 0) | |||
total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max) | |||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
if log_params: | |||
print_param_stats(params) | |||
total_flops, flops = sum_op_stats(flops, bar_length_max) | |||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
if log_flops: | |||
print_op_stats(flops) | |||
total_act_dims, total_act_size, activations = sum_activations_stats( | |||
activations, bar_length_max | |||
) | |||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
if log_activations: | |||
print_activations_stats(activations) | |||
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||
if cal_params: | |||
total_param_dims, total_param_size, params = sum_param_stats( | |||
params, bar_length_max | |||
) | |||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
if logging_to_stdout: | |||
print_param_stats(params) | |||
if cal_flops: | |||
total_flops, flops = sum_op_stats(flops, bar_length_max) | |||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
if logging_to_stdout: | |||
print_op_stats(flops) | |||
if cal_activations: | |||
total_act_dims, total_act_size, activations = sum_activations_stats( | |||
activations, bar_length_max | |||
) | |||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
if logging_to_stdout: | |||
print_activations_stats(activations, has_inputs) | |||
if cal_flops and cal_params: | |||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||
total_flops / total_param_size | |||
) | |||
print_summary(**extra_info) | |||
@@ -18,11 +18,15 @@ from megengine.utils.module_stats import module_stats | |||
def test_module_stats(): | |||
net = ResNet(BasicBlock, [2, 2, 2, 2]) | |||
input_shape = (1, 3, 224, 224) | |||
total_stats, stats_details = module_stats(net, input_shape) | |||
x1 = mge.tensor(np.zeros((1, 3, 224, 224))) | |||
gt_flops, gt_acts = net.get_stats(x1) | |||
total_stats, stats_details = module_stats(net, input_shapes=input_shape) | |||
x1 = np.random.random((1, 3, 224, 224)).astype("float32") | |||
gt_flops, gt_acts = net.get_stats(mge.tensor(x1)) | |||
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||
gt_flops, | |||
gt_acts, | |||
) | |||
total_stats, stats_details = module_stats(net, inputs=x1) | |||
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||
gt_flops, | |||
gt_acts, | |||