Browse Source

feat(mge/experimental): add visualization and net stats for python graph

GitOrigin-RevId: a1ab77c20a
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
53075cd3da
9 changed files with 402 additions and 172 deletions
  1. +8
    -0
      imperative/python/megengine/tools/README.md
  2. +0
    -0
      imperative/python/megengine/tools/__init__.py
  3. +50
    -6
      imperative/python/megengine/tools/compare_binary_iodump.py
  4. +176
    -0
      imperative/python/megengine/tools/network_visualize.py
  5. +1
    -1
      imperative/python/megengine/tools/profile_analyze.py
  6. +121
    -103
      imperative/python/megengine/utils/module_stats.py
  7. +2
    -4
      imperative/python/megengine/utils/network.py
  8. +44
    -1
      imperative/python/megengine/utils/network_node.py
  9. +0
    -57
      imperative/python/megengine/utils/plugin.py

+ 8
- 0
imperative/python/megengine/tools/README.md View File

@@ -0,0 +1,8 @@
# MegEngine Tools

This directory contains executable python files.
Use these files in the following way (replace `xxx` to specific file name, like `network_visualize`):

```
python -m megengine.tools.xxx
```

+ 0
- 0
imperative/python/megengine/tools/__init__.py View File


imperative/python/megengine/utils/compare_binary_iodump.py → imperative/python/megengine/tools/compare_binary_iodump.py View File

@@ -1,3 +1,4 @@
#! /usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -7,12 +8,55 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse import argparse
import os import os
import struct
import textwrap import textwrap
from pathlib import Path from pathlib import Path


import numpy as np import numpy as np


from megengine.utils import plugin

def load_tensor_binary(fobj):
"""
Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual
tensor value dump is implemented by ``mgb::debug::dump_tensor``.

:param fobj: file object, or a string that contains the file name.
:return: tuple ``(tensor_value, tensor_name)``.
"""
if isinstance(fobj, str):
with open(fobj, "rb") as fin:
return load_tensor_binary(fin)

DTYPE_LIST = {
0: np.float32,
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
# 5: _mgb.intb1,
# 6: _mgb.intb2,
# 7: _mgb.intb4,
8: None,
9: np.float16,
# quantized dtype start from 100000
# see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in
# dnn/include/megdnn/dtype.h
100000: np.uint8,
100001: np.int32,
100002: np.int8,
}

header_fmt = struct.Struct("III")
name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size))
assert (
DTYPE_LIST[dtype] is not None
), "Cannot load this tensor: dtype Byte is unsupported."

shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4)))
while shape[-1] == 0:
shape.pop(-1)
name = fobj.read(name_len).decode("ascii")
return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name




