feat(imperative): add more tools for megenginerevert-411-add-tools
@@ -1,8 +1,156 @@ | |||
# MegEngine Tools | |||
This directory contains executable python files. | |||
Use these files in the following way (replace `xxx` to specific file name, like `network_visualize`): | |||
MegEngine 相关的工具汇总。使用方法如下(可将 `xxx` 替换成任一脚本文件,如 `network_visualize`): | |||
``` | |||
```bash | |||
python -m megengine.tools.xxx | |||
``` | |||
``` | |||
工具列表: | |||
### accuracy_shake_var_tree | |||
将精度抖动分析结果构造成树结构,方便锁定引起抖动的根节点,以及查找依赖关系。 | |||
输入: compare_binary_iodump 的输出存入到的一个文件 | |||
输出: 第一个出现结果不一致的输出结点 | |||
执行命令: accuracy_shake_var_tree 中定义了一些函数组件,可按需集成到实际代码中。下面有一个测试代码: | |||
```python | |||
import megengine.tools.accuracy_shake_var_tree as st | |||
r = st.parse('diff.txt') | |||
for key, value in r.items(): | |||
n = st.varNode.get_varNode(key) | |||
n.show_src_info() | |||
print("reference nodes:") | |||
for i in n.get_reference_list(): | |||
print(i.id) | |||
``` | |||
### benchmark_op | |||
逐个运行 functional op(并不是所有的 functional op),对比 MegEngine 与 PyTorch 的性能,通过量化结果来指导如何进行下一步的优化。 | |||
输入: 无 | |||
输出: 打印一个列表,对比在小输入和大输入的情况下 MegEngine 和 Pytorch 执行一些 functional op 的速度对比 | |||
执行命令: `python3 -m megengine.tools.benchmark_op` | |||
### compare_binary_iodump | |||
分析同一模型在不同平台下给定相同输入之后的输出是否完全一致。 | |||
输入: 两个目录(假设分别为 expect/ 和 actual/),分别存有不同平台下运行的 tensor 结果 | |||
输出: 打印所有的输出 tensor 信息,如果某个 tensor 在两个平台上的值不一致,那么会打印出第一个不一致的值 | |||
执行命令: `python3 -m megengine.tools.compare_binary_iodump expect/ actual/` | |||
### draw_graph | |||
用来查看静态图的 op 序列,有助于理解 MegEngine 的静态图在动态图的基础上做了哪些优化。 | |||
输入: `megengine.core.tensor.megbrain_graph.Graph._to_json` 得出的静态图描述文件,为 json 格式 | |||
输出: 一个 dot 文件,可通过 dot 命令绘制出图片 | |||
执行命令: | |||
```bash | |||
python3 -m megengine.tools.draw_graph -i dump.json -o dump.dot | |||
dot -Tpng dump.dot -o dump.png | |||
``` | |||
### dump_with_testcase_mge | |||
将待测数据提前注入模型文件,并在本地运行得到期望结果,可与实际运行的结果进行比对以检查是否出错。 | |||
输入: 一个 MegEngine 模型文件,可选一些 npy 文件作为模型输入(也可以随机生成输入,如下面的命令示例) | |||
输出: 一个带输入的 MegEngine 模型文件 | |||
执行命令: `python3 -m megengine.tools.dump_with_testcase_mge model.mge -d "#rand(0,255,14,2)"` | |||
### graph_info_analyze | |||
将图和内存信息的 json 文件的文件夹 logs 转换为 TensorBoard 的输入文件夹 logs_p。以便 TensorBoard 对图结构以及内存信息进行可视化。 | |||
输入: 图和内存信息的 json 文件的文件夹 | |||
输出: TensorBoard 的输入文件夹 | |||
执行命令: `python3 -m megengine.tools.graph_info_analyze -i logs -o logs_p` | |||
### load_network_and_run | |||
python 版本的 load_and_run。 | |||
输入: MegEngine 的模型文件,可选一些 npy 文件作为模型输入 | |||
输出: 模型执行并打印一些测速信息 | |||
执行命令: `python3 -m megengine.tools.load_network_and_run model.mge --iter 10` | |||
### network_visualize | |||
1. 分析给定的 MegEngine 模型中参数量信息,包括 shape、dtype、mean、std 以及 size 占比等。 | |||
2. 分析给定的 MegEngine 模型中算子 FLOPs 计算量以及占比,还有算子的 inputs\outputs shape、感受野、stride 等。 | |||
输入: MegEngine 的模型文件 | |||
输出: 模型中的参数量信息或计算量信息 | |||
执行命令: | |||
```bash | |||
# 分析参数量 | |||
python3 -m megengine.tools.network_visualize model.mge --cal_params --logging_to_stdout | |||
# 分析计算量 | |||
python3 -m megengine.tools.network_visualize model.mge --cal_flops --logging_to_stdout | |||
``` | |||
### profile_analyze | |||
对于 load_and_run --profile 运行模型生成的 profile.json 文件或者 trace 模式下开启 profiling 功能并通过 trace.get_profile() 得到的 json 文件进行分析,得到静态图中算子的时间和显存占比等信息,以表格形式呈现。 | |||
输入: load_and_run 生成的 profile 文件 | |||
输出: 一个按照参数在输入文件中筛选得出的数据表格 | |||
执行命令: | |||
```bash | |||
# 生成供分析的 json 文件 | |||
python3 -m megengine.tools.load_network_and_run model.mge --warm-up --iter 10 --profile profile.json | |||
#分析耗时前 3 的单个算子 | |||
python3 -m megengine.tools.profile_analyze profile.json -t 3 | |||
#筛选用时超过 10us 的 conv 按 flops 排序 | |||
python3 -m megengine.tools.profile_analyze profile.json -t 3 --order-by +flops --min-time 1e-5 --type ConvolutionForward | |||
``` | |||
### profiler | |||
对给定的训练程序,记录训练过程并以通用格式存储,可在浏览器上可视化。 | |||
输入: 需要一个 MegEngine 的训练程序(称之为 train.py,其中包含一个典型的 MegEngine 训练过程) | |||
输出: 一些记录 profile 过程的 json 文件,默认在 profile 子目录下,可用 https://ui.perfetto.dev/ 进行加载并且可视化 | |||
执行命令: `python3 -m megengine.tools.profiler train.py` | |||
### svg_viewer | |||
查看 MegEngine 生成的显存占用图,可以帮助用户了解显存使用情况. | |||
输入: 显存占用的 svg 图片 | |||
输出: 网页展示的可视化 | |||
执行命令: `python3 -m megengine.tools.svg_viewer` |
@@ -0,0 +1,151 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import time | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as MM | |||
import megengine.functional as MF | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as TF | |||
from tabulate import tabulate | |||
module_cache = { | |||
"conv2d": (MM.Conv2d(32, 32, 3, 1, 0), nn.Conv2d(32, 32, 3, 1, 0).cuda()), | |||
"dw_conv2d": (MM.Conv2d(32, 32, 3, 1, 0, groups=32), nn.Conv2d(32, 32, 3, 1, 0, groups=32).cuda()), | |||
"conv3d": (MM.Conv3d(32, 32, 3, 1, 0), nn.Conv3d(32, 32, 3, 1, 0).cuda()), | |||
"ConvTranspose2d": (MM.ConvTranspose2d(32, 32, 3, 1, 0), nn.ConvTranspose2d(32, 32, 3, 1, 0).cuda()), | |||
"BatchNorm2d": (MM.BatchNorm2d(64), nn.BatchNorm2d(64).cuda()), | |||
"Linear": (MM.Linear(1000, 1000), nn.Linear(1000, 1000).cuda()), | |||
} | |||
test_cases = [ | |||
# (mge op, torch op, small inps, large inps, unpack_inps, rep) | |||
("adaptive_avg_pool2d", lambda x: MF.adaptive_avg_pool2d(x, (7,7)), lambda x: TF.adaptive_avg_pool2d(x, (7,7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||
("adaptive_max_pool2d", lambda x: MF.adaptive_max_pool2d(x, (7,7)), lambda x: TF.adaptive_max_pool2d(x, (7,7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||
("argsort", MF.argsort, torch.argsort, [(1000,)], [(1000, 1000),], True, 1000), | |||
("avg_pool2d", lambda x: MF.avg_pool2d(x, 2), lambda x: TF.avg_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||
("broadcast", lambda x: MF.broadcast_to(x, (5,) + x.shape), lambda x: torch.broadcast_to(x, (5,)+x.shape), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("batchedmatmul", MF.matmul, torch.matmul, [(8, 64, 32), (8, 32, 64)], [(8, 2048, 512), (8, 512, 2048)], True, 1000), | |||
("batchnrom2d", lambda x: module_cache["BatchNorm2d"][0](x), lambda x: module_cache["BatchNorm2d"][1](x), [(2, 64, 16, 16)], [(64, 64, 128, 128)], True, 1000), | |||
("concat", MF.concat, torch.cat, [(20, 100), (50, 100), (30, 100)], [(64, 512, 16, 16), (64, 512, 16, 16), (64, 512, 16, 16)], False, 1000), | |||
("conv2d", lambda x: module_cache["conv2d"][0](x), lambda x: module_cache["conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), | |||
("conv3d", lambda x: module_cache["conv3d"][0](x), lambda x: module_cache["conv3d"][1](x), [(2, 32, 8, 8, 8)], [(32, 32, 16, 16, 16)], True, 1000), | |||
("convTranspose2d", lambda x: module_cache["ConvTranspose2d"][0](x), lambda x: module_cache["ConvTranspose2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), | |||
("dropout", lambda x: MF.dropout(x, 0.5), TF.dropout, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("dw_conv2d", lambda x: module_cache["dw_conv2d"][0](x), lambda x: module_cache["dw_conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), | |||
("elemwise.unary", MF.log, torch.log, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("elemwise.binary", MF.add, torch.add, [(100,100), (100,100)], [(64, 512, 16, 16), (64, 512, 16, 16)], True, 1000), | |||
("expand_dims", lambda x: MF.expand_dims(x, 0), lambda x: torch.unsqueeze(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("gelu", MF.gelu, TF.gelu, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("hswish", MF.hswish, TF.hardswish, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("hsigmoid", MF.hsigmoid, TF.hardsigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("isinf", MF.isinf, torch.isinf, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("indeixngMultiAxisVec", lambda x: x[[1,3,5], [1,3,5], [1,3,5], [1,3,5]], lambda x: x[[1,3,5], [1,3,5], [1,3,5], [1,3,5]], [(10,10,10,10)], [(64, 512, 16, 16)], True, 1000), | |||
("logsigmoid", MF.logsigmoid, TF.logsigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("leaky_relu", lambda x: MF.leaky_relu(x, 0.5), lambda x: TF.leaky_relu(x, 0.5), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("linear", lambda x: module_cache["Linear"][0](x), lambda x: module_cache["Linear"][1](x), [(10, 1000)], [(64, 128, 1000)], True, 1000), | |||
("matinv", MF.matinv, torch.inverse, [(10,10)], [(30, 30)], True, 1000), | |||
("matmul", MF.matmul, torch.matmul, [(64,32), (32, 64)], [(2048, 1024), (1024, 2048)], True, 1000), | |||
("max_pool2d", lambda x: MF.max_pool2d(x, 2), lambda x: TF.max_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||
("normal", lambda x: mge.random.normal(0,1, x.shape), lambda x: torch.randn(x.shape, device="cuda"), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("prelu", MF.prelu, TF.prelu, [(100,100), (1,)], [(64, 512, 16, 16), (1,)], True, 1000), | |||
("reduce.max", lambda x: MF.max(x, 0), lambda x: torch.max(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("relu", MF.relu, TF.relu, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("relu6", MF.relu6, TF.relu6, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("repeat", lambda x: MF.repeat(x, 5), lambda x: torch.repeat_interleave(x, 5), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("silu", MF.silu, TF.silu, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("split", lambda x: MF.split(x, 5), lambda x: torch.split(x, 5), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("sigmoid", MF.sigmoid, TF.sigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("softmax", lambda x: MF.softmax(x, axis=1), lambda x: TF.softmax(x, dim=1), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("softplus", MF.softplus, TF.softplus, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("squeeze", lambda x: MF.squeeze(x, 0), lambda x: torch.squeeze(x, 0), [(1, 100,100)], [(1, 64, 512, 16, 16)], True, 1000), | |||
("stack", MF.stack, torch.stack, [(100,100), (100,100)], [(64, 512, 16, 16), (64, 512, 16, 16)], False, 10000), | |||
("subtensor", lambda x: x[0:20, 10:60], lambda x: x[0:20, 10:60], [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("topk", lambda x: MF.topk(x, 10), lambda x: torch.topk(x, 10), [(100,100)], [(1000, 1000)], True, 1000), | |||
("tile", lambda x: MF.tile(x, (2,)*len(x.shape)), lambda x: torch.tile(x, (2,)*len(x.shape)), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("transpose", lambda x: MF.transpose(x, list(range(len(x.shape)))[::-1]), lambda x: torch.permute(x, list(range(len(x.shape)))[::-1]), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("where", lambda x: MF.where(x > 0.5, x, x), lambda x: torch.where(x > 0.5, x, x), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
("uniform", lambda x: mge.random.uniform(0,1, x.shape), lambda x: torch.rand(x.shape, device="cuda"), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
] | |||
def perf_func(func, inps, reps, unpack_inps, is_mge): | |||
if is_mge: | |||
mge._full_sync() | |||
tik = time.time() | |||
for _ in range(reps): | |||
if unpack_inps: | |||
out = func(*inps) | |||
else: | |||
out = func(inps) | |||
mge._full_sync() | |||
else: | |||
torch.cuda.synchronize() | |||
with torch.no_grad(): | |||
tik = time.time() | |||
for _ in range(reps): | |||
if unpack_inps: | |||
out = func(*inps) | |||
else: | |||
out = func(inps) | |||
torch.cuda.synchronize() | |||
return time.time() - tik | |||
def get_avg_time(func, inps, reps, unpack_inps, is_mge): | |||
# warm up | |||
for _ in range(2): | |||
t = perf_func(func, inps, reps, unpack_inps, is_mge) | |||
times = [] | |||
for _ in range(5): | |||
t = perf_func(func, inps, reps, unpack_inps, is_mge) | |||
times.append(t) | |||
return np.mean(times) | |||
def get_perf_results(mge_func, torch_func, shapes, unpack_inps, reps): | |||
inps = [ | |||
np.random.randn(*shape) for shape in shapes | |||
] | |||
inps_mge = [mge.tensor(inp, dtype="float32") for inp in inps] | |||
avg_time_mge = get_avg_time(mge_func, inps_mge, reps, unpack_inps, True) | |||
inps_torch = [torch.Tensor(inp).type(torch.float).cuda() for inp in inps] | |||
avg_time_torch = get_avg_time(torch_func, inps_torch, reps, unpack_inps, False) | |||
return avg_time_mge, avg_time_torch | |||
if __name__ == "__main__": | |||
header = ["opr_name", "time(mge/pytorch; small input)", "time(mge/pytorch; large input)"] | |||
table = [] | |||
for case in test_cases: | |||
assert len(case) == 7 | |||
name, mge_func, torch_func, small_shapes, large_shapes, unpack_inps, reps = case | |||
data = [] | |||
data.append(name) | |||
print("========== op: {}".format(name)) | |||
avg_time_mge, avg_time_torch = get_perf_results(mge_func, torch_func, small_shapes, unpack_inps, reps) | |||
print("mge time: {}".format(avg_time_mge)) | |||
print("torch time: {}".format(avg_time_torch)) | |||
data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) | |||
avg_time_mge, avg_time_torch = get_perf_results(mge_func, torch_func, large_shapes, unpack_inps, reps) | |||
print("mge time: {}".format(avg_time_mge)) | |||
print("torch time: {}".format(avg_time_torch)) | |||
data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) | |||
table.append(data) | |||
print(tabulate(table, header, tablefmt="github")) |
@@ -0,0 +1,535 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import argparse | |||
import os | |||
import re | |||
import struct | |||
import cv2 | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.core._imperative_rt as rt | |||
import megengine.core.tensor.megbrain_graph as G | |||
from megengine import tensor | |||
from megengine.core.ops import builtin | |||
from megengine.utils import comp_graph_tools as cgtools | |||
logger = mge.get_logger(__name__) | |||
def auto_reformat_image(args, path, data, dst_shape): | |||
"""reformat image to target shape | |||
:param data: image data as numpy array | |||
:param dst_shape: target shape | |||
""" | |||
dim3_format = False # required input format does not contain batch | |||
hwc_format = False # required input format is NHWC | |||
if not dst_shape: # input tensor shape is not predefined | |||
if len(data.shape) == 2: | |||
chl = 1 | |||
h = data.shape[0] | |||
w = data.shape[1] | |||
else: | |||
assert len(data.shape) == 3, "Input image must be of dimension 2 or 3" | |||
h, w, chl = data.shape | |||
dst_shape = (1, chl, h, w) | |||
if len(dst_shape) == 3: | |||
dst_shape = (1,) + dst_shape | |||
dim3_format = True | |||
assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) | |||
chl = dst_shape[1] | |||
if chl in [1, 3]: | |||
n, c, h, w = dst_shape | |||
dst_shape = (n, h, w, c) | |||
else: | |||
chl = dst_shape[3] | |||
assert chl in [1, 3], "can not infer input format from shape: {}".format( | |||
dst_shape | |||
) | |||
hwc_format = True | |||
# dst_shape has now been normalized to NHWC format | |||
if args.resize_input: | |||
h, w = dst_shape[1:3] | |||
data = cv2.resize(data, (w, h)) | |||
logger.info("input {} resized to {}".format(path, data.shape)) | |||
if chl == 1: | |||
data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) | |||
data = data[:, :, np.newaxis] | |||
assert data.ndim == 3 | |||
data = data[np.newaxis] | |||
# data normalized to NHWC format | |||
if not hwc_format: | |||
data = np.transpose(data, (0, 3, 1, 2)) | |||
if dim3_format: | |||
data = np.squeeze(data, 0) | |||
return data | |||
def read_input_data(args, dst_shape, dtype, path, repeat): | |||
def check_shape_equal(dst_shape, data_shape): | |||
if len(dst_shape): | |||
assert len(data_shape) == len( | |||
dst_shape | |||
), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape) | |||
if data_shape[1:] != dst_shape[1:]: | |||
logger.warning( | |||
"dst_shape is {}; data_shape is {}".format(dst_shape, data_shape) | |||
) | |||
if path.startswith("#"): | |||
assert not args.resize_input | |||
assert not args.input_transform | |||
spec = path | |||
m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec) | |||
assert m, "bad spec {}".format(spec) | |||
rng_min = float(m.group(1)) | |||
rng_max = float(m.group(2)) | |||
if m.group(3): | |||
shape_str = m.group(3) | |||
try: | |||
shape = shape_str[1:].split(",") | |||
if shape[-1].strip() == "...": | |||
shape = shape[:-1] | |||
shape.extend(list(dst_shape[len(shape) :])) | |||
data_shape = tuple(map(int, shape)) | |||
except ValueError as e: | |||
raise ValueError("bad spec {}: {}".format(spec, e.args)) | |||
else: | |||
data_shape = dst_shape | |||
check_shape_equal(dst_shape, data_shape) | |||
return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) | |||
# try to load image | |||
data = cv2.imread(path, cv2.IMREAD_COLOR) | |||
if data is None: | |||
assert not args.resize_input | |||
data = np.load(path) | |||
assert isinstance(data, np.ndarray) | |||
else: | |||
# load image succeeds, so we expect input format is image format | |||
data = auto_reformat_image(args, path, data, dst_shape) | |||
data = np.repeat(data, repeat, axis=0) | |||
if repeat > 1: | |||
logger.info( | |||
"repeat input for {} times, data shape is {}".format(repeat, data.shape) | |||
) | |||
check_shape_equal(dst_shape, data.shape) | |||
if args.input_transform: | |||
data = eval(args.input_transform, {"data": data, "np": np}) | |||
return data | |||
def gen_one_testcase(args, inputs, spec): | |||
paths = spec.split(";") | |||
if len(paths) != len(inputs): | |||
if len(paths) == 1 and paths[0].startswith("#"): | |||
paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] | |||
assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format( | |||
inputs.keys(), paths | |||
) | |||
if len(paths) == 1 and ":" not in paths[0]: | |||
paths[0] = next(iter(inputs.keys())) + ":" + paths[0] | |||
ret = {} | |||
for path in paths: | |||
var, path = path.split(":") | |||
if args.repeat: | |||
repeat = args.repeat | |||
else: | |||
repeat = 1 | |||
ret[var] = read_input_data( | |||
args, inputs[var].shape, inputs[var].dtype, path, repeat | |||
) | |||
return ret | |||
def make_feeds(args): | |||
ret = G.load_graph(args.input) | |||
cg_rt, outputs = ret.graph, ret.output_vars_list | |||
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | |||
inputs = {i.name: i for i in inputs} | |||
if not args.no_assert: | |||
replace_varmap = {} | |||
inp_map = {} | |||
# replace var use InputNode | |||
for name, var in inputs.items(): | |||
inp = G.InputNode( | |||
device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt | |||
) | |||
replace_varmap[var] = inp.outputs[0] | |||
inp_map[name] = inp | |||
new = cgtools.replace_vars(outputs, replace_varmap) | |||
if isinstance(new, rt.VarNode): | |||
new = list(new) | |||
output_nodes = [G.OutputNode(var) for var in new] | |||
func = cg_rt.compile([node.outputs[0] for node in output_nodes]) | |||
def make_dev_tensor(value, dtype=None, device=None): | |||
return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||
def calculate(*args, **kwargs): | |||
output_val = [] | |||
# set inputs value | |||
for name, var in inputs.items(): | |||
val = kwargs.pop(name, None) | |||
assert val is not None, "miss input name{}".format(name) | |||
dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") | |||
inp_map[name].set_value(dev_tensor) | |||
func.execute() | |||
for res in output_nodes: | |||
output_val.append(res.get_value().numpy()) | |||
return output_val | |||
def expect_name(var): | |||
return "{}:expect".format(var.name) | |||
testcases = [] | |||
np.set_printoptions(precision=2, threshold=4, suppress=True) | |||
data_list = [] | |||
for item in args.data: | |||
if item.startswith("@"): | |||
with open(item[1:], "r") as f: | |||
data_list.extend([line.rstrip() for line in f if line.rstrip() != ""]) | |||
else: | |||
data_list.append(item) | |||
for inp_spec in data_list: | |||
cur_testcase = gen_one_testcase(args, inputs, inp_spec) | |||
assert len(cur_testcase) == len( | |||
inputs | |||
), "required inputs: {}; given data: {}".format( | |||
inputs.keys(), cur_testcase.keys() | |||
) | |||
if not args.no_assert: | |||
outputs_get = calculate(**cur_testcase) | |||
for var, val in zip(outputs, outputs_get): | |||
cur_testcase[expect_name(var)] = val | |||
logger.info( | |||
"generate test groundtruth: var={} shape={} range=({}, {})" | |||
" mean={} var={}".format( | |||
var, val.shape, val.min(), val.max(), np.mean(val), np.var(val) | |||
) | |||
) | |||
testcases.append(cur_testcase) | |||
logger.info( | |||
"add testcase: \n {}".format( | |||
"\n ".join( | |||
"{}: shape={} dtype={} range=({:.2f},{:.2f}) " | |||
"mean={:.2f} sd={:.2f}".format( | |||
k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) | |||
) | |||
for k, v in sorted(cur_testcase.items()) | |||
) | |||
) | |||
) | |||
if not args.no_assert: | |||
def expect_shp(var): | |||
ret = var.shape | |||
if ret: | |||
return ret | |||
return testcases[0][expect_name(var)].shape | |||
def assert_equal(expect, real, **kwargs): | |||
op = builtin.AssertEqual(**kwargs) | |||
(res,) = G.apply_normal_varnode(op, expect, real) | |||
return res | |||
verbose = not args.silent | |||
outputs_new = [] | |||
for i in outputs: | |||
device = rt.CompNode("xpux") | |||
dtype = i.dtype | |||
name = expect_name(i) | |||
shape = expect_shp(i) | |||
# make expect output as one input of model. | |||
expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name) | |||
# insert assert opr to check expect and real. | |||
outputs_new.append( | |||
assert_equal( | |||
expect_get, | |||
i, | |||
verbose=verbose, | |||
maxerr=args.maxerr, | |||
) | |||
) | |||
inputs[expect_name(i)] = expect_get | |||
outputs = outputs_new | |||
return {"outputs": outputs, "testcases": testcases} | |||
def optimize_for_inference(args, outputs): | |||
args_list = [ | |||
"enable_io16xc32", | |||
"enable_ioc16", | |||
"enable_hwcd4", | |||
"enable_nchw4", | |||
"enable_nchw88", | |||
"enable_nchw44", | |||
"enable_nchw44_dot", | |||
"enable_nchw32", | |||
"enable_chwn4", | |||
"enable_fuse_conv_bias_nonlinearity", | |||
"enable_fuse_conv_bias_with_z", | |||
"enable_fuse_preprocess", | |||
] | |||
kwargs = {} | |||
for k in args_list: | |||
if getattr(args, k): | |||
kwargs[k] = True | |||
if args.optimize_for_inference: | |||
outputs = G.optimize_for_inference(outputs, **kwargs) | |||
return outputs | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description="Pack computing graph, input values and expected output " | |||
"values into one file for checking correctness. README.md gives more " | |||
"details on the usage", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument("input", help="MegEngine dumped model file") | |||
parser.add_argument("-o", "--output", help="output file", required=True) | |||
parser.add_argument( | |||
"-d", | |||
"--data", | |||
default=[], | |||
action="append", | |||
required=True, | |||
help="Given input test data when input file is a network, " | |||
"and current network output would be used as groundtruth. " | |||
"The format is var0:file0;var1:file1... to specify data files for " | |||
"input vars. It can also be #rand(min,max,shape...) for generating " | |||
"random input data, for example, #rand(0,255), " | |||
"#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means " | |||
"the remaining part of the original shape. " | |||
"If the shape is not specified, the shape of " | |||
"corresponding input tensors in the network will be used. " | |||
"If there is only one input var, its name can be omitted. " | |||
"Each data file can either be an image which can be loaded by opencv, " | |||
"or a pickled numpy.ndarray. " | |||
"This option can be given multiple times to add multiple testcases. " | |||
" *NOTE* " | |||
"If you start the data with the letter @, the rest should be a " | |||
"filename, and each line in the file should be a single datum in " | |||
"the format described above. ", | |||
) | |||
parser.add_argument( | |||
"--repeat", | |||
type=int, | |||
default=1, | |||
help="Specify how many times the input image is repeated. " | |||
"Useful when running benchmark for batch size other than one. " | |||
"Have no effect on randomly generated input data.", | |||
) | |||
parser.add_argument( | |||
"--silent", | |||
action="store_true", | |||
help="set verbose to False in asserti_equal opr", | |||
) | |||
parser.add_argument( | |||
"--optimize-for-inference", | |||
action="store_true", | |||
help="enable optimization for inference", | |||
) | |||
parser.add_argument( | |||
"--no-assert", | |||
action="store_true", | |||
help="do not insert assert_equal opr to check result; " | |||
"this option is useful for benchmarking", | |||
) | |||
parser.add_argument( | |||
"--maxerr", | |||
type=float, | |||
default=1e-4, | |||
help="max error for assert_equal check during runtime", | |||
) | |||
parser.add_argument( | |||
"--resize-input", | |||
action="store_true", | |||
help="resize input image to fit input var shape", | |||
) | |||
parser.add_argument( | |||
"--input-transform", | |||
help="a python expression to transform the input data. " | |||
"Example: data / np.std(data)", | |||
) | |||
parser.add_argument( | |||
"--discard-var-name", | |||
action="store_true", | |||
help="discard variable and param names in the " "generated output", | |||
) | |||
parser.add_argument( | |||
"--output-strip-info", action="store_true", help="output code strip information" | |||
) | |||
parser.add_argument( | |||
"--enable-io16xc32", | |||
action="store_true", | |||
help="transform the mode to float16 io float32 compute", | |||
) | |||
parser.add_argument( | |||
"--enable-ioc16", | |||
action="store_true", | |||
help="transform the dtype of the model to float16 io " "and compute", | |||
) | |||
parser.add_argument( | |||
"--enable-fuse-conv-bias-nonlinearity", | |||
action="store_true", | |||
help="fuse convolution bias and nonlinearity opr to a " | |||
"conv_bias opr and compute", | |||
) | |||
parser.add_argument( | |||
"--enable-hwcd4", | |||
action="store_true", | |||
help="transform the model format from NCHW to NHWCD4 " | |||
"for inference; you may need to disable CUDA and set " | |||
"MGB_USE_MEGDNN_DBG=2", | |||
) | |||
parser.add_argument( | |||
"--enable-nchw4", | |||
action="store_true", | |||
help="transform the model format from NCHW to NCHW4 " "for inference", | |||
) | |||
parser.add_argument( | |||
"--enable-nchw88", | |||
action="store_true", | |||
help="transform the model format from NCHW to NCHW88 " "for inference", | |||
) | |||
parser.add_argument( | |||
"--enable-nchw44", | |||
action="store_true", | |||
help="transform the model format from NCHW to NCHW44 " "for inference", | |||
) | |||
parser.add_argument( | |||
"--enable-nchw44-dot", | |||
action="store_true", | |||
help="transform the model format from NCHW to NCHW44_DOT " | |||
"for optimizing armv8.2 dot in inference", | |||
) | |||
parser.add_argument( | |||
"--enable-nchw32", | |||
action="store_true", | |||
help="transform the model format from NCHW4 to NCHW32 " | |||
"for inference on nvidia TensoCore", | |||
) | |||
parser.add_argument( | |||
"--enable-chwn4", | |||
action="store_true", | |||
help="transform the model format to CHWN4 " | |||
"for inference, mainly used for nvidia tensorcore", | |||
) | |||
parser.add_argument( | |||
"--enable-fuse-conv-bias-with-z", | |||
action="store_true", | |||
help="fuse conv_bias with z input for inference on " | |||
"nvidia GPU (this optimization pass will result in mismatch " | |||
"of the precision of output of training and inference)", | |||
) | |||
parser.add_argument( | |||
"--enable-fuse-preprocess", | |||
action="store_true", | |||
help="fuse astype\pad_channel\dimshuffle and etc opr " | |||
"from h2d opr", | |||
) | |||
args = parser.parse_args() | |||
feeds = make_feeds(args) | |||
assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty" | |||
output_mgbvars = feeds["outputs"] | |||
output_mgbvars = optimize_for_inference(args, output_mgbvars) | |||
inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") | |||
inputs = sorted((i.name, i.dtype) for i in inputs) | |||
if args.discard_var_name: | |||
sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) | |||
else: | |||
sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) | |||
strip_info_file = args.output + ".json" if args.output_strip_info else None | |||
with open(args.output, "wb") as fout: | |||
fout.write(b"mgbtest0") | |||
fout.write(struct.pack("I", len(feeds["testcases"]))) | |||
dump_content, stat = G.dump_graph( | |||
output_mgbvars, | |||
append_json=True, | |||
strip_info_file=strip_info_file, | |||
**sereg_kwargs, | |||
) | |||
fout.write(dump_content) | |||
logger.info( | |||
"graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format( | |||
stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 | |||
) | |||
) | |||
def make_dev_tensor(value, dtype=None, device=None): | |||
return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||
for testcase in feeds["testcases"]: | |||
assert isinstance(testcase, dict) | |||
cg = G.Graph() | |||
output_mgbvars = [] | |||
for name, dtype in inputs: | |||
output_mgbvars.append( | |||
cg.make_const( | |||
make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux") | |||
) | |||
) | |||
assert not testcase, "extra inputs provided in testcase: {}".format( | |||
testcase.keys() | |||
) | |||
with open(args.output, "ab") as fout: | |||
dump_content, _ = G.dump_graph( | |||
output_mgbvars, strip_info_file=strip_info_file, append_json=True | |||
) | |||
fout.write(dump_content) | |||
if __name__ == "__main__": | |||
main() | |||