|
@@ -457,6 +457,7 @@ def module_stats( |
|
|
log_activations = False |
|
|
log_activations = False |
|
|
|
|
|
|
|
|
disable_receptive_field() |
|
|
disable_receptive_field() |
|
|
|
|
|
recorded_parameters = set() |
|
|
|
|
|
|
|
|
def module_stats_hook(module, inputs, outputs, name=""): |
|
|
def module_stats_hook(module, inputs, outputs, name=""): |
|
|
class_name = str(module.__class__).split(".")[-1].split("'")[0] |
|
|
class_name = str(module.__class__).split(".")[-1].split("'")[0] |
|
@@ -468,17 +469,27 @@ def module_stats( |
|
|
flops.append(flops_stats) |
|
|
flops.append(flops_stats) |
|
|
|
|
|
|
|
|
if cal_params: |
|
|
if cal_params: |
|
|
if hasattr(module, "weight") and module.weight is not None: |
|
|
|
|
|
|
|
|
if ( |
|
|
|
|
|
hasattr(module, "weight") |
|
|
|
|
|
and (module.weight is not None) |
|
|
|
|
|
and module.weight not in recorded_parameters |
|
|
|
|
|
): |
|
|
w = module.weight |
|
|
w = module.weight |
|
|
param_stats = get_param_stats(w) |
|
|
param_stats = get_param_stats(w) |
|
|
param_stats["name"] = name + "-w" |
|
|
param_stats["name"] = name + "-w" |
|
|
params.append(param_stats) |
|
|
params.append(param_stats) |
|
|
|
|
|
recorded_parameters.add(w) |
|
|
|
|
|
|
|
|
if hasattr(module, "bias") and module.bias is not None: |
|
|
|
|
|
|
|
|
if ( |
|
|
|
|
|
hasattr(module, "bias") |
|
|
|
|
|
and module.bias is not None |
|
|
|
|
|
and module.bias not in recorded_parameters |
|
|
|
|
|
): |
|
|
b = module.bias |
|
|
b = module.bias |
|
|
param_stats = get_param_stats(b) |
|
|
param_stats = get_param_stats(b) |
|
|
param_stats["name"] = name + "-b" |
|
|
param_stats["name"] = name + "-b" |
|
|
params.append(param_stats) |
|
|
params.append(param_stats) |
|
|
|
|
|
recorded_parameters.add(b) |
|
|
|
|
|
|
|
|
if cal_activations: |
|
|
if cal_activations: |
|
|
if not isinstance(outputs, (tuple, list)): |
|
|
if not isinstance(outputs, (tuple, list)): |
|
@@ -504,7 +515,6 @@ def module_stats( |
|
|
hooks.append( |
|
|
hooks.append( |
|
|
module.register_forward_hook(partial(module_stats_hook, name=name)) |
|
|
module.register_forward_hook(partial(module_stats_hook, name=name)) |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
with set_module_mode_safe(model, training=False) as model: |
|
|
with set_module_mode_safe(model, training=False) as model: |
|
|
model(*inputs) |
|
|
model(*inputs) |
|
|
|
|
|
|
|
|