|
@@ -187,7 +187,7 @@ def print_params_stats(params, bar_length_max=20): |
|
|
return total_param_size |
|
|
return total_param_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def net_stats( |
|
|
|
|
|
|
|
|
def module_stats( |
|
|
model: m.Module, |
|
|
model: m.Module, |
|
|
input_size: int, |
|
|
input_size: int, |
|
|
bar_length_max: int = 20, |
|
|
bar_length_max: int = 20, |
|
@@ -212,7 +212,7 @@ def net_stats( |
|
|
else: |
|
|
else: |
|
|
return 4 |
|
|
return 4 |
|
|
|
|
|
|
|
|
def net_stats_hook(module, input, output, name=""): |
|
|
|
|
|
|
|
|
def module_stats_hook(module, input, output, name=""): |
|
|
class_name = str(module.__class__).split(".")[-1].split("'")[0] |
|
|
class_name = str(module.__class__).split(".")[-1].split("'")[0] |
|
|
|
|
|
|
|
|
flops_fun = CALC_FLOPS.get(type(module)) |
|
|
flops_fun = CALC_FLOPS.get(type(module)) |
|
@@ -280,7 +280,7 @@ def net_stats( |
|
|
for (name, module) in model.named_modules(): |
|
|
for (name, module) in model.named_modules(): |
|
|
if isinstance(module, hook_modules): |
|
|
if isinstance(module, hook_modules): |
|
|
hooks.append( |
|
|
hooks.append( |
|
|
module.register_forward_hook(partial(net_stats_hook, name=name)) |
|
|
|
|
|
|
|
|
module.register_forward_hook(partial(module_stats_hook, name=name)) |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] |
|
|
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] |
|
|