def check(v0, v1, name, max_err): def check(v0, v1, name, max_err):
@@ -26,9 +70,9 @@ def check(v0, v1, name, max_err):
) )
vdiv = np.max([np.abs(v0), np.abs(v1), np.ones_like(v0)], axis=0) vdiv = np.max([np.abs(v0), np.abs(v1), np.ones_like(v0)], axis=0)
err = np.abs(v0 - v1) / vdiv err = np.abs(v0 - v1) / vdiv
check = err > max_err
if check.sum():
idx = tuple(i[0] for i in np.nonzero(check))
rst = err > max_err
if rst.sum():
idx = tuple(i[0] for i in np.nonzero(rst))
raise AssertionError( raise AssertionError(
"{} not equal: " "{} not equal: "
"shape={} nonequal_idx={} v0={} v1={} err={}".format( "shape={} nonequal_idx={} v0={} v1={} err={}".format(
@@ -79,8 +123,8 @@ def main():
files1 = sorted(files1) files1 = sorted(files1)


for i, j in zip(files0, files1): for i, j in zip(files0, files1):
val0, name0 = plugin.load_tensor_binary(i)
val1, name1 = plugin.load_tensor_binary(j)
val0, name0 = load_tensor_binary(i)
val1, name1 = load_tensor_binary(j)
name = "{}: \n{}\n{}\n".format( name = "{}: \n{}\n{}\n".format(
i, "\n ".join(textwrap.wrap(name0)), "\n ".join(textwrap.wrap(name1)) i, "\n ".join(textwrap.wrap(name0)), "\n ".join(textwrap.wrap(name1))
) )

+ 176
- 0
imperative/python/megengine/tools/network_visualize.py View File

@@ -0,0 +1,176 @@
#! /usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse

import numpy as np

from megengine.core.tensor.dtype import is_quantize
from megengine.logger import get_logger
from megengine.utils.module_stats import (
print_flops_stats,
print_params_stats,
sizeof_fmt,
)
from megengine.utils.network import Network

logger = get_logger(__name__)


def visualize(
model_path: str,
log_path: str,
bar_length_max: int = 20,
log_params: bool = True,
log_flops: bool = True,
):
r"""
Load megengine dumped model and visualize graph structure with tensorboard log files.
Can also record and print model's statistics like :func:`~.net_stats`

:param model_path: dir path for megengine dumped model.
:param log_path: dir path for tensorboard graph log.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
"""
try:
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.node_def_pb2 import NodeDef
from tensorboard.compat.proto.step_stats_pb2 import (
AllocatorMemoryUsed,
DeviceStepStats,
NodeExecStats,
StepStats,
)
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboard.compat.proto.versions_pb2 import VersionDef
from tensorboardX import SummaryWriter
except ImportError:
logger.error(
"TensorBoard and TensorboardX are required for visualize.", exc_info=True
)
return

graph = Network.load(model_path)
writer = SummaryWriter(log_path)

def process_name(name):
return name.replace(".", "/").encode(encoding="utf-8")

node_list = []
flops_list = []
params_list = []
for node in graph.all_oprs:
if hasattr(node, "output_idx"):
node_oup = node.outputs[node.output_idx]
else:
if len(node.outputs) != 1:
logger.warning(
"OpNode {} has more than one output and not has 'output_idx' attr.".format(
node
)
)
node_oup = node.outputs[0]

inp_list = [process_name(var.owner.name) for var in node.inputs]
attr = {
"_output_shapes": AttrValue(
list=AttrValue.ListValue(
shape=[
TensorShapeProto(
dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape]
)
]
)
),
}
if hasattr(node, "calc_flops"):
flops_num = node.calc_flops()
# add op flops attr
attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8"))
flops_list.append(
dict(
name=node.name,
class_name=node.type,
input_shapes=[i.shape for i in node.inputs],
output_shapes=[o.shape for o in node.outputs],
flops_num=flops_num,
flops_cum=0,
)
)
if node.type == "ImmutableTensor":
param_dim = np.prod(node_oup.shape)
# TODO: consider other quantize dtypes
param_bytes = 1 if is_quantize(node_oup.dtype) else 4
# add tensor size attr
attr["size"] = AttrValue(
s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8")
)
params_list.append(
dict(
name=node.name,
shape=node_oup.shape,
param_dim=param_dim,
bits=param_bytes * 8,
size=param_dim * param_bytes,
size_cum=0,
mean="{:.2g}".format(node.numpy().mean()),
std="{:.2g}".format(node.numpy().std()),
)
)
node_list.append(
NodeDef(
name=process_name(node.name), op=node.type, input=inp_list, attr=attr,
)
)

total_flops, total_params = 0, 0
if log_params:
total_params = print_params_stats(params_list, bar_length_max)
if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max)

graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))

device = "/device:CPU:0"
stepstats = RunMetadata(
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
)
writer._get_file_writer().add_graph((graph_def, stepstats))
return total_params, total_flops


