Browse Source

fix(mge/tools): rename `net_stats` in function and examples to match file name

GitOrigin-RevId: 82a1377d66
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
786c36ff42
2 changed files with 4 additions and 4 deletions
  1. +1
    -1
      imperative/python/megengine/tools/network_visualize.py
  2. +3
    -3
      imperative/python/megengine/utils/module_stats.py

+ 1
- 1
imperative/python/megengine/tools/network_visualize.py View File

@@ -31,7 +31,7 @@ def visualize(
):
r"""
Load megengine dumped model and visualize graph structure with tensorboard log files.
Can also record and print model's statistics like :func:`~.net_stats`
Can also record and print model's statistics like :func:`~.module_stats`

:param model_path: dir path for megengine dumped model.
:param log_path: dir path for tensorboard graph log.


+ 3
- 3
imperative/python/megengine/utils/module_stats.py View File

@@ -187,7 +187,7 @@ def print_params_stats(params, bar_length_max=20):
return total_param_size


def net_stats(
def module_stats(
model: m.Module,
input_size: int,
bar_length_max: int = 20,
@@ -212,7 +212,7 @@ def net_stats(
else:
return 4

def net_stats_hook(module, input, output, name=""):
def module_stats_hook(module, input, output, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]

flops_fun = CALC_FLOPS.get(type(module))
@@ -280,7 +280,7 @@ def net_stats(
for (name, module) in model.named_modules():
if isinstance(module, hook_modules):
hooks.append(
module.register_forward_hook(partial(net_stats_hook, name=name))
module.register_forward_hook(partial(module_stats_hook, name=name))
)

inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size]


Loading…
Cancel
Save