@@ -14,6 +14,7 @@ from collections import namedtuple | |||||
import numpy as np | import numpy as np | ||||
from tqdm import tqdm | from tqdm import tqdm | ||||
import megengine as mge | |||||
from megengine.core.tensor.dtype import is_quantize | from megengine.core.tensor.dtype import is_quantize | ||||
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | ||||
from megengine.utils.module_stats import ( | from megengine.utils.module_stats import ( | ||||
@@ -119,7 +120,9 @@ def visualize( | |||||
flops_list = [] | flops_list = [] | ||||
params_list = [] | params_list = [] | ||||
activations_list = [] | activations_list = [] | ||||
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||||
total_stats = namedtuple( | |||||
"total_stats", ["param_size", "param_dims", "flops", "act_size", "act_dims"] | |||||
) | |||||
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | ||||
for node in tqdm(graph.all_oprs): | for node in tqdm(graph.all_oprs): | ||||
@@ -166,14 +169,14 @@ def visualize( | |||||
flops_list.append(flops_stats) | flops_list.append(flops_stats) | ||||
if cal_activations: | if cal_activations: | ||||
acts = get_activation_stats(node_oup.numpy(), has_input=has_input) | |||||
acts = get_activation_stats(node_oup, has_input=has_input) | |||||
acts["name"] = node.name | acts["name"] = node.name | ||||
acts["class_name"] = node.type | acts["class_name"] = node.type | ||||
activations_list.append(acts) | activations_list.append(acts) | ||||
if cal_params: | if cal_params: | ||||
if node.type == "ImmutableTensor": | if node.type == "ImmutableTensor": | ||||
param_stats = get_param_stats(node.numpy()) | |||||
param_stats = get_param_stats(node_oup) | |||||
# add tensor size attr | # add tensor size attr | ||||
if log_path: | if log_path: | ||||
attr["size"] = AttrValue( | attr["size"] = AttrValue( | ||||
@@ -248,7 +251,11 @@ def visualize( | |||||
return ( | return ( | ||||
total_stats( | total_stats( | ||||
param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||||
param_size=total_param_size, | |||||
param_dims=total_param_dims, | |||||
flops=total_flops, | |||||
act_size=total_act_size, | |||||
act_dims=total_act_dims, | |||||
), | ), | ||||
stats_details( | stats_details( | ||||
params=params_list, flops=flops_list, activations=activations_list | params=params_list, flops=flops_list, activations=activations_list | ||||
@@ -264,6 +271,10 @@ def main(): | |||||
parser.add_argument("model_path", help="dumped model path.") | parser.add_argument("model_path", help="dumped model path.") | ||||
parser.add_argument("--log_path", help="tensorboard log path.") | parser.add_argument("--log_path", help="tensorboard log path.") | ||||
parser.add_argument( | parser.add_argument( | ||||
"--load_input_data", | |||||
help="load input data from pickle file; it should be a numpy array or a dict of numpy array", | |||||
) | |||||
parser.add_argument( | |||||
"--bar_length_max", | "--bar_length_max", | ||||
type=int, | type=int, | ||||
default=20, | default=20, | ||||
@@ -295,6 +306,19 @@ def main(): | |||||
help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.", | help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.", | ||||
) | ) | ||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if args.load_input_data: | |||||
logger.info("load data from {}".format(args.load_input_data)) | |||||
data = mge.load(args.load_input_data) | |||||
if isinstance(data, dict): | |||||
for v in data.values(): | |||||
assert isinstance( | |||||
v, np.ndarray | |||||
), "data should provide ndarray; got {} instead".format(v) | |||||
args.inp_dict = data | |||||
elif isinstance(data, np.ndarray): | |||||
args.input = data | |||||
else: | |||||
logger.error("input data should be a numpy array or a dict of numpy array") | |||||
if args.all: | if args.all: | ||||
args.cal_params = True | args.cal_params = True | ||||
args.cal_flops = True | args.cal_flops = True | ||||
@@ -304,6 +328,7 @@ def main(): | |||||
args.log_path = "./log" | args.log_path = "./log" | ||||
kwargs = vars(args) | kwargs = vars(args) | ||||
kwargs.pop("all") | kwargs.pop("all") | ||||
kwargs.pop("load_input_data") | |||||
visualize(**kwargs) | visualize(**kwargs) | ||||
@@ -113,7 +113,12 @@ def flops_norm(module: m.Linear, inputs, outputs): | |||||
@register_flops(m.AvgPool2d, m.MaxPool2d) | @register_flops(m.AvgPool2d, m.MaxPool2d) | ||||
def flops_pool(module: m.AvgPool2d, inputs, outputs): | def flops_pool(module: m.AvgPool2d, inputs, outputs): | ||||
return np.prod(outputs[0].shape) * (module.kernel_size ** 2) | |||||
kernel_sum = 0 | |||||
if isinstance(module.kernel_size, tuple) and len(module.kernel_size) == 2: | |||||
kernel_sum = np.prod(module.kernel_size) | |||||
else: | |||||
kernel_sum = module.kernel_size ** 2 | |||||
return np.prod(outputs[0].shape) * kernel_sum | |||||
@register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d) | @register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d) | ||||
@@ -157,12 +162,12 @@ hook_modules = ( | |||||
def _mean(inp): | def _mean(inp): | ||||
inp = mge.tensor(inp) | |||||
inp = mge.tensor(inp).astype(np.float32) | |||||
return F.mean(inp).numpy() | return F.mean(inp).numpy() | ||||
def _std(inp): | def _std(inp): | ||||
inp = mge.tensor(inp) | |||||
inp = mge.tensor(inp).astype(np.float32) | |||||
return F.std(inp).numpy() | return F.std(inp).numpy() | ||||
@@ -337,7 +342,7 @@ def print_param_stats(params): | |||||
) | ) | ||||
def get_activation_stats(output: np.ndarray, has_input=False): | |||||
def get_activation_stats(output: Tensor, has_input=False): | |||||
out_shape = output.shape | out_shape = output.shape | ||||
activations_dtype = np.dtype(output.dtype) | activations_dtype = np.dtype(output.dtype) | ||||
nbits = get_dtype_bit(activations_dtype.name) | nbits = get_dtype_bit(activations_dtype.name) | ||||
@@ -351,8 +356,8 @@ def get_activation_stats(output: np.ndarray, has_input=False): | |||||
"size": act_size, | "size": act_size, | ||||
} | } | ||||
if has_input: | if has_input: | ||||
activation_stats["mean"] = "{:.3g}".format(output.mean()) | |||||
activation_stats["std"] = "{:.3g}".format(output.std()) | |||||
activation_stats["mean"] = "{:.3g}".format(_mean(output)) | |||||
activation_stats["std"] = "{:.3g}".format(_std(output)) | |||||
return activation_stats | return activation_stats | ||||
@@ -462,21 +467,21 @@ def module_stats( | |||||
if cal_params: | if cal_params: | ||||
if hasattr(module, "weight") and module.weight is not None: | if hasattr(module, "weight") and module.weight is not None: | ||||
w = module.weight | w = module.weight | ||||
param_stats = get_param_stats(w.numpy()) | |||||
param_stats = get_param_stats(w) | |||||
param_stats["name"] = name + "-w" | param_stats["name"] = name + "-w" | ||||
params.append(param_stats) | params.append(param_stats) | ||||
if hasattr(module, "bias") and module.bias is not None: | if hasattr(module, "bias") and module.bias is not None: | ||||
b = module.bias | b = module.bias | ||||
param_stats = get_param_stats(b.numpy()) | |||||
param_stats = get_param_stats(b) | |||||
param_stats["name"] = name + "-b" | param_stats["name"] = name + "-b" | ||||
params.append(param_stats) | params.append(param_stats) | ||||
if cal_activations: | if cal_activations: | ||||
if not isinstance(outputs, (tuple, list)): | if not isinstance(outputs, (tuple, list)): | ||||
output = outputs.numpy() | |||||
output = outputs | |||||
else: | else: | ||||
output = outputs[0].numpy() | |||||
output = outputs[0] | |||||
activation_stats = get_activation_stats(output, has_inputs) | activation_stats = get_activation_stats(output, has_inputs) | ||||
activation_stats["name"] = name | activation_stats["name"] = name | ||||
activation_stats["class_name"] = class_name | activation_stats["class_name"] = class_name | ||||
@@ -486,7 +491,9 @@ def module_stats( | |||||
flops = [] | flops = [] | ||||
hooks = [] | hooks = [] | ||||
activations = [] | activations = [] | ||||
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||||
total_stats = namedtuple( | |||||
"total_stats", ["param_size", "param_dims", "flops", "act_size", "act_dims"] | |||||
) | |||||
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | ||||
for (name, module) in model.named_modules(): | for (name, module) in model.named_modules(): | ||||
@@ -536,7 +543,7 @@ def module_stats( | |||||
if logging_to_stdout: | if logging_to_stdout: | ||||
print_activations_stats(activations, has_inputs) | print_activations_stats(activations, has_inputs) | ||||
if cal_flops and cal_params: | |||||
if cal_flops and cal_params and total_param_size != 0: | |||||
extra_info["flops/param_size"] = "{:3.3f}".format( | extra_info["flops/param_size"] = "{:3.3f}".format( | ||||
total_flops / total_param_size | total_flops / total_param_size | ||||
) | ) | ||||
@@ -545,7 +552,11 @@ def module_stats( | |||||
return ( | return ( | ||||
total_stats( | total_stats( | ||||
param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||||
param_size=total_param_size, | |||||
param_dims=total_param_dims, | |||||
flops=total_flops, | |||||
act_size=total_act_size, | |||||
act_dims=total_act_dims, | |||||
), | ), | ||||
stats_details(params=params, flops=flops, activations=activations), | stats_details(params=params, flops=flops, activations=activations), | ||||
) | ) |
@@ -21,16 +21,10 @@ def test_module_stats(): | |||||
total_stats, stats_details = module_stats(net, input_shapes=input_shape) | total_stats, stats_details = module_stats(net, input_shapes=input_shape) | ||||
x1 = np.random.random((1, 3, 224, 224)).astype("float32") | x1 = np.random.random((1, 3, 224, 224)).astype("float32") | ||||
gt_flops, gt_acts = net.get_stats(mge.tensor(x1)) | gt_flops, gt_acts = net.get_stats(mge.tensor(x1)) | ||||
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||||
gt_flops, | |||||
gt_acts, | |||||
) | |||||
assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,) | |||||
total_stats, stats_details = module_stats(net, inputs=x1) | total_stats, stats_details = module_stats(net, inputs=x1) | ||||
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||||
gt_flops, | |||||
gt_acts, | |||||
) | |||||
assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,) | |||||
class BasicBlock(M.Module): | class BasicBlock(M.Module): | ||||