def main():
parser = argparse.ArgumentParser(
description="load a megengine dumped model and export log file for tensorboard visualization.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("model_path", help="dumped model path.")
parser.add_argument("log_path", help="tensorboard log path.")
parser.add_argument(
"--bar_length_max",
type=int,
default=20,
help="size of bar indicating max flops or parameter size in net stats.",
)
parser.add_argument(
"--log_params",
action="store_true",
help="whether print and record params size.",
)
parser.add_argument(
"--log_flops", action="store_true", help="whether print and record op flops.",
)
visualize(**vars(parser.parse_args()))


if __name__ == "__main__":
main()

imperative/python/megengine/utils/profile_analyze.py → imperative/python/megengine/tools/profile_analyze.py View File

@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
#! /usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.

imperative/python/megengine/utils/net_stats.py → imperative/python/megengine/utils/module_stats.py View File

@@ -84,26 +84,125 @@ hook_modules = (
) )




def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=True):
def dict2table(list_of_dict, header):
table_data = [header]
for d in list_of_dict:
row = []
for h in header:
v = ""
if h in d:
v = d[h]
row.append(v)
table_data.append(row)
return table_data

def sizeof_fmt(num, suffix="B"):
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
if abs(num) < 1024.0:
return "{:3.3f} {}{}".format(num, unit, suffix)
num /= 1024.0
sign_str = "-" if num < 0 else ""
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix)
def dict2table(list_of_dict, header):
table_data = [header]
for d in list_of_dict:
row = []
for h in header:
v = ""
if h in d:
v = d[h]
row.append(v)
table_data.append(row)
return table_data


def sizeof_fmt(num, suffix="B"):
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
if abs(num) < 1024.0:
return "{:3.3f} {}{}".format(num, unit, suffix)
num /= 1024.0
sign_str = "-" if num < 0 else ""
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix)


def print_flops_stats(flops, bar_length_max=20):
flops_list = [i["flops_num"] for i in flops]
max_flops_num = max(flops_list + [0])
# calc total flops and set flops_cum
total_flops_num = 0
for d in flops:
total_flops_num += int(d["flops_num"])
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs")

for i in flops:
f = i["flops_num"]
i["flops"] = sizeof_fmt(f, suffix="OPs")
r = i["ratio"] = f / total_flops_num
i["percentage"] = "{:.2f}%".format(r * 100)
bar_length = int(f / max_flops_num * bar_length_max)
i["bar"] = "#" * bar_length

header = [
"name",
"class_name",
"input_shapes",
"output_shapes",
"flops",
"flops_cum",
"percentage",
"bar",
]

total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops)
flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
)

logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)))

return total_flops_num


def print_params_stats(params, bar_length_max=20):
total_param_dims, total_param_size = 0, 0
for d in params:
total_param_dims += int(d["param_dim"])
total_param_size += int(d["size"])
d["size"] = sizeof_fmt(d["size"])
d["size_cum"] = sizeof_fmt(total_param_size)

for d in params:
ratio = d["param_dim"] / total_param_dims
d["ratio"] = ratio
d["percentage"] = "{:.2f}%".format(ratio * 100)

# construct bar
max_ratio = max([d["ratio"] for d in params])
for d in params:
bar_length = int(d["ratio"] / max_ratio * bar_length_max)
d["size_bar"] = "#" * bar_length

param_size = sizeof_fmt(total_param_size)
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,))

header = [
"name",
"shape",
"mean",
"std",
"param_dim",
"bits",
"size",
"size_cum",
"percentage",
"size_bar",
]

logger.info(
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header))
)

return total_param_size


def net_stats(
model: m.Module,
input_size: int,
bar_length_max: int = 20,
log_params: bool = True,
log_flops: bool = True,
):
r"""
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.

:param model: model that need to get stats info.
:param input_size: size of input for running model and calculating stats.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
"""


def get_byteswidth(tensor): def get_byteswidth(tensor):
if dtype.is_quantize(tensor.dtype): if dtype.is_quantize(tensor.dtype):
@@ -113,87 +212,6 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T
else: else:
return 4 return 4


def print_flops_stats(flops):
flops_list = [i["flops_num"] for i in flops]
max_flops_num = max(flops_list + [0])
# calc total flops and set flops_cum
total_flops_num = 0
for d in flops:
total_flops_num += int(d["flops_num"])
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs")

