Browse Source

fix(mge/tools): rename `net_stats` in function and examples to match file name

GitOrigin-RevId: 82a1377d66
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
786c36ff42
2 changed files with 4 additions and 4 deletions
  1. +1
    -1
      imperative/python/megengine/tools/network_visualize.py
  2. +3
    -3
      imperative/python/megengine/utils/module_stats.py

+ 1
- 1
imperative/python/megengine/tools/network_visualize.py View File

@@ -31,7 +31,7 @@ def visualize(
): ):
r""" r"""
Load megengine dumped model and visualize graph structure with tensorboard log files. Load megengine dumped model and visualize graph structure with tensorboard log files.
Can also record and print model's statistics like :func:`~.net_stats`
Can also record and print model's statistics like :func:`~.module_stats`


:param model_path: dir path for megengine dumped model. :param model_path: dir path for megengine dumped model.
:param log_path: dir path for tensorboard graph log. :param log_path: dir path for tensorboard graph log.


+ 3
- 3
imperative/python/megengine/utils/module_stats.py View File

@@ -187,7 +187,7 @@ def print_params_stats(params, bar_length_max=20):
return total_param_size return total_param_size




def net_stats(
def module_stats(
model: m.Module, model: m.Module,
input_size: int, input_size: int,
bar_length_max: int = 20, bar_length_max: int = 20,
@@ -212,7 +212,7 @@ def net_stats(
else: else:
return 4 return 4


def net_stats_hook(module, input, output, name=""):
def module_stats_hook(module, input, output, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0] class_name = str(module.__class__).split(".")[-1].split("'")[0]


flops_fun = CALC_FLOPS.get(type(module)) flops_fun = CALC_FLOPS.get(type(module))
@@ -280,7 +280,7 @@ def net_stats(
for (name, module) in model.named_modules(): for (name, module) in model.named_modules():
if isinstance(module, hook_modules): if isinstance(module, hook_modules):
hooks.append( hooks.append(
module.register_forward_hook(partial(net_stats_hook, name=name))
module.register_forward_hook(partial(module_stats_hook, name=name))
) )


inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size]


Loading…
Cancel
Save