Browse Source

feat(mge/tools): add support of receptive_field stats for NetworkNode

GitOrigin-RevId: 11ef335468
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
13481fd2ca
3 changed files with 238 additions and 148 deletions
  1. +28
    -25
      imperative/python/megengine/tools/network_visualize.py
  2. +143
    -76
      imperative/python/megengine/utils/module_stats.py
  3. +67
    -47
      imperative/python/megengine/utils/network_node.py

+ 28
- 25
imperative/python/megengine/tools/network_visualize.py View File

@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import json
import logging

import numpy as np
@@ -14,6 +15,7 @@ import numpy as np
from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import (
get_flops_stats,
get_param_stats,
print_flops_stats,
print_params_stats,
@@ -89,6 +91,7 @@ def visualize(

inp_list = [process_name(var.owner.name) for var in node.inputs]
if log_path:
# detail format see tensorboard/compat/proto/attr_value.proto
attr = {
"_output_shapes": AttrValue(
list=AttrValue.ListValue(
@@ -101,24 +104,20 @@ def visualize(
]
)
),
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
if hasattr(node, "calc_flops"):
flops_num = node.calc_flops()
flops_stats = get_flops_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path:
if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_num).encode(encoding="utf-8")
)
flops_list.append(
dict(
name=node.name,
class_name=node.type,
input_shapes=[i.shape for i in node.inputs],
output_shapes=[o.shape for o in node.outputs],
flops_num=flops_num,
flops_cum=0,
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
)
)
flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)

if node.type == "ImmutableTensor":
param_stats = get_param_stats(node.numpy())
# add tensor size attr
@@ -132,6 +131,7 @@ def visualize(
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue

if log_path:
node_list.append(
NodeDef(
@@ -141,14 +141,26 @@ def visualize(
attr=attr,
)
)
# summary
extra_info = {
"#ops": len(graph.all_oprs),
"#params": len(params_list),
}

total_flops, total_params = None, None
total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params:
total_param_dims, total_param_size = print_params_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)

if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
@@ -160,21 +172,12 @@ def visualize(
writer = SummaryWriter(log_path)
writer._get_file_writer().add_graph((graph_def, stepstats))

# summary
extra_info = {
"#ops": len(graph.all_oprs),
"#params": len(params_list),
"total_param_dims": sizeof_fmt(total_param_dims),
"total_param_size": sizeof_fmt(total_param_size),
"total_flops": sizeof_fmt(total_flops, suffix="OPs"),
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size),
}
print_summary(**extra_info)

# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level)

return total_params, total_flops
return total_param_size, total_flops


def main():


+ 143
- 76
imperative/python/megengine/utils/module_stats.py View File

@@ -26,61 +26,95 @@ logger = mge.get_logger(__name__)
logger.setLevel("INFO")


CALC_FLOPS = {}


def _register_modules(*modules):
_calc_flops_dict = {}
_calc_receptive_field_dict = {}


def _receptive_field_fallback(module, inputs, outputs):
assert not hasattr(module, "_rf")
assert not hasattr(module, "_stride")
if len(inputs) == 0:
# TODO: support other dimension
module._rf = (1, 1)
module._stride = (1, 1)
return module._rf, module._stride
rf, stride = preprocess_receptive_field(module, inputs, outputs)
module._rf = rf
module._stride = stride
return rf, stride


# key tuple, impl_dict, fallback
_iter_list = [
("flops_num", _calc_flops_dict, None),
(
("receptive_field", "stride"),
_calc_receptive_field_dict,
_receptive_field_fallback,
),
]


def _register_dict(*modules, dict=None):
def callback(impl):
for module in modules:
CALC_FLOPS[module] = impl
dict[module] = impl
return impl

return callback


@_register_modules(
m.Conv2d,
m.ConvTranspose2d,
m.LocalConv2d,
qm.Conv2d,
qm.ConvRelu2d,
qm.ConvBn2d,
qm.ConvBnRelu2d,
qatm.Conv2d,
qatm.ConvRelu2d,
qatm.ConvBn2d,
qatm.ConvBnRelu2d,
def register_flops(*modules):
return _register_dict(*modules, dict=_calc_flops_dict)


def register_receptive_field(*modules):
return _register_dict(*modules, dict=_calc_receptive_field_dict)


@register_flops(
m.Conv1d, m.Conv2d, m.Conv3d,
)
def count_convNd(module, input, output):
def flops_convNd(module: m.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0
group = module.groups
ic = input[0].shape[1]
oc = output[0].shape[1]
ic = inputs[0].shape[1]
oc = outputs[0].shape[1]
goc = oc // group
gic = ic // group
N = output[0].shape[0]
HW = np.prod(output[0].shape[2:])
N = outputs[0].shape[0]
HW = np.prod(outputs[0].shape[2:])
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias)


@_register_modules(m.ConvTranspose2d)
def count_deconvNd(module, input, output):
return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size)
@register_flops(m.ConvTranspose2d)
def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs):
return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size)


