GitOrigin-RevId: 11ef335468
tags/v1.3.1
@@ -7,6 +7,7 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import argparse | |||
import json | |||
import logging | |||
import numpy as np | |||
@@ -14,6 +15,7 @@ import numpy as np | |||
from megengine.core.tensor.dtype import is_quantize | |||
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | |||
from megengine.utils.module_stats import ( | |||
get_flops_stats, | |||
get_param_stats, | |||
print_flops_stats, | |||
print_params_stats, | |||
@@ -89,6 +91,7 @@ def visualize( | |||
inp_list = [process_name(var.owner.name) for var in node.inputs] | |||
if log_path: | |||
# detail format see tensorboard/compat/proto/attr_value.proto | |||
attr = { | |||
"_output_shapes": AttrValue( | |||
list=AttrValue.ListValue( | |||
@@ -101,24 +104,20 @@ def visualize( | |||
] | |||
) | |||
), | |||
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")), | |||
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), | |||
} | |||
if hasattr(node, "calc_flops"): | |||
flops_num = node.calc_flops() | |||
flops_stats = get_flops_stats(node, node.inputs, node.outputs) | |||
if flops_stats is not None: | |||
# add op flops attr | |||
if log_path: | |||
if log_path and hasattr(flops_stats, "flops_num"): | |||
attr["flops"] = AttrValue( | |||
s=sizeof_fmt(flops_num).encode(encoding="utf-8") | |||
) | |||
flops_list.append( | |||
dict( | |||
name=node.name, | |||
class_name=node.type, | |||
input_shapes=[i.shape for i in node.inputs], | |||
output_shapes=[o.shape for o in node.outputs], | |||
flops_num=flops_num, | |||
flops_cum=0, | |||
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||
) | |||
) | |||
flops_stats["name"] = node.name | |||
flops_stats["class_name"] = node.type | |||
flops_list.append(flops_stats) | |||
if node.type == "ImmutableTensor": | |||
param_stats = get_param_stats(node.numpy()) | |||
# add tensor size attr | |||
@@ -132,6 +131,7 @@ def visualize( | |||
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug | |||
if not len(node.name.split(".")) > 2 and not node in graph.input_vars: | |||
continue | |||
if log_path: | |||
node_list.append( | |||
NodeDef( | |||
@@ -141,14 +141,26 @@ def visualize( | |||
attr=attr, | |||
) | |||
) | |||
# summary | |||
extra_info = { | |||
"#ops": len(graph.all_oprs), | |||
"#params": len(params_list), | |||
} | |||
total_flops, total_params = None, None | |||
total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||
if log_params: | |||
total_param_dims, total_param_size = print_params_stats( | |||
params_list, bar_length_max | |||
) | |||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
if log_flops: | |||
total_flops = print_flops_stats(flops_list, bar_length_max) | |||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
if log_params and log_flops: | |||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||
total_flops / total_param_size | |||
) | |||
if log_path: | |||
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||
@@ -160,21 +172,12 @@ def visualize( | |||
writer = SummaryWriter(log_path) | |||
writer._get_file_writer().add_graph((graph_def, stepstats)) | |||
# summary | |||
extra_info = { | |||
"#ops": len(graph.all_oprs), | |||
"#params": len(params_list), | |||
"total_param_dims": sizeof_fmt(total_param_dims), | |||
"total_param_size": sizeof_fmt(total_param_size), | |||
"total_flops": sizeof_fmt(total_flops, suffix="OPs"), | |||
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size), | |||
} | |||
print_summary(**extra_info) | |||
# FIXME: remove this after resolving "span dist too large" warning | |||
_imperative_rt_logger.set_log_level(old_level) | |||
return total_params, total_flops | |||
return total_param_size, total_flops | |||
def main(): | |||
@@ -26,61 +26,95 @@ logger = mge.get_logger(__name__) | |||
logger.setLevel("INFO") | |||
CALC_FLOPS = {} | |||
def _register_modules(*modules): | |||
_calc_flops_dict = {} | |||
_calc_receptive_field_dict = {} | |||
def _receptive_field_fallback(module, inputs, outputs): | |||
assert not hasattr(module, "_rf") | |||
assert not hasattr(module, "_stride") | |||
if len(inputs) == 0: | |||
# TODO: support other dimension | |||
module._rf = (1, 1) | |||
module._stride = (1, 1) | |||
return module._rf, module._stride | |||
rf, stride = preprocess_receptive_field(module, inputs, outputs) | |||
module._rf = rf | |||
module._stride = stride | |||
return rf, stride | |||
# key tuple, impl_dict, fallback | |||
_iter_list = [ | |||
("flops_num", _calc_flops_dict, None), | |||
( | |||
("receptive_field", "stride"), | |||
_calc_receptive_field_dict, | |||
_receptive_field_fallback, | |||
), | |||
] | |||
def _register_dict(*modules, dict=None): | |||
def callback(impl): | |||
for module in modules: | |||
CALC_FLOPS[module] = impl | |||
dict[module] = impl | |||
return impl | |||
return callback | |||
@_register_modules( | |||
m.Conv2d, | |||
m.ConvTranspose2d, | |||
m.LocalConv2d, | |||
qm.Conv2d, | |||
qm.ConvRelu2d, | |||
qm.ConvBn2d, | |||
qm.ConvBnRelu2d, | |||
qatm.Conv2d, | |||
qatm.ConvRelu2d, | |||
qatm.ConvBn2d, | |||
qatm.ConvBnRelu2d, | |||
def register_flops(*modules): | |||
return _register_dict(*modules, dict=_calc_flops_dict) | |||
def register_receptive_field(*modules): | |||
return _register_dict(*modules, dict=_calc_receptive_field_dict) | |||
@register_flops( | |||
m.Conv1d, m.Conv2d, m.Conv3d, | |||
) | |||
def count_convNd(module, input, output): | |||
def flops_convNd(module: m.Conv2d, inputs, outputs): | |||
bias = 1 if module.bias is not None else 0 | |||
group = module.groups | |||
ic = input[0].shape[1] | |||
oc = output[0].shape[1] | |||
ic = inputs[0].shape[1] | |||
oc = outputs[0].shape[1] | |||
goc = oc // group | |||
gic = ic // group | |||
N = output[0].shape[0] | |||
HW = np.prod(output[0].shape[2:]) | |||
N = outputs[0].shape[0] | |||
HW = np.prod(outputs[0].shape[2:]) | |||
# N x Cout x H x W x (Cin x Kw x Kh + bias) | |||
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) | |||
@_register_modules(m.ConvTranspose2d) | |||
def count_deconvNd(module, input, output): | |||
return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size) | |||
@register_flops(m.ConvTranspose2d) | |||
def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs): | |||
return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size) | |||
@register_flops(m.Linear) | |||
def flops_linear(module: m.Linear, inputs, outputs): | |||
bias = 1 if module.bias is not None else 0 | |||
return np.prod(outputs[0].shape) * module.in_features | |||
@_register_modules(m.Linear, qatm.Linear, qm.Linear) | |||
def count_linear(module, input, output): | |||
return np.prod(output[0].shape) * module.in_features | |||
@register_flops(m.BatchMatMulActivation) | |||
def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs): | |||
bias = 1 if module.bias is not None else 0 | |||
x = inputs[0] | |||
w = module.weight | |||
batch_size = x.shape[0] | |||
n, p = x.shape[1:] | |||
_, m = w.shape[1:] | |||
return n * (p + bias) * m * batch_size | |||
# does not need import qat and quantized module since they inherit from float module. | |||
hook_modules = ( | |||
m.Conv2d, | |||
m.ConvTranspose2d, | |||
m.LocalConv2d, | |||
m.BatchNorm2d, | |||
m.conv._ConvNd, | |||
m.Linear, | |||
m.BatchMatMulActivation, | |||
) | |||
@@ -106,28 +140,71 @@ def sizeof_fmt(num, suffix="B"): | |||
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | |||
def preprocess_receptive_field(module, inputs, outputs): | |||
# TODO: support other dimensions | |||
pre_rf = ( | |||
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs), | |||
max(i.owner._rf[1] for i in inputs), | |||
) | |||
pre_stride = ( | |||
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs), | |||
max(i.owner._stride[1] for i in inputs), | |||
) | |||
return pre_rf, pre_stride | |||
def get_flops_stats(module, inputs, outputs): | |||
rst = { | |||
"input_shapes": [i.shape for i in inputs], | |||
"output_shapes": [o.shape for o in outputs], | |||
} | |||
valid_flag = False | |||
for key, _dict, fallback in _iter_list: | |||
for _type in _dict: | |||
if isinstance(module, _type): | |||
value = _dict[_type](module, inputs, outputs) | |||
valid_flag = True | |||
break | |||
else: | |||
if fallback is not None: | |||
value = fallback(module, inputs, outputs) | |||
continue | |||
if isinstance(key, tuple): | |||
assert isinstance(value, tuple) | |||
for k, v in zip(key, value): | |||
rst[k] = v | |||
else: | |||
rst[key] = value | |||
if valid_flag: | |||
return rst | |||
else: | |||
return None | |||
return | |||
def print_flops_stats(flops, bar_length_max=20): | |||
flops_list = [i["flops_num"] for i in flops] | |||
max_flops_num = max(flops_list + [0]) | |||
# calc total flops and set flops_cum | |||
max_flops_num = max([i["flops_num"] for i in flops] + [0]) | |||
total_flops_num = 0 | |||
for d in flops: | |||
total_flops_num += int(d["flops_num"]) | |||
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | |||
for d in flops: | |||
f = d["flops_num"] | |||
d["flops"] = sizeof_fmt(f, suffix="OPs") | |||
r = d["ratio"] = f / total_flops_num | |||
d["percentage"] = "{:.2f}%".format(r * 100) | |||
bar_length = int(f / max_flops_num * bar_length_max) | |||
ratio = d["ratio"] = d["flops_num"] / total_flops_num | |||
d["percentage"] = "{:.2f}%".format(ratio * 100) | |||
bar_length = int(d["flops_num"] / max_flops_num * bar_length_max) | |||
d["bar"] = "#" * bar_length | |||
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") | |||
header = [ | |||
"name", | |||
"class_name", | |||
"input_shapes", | |||
"output_shapes", | |||
"receptive_field", | |||
"stride", | |||
"flops", | |||
"flops_cum", | |||
"percentage", | |||
@@ -154,8 +231,8 @@ def get_param_stats(param: np.ndarray): | |||
param_size = param_dim * nbits // 8 | |||
return { | |||
"shape": shape, | |||
"mean": param.mean(), | |||
"std": param.std(), | |||
"mean": "{:.3g}".format(param.mean()), | |||
"std": "{:.3g}".format(param.std()), | |||
"param_dim": param_dim, | |||
"nbits": nbits, | |||
"size": param_size, | |||
@@ -163,21 +240,20 @@ def get_param_stats(param: np.ndarray): | |||
def print_params_stats(params, bar_length_max=20): | |||
max_size = max([d["size"] for d in params] + [0]) | |||
total_param_dims, total_param_size = 0, 0 | |||
for d in params: | |||
total_param_dims += int(d["param_dim"]) | |||
total_param_size += int(d["size"]) | |||
ratio = d["size"] / total_param_size | |||
d["size"] = sizeof_fmt(d["size"]) | |||
d["size_cum"] = sizeof_fmt(total_param_size) | |||
d["ratio"] = ratio | |||
d["percentage"] = "{:.2f}%".format(ratio * 100) | |||
# construct bar | |||
max_ratio = max([d["ratio"] for d in params]) | |||
for d in params: | |||
bar_length = int(d["ratio"] / max_ratio * bar_length_max) | |||
ratio = d["size"] / total_param_size | |||
d["ratio"] = ratio | |||
d["percentage"] = "{:.2f}%".format(ratio * 100) | |||
bar_length = int(d["size"] / max_size * bar_length_max) | |||
d["size_bar"] = "#" * bar_length | |||
d["size"] = sizeof_fmt(d["size"]) | |||
param_size = sizeof_fmt(total_param_size) | |||
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | |||
@@ -225,26 +301,14 @@ def module_stats( | |||
:param log_flops: whether print and record op flops. | |||
""" | |||
def module_stats_hook(module, input, output, name=""): | |||
def module_stats_hook(module, inputs, outputs, name=""): | |||
class_name = str(module.__class__).split(".")[-1].split("'")[0] | |||
flops_fun = CALC_FLOPS.get(type(module)) | |||
if callable(flops_fun): | |||
flops_num = flops_fun(module, input, output) | |||
if not isinstance(output, (list, tuple)): | |||
output = [output] | |||
flops.append( | |||
dict( | |||
name=name, | |||
class_name=class_name, | |||
input_shapes=[i.shape for i in input], | |||
output_shapes=[o.shape for o in output], | |||
flops_num=flops_num, | |||
flops_cum=0, | |||
) | |||
) | |||
flops_stats = get_flops_stats(module, inputs, outputs) | |||
if flops_stats is not None: | |||
flops_stats["name"] = name | |||
flops_stats["class_name"] = class_name | |||
flops.append(flops_stats) | |||
if hasattr(module, "weight") and module.weight is not None: | |||
w = module.weight | |||
@@ -278,19 +342,22 @@ def module_stats( | |||
for h in hooks: | |||
h.remove() | |||
total_flops, total_params = 0, 0 | |||
extra_info = { | |||
"#params": len(params), | |||
} | |||
total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||
if log_params: | |||
total_param_dims, total_param_size = print_params_stats(params, bar_length_max) | |||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
if log_flops: | |||
total_flops = print_flops_stats(flops, bar_length_max) | |||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
if log_params and log_flops: | |||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||
total_flops / total_param_size | |||
) | |||
extra_info = { | |||
"#params": len(params), | |||
"total_param_dims": sizeof_fmt(total_param_dims), | |||
"total_param_size": sizeof_fmt(total_param_size), | |||
"total_flops": sizeof_fmt(total_flops, suffix="OPs"), | |||
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size), | |||
} | |||
print_summary(**extra_info) | |||
return total_params, total_flops | |||
return total_param_size, total_flops |
@@ -18,6 +18,11 @@ from ..core.ops import builtin | |||
from ..core.tensor.megbrain_graph import InputNode | |||
from ..tensor import Tensor | |||
from .comp_graph_tools import replace_vars | |||
from .module_stats import ( | |||
preprocess_receptive_field, | |||
register_flops, | |||
register_receptive_field, | |||
) | |||
class NetworkNode: | |||
@@ -225,8 +230,21 @@ class Elemwise(OpNode): | |||
type = "Elemwise" | |||
opdef = builtin.Elemwise | |||
def calc_flops(self): | |||
return np.prod(self.outputs[0].shape) | |||
class ElemwiseMultiType(OpNode): | |||
type = "ElemwiseMultiType" | |||
opdef = builtin.ElemwiseMultiType | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(ElemwiseMultiType, cls).load(opr) | |||
obj.params["dtype"] = opr.outputs[0].dtype | |||
return obj | |||
@register_flops(Elemwise, ElemwiseMultiType) | |||
def flops_elemwise(opnode: Elemwise, inputs, outputs): | |||
return np.prod(outputs[0].shape) | |||
class Reduce(OpNode): | |||
@@ -255,20 +273,24 @@ class MatrixMul(OpNode): | |||
type = "MatrixMul" | |||
opdef = builtin.MatrixMul | |||
def calc_flops(self): | |||
assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2 | |||
mid_shape = self.inputs[0].shape[1] | |||
return np.prod(self.outputs[0].shape) * mid_shape | |||
@register_flops(MatrixMul) | |||
def flops_matmul(opnode: MatrixMul, inputs, outputs): | |||
assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2 | |||
mid_shape = inputs[0].shape[1] | |||
return np.prod(outputs[0].shape) * mid_shape | |||
class BatchedMatrixMul(OpNode): | |||
type = "BatchedMatmul" | |||
opdef = builtin.BatchedMatrixMul | |||
def calc_flops(self): | |||
assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3 | |||
mid_shape = self.inputs[0].shape[2] | |||
return np.prod(self.outputs[0].shape) * mid_shape | |||
@register_flops(BatchedMatrixMul) | |||
def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs): | |||
assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3 | |||
mid_shape = inputs[0].shape[2] | |||
return np.prod(outputs[0].shape) * mid_shape | |||
class Dot(OpNode): | |||
@@ -285,18 +307,6 @@ class ConvolutionForward(OpNode): | |||
type = "Convolution" | |||
opdef = builtin.Convolution | |||
def calc_flops(self): | |||
param_W_shape = self.inputs[1].shape | |||
kh = param_W_shape[-2] | |||
kw = param_W_shape[-1] | |||
if len(param_W_shape) == 5: | |||
num_input = param_W_shape[2] | |||
else: | |||
num_input = param_W_shape[1] | |||
NCHW = np.prod(self.outputs[0].shape) | |||
# N x Cout x H x W x (Cin x Kw x Kh) | |||
return NCHW * (num_input * kw * kh) | |||
class ConvolutionBackwardData(OpNode): | |||
type = "ConvTranspose" | |||
@@ -343,17 +353,41 @@ class ConvBiasForward(OpNode): | |||
obj.params["dtype"] = opr.outputs[0].dtype | |||
return obj | |||
def calc_flops(self): | |||
param_W_shape = self.inputs[1].shape | |||
kh = param_W_shape[-2] | |||
kw = param_W_shape[-1] | |||
if len(param_W_shape) == 5: | |||
num_input = param_W_shape[2] | |||
else: | |||
num_input = param_W_shape[1] | |||
NCHW = np.prod(self.outputs[0].shape) | |||
# N x Cout x H x W x (Cin x Kw x Kh + bias) | |||
return NCHW * (num_input * kw * kh + 1) | |||
@register_flops( | |||
ConvolutionForward, ConvBiasForward, | |||
) | |||
def flops_conv(opnode: ConvolutionForward, inputs, outputs): | |||
param_W_shape = inputs[1].shape | |||
kh = param_W_shape[-2] | |||
kw = param_W_shape[-1] | |||
if len(param_W_shape) == 5: | |||
num_input = param_W_shape[2] | |||
else: | |||
num_input = param_W_shape[1] | |||
NCHW = np.prod(outputs[0].shape) | |||
bias = 1 if isinstance(opnode, ConvBiasForward) else 0 | |||
# N x Cout x H x W x (Cin x Kw x Kh) | |||
return NCHW * (num_input * kw * kh + bias) | |||
@register_receptive_field(ConvolutionForward, ConvBiasForward) | |||
def receptive_field(opnode: ConvolutionForward, inputs, outputs): | |||
pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs) | |||
param_W_shape = inputs[1].shape | |||
kh = param_W_shape[-2] | |||
kw = param_W_shape[-1] | |||
rf = ( | |||
kh * pre_stride[0] + pre_rf[0] - pre_stride[0], | |||
kw * pre_stride[1] + pre_rf[1] - pre_stride[1], | |||
) | |||
stride = ( | |||
opnode.params["stride_h"] * pre_stride[0], | |||
opnode.params["stride_w"] * pre_stride[1], | |||
) | |||
opnode._rf = rf | |||
opnode._stride = stride | |||
return rf, stride | |||
class BatchConvBiasForward(OpNode): | |||
@@ -652,20 +686,6 @@ class AssertEqual(OpNode): | |||
opdef = builtin.AssertEqual | |||
class ElemwiseMultiType(OpNode): | |||
type = "ElemwiseMultiType" | |||
opdef = builtin.ElemwiseMultiType | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(ElemwiseMultiType, cls).load(opr) | |||
obj.params["dtype"] = opr.outputs[0].dtype | |||
return obj | |||
def calc_flops(self): | |||
return np.prod(self.outputs[0].shape) | |||
class CvtColorForward(OpNode): | |||
type = "CvtColor" | |||
opdef = builtin.CvtColor |