for i in flops:
f = i["flops_num"]
i["flops"] = sizeof_fmt(f, suffix="OPs")
r = i["ratio"] = f / total_flops_num
i["percentage"] = "{:.2f}%".format(r * 100)
bar_length = int(f / max_flops_num * bar_length_max)
i["bar"] = "#" * bar_length

header = [
"name",
"class_name",
"input_shapes",
"output_shapes",
"flops",
"flops_cum",
"percentage",
"bar",
]

total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops)
flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
)

logger.info(
"flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))
)

return total_flops_num

def print_params_stats(params):
total_param_dims, total_param_size = 0, 0
for d in params:
total_param_dims += int(d["param_dim"])
total_param_size += int(d["size"])
d["size"] = sizeof_fmt(d["size"])
d["size_cum"] = sizeof_fmt(total_param_size)

for d in params:
ratio = d["param_dim"] / total_param_dims
d["ratio"] = ratio
d["percentage"] = "{:.2f}%".format(ratio * 100)

# construct bar
max_ratio = max([d["ratio"] for d in params])
for d in params:
bar_length = int(d["ratio"] / max_ratio * bar_length_max)
d["size_bar"] = "#" * bar_length

param_size = sizeof_fmt(total_param_size)
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,))

header = [
"name",
"shape",
"mean",
"std",
"param_dim",
"bits",
"size",
"size_cum",
"percentage",
"size_bar",
]

logger.info(
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header))
)

return total_param_size

def net_stats_hook(module, input, output, name=""): def net_stats_hook(module, input, output, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0] class_name = str(module.__class__).split(".")[-1].split("'")[0]


@@ -273,8 +291,8 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T


total_flops, total_params = 0, 0 total_flops, total_params = 0, 0
if log_params: if log_params:
total_params = print_params_stats(params)
total_params = print_params_stats(params, bar_length_max)
if log_flops: if log_flops:
total_flops = print_flops_stats(flops)
total_flops = print_flops_stats(flops, bar_length_max)


return total_params, total_flops return total_params, total_flops

+ 2
- 4
imperative/python/megengine/utils/network.py View File