@register_flops(m.Linear)
def flops_linear(module: m.Linear, inputs, outputs):
bias = 1 if module.bias is not None else 0
return np.prod(outputs[0].shape) * module.in_features

@_register_modules(m.Linear, qatm.Linear, qm.Linear)
def count_linear(module, input, output):
return np.prod(output[0].shape) * module.in_features

@register_flops(m.BatchMatMulActivation)
def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs):
bias = 1 if module.bias is not None else 0
x = inputs[0]
w = module.weight
batch_size = x.shape[0]
n, p = x.shape[1:]
_, m = w.shape[1:]
return n * (p + bias) * m * batch_size


# does not need import qat and quantized module since they inherit from float module.
hook_modules = (
m.Conv2d,
m.ConvTranspose2d,
m.LocalConv2d,
m.BatchNorm2d,
m.conv._ConvNd,
m.Linear,
m.BatchMatMulActivation,
)


@@ -106,28 +140,71 @@ def sizeof_fmt(num, suffix="B"):
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix)


def preprocess_receptive_field(module, inputs, outputs):
# TODO: support other dimensions
pre_rf = (
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs),
max(i.owner._rf[1] for i in inputs),
)
pre_stride = (
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs),
max(i.owner._stride[1] for i in inputs),
)
return pre_rf, pre_stride


def get_flops_stats(module, inputs, outputs):
rst = {
"input_shapes": [i.shape for i in inputs],
"output_shapes": [o.shape for o in outputs],
}
valid_flag = False
for key, _dict, fallback in _iter_list:
for _type in _dict:
if isinstance(module, _type):
value = _dict[_type](module, inputs, outputs)
valid_flag = True
break
else:
if fallback is not None:
value = fallback(module, inputs, outputs)
continue

if isinstance(key, tuple):
assert isinstance(value, tuple)
for k, v in zip(key, value):
rst[k] = v
else:
rst[key] = value

if valid_flag:
return rst
else:
return None
return


def print_flops_stats(flops, bar_length_max=20):
flops_list = [i["flops_num"] for i in flops]
max_flops_num = max(flops_list + [0])
# calc total flops and set flops_cum
max_flops_num = max([i["flops_num"] for i in flops] + [0])
total_flops_num = 0
for d in flops:
total_flops_num += int(d["flops_num"])
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs")

for d in flops:
f = d["flops_num"]
d["flops"] = sizeof_fmt(f, suffix="OPs")
r = d["ratio"] = f / total_flops_num
d["percentage"] = "{:.2f}%".format(r * 100)
bar_length = int(f / max_flops_num * bar_length_max)
ratio = d["ratio"] = d["flops_num"] / total_flops_num
d["percentage"] = "{:.2f}%".format(ratio * 100)
bar_length = int(d["flops_num"] / max_flops_num * bar_length_max)
d["bar"] = "#" * bar_length
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs")

