@@ -14,6 +14,7 @@ from collections import namedtuple | |||
import numpy as np | |||
from tqdm import tqdm | |||
import megengine as mge | |||
from megengine.core.tensor.dtype import is_quantize | |||
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | |||
from megengine.utils.module_stats import ( | |||
@@ -119,7 +120,9 @@ def visualize( | |||
flops_list = [] | |||
params_list = [] | |||
activations_list = [] | |||
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||
total_stats = namedtuple( | |||
"total_stats", ["param_size", "param_dims", "flops", "act_size", "act_dims"] | |||
) | |||
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | |||
for node in tqdm(graph.all_oprs): | |||
@@ -166,14 +169,14 @@ def visualize( | |||
flops_list.append(flops_stats) | |||
if cal_activations: | |||
acts = get_activation_stats(node_oup.numpy(), has_input=has_input) | |||
acts = get_activation_stats(node_oup, has_input=has_input) | |||
acts["name"] = node.name | |||
acts["class_name"] = node.type | |||
activations_list.append(acts) | |||
if cal_params: | |||
if node.type == "ImmutableTensor": | |||
param_stats = get_param_stats(node.numpy()) | |||
param_stats = get_param_stats(node_oup) | |||
# add tensor size attr | |||
if log_path: | |||
attr["size"] = AttrValue( | |||
@@ -248,7 +251,11 @@ def visualize( | |||
return ( | |||
total_stats( | |||
param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||
param_size=total_param_size, | |||
param_dims=total_param_dims, | |||
flops=total_flops, | |||
act_size=total_act_size, | |||
act_dims=total_act_dims, | |||
), | |||
stats_details( | |||
params=params_list, flops=flops_list, activations=activations_list | |||
@@ -264,6 +271,10 @@ def main(): | |||
parser.add_argument("model_path", help="dumped model path.") | |||
parser.add_argument("--log_path", help="tensorboard log path.") | |||
parser.add_argument( | |||
"--load_input_data", | |||
help="load input data from pickle file; it should be a numpy array or a dict of numpy array", | |||
) | |||
parser.add_argument( | |||
"--bar_length_max", | |||
type=int, | |||
default=20, | |||
@@ -295,6 +306,19 @@ def main(): | |||
help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.", | |||
) | |||
args = parser.parse_args() | |||
if args.load_input_data: | |||
logger.info("load data from {}".format(args.load_input_data)) | |||
data = mge.load(args.load_input_data) | |||
if isinstance(data, dict): | |||
for v in data.values(): | |||
assert isinstance( | |||
v, np.ndarray | |||
), "data should provide ndarray; got {} instead".format(v) | |||
args.inp_dict = data | |||
elif isinstance(data, np.ndarray): | |||
args.input = data | |||
else: | |||
logger.error("input data should be a numpy array or a dict of numpy array") | |||
if args.all: | |||
args.cal_params = True | |||
args.cal_flops = True | |||
@@ -304,6 +328,7 @@ def main(): | |||
args.log_path = "./log" | |||
kwargs = vars(args) | |||
kwargs.pop("all") | |||
kwargs.pop("load_input_data") | |||
visualize(**kwargs) | |||
@@ -113,7 +113,12 @@ def flops_norm(module: m.Linear, inputs, outputs): | |||
@register_flops(m.AvgPool2d, m.MaxPool2d) | |||
def flops_pool(module: m.AvgPool2d, inputs, outputs): | |||
return np.prod(outputs[0].shape) * (module.kernel_size ** 2) | |||
kernel_sum = 0 | |||
if isinstance(module.kernel_size, tuple) and len(module.kernel_size) == 2: | |||
kernel_sum = np.prod(module.kernel_size) | |||
else: | |||
kernel_sum = module.kernel_size ** 2 | |||
return np.prod(outputs[0].shape) * kernel_sum | |||
@register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d) | |||
@@ -157,12 +162,12 @@ hook_modules = ( | |||
def _mean(inp): | |||
inp = mge.tensor(inp) | |||
inp = mge.tensor(inp).astype(np.float32) | |||
return F.mean(inp).numpy() | |||
def _std(inp): | |||
inp = mge.tensor(inp) | |||
inp = mge.tensor(inp).astype(np.float32) | |||
return F.std(inp).numpy() | |||
@@ -337,7 +342,7 @@ def print_param_stats(params): | |||
) | |||
def get_activation_stats(output: np.ndarray, has_input=False): | |||
def get_activation_stats(output: Tensor, has_input=False): | |||
out_shape = output.shape | |||
activations_dtype = np.dtype(output.dtype) | |||
nbits = get_dtype_bit(activations_dtype.name) | |||
@@ -351,8 +356,8 @@ def get_activation_stats(output: np.ndarray, has_input=False): | |||
"size": act_size, | |||
} | |||
if has_input: | |||
activation_stats["mean"] = "{:.3g}".format(output.mean()) | |||
activation_stats["std"] = "{:.3g}".format(output.std()) | |||
activation_stats["mean"] = "{:.3g}".format(_mean(output)) | |||
activation_stats["std"] = "{:.3g}".format(_std(output)) | |||
return activation_stats | |||
@@ -462,21 +467,21 @@ def module_stats( | |||
if cal_params: | |||
if hasattr(module, "weight") and module.weight is not None: | |||
w = module.weight | |||
param_stats = get_param_stats(w.numpy()) | |||
param_stats = get_param_stats(w) | |||
param_stats["name"] = name + "-w" | |||
params.append(param_stats) | |||
if hasattr(module, "bias") and module.bias is not None: | |||
b = module.bias | |||
param_stats = get_param_stats(b.numpy()) | |||
param_stats = get_param_stats(b) | |||
param_stats["name"] = name + "-b" | |||
params.append(param_stats) | |||
if cal_activations: | |||
if not isinstance(outputs, (tuple, list)): | |||
output = outputs.numpy() | |||
output = outputs | |||
else: | |||
output = outputs[0].numpy() | |||
output = outputs[0] | |||
activation_stats = get_activation_stats(output, has_inputs) | |||
activation_stats["name"] = name | |||
activation_stats["class_name"] = class_name | |||
@@ -486,7 +491,9 @@ def module_stats( | |||
flops = [] | |||
hooks = [] | |||
activations = [] | |||
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||
total_stats = namedtuple( | |||
"total_stats", ["param_size", "param_dims", "flops", "act_size", "act_dims"] | |||
) | |||
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | |||
for (name, module) in model.named_modules(): | |||
@@ -536,7 +543,7 @@ def module_stats( | |||
if logging_to_stdout: | |||
print_activations_stats(activations, has_inputs) | |||
if cal_flops and cal_params: | |||
if cal_flops and cal_params and total_param_size != 0: | |||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||
total_flops / total_param_size | |||
) | |||
@@ -545,7 +552,11 @@ def module_stats( | |||
return ( | |||
total_stats( | |||
param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||
param_size=total_param_size, | |||
param_dims=total_param_dims, | |||
flops=total_flops, | |||
act_size=total_act_size, | |||
act_dims=total_act_dims, | |||
), | |||
stats_details(params=params, flops=flops, activations=activations), | |||
) |
@@ -21,16 +21,10 @@ def test_module_stats(): | |||
total_stats, stats_details = module_stats(net, input_shapes=input_shape) | |||
x1 = np.random.random((1, 3, 224, 224)).astype("float32") | |||
gt_flops, gt_acts = net.get_stats(mge.tensor(x1)) | |||
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||
gt_flops, | |||
gt_acts, | |||
) | |||
assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,) | |||
total_stats, stats_details = module_stats(net, inputs=x1) | |||
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||
gt_flops, | |||
gt_acts, | |||
) | |||
assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,) | |||
class BasicBlock(M.Module): | |||