GitOrigin-RevId: a1ab77c20a
tags/v1.3.0
@@ -0,0 +1,8 @@ | |||
# MegEngine Tools | |||
This directory contains executable python files. | |||
Use these files in the following way (replace `xxx` to specific file name, like `network_visualize`): | |||
``` | |||
python -m megengine.tools.xxx | |||
``` |
@@ -1,3 +1,4 @@ | |||
#! /usr/bin/env python3 | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -7,12 +8,55 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import argparse | |||
import os | |||
import struct | |||
import textwrap | |||
from pathlib import Path | |||
import numpy as np | |||
from megengine.utils import plugin | |||
def load_tensor_binary(fobj): | |||
""" | |||
Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual | |||
tensor value dump is implemented by ``mgb::debug::dump_tensor``. | |||
:param fobj: file object, or a string that contains the file name. | |||
:return: tuple ``(tensor_value, tensor_name)``. | |||
""" | |||
if isinstance(fobj, str): | |||
with open(fobj, "rb") as fin: | |||
return load_tensor_binary(fin) | |||
DTYPE_LIST = { | |||
0: np.float32, | |||
1: np.uint8, | |||
2: np.int8, | |||
3: np.int16, | |||
4: np.int32, | |||
# 5: _mgb.intb1, | |||
# 6: _mgb.intb2, | |||
# 7: _mgb.intb4, | |||
8: None, | |||
9: np.float16, | |||
# quantized dtype start from 100000 | |||
# see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in | |||
# dnn/include/megdnn/dtype.h | |||
100000: np.uint8, | |||
100001: np.int32, | |||
100002: np.int8, | |||
} | |||
header_fmt = struct.Struct("III") | |||
name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size)) | |||
assert ( | |||
DTYPE_LIST[dtype] is not None | |||
), "Cannot load this tensor: dtype Byte is unsupported." | |||
shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4))) | |||
while shape[-1] == 0: | |||
shape.pop(-1) | |||
name = fobj.read(name_len).decode("ascii") | |||
return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name | |||
def check(v0, v1, name, max_err): | |||
@@ -26,9 +70,9 @@ def check(v0, v1, name, max_err): | |||
) | |||
vdiv = np.max([np.abs(v0), np.abs(v1), np.ones_like(v0)], axis=0) | |||
err = np.abs(v0 - v1) / vdiv | |||
check = err > max_err | |||
if check.sum(): | |||
idx = tuple(i[0] for i in np.nonzero(check)) | |||
rst = err > max_err | |||
if rst.sum(): | |||
idx = tuple(i[0] for i in np.nonzero(rst)) | |||
raise AssertionError( | |||
"{} not equal: " | |||
"shape={} nonequal_idx={} v0={} v1={} err={}".format( | |||
@@ -79,8 +123,8 @@ def main(): | |||
files1 = sorted(files1) | |||
for i, j in zip(files0, files1): | |||
val0, name0 = plugin.load_tensor_binary(i) | |||
val1, name1 = plugin.load_tensor_binary(j) | |||
val0, name0 = load_tensor_binary(i) | |||
val1, name1 = load_tensor_binary(j) | |||
name = "{}: \n{}\n{}\n".format( | |||
i, "\n ".join(textwrap.wrap(name0)), "\n ".join(textwrap.wrap(name1)) | |||
) |
@@ -0,0 +1,176 @@ | |||
#! /usr/bin/env python3 | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import argparse | |||
import numpy as np | |||
from megengine.core.tensor.dtype import is_quantize | |||
from megengine.logger import get_logger | |||
from megengine.utils.module_stats import ( | |||
print_flops_stats, | |||
print_params_stats, | |||
sizeof_fmt, | |||
) | |||
from megengine.utils.network import Network | |||
logger = get_logger(__name__) | |||
def visualize( | |||
model_path: str, | |||
log_path: str, | |||
bar_length_max: int = 20, | |||
log_params: bool = True, | |||
log_flops: bool = True, | |||
): | |||
r""" | |||
Load megengine dumped model and visualize graph structure with tensorboard log files. | |||
Can also record and print model's statistics like :func:`~.net_stats` | |||
:param model_path: dir path for megengine dumped model. | |||
:param log_path: dir path for tensorboard graph log. | |||
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
:param log_params: whether print and record params size. | |||
:param log_flops: whether print and record op flops. | |||
""" | |||
try: | |||
from tensorboard.compat.proto.attr_value_pb2 import AttrValue | |||
from tensorboard.compat.proto.config_pb2 import RunMetadata | |||
from tensorboard.compat.proto.graph_pb2 import GraphDef | |||
from tensorboard.compat.proto.node_def_pb2 import NodeDef | |||
from tensorboard.compat.proto.step_stats_pb2 import ( | |||
AllocatorMemoryUsed, | |||
DeviceStepStats, | |||
NodeExecStats, | |||
StepStats, | |||
) | |||
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto | |||
from tensorboard.compat.proto.versions_pb2 import VersionDef | |||
from tensorboardX import SummaryWriter | |||
except ImportError: | |||
logger.error( | |||
"TensorBoard and TensorboardX are required for visualize.", exc_info=True | |||
) | |||
return | |||
graph = Network.load(model_path) | |||
writer = SummaryWriter(log_path) | |||
def process_name(name): | |||
return name.replace(".", "/").encode(encoding="utf-8") | |||
node_list = [] | |||
flops_list = [] | |||
params_list = [] | |||
for node in graph.all_oprs: | |||
if hasattr(node, "output_idx"): | |||
node_oup = node.outputs[node.output_idx] | |||
else: | |||
if len(node.outputs) != 1: | |||
logger.warning( | |||
"OpNode {} has more than one output and not has 'output_idx' attr.".format( | |||
node | |||
) | |||
) | |||
node_oup = node.outputs[0] | |||
inp_list = [process_name(var.owner.name) for var in node.inputs] | |||
attr = { | |||
"_output_shapes": AttrValue( | |||
list=AttrValue.ListValue( | |||
shape=[ | |||
TensorShapeProto( | |||
dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape] | |||
) | |||
] | |||
) | |||
), | |||
} | |||
if hasattr(node, "calc_flops"): | |||
flops_num = node.calc_flops() | |||
# add op flops attr | |||
attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8")) | |||
flops_list.append( | |||
dict( | |||
name=node.name, | |||
class_name=node.type, | |||
input_shapes=[i.shape for i in node.inputs], | |||
output_shapes=[o.shape for o in node.outputs], | |||
flops_num=flops_num, | |||
flops_cum=0, | |||
) | |||
) | |||
if node.type == "ImmutableTensor": | |||
param_dim = np.prod(node_oup.shape) | |||
# TODO: consider other quantize dtypes | |||
param_bytes = 1 if is_quantize(node_oup.dtype) else 4 | |||
# add tensor size attr | |||
attr["size"] = AttrValue( | |||
s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") | |||
) | |||
params_list.append( | |||
dict( | |||
name=node.name, | |||
shape=node_oup.shape, | |||
param_dim=param_dim, | |||
bits=param_bytes * 8, | |||
size=param_dim * param_bytes, | |||
size_cum=0, | |||
mean="{:.2g}".format(node.numpy().mean()), | |||
std="{:.2g}".format(node.numpy().std()), | |||
) | |||
) | |||
node_list.append( | |||
NodeDef( | |||
name=process_name(node.name), op=node.type, input=inp_list, attr=attr, | |||
) | |||
) | |||
total_flops, total_params = 0, 0 | |||
if log_params: | |||
total_params = print_params_stats(params_list, bar_length_max) | |||
if log_flops: | |||
total_flops = print_flops_stats(flops_list, bar_length_max) | |||
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||
device = "/device:CPU:0" | |||
stepstats = RunMetadata( | |||
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) | |||
) | |||
writer._get_file_writer().add_graph((graph_def, stepstats)) | |||
return total_params, total_flops | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description="load a megengine dumped model and export log file for tensorboard visualization.", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument("model_path", help="dumped model path.") | |||
parser.add_argument("log_path", help="tensorboard log path.") | |||
parser.add_argument( | |||
"--bar_length_max", | |||
type=int, | |||
default=20, | |||
help="size of bar indicating max flops or parameter size in net stats.", | |||
) | |||
parser.add_argument( | |||
"--log_params", | |||
action="store_true", | |||
help="whether print and record params size.", | |||
) | |||
parser.add_argument( | |||
"--log_flops", action="store_true", help="whether print and record op flops.", | |||
) | |||
visualize(**vars(parser.parse_args())) | |||
if __name__ == "__main__": | |||
main() |
@@ -1,4 +1,4 @@ | |||
# -*- coding: utf-8 -*- | |||
#! /usr/bin/env python3 | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. |
@@ -84,26 +84,125 @@ hook_modules = ( | |||
) | |||
def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=True): | |||
def dict2table(list_of_dict, header): | |||
table_data = [header] | |||
for d in list_of_dict: | |||
row = [] | |||
for h in header: | |||
v = "" | |||
if h in d: | |||
v = d[h] | |||
row.append(v) | |||
table_data.append(row) | |||
return table_data | |||
def sizeof_fmt(num, suffix="B"): | |||
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: | |||
if abs(num) < 1024.0: | |||
return "{:3.3f} {}{}".format(num, unit, suffix) | |||
num /= 1024.0 | |||
sign_str = "-" if num < 0 else "" | |||
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | |||
def dict2table(list_of_dict, header): | |||
table_data = [header] | |||
for d in list_of_dict: | |||
row = [] | |||
for h in header: | |||
v = "" | |||
if h in d: | |||
v = d[h] | |||
row.append(v) | |||
table_data.append(row) | |||
return table_data | |||
def sizeof_fmt(num, suffix="B"): | |||
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: | |||
if abs(num) < 1024.0: | |||
return "{:3.3f} {}{}".format(num, unit, suffix) | |||
num /= 1024.0 | |||
sign_str = "-" if num < 0 else "" | |||
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | |||
def print_flops_stats(flops, bar_length_max=20): | |||
flops_list = [i["flops_num"] for i in flops] | |||
max_flops_num = max(flops_list + [0]) | |||
# calc total flops and set flops_cum | |||
total_flops_num = 0 | |||
for d in flops: | |||
total_flops_num += int(d["flops_num"]) | |||
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | |||
for i in flops: | |||
f = i["flops_num"] | |||
i["flops"] = sizeof_fmt(f, suffix="OPs") | |||
r = i["ratio"] = f / total_flops_num | |||
i["percentage"] = "{:.2f}%".format(r * 100) | |||
bar_length = int(f / max_flops_num * bar_length_max) | |||
i["bar"] = "#" * bar_length | |||
header = [ | |||
"name", | |||
"class_name", | |||
"input_shapes", | |||
"output_shapes", | |||
"flops", | |||
"flops_cum", | |||
"percentage", | |||
"bar", | |||
] | |||
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||
total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops) | |||
flops.append( | |||
dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||
) | |||
logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) | |||
return total_flops_num | |||
def print_params_stats(params, bar_length_max=20): | |||
total_param_dims, total_param_size = 0, 0 | |||
for d in params: | |||
total_param_dims += int(d["param_dim"]) | |||
total_param_size += int(d["size"]) | |||
d["size"] = sizeof_fmt(d["size"]) | |||
d["size_cum"] = sizeof_fmt(total_param_size) | |||
for d in params: | |||
ratio = d["param_dim"] / total_param_dims | |||
d["ratio"] = ratio | |||
d["percentage"] = "{:.2f}%".format(ratio * 100) | |||
# construct bar | |||
max_ratio = max([d["ratio"] for d in params]) | |||
for d in params: | |||
bar_length = int(d["ratio"] / max_ratio * bar_length_max) | |||
d["size_bar"] = "#" * bar_length | |||
param_size = sizeof_fmt(total_param_size) | |||
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | |||
header = [ | |||
"name", | |||
"shape", | |||
"mean", | |||
"std", | |||
"param_dim", | |||
"bits", | |||
"size", | |||
"size_cum", | |||
"percentage", | |||
"size_bar", | |||
] | |||
logger.info( | |||
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | |||
) | |||
return total_param_size | |||
def net_stats( | |||
model: m.Module, | |||
input_size: int, | |||
bar_length_max: int = 20, | |||
log_params: bool = True, | |||
log_flops: bool = True, | |||
): | |||
r""" | |||
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | |||
:param model: model that need to get stats info. | |||
:param input_size: size of input for running model and calculating stats. | |||
:param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
:param log_params: whether print and record params size. | |||
:param log_flops: whether print and record op flops. | |||
""" | |||
def get_byteswidth(tensor): | |||
if dtype.is_quantize(tensor.dtype): | |||
@@ -113,87 +212,6 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T | |||
else: | |||
return 4 | |||
def print_flops_stats(flops): | |||
flops_list = [i["flops_num"] for i in flops] | |||
max_flops_num = max(flops_list + [0]) | |||
# calc total flops and set flops_cum | |||
total_flops_num = 0 | |||
for d in flops: | |||
total_flops_num += int(d["flops_num"]) | |||
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | |||
for i in flops: | |||
f = i["flops_num"] | |||
i["flops"] = sizeof_fmt(f, suffix="OPs") | |||
r = i["ratio"] = f / total_flops_num | |||
i["percentage"] = "{:.2f}%".format(r * 100) | |||
bar_length = int(f / max_flops_num * bar_length_max) | |||
i["bar"] = "#" * bar_length | |||
header = [ | |||
"name", | |||
"class_name", | |||
"input_shapes", | |||
"output_shapes", | |||
"flops", | |||
"flops_cum", | |||
"percentage", | |||
"bar", | |||
] | |||
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||
total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops) | |||
flops.append( | |||
dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||
) | |||
logger.info( | |||
"flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)) | |||
) | |||
return total_flops_num | |||
def print_params_stats(params): | |||
total_param_dims, total_param_size = 0, 0 | |||
for d in params: | |||
total_param_dims += int(d["param_dim"]) | |||
total_param_size += int(d["size"]) | |||
d["size"] = sizeof_fmt(d["size"]) | |||
d["size_cum"] = sizeof_fmt(total_param_size) | |||
for d in params: | |||
ratio = d["param_dim"] / total_param_dims | |||
d["ratio"] = ratio | |||
d["percentage"] = "{:.2f}%".format(ratio * 100) | |||
# construct bar | |||
max_ratio = max([d["ratio"] for d in params]) | |||
for d in params: | |||
bar_length = int(d["ratio"] / max_ratio * bar_length_max) | |||
d["size_bar"] = "#" * bar_length | |||
param_size = sizeof_fmt(total_param_size) | |||
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | |||
header = [ | |||
"name", | |||
"shape", | |||
"mean", | |||
"std", | |||
"param_dim", | |||
"bits", | |||
"size", | |||
"size_cum", | |||
"percentage", | |||
"size_bar", | |||
] | |||
logger.info( | |||
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | |||
) | |||
return total_param_size | |||
def net_stats_hook(module, input, output, name=""): | |||
class_name = str(module.__class__).split(".")[-1].split("'")[0] | |||
@@ -273,8 +291,8 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T | |||
total_flops, total_params = 0, 0 | |||
if log_params: | |||
total_params = print_params_stats(params) | |||
total_params = print_params_stats(params, bar_length_max) | |||
if log_flops: | |||
total_flops = print_flops_stats(flops) | |||
total_flops = print_flops_stats(flops, bar_length_max) | |||
return total_params, total_flops |
@@ -19,9 +19,9 @@ from ..core._imperative_rt import ComputingGraph | |||
from ..core.tensor import megbrain_graph as G | |||
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | |||
from .network_node import ( | |||
NetworkNode, | |||
Host2DeviceCopy, | |||
ImmutableTensor, | |||
NetworkNode, | |||
OpNode, | |||
VarNode, | |||
str_to_mge_class, | |||
@@ -606,9 +606,7 @@ class NodeFilterType(NodeFilter): | |||
_node_type = None | |||
def __init__(self, node_iter, node_type): | |||
assert issubclass(node_type, NetworkNode), "bad opr type: {}".format( | |||
node_type | |||
) | |||
assert issubclass(node_type, NetworkNode), "bad opr type: {}".format(node_type) | |||
super().__init__(node_iter) | |||
self._node_type = node_type | |||
@@ -10,6 +10,8 @@ import json | |||
import sys | |||
from typing import Callable | |||
import numpy as np | |||
from ..core import _imperative_rt as rt | |||
from ..core._wrap import Device | |||
from ..core.ops import builtin | |||
@@ -52,7 +54,7 @@ class VarNode(NetworkNode): | |||
return self.var.dtype if self.var else None | |||
def set_owner_opr(self, owner_opr): | |||
self.owner_opr = owner_opr | |||
self.owner = owner_opr | |||
class OpNode(NetworkNode): | |||
@@ -223,6 +225,9 @@ class Elemwise(OpNode): | |||
type = "Elemwise" | |||
opdef = builtin.Elemwise | |||
def calc_flops(self): | |||
return np.prod(self.outputs[0].shape) | |||
class Reduce(OpNode): | |||
type = "Reduce" | |||
@@ -250,11 +255,21 @@ class MatrixMul(OpNode): | |||
type = "MatrixMul" | |||
opdef = builtin.MatrixMul | |||
def calc_flops(self): | |||
assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2 | |||
mid_shape = self.inputs[0].shape[1] | |||
return np.prod(self.outputs[0].shape) * mid_shape | |||
class BatchedMatrixMul(OpNode): | |||
type = "BatchedMatmul" | |||
opdef = builtin.BatchedMatrixMul | |||
def calc_flops(self): | |||
assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3 | |||
mid_shape = self.inputs[0].shape[2] | |||
return np.prod(self.outputs[0].shape) * mid_shape | |||
class Dot(OpNode): | |||
type = "Dot" | |||
@@ -270,6 +285,18 @@ class ConvolutionForward(OpNode): | |||
type = "Convolution" | |||
opdef = builtin.Convolution | |||
def calc_flops(self): | |||
param_W_shape = self.inputs[1].shape | |||
kh = param_W_shape[-2] | |||
kw = param_W_shape[-1] | |||
if len(param_W_shape) == 5: | |||
num_input = param_W_shape[2] | |||
else: | |||
num_input = param_W_shape[1] | |||
NCHW = np.prod(self.outputs[0].shape) | |||
# N x Cout x H x W x (Cin x Kw x Kh) | |||
return NCHW * (num_input * kw * kh) | |||
class ConvolutionBackwardData(OpNode): | |||
type = "ConvTranspose" | |||
@@ -316,6 +343,18 @@ class ConvBiasForward(OpNode): | |||
obj.params["dtype"] = opr.outputs[0].dtype | |||
return obj | |||
def calc_flops(self): | |||
param_W_shape = self.inputs[1].shape | |||
kh = param_W_shape[-2] | |||
kw = param_W_shape[-1] | |||
if len(param_W_shape) == 5: | |||
num_input = param_W_shape[2] | |||
else: | |||
num_input = param_W_shape[1] | |||
NCHW = np.prod(self.outputs[0].shape) | |||
# N x Cout x H x W x (Cin x Kw x Kh + bias) | |||
return NCHW * (num_input * kw * kh + 1) | |||
class BatchConvBiasForward(OpNode): | |||
type = "BatchConvBias" | |||
@@ -331,6 +370,7 @@ class BatchConvBiasForward(OpNode): | |||
class BatchNormForward(OpNode): | |||
type = "BatchNorm" | |||
opdef = builtin.BatchNorm | |||
output_idx = -1 | |||
class ROIAlignForward(OpNode): | |||
@@ -622,6 +662,9 @@ class ElemwiseMultiType(OpNode): | |||
obj.params["dtype"] = opr.outputs[0].dtype | |||
return obj | |||
def calc_flops(self): | |||
return np.prod(self.outputs[0].shape) | |||
class CvtColorForward(OpNode): | |||
type = "CvtColor" | |||
@@ -1,57 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import struct | |||
import numpy as np | |||
def load_tensor_binary(fobj): | |||
""" | |||
Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual | |||
tensor value dump is implemented by ``mgb::debug::dump_tensor``. | |||
Multiple values can be compared by ``tools/compare_binary_iodump.py``. | |||
:param fobj: file object, or a string that contains the file name. | |||
:return: tuple ``(tensor_value, tensor_name)``. | |||
""" | |||
if isinstance(fobj, str): | |||
with open(fobj, "rb") as fin: | |||
return load_tensor_binary(fin) | |||
DTYPE_LIST = { | |||
0: np.float32, | |||
1: np.uint8, | |||
2: np.int8, | |||
3: np.int16, | |||
4: np.int32, | |||
# 5: _mgb.intb1, | |||
# 6: _mgb.intb2, | |||
# 7: _mgb.intb4, | |||
8: None, | |||
9: np.float16, | |||
# quantized dtype start from 100000 | |||
# see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in | |||
# dnn/include/megdnn/dtype.h | |||
100000: np.uint8, | |||
100001: np.int32, | |||
100002: np.int8, | |||
} | |||
header_fmt = struct.Struct("III") | |||
name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size)) | |||
assert ( | |||
DTYPE_LIST[dtype] is not None | |||
), "Cannot load this tensor: dtype Byte is unsupported." | |||
shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4))) | |||
while shape[-1] == 0: | |||
shape.pop(-1) | |||
name = fobj.read(name_len).decode("ascii") | |||
return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name |