GitOrigin-RevId: 11ef335468
tags/v1.3.1
@@ -7,6 +7,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import argparse | import argparse | ||||
import json | |||||
import logging | import logging | ||||
import numpy as np | import numpy as np | ||||
@@ -14,6 +15,7 @@ import numpy as np | |||||
from megengine.core.tensor.dtype import is_quantize | from megengine.core.tensor.dtype import is_quantize | ||||
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | ||||
from megengine.utils.module_stats import ( | from megengine.utils.module_stats import ( | ||||
get_flops_stats, | |||||
get_param_stats, | get_param_stats, | ||||
print_flops_stats, | print_flops_stats, | ||||
print_params_stats, | print_params_stats, | ||||
@@ -89,6 +91,7 @@ def visualize( | |||||
inp_list = [process_name(var.owner.name) for var in node.inputs] | inp_list = [process_name(var.owner.name) for var in node.inputs] | ||||
if log_path: | if log_path: | ||||
# detail format see tensorboard/compat/proto/attr_value.proto | |||||
attr = { | attr = { | ||||
"_output_shapes": AttrValue( | "_output_shapes": AttrValue( | ||||
list=AttrValue.ListValue( | list=AttrValue.ListValue( | ||||
@@ -101,24 +104,20 @@ def visualize( | |||||
] | ] | ||||
) | ) | ||||
), | ), | ||||
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")), | |||||
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), | |||||
} | } | ||||
if hasattr(node, "calc_flops"): | |||||
flops_num = node.calc_flops() | |||||
flops_stats = get_flops_stats(node, node.inputs, node.outputs) | |||||
if flops_stats is not None: | |||||
# add op flops attr | # add op flops attr | ||||
if log_path: | |||||
if log_path and hasattr(flops_stats, "flops_num"): | |||||
attr["flops"] = AttrValue( | attr["flops"] = AttrValue( | ||||
s=sizeof_fmt(flops_num).encode(encoding="utf-8") | |||||
) | |||||
flops_list.append( | |||||
dict( | |||||
name=node.name, | |||||
class_name=node.type, | |||||
input_shapes=[i.shape for i in node.inputs], | |||||
output_shapes=[o.shape for o in node.outputs], | |||||
flops_num=flops_num, | |||||
flops_cum=0, | |||||
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||||
) | ) | ||||
) | |||||
flops_stats["name"] = node.name | |||||
flops_stats["class_name"] = node.type | |||||
flops_list.append(flops_stats) | |||||
if node.type == "ImmutableTensor": | if node.type == "ImmutableTensor": | ||||
param_stats = get_param_stats(node.numpy()) | param_stats = get_param_stats(node.numpy()) | ||||
# add tensor size attr | # add tensor size attr | ||||
@@ -132,6 +131,7 @@ def visualize( | |||||
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug | # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug | ||||
if not len(node.name.split(".")) > 2 and not node in graph.input_vars: | if not len(node.name.split(".")) > 2 and not node in graph.input_vars: | ||||
continue | continue | ||||
if log_path: | if log_path: | ||||
node_list.append( | node_list.append( | ||||
NodeDef( | NodeDef( | ||||
@@ -141,14 +141,26 @@ def visualize( | |||||
attr=attr, | attr=attr, | ||||
) | ) | ||||
) | ) | ||||
# summary | |||||
extra_info = { | |||||
"#ops": len(graph.all_oprs), | |||||
"#params": len(params_list), | |||||
} | |||||
total_flops, total_params = None, None | |||||
total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||||
if log_params: | if log_params: | ||||
total_param_dims, total_param_size = print_params_stats( | total_param_dims, total_param_size = print_params_stats( | ||||
params_list, bar_length_max | params_list, bar_length_max | ||||
) | ) | ||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
if log_flops: | if log_flops: | ||||
total_flops = print_flops_stats(flops_list, bar_length_max) | total_flops = print_flops_stats(flops_list, bar_length_max) | ||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if log_params and log_flops: | |||||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||||
total_flops / total_param_size | |||||
) | |||||
if log_path: | if log_path: | ||||
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | ||||
@@ -160,21 +172,12 @@ def visualize( | |||||
writer = SummaryWriter(log_path) | writer = SummaryWriter(log_path) | ||||
writer._get_file_writer().add_graph((graph_def, stepstats)) | writer._get_file_writer().add_graph((graph_def, stepstats)) | ||||
# summary | |||||
extra_info = { | |||||
"#ops": len(graph.all_oprs), | |||||
"#params": len(params_list), | |||||
"total_param_dims": sizeof_fmt(total_param_dims), | |||||
"total_param_size": sizeof_fmt(total_param_size), | |||||
"total_flops": sizeof_fmt(total_flops, suffix="OPs"), | |||||
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size), | |||||
} | |||||
print_summary(**extra_info) | print_summary(**extra_info) | ||||
# FIXME: remove this after resolving "span dist too large" warning | # FIXME: remove this after resolving "span dist too large" warning | ||||
_imperative_rt_logger.set_log_level(old_level) | _imperative_rt_logger.set_log_level(old_level) | ||||
return total_params, total_flops | |||||
return total_param_size, total_flops | |||||
def main(): | def main(): | ||||
@@ -26,61 +26,95 @@ logger = mge.get_logger(__name__) | |||||
logger.setLevel("INFO") | logger.setLevel("INFO") | ||||
CALC_FLOPS = {} | |||||
def _register_modules(*modules): | |||||
_calc_flops_dict = {} | |||||
_calc_receptive_field_dict = {} | |||||
def _receptive_field_fallback(module, inputs, outputs): | |||||
assert not hasattr(module, "_rf") | |||||
assert not hasattr(module, "_stride") | |||||
if len(inputs) == 0: | |||||
# TODO: support other dimension | |||||
module._rf = (1, 1) | |||||
module._stride = (1, 1) | |||||
return module._rf, module._stride | |||||
rf, stride = preprocess_receptive_field(module, inputs, outputs) | |||||
module._rf = rf | |||||
module._stride = stride | |||||
return rf, stride | |||||
# key tuple, impl_dict, fallback | |||||
_iter_list = [ | |||||
("flops_num", _calc_flops_dict, None), | |||||
( | |||||
("receptive_field", "stride"), | |||||
_calc_receptive_field_dict, | |||||
_receptive_field_fallback, | |||||
), | |||||
] | |||||
def _register_dict(*modules, dict=None): | |||||
def callback(impl): | def callback(impl): | ||||
for module in modules: | for module in modules: | ||||
CALC_FLOPS[module] = impl | |||||
dict[module] = impl | |||||
return impl | return impl | ||||
return callback | return callback | ||||
@_register_modules( | |||||
m.Conv2d, | |||||
m.ConvTranspose2d, | |||||
m.LocalConv2d, | |||||
qm.Conv2d, | |||||
qm.ConvRelu2d, | |||||
qm.ConvBn2d, | |||||
qm.ConvBnRelu2d, | |||||
qatm.Conv2d, | |||||
qatm.ConvRelu2d, | |||||
qatm.ConvBn2d, | |||||
qatm.ConvBnRelu2d, | |||||
def register_flops(*modules): | |||||
return _register_dict(*modules, dict=_calc_flops_dict) | |||||
def register_receptive_field(*modules): | |||||
return _register_dict(*modules, dict=_calc_receptive_field_dict) | |||||
@register_flops( | |||||
m.Conv1d, m.Conv2d, m.Conv3d, | |||||
) | ) | ||||
def count_convNd(module, input, output): | |||||
def flops_convNd(module: m.Conv2d, inputs, outputs): | |||||
bias = 1 if module.bias is not None else 0 | bias = 1 if module.bias is not None else 0 | ||||
group = module.groups | group = module.groups | ||||
ic = input[0].shape[1] | |||||
oc = output[0].shape[1] | |||||
ic = inputs[0].shape[1] | |||||
oc = outputs[0].shape[1] | |||||
goc = oc // group | goc = oc // group | ||||
gic = ic // group | gic = ic // group | ||||
N = output[0].shape[0] | |||||
HW = np.prod(output[0].shape[2:]) | |||||
N = outputs[0].shape[0] | |||||
HW = np.prod(outputs[0].shape[2:]) | |||||
# N x Cout x H x W x (Cin x Kw x Kh + bias) | # N x Cout x H x W x (Cin x Kw x Kh + bias) | ||||
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) | return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) | ||||
@_register_modules(m.ConvTranspose2d) | |||||
def count_deconvNd(module, input, output): | |||||
return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size) | |||||
@register_flops(m.ConvTranspose2d) | |||||
def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs): | |||||
return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size) | |||||
@register_flops(m.Linear) | |||||
def flops_linear(module: m.Linear, inputs, outputs): | |||||
bias = 1 if module.bias is not None else 0 | |||||
return np.prod(outputs[0].shape) * module.in_features | |||||
@_register_modules(m.Linear, qatm.Linear, qm.Linear) | |||||
def count_linear(module, input, output): | |||||
return np.prod(output[0].shape) * module.in_features | |||||
@register_flops(m.BatchMatMulActivation) | |||||
def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs): | |||||
bias = 1 if module.bias is not None else 0 | |||||
x = inputs[0] | |||||
w = module.weight | |||||
batch_size = x.shape[0] | |||||
n, p = x.shape[1:] | |||||
_, m = w.shape[1:] | |||||
return n * (p + bias) * m * batch_size | |||||
# does not need import qat and quantized module since they inherit from float module. | # does not need import qat and quantized module since they inherit from float module. | ||||
hook_modules = ( | hook_modules = ( | ||||
m.Conv2d, | |||||
m.ConvTranspose2d, | |||||
m.LocalConv2d, | |||||
m.BatchNorm2d, | |||||
m.conv._ConvNd, | |||||
m.Linear, | m.Linear, | ||||
m.BatchMatMulActivation, | |||||
) | ) | ||||
@@ -106,28 +140,71 @@ def sizeof_fmt(num, suffix="B"): | |||||
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | ||||
def preprocess_receptive_field(module, inputs, outputs): | |||||
# TODO: support other dimensions | |||||
pre_rf = ( | |||||
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs), | |||||
max(i.owner._rf[1] for i in inputs), | |||||
) | |||||
pre_stride = ( | |||||
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs), | |||||
max(i.owner._stride[1] for i in inputs), | |||||
) | |||||
return pre_rf, pre_stride | |||||
def get_flops_stats(module, inputs, outputs): | |||||
rst = { | |||||
"input_shapes": [i.shape for i in inputs], | |||||
"output_shapes": [o.shape for o in outputs], | |||||
} | |||||
valid_flag = False | |||||
for key, _dict, fallback in _iter_list: | |||||
for _type in _dict: | |||||
if isinstance(module, _type): | |||||
value = _dict[_type](module, inputs, outputs) | |||||
valid_flag = True | |||||
break | |||||
else: | |||||
if fallback is not None: | |||||
value = fallback(module, inputs, outputs) | |||||
continue | |||||
if isinstance(key, tuple): | |||||
assert isinstance(value, tuple) | |||||
for k, v in zip(key, value): | |||||
rst[k] = v | |||||
else: | |||||
rst[key] = value | |||||
if valid_flag: | |||||
return rst | |||||
else: | |||||
return None | |||||
return | |||||
def print_flops_stats(flops, bar_length_max=20): | def print_flops_stats(flops, bar_length_max=20): | ||||
flops_list = [i["flops_num"] for i in flops] | |||||
max_flops_num = max(flops_list + [0]) | |||||
# calc total flops and set flops_cum | |||||
max_flops_num = max([i["flops_num"] for i in flops] + [0]) | |||||
total_flops_num = 0 | total_flops_num = 0 | ||||
for d in flops: | for d in flops: | ||||
total_flops_num += int(d["flops_num"]) | total_flops_num += int(d["flops_num"]) | ||||
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | ||||
for d in flops: | for d in flops: | ||||
f = d["flops_num"] | |||||
d["flops"] = sizeof_fmt(f, suffix="OPs") | |||||
r = d["ratio"] = f / total_flops_num | |||||
d["percentage"] = "{:.2f}%".format(r * 100) | |||||
bar_length = int(f / max_flops_num * bar_length_max) | |||||
ratio = d["ratio"] = d["flops_num"] / total_flops_num | |||||
d["percentage"] = "{:.2f}%".format(ratio * 100) | |||||
bar_length = int(d["flops_num"] / max_flops_num * bar_length_max) | |||||
d["bar"] = "#" * bar_length | d["bar"] = "#" * bar_length | ||||
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") | |||||
header = [ | header = [ | ||||
"name", | "name", | ||||
"class_name", | "class_name", | ||||
"input_shapes", | "input_shapes", | ||||
"output_shapes", | "output_shapes", | ||||
"receptive_field", | |||||
"stride", | |||||
"flops", | "flops", | ||||
"flops_cum", | "flops_cum", | ||||
"percentage", | "percentage", | ||||
@@ -154,8 +231,8 @@ def get_param_stats(param: np.ndarray): | |||||
param_size = param_dim * nbits // 8 | param_size = param_dim * nbits // 8 | ||||
return { | return { | ||||
"shape": shape, | "shape": shape, | ||||
"mean": param.mean(), | |||||
"std": param.std(), | |||||
"mean": "{:.3g}".format(param.mean()), | |||||
"std": "{:.3g}".format(param.std()), | |||||
"param_dim": param_dim, | "param_dim": param_dim, | ||||
"nbits": nbits, | "nbits": nbits, | ||||
"size": param_size, | "size": param_size, | ||||
@@ -163,21 +240,20 @@ def get_param_stats(param: np.ndarray): | |||||
def print_params_stats(params, bar_length_max=20): | def print_params_stats(params, bar_length_max=20): | ||||
max_size = max([d["size"] for d in params] + [0]) | |||||
total_param_dims, total_param_size = 0, 0 | total_param_dims, total_param_size = 0, 0 | ||||
for d in params: | for d in params: | ||||
total_param_dims += int(d["param_dim"]) | total_param_dims += int(d["param_dim"]) | ||||
total_param_size += int(d["size"]) | total_param_size += int(d["size"]) | ||||
ratio = d["size"] / total_param_size | |||||
d["size"] = sizeof_fmt(d["size"]) | |||||
d["size_cum"] = sizeof_fmt(total_param_size) | d["size_cum"] = sizeof_fmt(total_param_size) | ||||
d["ratio"] = ratio | |||||
d["percentage"] = "{:.2f}%".format(ratio * 100) | |||||
# construct bar | |||||
max_ratio = max([d["ratio"] for d in params]) | |||||
for d in params: | for d in params: | ||||
bar_length = int(d["ratio"] / max_ratio * bar_length_max) | |||||
ratio = d["size"] / total_param_size | |||||
d["ratio"] = ratio | |||||
d["percentage"] = "{:.2f}%".format(ratio * 100) | |||||
bar_length = int(d["size"] / max_size * bar_length_max) | |||||
d["size_bar"] = "#" * bar_length | d["size_bar"] = "#" * bar_length | ||||
d["size"] = sizeof_fmt(d["size"]) | |||||
param_size = sizeof_fmt(total_param_size) | param_size = sizeof_fmt(total_param_size) | ||||
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | ||||
@@ -225,26 +301,14 @@ def module_stats( | |||||
:param log_flops: whether print and record op flops. | :param log_flops: whether print and record op flops. | ||||
""" | """ | ||||
def module_stats_hook(module, input, output, name=""): | |||||
def module_stats_hook(module, inputs, outputs, name=""): | |||||
class_name = str(module.__class__).split(".")[-1].split("'")[0] | class_name = str(module.__class__).split(".")[-1].split("'")[0] | ||||
flops_fun = CALC_FLOPS.get(type(module)) | |||||
if callable(flops_fun): | |||||
flops_num = flops_fun(module, input, output) | |||||
if not isinstance(output, (list, tuple)): | |||||
output = [output] | |||||
flops.append( | |||||
dict( | |||||
name=name, | |||||
class_name=class_name, | |||||
input_shapes=[i.shape for i in input], | |||||
output_shapes=[o.shape for o in output], | |||||
flops_num=flops_num, | |||||
flops_cum=0, | |||||
) | |||||
) | |||||
flops_stats = get_flops_stats(module, inputs, outputs) | |||||
if flops_stats is not None: | |||||
flops_stats["name"] = name | |||||
flops_stats["class_name"] = class_name | |||||
flops.append(flops_stats) | |||||
if hasattr(module, "weight") and module.weight is not None: | if hasattr(module, "weight") and module.weight is not None: | ||||
w = module.weight | w = module.weight | ||||
@@ -278,19 +342,22 @@ def module_stats( | |||||
for h in hooks: | for h in hooks: | ||||
h.remove() | h.remove() | ||||
total_flops, total_params = 0, 0 | |||||
extra_info = { | |||||
"#params": len(params), | |||||
} | |||||
total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||||
if log_params: | if log_params: | ||||
total_param_dims, total_param_size = print_params_stats(params, bar_length_max) | total_param_dims, total_param_size = print_params_stats(params, bar_length_max) | ||||
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||||
extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
if log_flops: | if log_flops: | ||||
total_flops = print_flops_stats(flops, bar_length_max) | total_flops = print_flops_stats(flops, bar_length_max) | ||||
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
if log_params and log_flops: | |||||
extra_info["flops/param_size"] = "{:3.3f}".format( | |||||
total_flops / total_param_size | |||||
) | |||||
extra_info = { | |||||
"#params": len(params), | |||||
"total_param_dims": sizeof_fmt(total_param_dims), | |||||
"total_param_size": sizeof_fmt(total_param_size), | |||||
"total_flops": sizeof_fmt(total_flops, suffix="OPs"), | |||||
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size), | |||||
} | |||||
print_summary(**extra_info) | print_summary(**extra_info) | ||||
return total_params, total_flops | |||||
return total_param_size, total_flops |
@@ -18,6 +18,11 @@ from ..core.ops import builtin | |||||
from ..core.tensor.megbrain_graph import InputNode | from ..core.tensor.megbrain_graph import InputNode | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .comp_graph_tools import replace_vars | from .comp_graph_tools import replace_vars | ||||
from .module_stats import ( | |||||
preprocess_receptive_field, | |||||
register_flops, | |||||
register_receptive_field, | |||||
) | |||||
class NetworkNode: | class NetworkNode: | ||||
@@ -225,8 +230,21 @@ class Elemwise(OpNode): | |||||
type = "Elemwise" | type = "Elemwise" | ||||
opdef = builtin.Elemwise | opdef = builtin.Elemwise | ||||
def calc_flops(self): | |||||
return np.prod(self.outputs[0].shape) | |||||
class ElemwiseMultiType(OpNode): | |||||
type = "ElemwiseMultiType" | |||||
opdef = builtin.ElemwiseMultiType | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(ElemwiseMultiType, cls).load(opr) | |||||
obj.params["dtype"] = opr.outputs[0].dtype | |||||
return obj | |||||
@register_flops(Elemwise, ElemwiseMultiType) | |||||
def flops_elemwise(opnode: Elemwise, inputs, outputs): | |||||
return np.prod(outputs[0].shape) | |||||
class Reduce(OpNode): | class Reduce(OpNode): | ||||
@@ -255,20 +273,24 @@ class MatrixMul(OpNode): | |||||
type = "MatrixMul" | type = "MatrixMul" | ||||
opdef = builtin.MatrixMul | opdef = builtin.MatrixMul | ||||
def calc_flops(self): | |||||
assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2 | |||||
mid_shape = self.inputs[0].shape[1] | |||||
return np.prod(self.outputs[0].shape) * mid_shape | |||||
@register_flops(MatrixMul) | |||||
def flops_matmul(opnode: MatrixMul, inputs, outputs): | |||||
assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2 | |||||
mid_shape = inputs[0].shape[1] | |||||
return np.prod(outputs[0].shape) * mid_shape | |||||
class BatchedMatrixMul(OpNode): | class BatchedMatrixMul(OpNode): | ||||
type = "BatchedMatmul" | type = "BatchedMatmul" | ||||
opdef = builtin.BatchedMatrixMul | opdef = builtin.BatchedMatrixMul | ||||
def calc_flops(self): | |||||
assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3 | |||||
mid_shape = self.inputs[0].shape[2] | |||||
return np.prod(self.outputs[0].shape) * mid_shape | |||||
@register_flops(BatchedMatrixMul) | |||||
def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs): | |||||
assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3 | |||||
mid_shape = inputs[0].shape[2] | |||||
return np.prod(outputs[0].shape) * mid_shape | |||||
class Dot(OpNode): | class Dot(OpNode): | ||||
@@ -285,18 +307,6 @@ class ConvolutionForward(OpNode): | |||||
type = "Convolution" | type = "Convolution" | ||||
opdef = builtin.Convolution | opdef = builtin.Convolution | ||||
def calc_flops(self): | |||||
param_W_shape = self.inputs[1].shape | |||||
kh = param_W_shape[-2] | |||||
kw = param_W_shape[-1] | |||||
if len(param_W_shape) == 5: | |||||
num_input = param_W_shape[2] | |||||
else: | |||||
num_input = param_W_shape[1] | |||||
NCHW = np.prod(self.outputs[0].shape) | |||||
# N x Cout x H x W x (Cin x Kw x Kh) | |||||
return NCHW * (num_input * kw * kh) | |||||
class ConvolutionBackwardData(OpNode): | class ConvolutionBackwardData(OpNode): | ||||
type = "ConvTranspose" | type = "ConvTranspose" | ||||
@@ -343,17 +353,41 @@ class ConvBiasForward(OpNode): | |||||
obj.params["dtype"] = opr.outputs[0].dtype | obj.params["dtype"] = opr.outputs[0].dtype | ||||
return obj | return obj | ||||
def calc_flops(self): | |||||
param_W_shape = self.inputs[1].shape | |||||
kh = param_W_shape[-2] | |||||
kw = param_W_shape[-1] | |||||
if len(param_W_shape) == 5: | |||||
num_input = param_W_shape[2] | |||||
else: | |||||
num_input = param_W_shape[1] | |||||
NCHW = np.prod(self.outputs[0].shape) | |||||
# N x Cout x H x W x (Cin x Kw x Kh + bias) | |||||
return NCHW * (num_input * kw * kh + 1) | |||||
@register_flops( | |||||
ConvolutionForward, ConvBiasForward, | |||||
) | |||||
def flops_conv(opnode: ConvolutionForward, inputs, outputs): | |||||
param_W_shape = inputs[1].shape | |||||
kh = param_W_shape[-2] | |||||
kw = param_W_shape[-1] | |||||
if len(param_W_shape) == 5: | |||||
num_input = param_W_shape[2] | |||||
else: | |||||
num_input = param_W_shape[1] | |||||
NCHW = np.prod(outputs[0].shape) | |||||
bias = 1 if isinstance(opnode, ConvBiasForward) else 0 | |||||
# N x Cout x H x W x (Cin x Kw x Kh) | |||||
return NCHW * (num_input * kw * kh + bias) | |||||
@register_receptive_field(ConvolutionForward, ConvBiasForward) | |||||
def receptive_field(opnode: ConvolutionForward, inputs, outputs): | |||||
pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs) | |||||
param_W_shape = inputs[1].shape | |||||
kh = param_W_shape[-2] | |||||
kw = param_W_shape[-1] | |||||
rf = ( | |||||
kh * pre_stride[0] + pre_rf[0] - pre_stride[0], | |||||
kw * pre_stride[1] + pre_rf[1] - pre_stride[1], | |||||
) | |||||
stride = ( | |||||
opnode.params["stride_h"] * pre_stride[0], | |||||
opnode.params["stride_w"] * pre_stride[1], | |||||
) | |||||
opnode._rf = rf | |||||
opnode._stride = stride | |||||
return rf, stride | |||||
class BatchConvBiasForward(OpNode): | class BatchConvBiasForward(OpNode): | ||||
@@ -652,20 +686,6 @@ class AssertEqual(OpNode): | |||||
opdef = builtin.AssertEqual | opdef = builtin.AssertEqual | ||||
class ElemwiseMultiType(OpNode): | |||||
type = "ElemwiseMultiType" | |||||
opdef = builtin.ElemwiseMultiType | |||||
@classmethod | |||||
def load(cls, opr): | |||||
obj = super(ElemwiseMultiType, cls).load(opr) | |||||
obj.params["dtype"] = opr.outputs[0].dtype | |||||
return obj | |||||
def calc_flops(self): | |||||
return np.prod(self.outputs[0].shape) | |||||
class CvtColorForward(OpNode): | class CvtColorForward(OpNode): | ||||
type = "CvtColor" | type = "CvtColor" | ||||
opdef = builtin.CvtColor | opdef = builtin.CvtColor |