header = [
"name",
"class_name",
"input_shapes",
"output_shapes",
"receptive_field",
"stride",
"flops",
"flops_cum",
"percentage",
@@ -154,8 +231,8 @@ def get_param_stats(param: np.ndarray):
param_size = param_dim * nbits // 8
return {
"shape": shape,
"mean": param.mean(),
"std": param.std(),
"mean": "{:.3g}".format(param.mean()),
"std": "{:.3g}".format(param.std()),
"param_dim": param_dim,
"nbits": nbits,
"size": param_size,
@@ -163,21 +240,20 @@ def get_param_stats(param: np.ndarray):


def print_params_stats(params, bar_length_max=20):
max_size = max([d["size"] for d in params] + [0])
total_param_dims, total_param_size = 0, 0
for d in params:
total_param_dims += int(d["param_dim"])
total_param_size += int(d["size"])
ratio = d["size"] / total_param_size
d["size"] = sizeof_fmt(d["size"])
d["size_cum"] = sizeof_fmt(total_param_size)
d["ratio"] = ratio
d["percentage"] = "{:.2f}%".format(ratio * 100)

# construct bar
max_ratio = max([d["ratio"] for d in params])
for d in params:
bar_length = int(d["ratio"] / max_ratio * bar_length_max)
ratio = d["size"] / total_param_size
d["ratio"] = ratio
d["percentage"] = "{:.2f}%".format(ratio * 100)
bar_length = int(d["size"] / max_size * bar_length_max)
d["size_bar"] = "#" * bar_length
d["size"] = sizeof_fmt(d["size"])

param_size = sizeof_fmt(total_param_size)
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,))
@@ -225,26 +301,14 @@ def module_stats(
:param log_flops: whether print and record op flops.
"""

def module_stats_hook(module, input, output, name=""):
def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]

flops_fun = CALC_FLOPS.get(type(module))
if callable(flops_fun):
flops_num = flops_fun(module, input, output)

if not isinstance(output, (list, tuple)):
output = [output]

flops.append(
dict(
name=name,
class_name=class_name,
input_shapes=[i.shape for i in input],
output_shapes=[o.shape for o in output],
flops_num=flops_num,
flops_cum=0,
)
)
flops_stats = get_flops_stats(module, inputs, outputs)
if flops_stats is not None:
flops_stats["name"] = name
flops_stats["class_name"] = class_name
flops.append(flops_stats)

if hasattr(module, "weight") and module.weight is not None:
w = module.weight
@@ -278,19 +342,22 @@ def module_stats(
for h in hooks:
h.remove()

total_flops, total_params = 0, 0
extra_info = {
"#params": len(params),
}
total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params:
total_param_dims, total_param_size = print_params_stats(params, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops:
total_flops = print_flops_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)

extra_info = {
"#params": len(params),
"total_param_dims": sizeof_fmt(total_param_dims),
"total_param_size": sizeof_fmt(total_param_size),
"total_flops": sizeof_fmt(total_flops, suffix="OPs"),
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size),
}
print_summary(**extra_info)

return total_params, total_flops
return total_param_size, total_flops

+ 67
- 47
imperative/python/megengine/utils/network_node.py View File

@@ -18,6 +18,11 @@ from ..core.ops import builtin
from ..core.tensor.megbrain_graph import InputNode
from ..tensor import Tensor
from .comp_graph_tools import replace_vars
from .module_stats import (
preprocess_receptive_field,
register_flops,
register_receptive_field,
)


class NetworkNode:
@@ -225,8 +230,21 @@ class Elemwise(OpNode):
type = "Elemwise"
opdef = builtin.Elemwise

def calc_flops(self):
return np.prod(self.outputs[0].shape)

class ElemwiseMultiType(OpNode):
type = "ElemwiseMultiType"
opdef = builtin.ElemwiseMultiType

@classmethod
def load(cls, opr):
obj = super(ElemwiseMultiType, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj


@register_flops(Elemwise, ElemwiseMultiType)
def flops_elemwise(opnode: Elemwise, inputs, outputs):
return np.prod(outputs[0].shape)


class Reduce(OpNode):
@@ -255,20 +273,24 @@ class MatrixMul(OpNode):
type = "MatrixMul"
opdef = builtin.MatrixMul

def calc_flops(self):
assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2
mid_shape = self.inputs[0].shape[1]
return np.prod(self.outputs[0].shape) * mid_shape

@register_flops(MatrixMul)
def flops_matmul(opnode: MatrixMul, inputs, outputs):
assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
mid_shape = inputs[0].shape[1]
return np.prod(outputs[0].shape) * mid_shape


class BatchedMatrixMul(OpNode):
type = "BatchedMatmul"
opdef = builtin.BatchedMatrixMul

def calc_flops(self):
assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3
mid_shape = self.inputs[0].shape[2]
return np.prod(self.outputs[0].shape) * mid_shape

@register_flops(BatchedMatrixMul)
def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
mid_shape = inputs[0].shape[2]
return np.prod(outputs[0].shape) * mid_shape


class Dot(OpNode):
@@ -285,18 +307,6 @@ class ConvolutionForward(OpNode):
type = "Convolution"
opdef = builtin.Convolution

def calc_flops(self):
param_W_shape = self.inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(self.outputs[0].shape)
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh)


class ConvolutionBackwardData(OpNode):
type = "ConvTranspose"
@@ -343,17 +353,41 @@ class ConvBiasForward(OpNode):
obj.params["dtype"] = opr.outputs[0].dtype
return obj

def calc_flops(self):
param_W_shape = self.inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(self.outputs[0].shape)
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return NCHW * (num_input * kw * kh + 1)

@register_flops(
ConvolutionForward, ConvBiasForward,
)
def flops_conv(opnode: ConvolutionForward, inputs, outputs):
param_W_shape = inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(outputs[0].shape)
bias = 1 if isinstance(opnode, ConvBiasForward) else 0
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh + bias)


@register_receptive_field(ConvolutionForward, ConvBiasForward)
def receptive_field(opnode: ConvolutionForward, inputs, outputs):
pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
param_W_shape = inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
rf = (
kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
)
stride = (
opnode.params["stride_h"] * pre_stride[0],
opnode.params["stride_w"] * pre_stride[1],
)
opnode._rf = rf
opnode._stride = stride
return rf, stride


class BatchConvBiasForward(OpNode):
@@ -652,20 +686,6 @@ class AssertEqual(OpNode):
opdef = builtin.AssertEqual


class ElemwiseMultiType(OpNode):
type = "ElemwiseMultiType"
opdef = builtin.ElemwiseMultiType

@classmethod
def load(cls, opr):
obj = super(ElemwiseMultiType, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj

def calc_flops(self):
return np.prod(self.outputs[0].shape)


class CvtColorForward(OpNode):
type = "CvtColor"
opdef = builtin.CvtColor

Loading…
Cancel
Save