BREAKING CHANGE:
GitOrigin-RevId: cd2a1acd11
release-1.5
@@ -12,6 +12,7 @@ import re | |||||
from collections import namedtuple | from collections import namedtuple | ||||
import numpy as np | import numpy as np | ||||
from tqdm import tqdm | |||||
from megengine.core.tensor.dtype import is_quantize | from megengine.core.tensor.dtype import is_quantize | ||||
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | ||||
@@ -37,10 +38,13 @@ logger = get_logger(__name__) | |||||
def visualize( | def visualize( | ||||
model_path: str, | model_path: str, | ||||
log_path: str, | log_path: str, | ||||
input: np.ndarray = None, | |||||
inp_dict: dict = None, | |||||
cal_params: bool = True, | |||||
cal_flops: bool = True, | |||||
cal_activations: bool = True, | |||||
logging_to_stdout: bool = True, | |||||
bar_length_max: int = 20, | bar_length_max: int = 20, | ||||
log_params: bool = True, | |||||
log_flops: bool = True, | |||||
log_activations: bool = True, | |||||
): | ): | ||||
r""" | r""" | ||||
Load megengine dumped model and visualize graph structure with tensorboard log files. | Load megengine dumped model and visualize graph structure with tensorboard log files. | ||||
@@ -48,10 +52,14 @@ def visualize( | |||||
:param model_path: dir path for megengine dumped model. | :param model_path: dir path for megengine dumped model. | ||||
:param log_path: dir path for tensorboard graph log. | :param log_path: dir path for tensorboard graph log. | ||||
:param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input. | |||||
:param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used. | |||||
:param cal_params: whether calculate and record params size. | |||||
:param cal_flops: whether calculate and record op flops. | |||||
:param cal_activations: whether calculate and record op activations. | |||||
:param logging_to_stdout: whether print all calculated statistic details. | |||||
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. | :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | ||||
:param log_params: whether print and record params size. | |||||
:param log_flops: whether print and record op flops. | |||||
:param log_activations: whether print and record op activations. | |||||
""" | """ | ||||
if log_path: | if log_path: | ||||
try: | try: | ||||
@@ -78,6 +86,27 @@ def visualize( | |||||
enable_receptive_field() | enable_receptive_field() | ||||
graph = Network.load(model_path) | graph = Network.load(model_path) | ||||
graph.reset_batch_size(1) | |||||
has_input = False | |||||
if input is not None or inp_dict is not None: | |||||
has_input = True | |||||
repl_dict = {} | |||||
inp_vars = graph.input_vars | |||||
if inp_dict is not None: | |||||
assert len(inp_dict) == len( | |||||
inp_vars | |||||
), "Inputs are not sufficient for calculation." | |||||
for v in inp_vars: | |||||
new_input = graph.make_const(inp_dict[v.name], name=v.name) | |||||
repl_dict[v] = new_input | |||||
else: | |||||
assert len(inp_vars) == 1, "The graph needs more than one input." | |||||
inp_var = inp_vars[0] | |||||
repl_dict[inp_var] = graph.make_const(input, name=inp_var.name) | |||||
graph.replace_vars(repl_dict=repl_dict) | |||||
graph._compile() | |||||
def process_name(name): | def process_name(name): | ||||
# nodes that start with point or contain float const will lead to display bug | # nodes that start with point or contain float const will lead to display bug | ||||
@@ -93,7 +122,7 @@ def visualize( | |||||
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | ||||
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | ||||
for node in graph.all_oprs: | |||||
for node in tqdm(graph.all_oprs): | |||||
if hasattr(node, "output_idx"): | if hasattr(node, "output_idx"): | ||||
node_oup = node.outputs[node.output_idx] | node_oup = node.outputs[node.output_idx] | ||||
else: | else: | ||||
@@ -123,31 +152,35 @@ def visualize( | |||||
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")), | "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), | ||||
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), | "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), | ||||
} | } | ||||
flops_stats = get_op_stats(node, node.inputs, node.outputs) | |||||
if flops_stats is not None: | |||||
# add op flops attr | |||||
if log_path and hasattr(flops_stats, "flops_num"): | |||||
attr["flops"] = AttrValue( | |||||
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||||
) | |||||
flops_stats["name"] = node.name | |||||
flops_stats["class_name"] = node.type | |||||
flops_list.append(flops_stats) | |||||
acts = get_activation_stats(node_oup) | |||||
if cal_flops: | |||||
flops_stats = get_op_stats(node, node.inputs, node.outputs) | |||||
if flops_stats is not None: | |||||
# add op flops attr | |||||
if log_path and hasattr(flops_stats, "flops_num"): | |||||
attr["flops"] = AttrValue( | |||||
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||||
) | |||||
flops_stats["name"] = node.name | |||||
flops_stats["class_name"] = node.type | |||||
flops_list.append(flops_stats) | |||||
if cal_activations: | |||||
acts = get_activation_stats(node_oup.numpy(), has_input=has_input) | |||||
acts["name"] = node.name | acts["name"] = node.name | ||||
acts["class_name"] = node.type | acts["class_name"] = node.type | ||||
activations_list.append(acts) | activations_list.append(acts) | ||||
if node.type == "ImmutableTensor": | |||||
param_stats = get_param_stats(node_oup) | |||||
# add tensor size attr | |||||
if log_path: | |||||
attr["size"] = AttrValue( | |||||
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") | |||||
) | |||||
param_stats["name"] = node.name | |||||
params_list.append(param_stats) | |||||
if cal_params: | |||||
if node.type == "ImmutableTensor": | |||||
param_stats = get_param_stats(node.numpy()) | |||||
# add tensor size attr | |||||
if log_path: | |||||
attr["size"] = AttrValue( | |||||
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") | |||||
) | |||||
param_stats["name"] = node.name | |||||
params_list.append(param_stats) | |||||
if log_path: | if log_path: | ||||
node_list.append( | node_list.append( | ||||
@@ -169,31 +202,37 @@ def visualize( | |||||
total_param_dims, | total_param_dims, | ||||
total_param_size, | total_param_size, | ||||
total_act_dims, | total_act_dims, | ||||
total_param_size, | |||||
total_act_size, | |||||
) = (0, 0, 0, 0, 0) | ) = (0, 0, 0, 0, 0) | ||||
total_param_dims, total_param_size, params = sum_param_stats( | |||||
params_list, bar_length_max | |||||
) | |||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
if log_params: | |||||
print_param_stats(params) | |||||
total_flops, flops = sum_op_stats(flops_list, bar_length_max) | |||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if log_flops: | |||||
print_op_stats(flops) | |||||
total_act_dims, total_act_size, activations = sum_activations_stats( | |||||
activations_list, bar_length_max | |||||
) | |||||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||||
if log_activations: | |||||
print_activations_stats(activations) | |||||
if cal_params: | |||||
total_param_dims, total_param_size, params_list = sum_param_stats( | |||||
params_list, bar_length_max | |||||
) | |||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
if logging_to_stdout: | |||||
print_param_stats(params_list) | |||||
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||||
if cal_flops: | |||||
total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) | |||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if logging_to_stdout: | |||||
print_op_stats(flops_list) | |||||
if cal_activations: | |||||
total_act_dims, total_act_size, activations_list = sum_activations_stats( | |||||
activations_list, bar_length_max | |||||
) | |||||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||||
if logging_to_stdout: | |||||
print_activations_stats(activations_list, has_input=has_input) | |||||
if cal_flops and cal_params: | |||||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||||
total_flops / total_param_size | |||||
) | |||||
if log_path: | if log_path: | ||||
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | ||||
@@ -211,7 +250,9 @@ def visualize( | |||||
total_stats( | total_stats( | ||||
param_size=total_param_size, flops=total_flops, act_size=total_act_size, | param_size=total_param_size, flops=total_flops, act_size=total_act_size, | ||||
), | ), | ||||
stats_details(params=params, flops=flops, activations=activations), | |||||
stats_details( | |||||
params=params_list, flops=flops_list, activations=activations_list | |||||
), | |||||
) | ) | ||||
@@ -229,12 +270,24 @@ def main(): | |||||
help="size of bar indicating max flops or parameter size in net stats.", | help="size of bar indicating max flops or parameter size in net stats.", | ||||
) | ) | ||||
parser.add_argument( | parser.add_argument( | ||||
"--log_params", | |||||
"--cal_params", | |||||
action="store_true", | |||||
help="whether calculate and record params size.", | |||||
) | |||||
parser.add_argument( | |||||
"--cal_flops", | |||||
action="store_true", | |||||
help="whether calculate and record op flops.", | |||||
) | |||||
parser.add_argument( | |||||
"--cal_activations", | |||||
action="store_true", | action="store_true", | ||||
help="whether print and record params size.", | |||||
help="whether calculate and record op activations.", | |||||
) | ) | ||||
parser.add_argument( | parser.add_argument( | ||||
"--log_flops", action="store_true", help="whether print and record op flops.", | |||||
"--logging_to_stdout", | |||||
action="store_true", | |||||
help="whether print all calculated statistic details.", | |||||
) | ) | ||||
parser.add_argument( | parser.add_argument( | ||||
"--all", | "--all", | ||||
@@ -243,8 +296,10 @@ def main(): | |||||
) | ) | ||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if args.all: | if args.all: | ||||
args.log_params = True | |||||
args.log_flops = True | |||||
args.cal_params = True | |||||
args.cal_flops = True | |||||
args.cal_activations = True | |||||
args.logging_to_stdout = True | |||||
if not args.log_path: | if not args.log_path: | ||||
args.log_path = "./log" | args.log_path = "./log" | ||||
kwargs = vars(args) | kwargs = vars(args) | ||||
@@ -5,8 +5,9 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from collections import namedtuple | |||||
from collections import Iterable, namedtuple | |||||
from functools import partial | from functools import partial | ||||
from typing import Iterable | |||||
import numpy as np | import numpy as np | ||||
import tabulate | import tabulate | ||||
@@ -19,6 +20,7 @@ from megengine import Tensor | |||||
from megengine import functional as F | from megengine import functional as F | ||||
from megengine.core.tensor.dtype import get_dtype_bit | from megengine.core.tensor.dtype import get_dtype_bit | ||||
from megengine.functional.tensor import zeros | from megengine.functional.tensor import zeros | ||||
from megengine.tensor import Tensor | |||||
from .module_utils import set_module_mode_safe | from .module_utils import set_module_mode_safe | ||||
@@ -335,21 +337,23 @@ def print_param_stats(params): | |||||
) | ) | ||||
def get_activation_stats(output: Tensor): | |||||
def get_activation_stats(output: np.ndarray, has_input=False): | |||||
out_shape = output.shape | out_shape = output.shape | ||||
activations_dtype = np.dtype(output.dtype) | activations_dtype = np.dtype(output.dtype) | ||||
nbits = get_dtype_bit(activations_dtype.name) | nbits = get_dtype_bit(activations_dtype.name) | ||||
act_dim = np.prod(out_shape) | act_dim = np.prod(out_shape) | ||||
act_size = act_dim * nbits // 8 | act_size = act_dim * nbits // 8 | ||||
return { | |||||
activation_stats = { | |||||
"dtype": activations_dtype, | "dtype": activations_dtype, | ||||
"shape": out_shape, | "shape": out_shape, | ||||
"act_dim": act_dim, | "act_dim": act_dim, | ||||
"mean": "{:.3g}".format(_mean(output)), | |||||
"std": "{:.3g}".format(_std(output)), | |||||
"nbits": nbits, | "nbits": nbits, | ||||
"size": act_size, | "size": act_size, | ||||
} | } | ||||
if has_input: | |||||
activation_stats["mean"] = "{:.3g}".format(output.mean()) | |||||
activation_stats["std"] = "{:.3g}".format(output.std()) | |||||
return activation_stats | |||||
def sum_activations_stats(activations, bar_length_max=20): | def sum_activations_stats(activations, bar_length_max=20): | ||||
@@ -373,14 +377,12 @@ def sum_activations_stats(activations, bar_length_max=20): | |||||
return total_act_dims, total_act_size, activations | return total_act_dims, total_act_size, activations | ||||
def print_activations_stats(activations): | |||||
def print_activations_stats(activations, has_input=False): | |||||
header = [ | header = [ | ||||
"name", | "name", | ||||
"class_name", | "class_name", | ||||
"dtype", | "dtype", | ||||
"shape", | "shape", | ||||
"mean", | |||||
"std", | |||||
"nbits", | "nbits", | ||||
"act_dim", | "act_dim", | ||||
"size", | "size", | ||||
@@ -388,6 +390,9 @@ def print_activations_stats(activations): | |||||
"percentage", | "percentage", | ||||
"size_bar", | "size_bar", | ||||
] | ] | ||||
if has_input: | |||||
header.insert(4, "mean") | |||||
header.insert(5, "std") | |||||
logger.info( | logger.info( | ||||
"activations stats: \n" | "activations stats: \n" | ||||
+ tabulate.tabulate(dict2table(activations, header=header)) | + tabulate.tabulate(dict2table(activations, header=header)) | ||||
@@ -402,56 +407,80 @@ def print_summary(**kwargs): | |||||
def module_stats( | def module_stats( | ||||
model: m.Module, | model: m.Module, | ||||
input_shapes: list, | |||||
inputs: Iterable[np.ndarray] = None, | |||||
input_shapes: list = None, | |||||
cal_params: bool = True, | |||||
cal_flops: bool = True, | |||||
cal_activations: bool = True, | |||||
logging_to_stdout: bool = True, | |||||
bar_length_max: int = 20, | bar_length_max: int = 20, | ||||
log_params: bool = True, | |||||
log_flops: bool = True, | |||||
log_activations: bool = True, | |||||
): | ): | ||||
r""" | r""" | ||||
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | ||||
:param model: model that need to get stats info. | :param model: model that need to get stats info. | ||||
:param input_shapes: shapes of inputs for running model and calculating stats. | |||||
:param inputs: user defined input data for running model and calculating stats, alternative with input_shapes. | |||||
:param input_shapes: shapes to generate random inputs for running model and calculating stats, alternative with inputs. | |||||
:param cal_params: whether calculate and record params size. | |||||
:param cal_flops: whether calculate and record op flops. | |||||
:param cal_activations: whether calculate and record op activations. | |||||
:param logging_to_stdout: whether print all calculated statistic details. | |||||
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. | :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | ||||
:param log_params: whether print and record params size. | |||||
:param log_flops: whether print and record op flops. | |||||
:param log_activations: whether print and record op activations. | |||||
""" | """ | ||||
has_inputs = False | |||||
if inputs is not None: | |||||
has_inputs = True | |||||
if not isinstance(inputs, (tuple, list)): | |||||
inputs = [inputs] | |||||
inputs = [Tensor(input, dtype=np.float32) for input in inputs] | |||||
else: | |||||
if input_shapes: | |||||
if not isinstance(input_shapes[0], tuple): | |||||
input_shapes = [input_shapes] | |||||
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] | |||||
else: | |||||
logger.error( | |||||
"Inputs or input_shapes is required for running model and calculating stats.", | |||||
exc_info=True, | |||||
) | |||||
return | |||||
if not cal_activations: | |||||
log_activations = False | |||||
disable_receptive_field() | disable_receptive_field() | ||||
def module_stats_hook(module, inputs, outputs, name=""): | def module_stats_hook(module, inputs, outputs, name=""): | ||||
class_name = str(module.__class__).split(".")[-1].split("'")[0] | class_name = str(module.__class__).split(".")[-1].split("'")[0] | ||||
flops_stats = get_op_stats(module, inputs, outputs) | |||||
if flops_stats is not None: | |||||
flops_stats["name"] = name | |||||
flops_stats["class_name"] = class_name | |||||
flops.append(flops_stats) | |||||
if hasattr(module, "weight") and module.weight is not None: | |||||
w = module.weight | |||||
param_stats = get_param_stats(w) | |||||
param_stats["name"] = name + "-w" | |||||
params.append(param_stats) | |||||
if hasattr(module, "bias") and module.bias is not None: | |||||
b = module.bias | |||||
param_stats = get_param_stats(b) | |||||
param_stats["name"] = name + "-b" | |||||
params.append(param_stats) | |||||
if not isinstance(outputs, tuple) or not isinstance(outputs, list): | |||||
output = outputs | |||||
else: | |||||
output = outputs[0] | |||||
activation_stats = get_activation_stats(output) | |||||
activation_stats["name"] = name | |||||
activation_stats["class_name"] = class_name | |||||
activations.append(activation_stats) | |||||
# multiple inputs to the network | |||||
if not isinstance(input_shapes[0], tuple): | |||||
input_shapes = [input_shapes] | |||||
if cal_flops: | |||||
flops_stats = get_op_stats(module, inputs, outputs) | |||||
if flops_stats is not None: | |||||
flops_stats["name"] = name | |||||
flops_stats["class_name"] = class_name | |||||
flops.append(flops_stats) | |||||
if cal_params: | |||||
if hasattr(module, "weight") and module.weight is not None: | |||||
w = module.weight | |||||
param_stats = get_param_stats(w.numpy()) | |||||
param_stats["name"] = name + "-w" | |||||
params.append(param_stats) | |||||
if hasattr(module, "bias") and module.bias is not None: | |||||
b = module.bias | |||||
param_stats = get_param_stats(b.numpy()) | |||||
param_stats["name"] = name + "-b" | |||||
params.append(param_stats) | |||||
if cal_activations: | |||||
if not isinstance(outputs, (tuple, list)): | |||||
output = outputs.numpy() | |||||
else: | |||||
output = outputs[0].numpy() | |||||
activation_stats = get_activation_stats(output, has_inputs) | |||||
activation_stats["name"] = name | |||||
activation_stats["class_name"] = class_name | |||||
activations.append(activation_stats) | |||||
params = [] | params = [] | ||||
flops = [] | flops = [] | ||||
@@ -466,7 +495,6 @@ def module_stats( | |||||
module.register_forward_hook(partial(module_stats_hook, name=name)) | module.register_forward_hook(partial(module_stats_hook, name=name)) | ||||
) | ) | ||||
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] | |||||
with set_module_mode_safe(model, training=False) as model: | with set_module_mode_safe(model, training=False) as model: | ||||
model(*inputs) | model(*inputs) | ||||
@@ -481,29 +509,37 @@ def module_stats( | |||||
total_param_dims, | total_param_dims, | ||||
total_param_size, | total_param_size, | ||||
total_act_dims, | total_act_dims, | ||||
total_param_size, | |||||
total_act_size, | |||||
) = (0, 0, 0, 0, 0) | ) = (0, 0, 0, 0, 0) | ||||
total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max) | |||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
if log_params: | |||||
print_param_stats(params) | |||||
total_flops, flops = sum_op_stats(flops, bar_length_max) | |||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if log_flops: | |||||
print_op_stats(flops) | |||||
total_act_dims, total_act_size, activations = sum_activations_stats( | |||||
activations, bar_length_max | |||||
) | |||||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||||
if log_activations: | |||||
print_activations_stats(activations) | |||||
extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||||
if cal_params: | |||||
total_param_dims, total_param_size, params = sum_param_stats( | |||||
params, bar_length_max | |||||
) | |||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
if logging_to_stdout: | |||||
print_param_stats(params) | |||||
if cal_flops: | |||||
total_flops, flops = sum_op_stats(flops, bar_length_max) | |||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if logging_to_stdout: | |||||
print_op_stats(flops) | |||||
if cal_activations: | |||||
total_act_dims, total_act_size, activations = sum_activations_stats( | |||||
activations, bar_length_max | |||||
) | |||||
extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||||
extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||||
if logging_to_stdout: | |||||
print_activations_stats(activations, has_inputs) | |||||
if cal_flops and cal_params: | |||||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||||
total_flops / total_param_size | |||||
) | |||||
print_summary(**extra_info) | print_summary(**extra_info) | ||||
@@ -18,11 +18,15 @@ from megengine.utils.module_stats import module_stats | |||||
def test_module_stats(): | def test_module_stats(): | ||||
net = ResNet(BasicBlock, [2, 2, 2, 2]) | net = ResNet(BasicBlock, [2, 2, 2, 2]) | ||||
input_shape = (1, 3, 224, 224) | input_shape = (1, 3, 224, 224) | ||||
total_stats, stats_details = module_stats(net, input_shape) | |||||
x1 = mge.tensor(np.zeros((1, 3, 224, 224))) | |||||
gt_flops, gt_acts = net.get_stats(x1) | |||||
total_stats, stats_details = module_stats(net, input_shapes=input_shape) | |||||
x1 = np.random.random((1, 3, 224, 224)).astype("float32") | |||||
gt_flops, gt_acts = net.get_stats(mge.tensor(x1)) | |||||
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||||
gt_flops, | |||||
gt_acts, | |||||
) | |||||
total_stats, stats_details = module_stats(net, inputs=x1) | |||||
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | ||||
gt_flops, | gt_flops, | ||||
gt_acts, | gt_acts, | ||||