@@ -19,9 +19,9 @@ from ..core._imperative_rt import ComputingGraph
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
from .network_node import ( from .network_node import (
NetworkNode,
Host2DeviceCopy, Host2DeviceCopy,
ImmutableTensor, ImmutableTensor,
NetworkNode,
OpNode, OpNode,
VarNode, VarNode,
str_to_mge_class, str_to_mge_class,
@@ -606,9 +606,7 @@ class NodeFilterType(NodeFilter):
_node_type = None _node_type = None


def __init__(self, node_iter, node_type): def __init__(self, node_iter, node_type):
assert issubclass(node_type, NetworkNode), "bad opr type: {}".format(
node_type
)
assert issubclass(node_type, NetworkNode), "bad opr type: {}".format(node_type)
super().__init__(node_iter) super().__init__(node_iter)
self._node_type = node_type self._node_type = node_type




+ 44
- 1
imperative/python/megengine/utils/network_node.py View File

@@ -10,6 +10,8 @@ import json
import sys import sys
from typing import Callable from typing import Callable


import numpy as np

from ..core import _imperative_rt as rt from ..core import _imperative_rt as rt
from ..core._wrap import Device from ..core._wrap import Device
from ..core.ops import builtin from ..core.ops import builtin
@@ -52,7 +54,7 @@ class VarNode(NetworkNode):
return self.var.dtype if self.var else None return self.var.dtype if self.var else None


def set_owner_opr(self, owner_opr): def set_owner_opr(self, owner_opr):
self.owner_opr = owner_opr
self.owner = owner_opr




class OpNode(NetworkNode): class OpNode(NetworkNode):
@@ -223,6 +225,9 @@ class Elemwise(OpNode):
type = "Elemwise" type = "Elemwise"
opdef = builtin.Elemwise opdef = builtin.Elemwise


def calc_flops(self):
return np.prod(self.outputs[0].shape)



class Reduce(OpNode): class Reduce(OpNode):
type = "Reduce" type = "Reduce"
@@ -250,11 +255,21 @@ class MatrixMul(OpNode):
type = "MatrixMul" type = "MatrixMul"
opdef = builtin.MatrixMul opdef = builtin.MatrixMul


def calc_flops(self):
assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2
mid_shape = self.inputs[0].shape[1]
return np.prod(self.outputs[0].shape) * mid_shape



class BatchedMatrixMul(OpNode): class BatchedMatrixMul(OpNode):
type = "BatchedMatmul" type = "BatchedMatmul"
opdef = builtin.BatchedMatrixMul opdef = builtin.BatchedMatrixMul


def calc_flops(self):
assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3
mid_shape = self.inputs[0].shape[2]
return np.prod(self.outputs[0].shape) * mid_shape



class Dot(OpNode): class Dot(OpNode):
type = "Dot" type = "Dot"
@@ -270,6 +285,18 @@ class ConvolutionForward(OpNode):
type = "Convolution" type = "Convolution"
opdef = builtin.Convolution opdef = builtin.Convolution


def calc_flops(self):
param_W_shape = self.inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(self.outputs[0].shape)
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh)



class ConvolutionBackwardData(OpNode): class ConvolutionBackwardData(OpNode):
type = "ConvTranspose" type = "ConvTranspose"
@@ -316,6 +343,18 @@ class ConvBiasForward(OpNode):
obj.params["dtype"] = opr.outputs[0].dtype obj.params["dtype"] = opr.outputs[0].dtype
return obj return obj


def calc_flops(self):
param_W_shape = self.inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(self.outputs[0].shape)
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return NCHW * (num_input * kw * kh + 1)



class BatchConvBiasForward(OpNode): class BatchConvBiasForward(OpNode):
type = "BatchConvBias" type = "BatchConvBias"
@@ -331,6 +370,7 @@ class BatchConvBiasForward(OpNode):
class BatchNormForward(OpNode): class BatchNormForward(OpNode):
type = "BatchNorm" type = "BatchNorm"
opdef = builtin.BatchNorm opdef = builtin.BatchNorm
output_idx = -1




class ROIAlignForward(OpNode): class ROIAlignForward(OpNode):
@@ -622,6 +662,9 @@ class ElemwiseMultiType(OpNode):
obj.params["dtype"] = opr.outputs[0].dtype obj.params["dtype"] = opr.outputs[0].dtype
return obj return obj


def calc_flops(self):
return np.prod(self.outputs[0].shape)



class CvtColorForward(OpNode): class CvtColorForward(OpNode):
type = "CvtColor" type = "CvtColor"


+ 0
- 57
imperative/python/megengine/utils/plugin.py View File

@@ -1,57 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import struct

import numpy as np


def load_tensor_binary(fobj):
"""
Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual
tensor value dump is implemented by ``mgb::debug::dump_tensor``.

Multiple values can be compared by ``tools/compare_binary_iodump.py``.

:param fobj: file object, or a string that contains the file name.
:return: tuple ``(tensor_value, tensor_name)``.
"""
if isinstance(fobj, str):
with open(fobj, "rb") as fin:
return load_tensor_binary(fin)

DTYPE_LIST = {
0: np.float32,
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
# 5: _mgb.intb1,
# 6: _mgb.intb2,
# 7: _mgb.intb4,
8: None,
9: np.float16,
# quantized dtype start from 100000
# see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in
# dnn/include/megdnn/dtype.h
100000: np.uint8,
100001: np.int32,
100002: np.int8,
}

header_fmt = struct.Struct("III")
name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size))
assert (
DTYPE_LIST[dtype] is not None
), "Cannot load this tensor: dtype Byte is unsupported."

shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4)))
while shape[-1] == 0:
shape.pop(-1)
name = fobj.read(name_len).decode("ascii")
return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name

Loading…
Cancel
Save