@@ -14,6 +14,8 @@ | |||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
#include "src/common/conv_bias.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "src/common/postprocess.h" | #include "src/common/postprocess.h" | ||||
namespace { | namespace { | ||||
@@ -31,17 +33,41 @@ namespace { | |||||
MEGDNN_MARK_USED_VAR(OW); \ | MEGDNN_MARK_USED_VAR(OW); \ | ||||
MEGDNN_MARK_USED_VAR(pack_oc_size) | MEGDNN_MARK_USED_VAR(pack_oc_size) | ||||
void to_handle_bias_and_nonlinear( | |||||
void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, | |||||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||||
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, | |||||
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW) { | |||||
auto handle = megdnn::inplace_cpu_handle(); | |||||
auto conv_dst_tensor_layout = megdnn::TensorLayout({N, OC, OH, OW}, dst_type); | |||||
auto conv_dst_tensor = megdnn::TensorND{conv_dst_ptr, conv_dst_tensor_layout}; | |||||
auto dst_tensor = megdnn::TensorND{dst_ptr, conv_dst_tensor_layout}; | |||||
auto bias_tensor_layout = conv_dst_tensor_layout; | |||||
if (megdnn::ConvBiasForward::BiasMode::BROADCAST_CHANNEL_BIAS == bias_mode) { | |||||
bias_tensor_layout = megdnn::TensorLayout({1, OC, 1, 1}, bias_type); | |||||
} else if (megdnn::ConvBiasForward::BiasMode::NO_BIAS == bias_mode) { | |||||
bias_tensor_layout = megdnn::TensorLayout({}, bias_type); | |||||
} | |||||
auto bias_tensor = | |||||
megdnn::TensorND{const_cast<void*>(bias_ptr), bias_tensor_layout}; | |||||
handle_bias_and_nonlinear( | |||||
handle.get(), nonlineMode, &conv_dst_tensor, &dst_tensor, &bias_tensor); | |||||
} | |||||
template < | template < | ||||
typename ctype, typename dtype = ctype, | typename ctype, typename dtype = ctype, | ||||
megdnn::PostprocessMode postprocess_mode = megdnn::PostprocessMode::FLOAT> | megdnn::PostprocessMode postprocess_mode = megdnn::PostprocessMode::FLOAT> | ||||
struct PostProcess { | struct PostProcess { | ||||
static void run( | static void run( | ||||
void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, | void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, | ||||
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||||
megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, | |||||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||||
POST_PROCESS_UNUSED_VAR(); | |||||
megdnn_throw("not impl PostProcess"); | |||||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||||
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, | |||||
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, | |||||
size_t pack_oc_size = 1) { | |||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
to_handle_bias_and_nonlinear( | |||||
conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type, | |||||
dst_type, N, OC, OH, OW); | |||||
} | } | ||||
}; | }; | ||||
@@ -49,11 +75,11 @@ template <typename ctype, typename dtype> | |||||
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | ||||
static void run( | static void run( | ||||
void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | ||||
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||||
megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, | |||||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||||
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, | |||||
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, | |||||
size_t pack_oc_size = 1) { | |||||
POST_PROCESS_UNUSED_VAR(); | POST_PROCESS_UNUSED_VAR(); | ||||
megdnn_throw("not impl PostProcess"); | |||||
} | } | ||||
}; | }; | ||||
@@ -61,11 +87,14 @@ template <typename opctype, typename opdtype> | |||||
struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | ||||
static void run( | static void run( | ||||
void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, | void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, | ||||
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||||
megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, | |||||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||||
POST_PROCESS_UNUSED_VAR(); | |||||
megdnn_throw("not impl PostProcess"); | |||||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||||
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, | |||||
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, | |||||
size_t pack_oc_size = 1) { | |||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
to_handle_bias_and_nonlinear( | |||||
conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type, | |||||
dst_type, N, OC, OH, OW); | |||||
} | } | ||||
}; | }; | ||||
@@ -73,11 +102,17 @@ template <typename ctype, typename dtype> | |||||
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | ||||
static void run( | static void run( | ||||
void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | ||||
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||||
megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, | |||||
size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||||
POST_PROCESS_UNUSED_VAR(); | |||||
megdnn_throw("not impl PostProcess"); | |||||
megdnn::ConvBiasForward::BiasMode bias_mode, | |||||
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, | |||||
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, | |||||
size_t pack_oc_size = 1) { | |||||
MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
if (bias_mode == megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||||
return; | |||||
} | |||||
to_handle_bias_and_nonlinear( | |||||
conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type, | |||||
dst_type, N, OC, OH, OW); | |||||
} | } | ||||
}; | }; | ||||
@@ -50,7 +50,8 @@ constexpr size_t BLOCK_LINE_SIZE_BYTES = 32; | |||||
//! ref U54-MC arch | //! ref U54-MC arch | ||||
constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; | constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; | ||||
#else | #else | ||||
#error "unknown megdnn arch" | |||||
//! for fallback, need keep same with MEGDNN_NAIVE | |||||
constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; | |||||
#endif | #endif | ||||
/** | /** | ||||
@@ -1,156 +1,8 @@ | |||||
# MegEngine Tools | # MegEngine Tools | ||||
MegEngine 相关的工具汇总。使用方法如下(可将 `xxx` 替换成任一脚本文件,如 `network_visualize`): | |||||
This directory contains executable python files. | |||||
Use these files in the following way (replace `xxx` to specific file name, like `network_visualize`): | |||||
```bash | |||||
python -m megengine.tools.xxx | |||||
``` | |||||
工具列表: | |||||
### accuracy_shake_var_tree | |||||
将精度抖动分析结果构造成树结构,方便锁定引起抖动的根节点,以及查找依赖关系。 | |||||
输入: compare_binary_iodump 的输出存入到的一个文件 | |||||
输出: 第一个出现结果不一致的输出结点 | |||||
执行命令: accuracy_shake_var_tree 中定义了一些函数组件,可按需集成到实际代码中。下面有一个测试代码: | |||||
```python | |||||
import megengine.tools.accuracy_shake_var_tree as st | |||||
r = st.parse('diff.txt') | |||||
for key, value in r.items(): | |||||
n = st.varNode.get_varNode(key) | |||||
n.show_src_info() | |||||
print("reference nodes:") | |||||
for i in n.get_reference_list(): | |||||
print(i.id) | |||||
``` | |||||
### benchmark_op | |||||
逐个运行 functional op(并不是所有的 functional op),对比 MegEngine 与 PyTorch 的性能,通过量化结果来指导如何进行下一步的优化。 | |||||
输入: 无 | |||||
输出: 打印一个列表,对比在小输入和大输入的情况下 MegEngine 和 Pytorch 执行一些 functional op 的速度对比 | |||||
执行命令: `python3 -m megengine.tools.benchmark_op` | |||||
### compare_binary_iodump | |||||
分析同一模型在不同平台下给定相同输入之后的输出是否完全一致。 | |||||
输入: 两个目录(假设分别为 expect/ 和 actual/),分别存有不同平台下运行的 tensor 结果 | |||||
输出: 打印所有的输出 tensor 信息,如果某个 tensor 在两个平台上的值不一致,那么会打印出第一个不一致的值 | |||||
执行命令: `python3 -m megengine.tools.compare_binary_iodump expect/ actual/` | |||||
### draw_graph | |||||
用来查看静态图的 op 序列,有助于理解 MegEngine 的静态图在动态图的基础上做了哪些优化。 | |||||
输入: `megengine.core.tensor.megbrain_graph.Graph._to_json` 得出的静态图描述文件,为 json 格式 | |||||
输出: 一个 dot 文件,可通过 dot 命令绘制出图片 | |||||
执行命令: | |||||
```bash | |||||
python3 -m megengine.tools.draw_graph -i dump.json -o dump.dot | |||||
dot -Tpng dump.dot -o dump.png | |||||
``` | |||||
### dump_with_testcase_mge | |||||
将待测数据提前注入模型文件,并在本地运行得到期望结果,可与实际运行的结果进行比对以检查是否出错。 | |||||
输入: 一个 MegEngine 模型文件,可选一些 npy 文件作为模型输入(也可以随机生成输入,如下面的命令示例) | |||||
输出: 一个带输入的 MegEngine 模型文件 | |||||
执行命令: `python3 -m megengine.tools.dump_with_testcase_mge model.mge -d "#rand(0,255,14,2)"` | |||||
### graph_info_analyze | |||||
将图和内存信息的 json 文件的文件夹 logs 转换为 TensorBoard 的输入文件夹 logs_p。以便 TensorBoard 对图结构以及内存信息进行可视化。 | |||||
输入: 图和内存信息的 json 文件的文件夹 | |||||
输出: TensorBoard 的输入文件夹 | |||||
执行命令: `python3 -m megengine.tools.graph_info_analyze -i logs -o logs_p` | |||||
### load_network_and_run | |||||
python 版本的 load_and_run。 | |||||
输入: MegEngine 的模型文件,可选一些 npy 文件作为模型输入 | |||||
输出: 模型执行并打印一些测速信息 | |||||
执行命令: `python3 -m megengine.tools.load_network_and_run model.mge --iter 10` | |||||
### network_visualize | |||||
1. 分析给定的 MegEngine 模型中参数量信息,包括 shape、dtype、mean、std 以及 size 占比等。 | |||||
2. 分析给定的 MegEngine 模型中算子 FLOPs 计算量以及占比,还有算子的 inputs\outputs shape、感受野、stride 等。 | |||||
输入: MegEngine 的模型文件 | |||||
输出: 模型中的参数量信息或计算量信息 | |||||
执行命令: | |||||
```bash | |||||
# 分析参数量 | |||||
python3 -m megengine.tools.network_visualize model.mge --cal_params --logging_to_stdout | |||||
# 分析计算量 | |||||
python3 -m megengine.tools.network_visualize model.mge --cal_flops --logging_to_stdout | |||||
``` | ``` | ||||
### profile_analyze | |||||
对于 load_and_run --profile 运行模型生成的 profile.json 文件或者 trace 模式下开启 profiling 功能并通过 trace.get_profile() 得到的 json 文件进行分析,得到静态图中算子的时间和显存占比等信息,以表格形式呈现。 | |||||
输入: load_and_run 生成的 profile 文件 | |||||
输出: 一个按照参数在输入文件中筛选得出的数据表格 | |||||
执行命令: | |||||
```bash | |||||
# 生成供分析的 json 文件 | |||||
python3 -m megengine.tools.load_network_and_run model.mge --warm-up --iter 10 --profile profile.json | |||||
#分析耗时前 3 的单个算子 | |||||
python3 -m megengine.tools.profile_analyze profile.json -t 3 | |||||
#筛选用时超过 10us 的 conv 按 flops 排序 | |||||
python3 -m megengine.tools.profile_analyze profile.json -t 3 --order-by +flops --min-time 1e-5 --type ConvolutionForward | |||||
``` | |||||
### profiler | |||||
对给定的训练程序,记录训练过程并以通用格式存储,可在浏览器上可视化。 | |||||
输入: 需要一个 MegEngine 的训练程序(称之为 train.py,其中包含一个典型的 MegEngine 训练过程) | |||||
输出: 一些记录 profile 过程的 json 文件,默认在 profile 子目录下,可用 https://ui.perfetto.dev/ 进行加载并且可视化 | |||||
执行命令: `python3 -m megengine.tools.profiler train.py` | |||||
### svg_viewer | |||||
查看 MegEngine 生成的显存占用图,可以帮助用户了解显存使用情况. | |||||
输入: 显存占用的 svg 图片 | |||||
输出: 网页展示的可视化 | |||||
执行命令: `python3 -m megengine.tools.svg_viewer` | |||||
python -m megengine.tools.xxx | |||||
``` |
@@ -1,151 +0,0 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import time | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.module as MM | |||||
import megengine.functional as MF | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as TF | |||||
from tabulate import tabulate | |||||
module_cache = { | |||||
"conv2d": (MM.Conv2d(32, 32, 3, 1, 0), nn.Conv2d(32, 32, 3, 1, 0).cuda()), | |||||
"dw_conv2d": (MM.Conv2d(32, 32, 3, 1, 0, groups=32), nn.Conv2d(32, 32, 3, 1, 0, groups=32).cuda()), | |||||
"conv3d": (MM.Conv3d(32, 32, 3, 1, 0), nn.Conv3d(32, 32, 3, 1, 0).cuda()), | |||||
"ConvTranspose2d": (MM.ConvTranspose2d(32, 32, 3, 1, 0), nn.ConvTranspose2d(32, 32, 3, 1, 0).cuda()), | |||||
"BatchNorm2d": (MM.BatchNorm2d(64), nn.BatchNorm2d(64).cuda()), | |||||
"Linear": (MM.Linear(1000, 1000), nn.Linear(1000, 1000).cuda()), | |||||
} | |||||
test_cases = [ | |||||
# (mge op, torch op, small inps, large inps, unpack_inps, rep) | |||||
("adaptive_avg_pool2d", lambda x: MF.adaptive_avg_pool2d(x, (7,7)), lambda x: TF.adaptive_avg_pool2d(x, (7,7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||||
("adaptive_max_pool2d", lambda x: MF.adaptive_max_pool2d(x, (7,7)), lambda x: TF.adaptive_max_pool2d(x, (7,7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||||
("argsort", MF.argsort, torch.argsort, [(1000,)], [(1000, 1000),], True, 1000), | |||||
("avg_pool2d", lambda x: MF.avg_pool2d(x, 2), lambda x: TF.avg_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||||
("broadcast", lambda x: MF.broadcast_to(x, (5,) + x.shape), lambda x: torch.broadcast_to(x, (5,)+x.shape), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("batchedmatmul", MF.matmul, torch.matmul, [(8, 64, 32), (8, 32, 64)], [(8, 2048, 512), (8, 512, 2048)], True, 1000), | |||||
("batchnrom2d", lambda x: module_cache["BatchNorm2d"][0](x), lambda x: module_cache["BatchNorm2d"][1](x), [(2, 64, 16, 16)], [(64, 64, 128, 128)], True, 1000), | |||||
("concat", MF.concat, torch.cat, [(20, 100), (50, 100), (30, 100)], [(64, 512, 16, 16), (64, 512, 16, 16), (64, 512, 16, 16)], False, 1000), | |||||
("conv2d", lambda x: module_cache["conv2d"][0](x), lambda x: module_cache["conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), | |||||
("conv3d", lambda x: module_cache["conv3d"][0](x), lambda x: module_cache["conv3d"][1](x), [(2, 32, 8, 8, 8)], [(32, 32, 16, 16, 16)], True, 1000), | |||||
("convTranspose2d", lambda x: module_cache["ConvTranspose2d"][0](x), lambda x: module_cache["ConvTranspose2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), | |||||
("dropout", lambda x: MF.dropout(x, 0.5), TF.dropout, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("dw_conv2d", lambda x: module_cache["dw_conv2d"][0](x), lambda x: module_cache["dw_conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), | |||||
("elemwise.unary", MF.log, torch.log, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("elemwise.binary", MF.add, torch.add, [(100,100), (100,100)], [(64, 512, 16, 16), (64, 512, 16, 16)], True, 1000), | |||||
("expand_dims", lambda x: MF.expand_dims(x, 0), lambda x: torch.unsqueeze(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("gelu", MF.gelu, TF.gelu, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("hswish", MF.hswish, TF.hardswish, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("hsigmoid", MF.hsigmoid, TF.hardsigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("isinf", MF.isinf, torch.isinf, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("indeixngMultiAxisVec", lambda x: x[[1,3,5], [1,3,5], [1,3,5], [1,3,5]], lambda x: x[[1,3,5], [1,3,5], [1,3,5], [1,3,5]], [(10,10,10,10)], [(64, 512, 16, 16)], True, 1000), | |||||
("logsigmoid", MF.logsigmoid, TF.logsigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("leaky_relu", lambda x: MF.leaky_relu(x, 0.5), lambda x: TF.leaky_relu(x, 0.5), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("linear", lambda x: module_cache["Linear"][0](x), lambda x: module_cache["Linear"][1](x), [(10, 1000)], [(64, 128, 1000)], True, 1000), | |||||
("matinv", MF.matinv, torch.inverse, [(10,10)], [(30, 30)], True, 1000), | |||||
("matmul", MF.matmul, torch.matmul, [(64,32), (32, 64)], [(2048, 1024), (1024, 2048)], True, 1000), | |||||
("max_pool2d", lambda x: MF.max_pool2d(x, 2), lambda x: TF.max_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||||
("normal", lambda x: mge.random.normal(0,1, x.shape), lambda x: torch.randn(x.shape, device="cuda"), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("prelu", MF.prelu, TF.prelu, [(100,100), (1,)], [(64, 512, 16, 16), (1,)], True, 1000), | |||||
("reduce.max", lambda x: MF.max(x, 0), lambda x: torch.max(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("relu", MF.relu, TF.relu, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("relu6", MF.relu6, TF.relu6, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("repeat", lambda x: MF.repeat(x, 5), lambda x: torch.repeat_interleave(x, 5), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("silu", MF.silu, TF.silu, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("split", lambda x: MF.split(x, 5), lambda x: torch.split(x, 5), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("sigmoid", MF.sigmoid, TF.sigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("softmax", lambda x: MF.softmax(x, axis=1), lambda x: TF.softmax(x, dim=1), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("softplus", MF.softplus, TF.softplus, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("squeeze", lambda x: MF.squeeze(x, 0), lambda x: torch.squeeze(x, 0), [(1, 100,100)], [(1, 64, 512, 16, 16)], True, 1000), | |||||
("stack", MF.stack, torch.stack, [(100,100), (100,100)], [(64, 512, 16, 16), (64, 512, 16, 16)], False, 10000), | |||||
("subtensor", lambda x: x[0:20, 10:60], lambda x: x[0:20, 10:60], [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("topk", lambda x: MF.topk(x, 10), lambda x: torch.topk(x, 10), [(100,100)], [(1000, 1000)], True, 1000), | |||||
("tile", lambda x: MF.tile(x, (2,)*len(x.shape)), lambda x: torch.tile(x, (2,)*len(x.shape)), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("transpose", lambda x: MF.transpose(x, list(range(len(x.shape)))[::-1]), lambda x: torch.permute(x, list(range(len(x.shape)))[::-1]), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("where", lambda x: MF.where(x > 0.5, x, x), lambda x: torch.where(x > 0.5, x, x), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
("uniform", lambda x: mge.random.uniform(0,1, x.shape), lambda x: torch.rand(x.shape, device="cuda"), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||||
] | |||||
def perf_func(func, inps, reps, unpack_inps, is_mge): | |||||
if is_mge: | |||||
mge._full_sync() | |||||
tik = time.time() | |||||
for _ in range(reps): | |||||
if unpack_inps: | |||||
out = func(*inps) | |||||
else: | |||||
out = func(inps) | |||||
mge._full_sync() | |||||
else: | |||||
torch.cuda.synchronize() | |||||
with torch.no_grad(): | |||||
tik = time.time() | |||||
for _ in range(reps): | |||||
if unpack_inps: | |||||
out = func(*inps) | |||||
else: | |||||
out = func(inps) | |||||
torch.cuda.synchronize() | |||||
return time.time() - tik | |||||
def get_avg_time(func, inps, reps, unpack_inps, is_mge): | |||||
# warm up | |||||
for _ in range(2): | |||||
t = perf_func(func, inps, reps, unpack_inps, is_mge) | |||||
times = [] | |||||
for _ in range(5): | |||||
t = perf_func(func, inps, reps, unpack_inps, is_mge) | |||||
times.append(t) | |||||
return np.mean(times) | |||||
def get_perf_results(mge_func, torch_func, shapes, unpack_inps, reps): | |||||
inps = [ | |||||
np.random.randn(*shape) for shape in shapes | |||||
] | |||||
inps_mge = [mge.tensor(inp, dtype="float32") for inp in inps] | |||||
avg_time_mge = get_avg_time(mge_func, inps_mge, reps, unpack_inps, True) | |||||
inps_torch = [torch.Tensor(inp).type(torch.float).cuda() for inp in inps] | |||||
avg_time_torch = get_avg_time(torch_func, inps_torch, reps, unpack_inps, False) | |||||
return avg_time_mge, avg_time_torch | |||||
if __name__ == "__main__": | |||||
header = ["opr_name", "time(mge/pytorch; small input)", "time(mge/pytorch; large input)"] | |||||
table = [] | |||||
for case in test_cases: | |||||
assert len(case) == 7 | |||||
name, mge_func, torch_func, small_shapes, large_shapes, unpack_inps, reps = case | |||||
data = [] | |||||
data.append(name) | |||||
print("========== op: {}".format(name)) | |||||
avg_time_mge, avg_time_torch = get_perf_results(mge_func, torch_func, small_shapes, unpack_inps, reps) | |||||
print("mge time: {}".format(avg_time_mge)) | |||||
print("torch time: {}".format(avg_time_torch)) | |||||
data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) | |||||
avg_time_mge, avg_time_torch = get_perf_results(mge_func, torch_func, large_shapes, unpack_inps, reps) | |||||
print("mge time: {}".format(avg_time_mge)) | |||||
print("torch time: {}".format(avg_time_torch)) | |||||
data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) | |||||
table.append(data) | |||||
print(tabulate(table, header, tablefmt="github")) |
@@ -1,535 +0,0 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import argparse | |||||
import os | |||||
import re | |||||
import struct | |||||
import cv2 | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.core._imperative_rt as rt | |||||
import megengine.core.tensor.megbrain_graph as G | |||||
from megengine import tensor | |||||
from megengine.core.ops import builtin | |||||
from megengine.utils import comp_graph_tools as cgtools | |||||
logger = mge.get_logger(__name__) | |||||
def auto_reformat_image(args, path, data, dst_shape): | |||||
"""reformat image to target shape | |||||
:param data: image data as numpy array | |||||
:param dst_shape: target shape | |||||
""" | |||||
dim3_format = False # required input format does not contain batch | |||||
hwc_format = False # required input format is NHWC | |||||
if not dst_shape: # input tensor shape is not predefined | |||||
if len(data.shape) == 2: | |||||
chl = 1 | |||||
h = data.shape[0] | |||||
w = data.shape[1] | |||||
else: | |||||
assert len(data.shape) == 3, "Input image must be of dimension 2 or 3" | |||||
h, w, chl = data.shape | |||||
dst_shape = (1, chl, h, w) | |||||
if len(dst_shape) == 3: | |||||
dst_shape = (1,) + dst_shape | |||||
dim3_format = True | |||||
assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) | |||||
chl = dst_shape[1] | |||||
if chl in [1, 3]: | |||||
n, c, h, w = dst_shape | |||||
dst_shape = (n, h, w, c) | |||||
else: | |||||
chl = dst_shape[3] | |||||
assert chl in [1, 3], "can not infer input format from shape: {}".format( | |||||
dst_shape | |||||
) | |||||
hwc_format = True | |||||
# dst_shape has now been normalized to NHWC format | |||||
if args.resize_input: | |||||
h, w = dst_shape[1:3] | |||||
data = cv2.resize(data, (w, h)) | |||||
logger.info("input {} resized to {}".format(path, data.shape)) | |||||
if chl == 1: | |||||
data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) | |||||
data = data[:, :, np.newaxis] | |||||
assert data.ndim == 3 | |||||
data = data[np.newaxis] | |||||
# data normalized to NHWC format | |||||
if not hwc_format: | |||||
data = np.transpose(data, (0, 3, 1, 2)) | |||||
if dim3_format: | |||||
data = np.squeeze(data, 0) | |||||
return data | |||||
def read_input_data(args, dst_shape, dtype, path, repeat): | |||||
def check_shape_equal(dst_shape, data_shape): | |||||
if len(dst_shape): | |||||
assert len(data_shape) == len( | |||||
dst_shape | |||||
), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape) | |||||
if data_shape[1:] != dst_shape[1:]: | |||||
logger.warning( | |||||
"dst_shape is {}; data_shape is {}".format(dst_shape, data_shape) | |||||
) | |||||
if path.startswith("#"): | |||||
assert not args.resize_input | |||||
assert not args.input_transform | |||||
spec = path | |||||
m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec) | |||||
assert m, "bad spec {}".format(spec) | |||||
rng_min = float(m.group(1)) | |||||
rng_max = float(m.group(2)) | |||||
if m.group(3): | |||||
shape_str = m.group(3) | |||||
try: | |||||
shape = shape_str[1:].split(",") | |||||
if shape[-1].strip() == "...": | |||||
shape = shape[:-1] | |||||
shape.extend(list(dst_shape[len(shape) :])) | |||||
data_shape = tuple(map(int, shape)) | |||||
except ValueError as e: | |||||
raise ValueError("bad spec {}: {}".format(spec, e.args)) | |||||
else: | |||||
data_shape = dst_shape | |||||
check_shape_equal(dst_shape, data_shape) | |||||
return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) | |||||
# try to load image | |||||
data = cv2.imread(path, cv2.IMREAD_COLOR) | |||||
if data is None: | |||||
assert not args.resize_input | |||||
data = np.load(path) | |||||
assert isinstance(data, np.ndarray) | |||||
else: | |||||
# load image succeeds, so we expect input format is image format | |||||
data = auto_reformat_image(args, path, data, dst_shape) | |||||
data = np.repeat(data, repeat, axis=0) | |||||
if repeat > 1: | |||||
logger.info( | |||||
"repeat input for {} times, data shape is {}".format(repeat, data.shape) | |||||
) | |||||
check_shape_equal(dst_shape, data.shape) | |||||
if args.input_transform: | |||||
data = eval(args.input_transform, {"data": data, "np": np}) | |||||
return data | |||||
def gen_one_testcase(args, inputs, spec): | |||||
paths = spec.split(";") | |||||
if len(paths) != len(inputs): | |||||
if len(paths) == 1 and paths[0].startswith("#"): | |||||
paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] | |||||
assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format( | |||||
inputs.keys(), paths | |||||
) | |||||
if len(paths) == 1 and ":" not in paths[0]: | |||||
paths[0] = next(iter(inputs.keys())) + ":" + paths[0] | |||||
ret = {} | |||||
for path in paths: | |||||
var, path = path.split(":") | |||||
if args.repeat: | |||||
repeat = args.repeat | |||||
else: | |||||
repeat = 1 | |||||
ret[var] = read_input_data( | |||||
args, inputs[var].shape, inputs[var].dtype, path, repeat | |||||
) | |||||
return ret | |||||
def make_feeds(args): | |||||
ret = G.load_graph(args.input) | |||||
cg_rt, outputs = ret.graph, ret.output_vars_list | |||||
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | |||||
inputs = {i.name: i for i in inputs} | |||||
if not args.no_assert: | |||||
replace_varmap = {} | |||||
inp_map = {} | |||||
# replace var use InputNode | |||||
for name, var in inputs.items(): | |||||
inp = G.InputNode( | |||||
device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt | |||||
) | |||||
replace_varmap[var] = inp.outputs[0] | |||||
inp_map[name] = inp | |||||
new = cgtools.replace_vars(outputs, replace_varmap) | |||||
if isinstance(new, rt.VarNode): | |||||
new = list(new) | |||||
output_nodes = [G.OutputNode(var) for var in new] | |||||
func = cg_rt.compile([node.outputs[0] for node in output_nodes]) | |||||
def make_dev_tensor(value, dtype=None, device=None): | |||||
return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
def calculate(*args, **kwargs): | |||||
output_val = [] | |||||
# set inputs value | |||||
for name, var in inputs.items(): | |||||
val = kwargs.pop(name, None) | |||||
assert val is not None, "miss input name{}".format(name) | |||||
dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") | |||||
inp_map[name].set_value(dev_tensor) | |||||
func.execute() | |||||
for res in output_nodes: | |||||
output_val.append(res.get_value().numpy()) | |||||
return output_val | |||||
def expect_name(var): | |||||
return "{}:expect".format(var.name) | |||||
testcases = [] | |||||
np.set_printoptions(precision=2, threshold=4, suppress=True) | |||||
data_list = [] | |||||
for item in args.data: | |||||
if item.startswith("@"): | |||||
with open(item[1:], "r") as f: | |||||
data_list.extend([line.rstrip() for line in f if line.rstrip() != ""]) | |||||
else: | |||||
data_list.append(item) | |||||
for inp_spec in data_list: | |||||
cur_testcase = gen_one_testcase(args, inputs, inp_spec) | |||||
assert len(cur_testcase) == len( | |||||
inputs | |||||
), "required inputs: {}; given data: {}".format( | |||||
inputs.keys(), cur_testcase.keys() | |||||
) | |||||
if not args.no_assert: | |||||
outputs_get = calculate(**cur_testcase) | |||||
for var, val in zip(outputs, outputs_get): | |||||
cur_testcase[expect_name(var)] = val | |||||
logger.info( | |||||
"generate test groundtruth: var={} shape={} range=({}, {})" | |||||
" mean={} var={}".format( | |||||
var, val.shape, val.min(), val.max(), np.mean(val), np.var(val) | |||||
) | |||||
) | |||||
testcases.append(cur_testcase) | |||||
logger.info( | |||||
"add testcase: \n {}".format( | |||||
"\n ".join( | |||||
"{}: shape={} dtype={} range=({:.2f},{:.2f}) " | |||||
"mean={:.2f} sd={:.2f}".format( | |||||
k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) | |||||
) | |||||
for k, v in sorted(cur_testcase.items()) | |||||
) | |||||
) | |||||
) | |||||
if not args.no_assert: | |||||
def expect_shp(var): | |||||
ret = var.shape | |||||
if ret: | |||||
return ret | |||||
return testcases[0][expect_name(var)].shape | |||||
def assert_equal(expect, real, **kwargs): | |||||
op = builtin.AssertEqual(**kwargs) | |||||
(res,) = G.apply_normal_varnode(op, expect, real) | |||||
return res | |||||
verbose = not args.silent | |||||
outputs_new = [] | |||||
for i in outputs: | |||||
device = rt.CompNode("xpux") | |||||
dtype = i.dtype | |||||
name = expect_name(i) | |||||
shape = expect_shp(i) | |||||
# make expect output as one input of model. | |||||
expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name) | |||||
# insert assert opr to check expect and real. | |||||
outputs_new.append( | |||||
assert_equal( | |||||
expect_get, | |||||
i, | |||||
verbose=verbose, | |||||
maxerr=args.maxerr, | |||||
) | |||||
) | |||||
inputs[expect_name(i)] = expect_get | |||||
outputs = outputs_new | |||||
return {"outputs": outputs, "testcases": testcases} | |||||
def optimize_for_inference(args, outputs): | |||||
args_list = [ | |||||
"enable_io16xc32", | |||||
"enable_ioc16", | |||||
"enable_hwcd4", | |||||
"enable_nchw4", | |||||
"enable_nchw88", | |||||
"enable_nchw44", | |||||
"enable_nchw44_dot", | |||||
"enable_nchw32", | |||||
"enable_chwn4", | |||||
"enable_fuse_conv_bias_nonlinearity", | |||||
"enable_fuse_conv_bias_with_z", | |||||
"enable_fuse_preprocess", | |||||
] | |||||
kwargs = {} | |||||
for k in args_list: | |||||
if getattr(args, k): | |||||
kwargs[k] = True | |||||
if args.optimize_for_inference: | |||||
outputs = G.optimize_for_inference(outputs, **kwargs) | |||||
return outputs | |||||
def main(): | |||||
parser = argparse.ArgumentParser( | |||||
description="Pack computing graph, input values and expected output " | |||||
"values into one file for checking correctness. README.md gives more " | |||||
"details on the usage", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument("input", help="MegEngine dumped model file") | |||||
parser.add_argument("-o", "--output", help="output file", required=True) | |||||
parser.add_argument( | |||||
"-d", | |||||
"--data", | |||||
default=[], | |||||
action="append", | |||||
required=True, | |||||
help="Given input test data when input file is a network, " | |||||
"and current network output would be used as groundtruth. " | |||||
"The format is var0:file0;var1:file1... to specify data files for " | |||||
"input vars. It can also be #rand(min,max,shape...) for generating " | |||||
"random input data, for example, #rand(0,255), " | |||||
"#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means " | |||||
"the remaining part of the original shape. " | |||||
"If the shape is not specified, the shape of " | |||||
"corresponding input tensors in the network will be used. " | |||||
"If there is only one input var, its name can be omitted. " | |||||
"Each data file can either be an image which can be loaded by opencv, " | |||||
"or a pickled numpy.ndarray. " | |||||
"This option can be given multiple times to add multiple testcases. " | |||||
" *NOTE* " | |||||
"If you start the data with the letter @, the rest should be a " | |||||
"filename, and each line in the file should be a single datum in " | |||||
"the format described above. ", | |||||
) | |||||
parser.add_argument( | |||||
"--repeat", | |||||
type=int, | |||||
default=1, | |||||
help="Specify how many times the input image is repeated. " | |||||
"Useful when running benchmark for batch size other than one. " | |||||
"Have no effect on randomly generated input data.", | |||||
) | |||||
parser.add_argument( | |||||
"--silent", | |||||
action="store_true", | |||||
help="set verbose to False in asserti_equal opr", | |||||
) | |||||
parser.add_argument( | |||||
"--optimize-for-inference", | |||||
action="store_true", | |||||
help="enable optimization for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--no-assert", | |||||
action="store_true", | |||||
help="do not insert assert_equal opr to check result; " | |||||
"this option is useful for benchmarking", | |||||
) | |||||
parser.add_argument( | |||||
"--maxerr", | |||||
type=float, | |||||
default=1e-4, | |||||
help="max error for assert_equal check during runtime", | |||||
) | |||||
parser.add_argument( | |||||
"--resize-input", | |||||
action="store_true", | |||||
help="resize input image to fit input var shape", | |||||
) | |||||
parser.add_argument( | |||||
"--input-transform", | |||||
help="a python expression to transform the input data. " | |||||
"Example: data / np.std(data)", | |||||
) | |||||
parser.add_argument( | |||||
"--discard-var-name", | |||||
action="store_true", | |||||
help="discard variable and param names in the " "generated output", | |||||
) | |||||
parser.add_argument( | |||||
"--output-strip-info", action="store_true", help="output code strip information" | |||||
) | |||||
parser.add_argument( | |||||
"--enable-io16xc32", | |||||
action="store_true", | |||||
help="transform the mode to float16 io float32 compute", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-ioc16", | |||||
action="store_true", | |||||
help="transform the dtype of the model to float16 io " "and compute", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-fuse-conv-bias-nonlinearity", | |||||
action="store_true", | |||||
help="fuse convolution bias and nonlinearity opr to a " | |||||
"conv_bias opr and compute", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-hwcd4", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NHWCD4 " | |||||
"for inference; you may need to disable CUDA and set " | |||||
"MGB_USE_MEGDNN_DBG=2", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw4", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW4 " "for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw88", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW88 " "for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw44", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW44 " "for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw44-dot", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW44_DOT " | |||||
"for optimizing armv8.2 dot in inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw32", | |||||
action="store_true", | |||||
help="transform the model format from NCHW4 to NCHW32 " | |||||
"for inference on nvidia TensoCore", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-chwn4", | |||||
action="store_true", | |||||
help="transform the model format to CHWN4 " | |||||
"for inference, mainly used for nvidia tensorcore", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-fuse-conv-bias-with-z", | |||||
action="store_true", | |||||
help="fuse conv_bias with z input for inference on " | |||||
"nvidia GPU (this optimization pass will result in mismatch " | |||||
"of the precision of output of training and inference)", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-fuse-preprocess", | |||||
action="store_true", | |||||
help="fuse astype\pad_channel\dimshuffle and etc opr " | |||||
"from h2d opr", | |||||
) | |||||
args = parser.parse_args() | |||||
feeds = make_feeds(args) | |||||
assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty" | |||||
output_mgbvars = feeds["outputs"] | |||||
output_mgbvars = optimize_for_inference(args, output_mgbvars) | |||||
inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") | |||||
inputs = sorted((i.name, i.dtype) for i in inputs) | |||||
if args.discard_var_name: | |||||
sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) | |||||
else: | |||||
sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) | |||||
strip_info_file = args.output + ".json" if args.output_strip_info else None | |||||
with open(args.output, "wb") as fout: | |||||
fout.write(b"mgbtest0") | |||||
fout.write(struct.pack("I", len(feeds["testcases"]))) | |||||
dump_content, stat = G.dump_graph( | |||||
output_mgbvars, | |||||
append_json=True, | |||||
strip_info_file=strip_info_file, | |||||
**sereg_kwargs, | |||||
) | |||||
fout.write(dump_content) | |||||
logger.info( | |||||
"graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format( | |||||
stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 | |||||
) | |||||
) | |||||
def make_dev_tensor(value, dtype=None, device=None): | |||||
return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
for testcase in feeds["testcases"]: | |||||
assert isinstance(testcase, dict) | |||||
cg = G.Graph() | |||||
output_mgbvars = [] | |||||
for name, dtype in inputs: | |||||
output_mgbvars.append( | |||||
cg.make_const( | |||||
make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux") | |||||
) | |||||
) | |||||
assert not testcase, "extra inputs provided in testcase: {}".format( | |||||
testcase.keys() | |||||
) | |||||
with open(args.output, "ab") as fout: | |||||
dump_content, _ = G.dump_graph( | |||||
output_mgbvars, strip_info_file=strip_info_file, append_json=True | |||||
) | |||||
fout.write(dump_content) | |||||
if __name__ == "__main__": | |||||
main() | |||||
@@ -0,0 +1,152 @@ | |||||
#!/usr/bin/env python3 | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
""" | |||||
purpose: Used to simply measure CPU performance by running several basic models. | |||||
how to use: python3 cpu_evaluation_tools.py --help for more details, now need to args: | |||||
--load_and_run_file: path of load_and_run binary, please refs to ../scripts/cmake-build/BUILD_README.md to build it. | |||||
--models_dir: path of model directory. | |||||
how to config test device info: config device[name/login_name/ip/port/thread_number]. | |||||
""" | |||||
import argparse | |||||
import logging | |||||
import os | |||||
import re | |||||
import subprocess | |||||
# test device | |||||
device = { | |||||
"name": "hwmt40p", | |||||
"login_name": "hwmt40p-K9000-maliG78", | |||||
"ip": "box86.br.megvii-inc.com", | |||||
"port": 2200, | |||||
"thread_number": 3, | |||||
} | |||||
# test models | |||||
test_cpu_models = [ | |||||
"inceptionv2", | |||||
"mobilenetv1", | |||||
"mobilenetv2", | |||||
"resnet18", | |||||
"resnet50", | |||||
"shufflenetv2", | |||||
"vgg16", | |||||
] | |||||
class SshConnector: | |||||
"""imp ssh control master connector""" | |||||
ip = None | |||||
port = None | |||||
login_name = None | |||||
def setup(self, login_name, ip, port): | |||||
self.ip = ip | |||||
self.login_name = login_name | |||||
self.port = port | |||||
def copy(self, src_list, dst_dir): | |||||
assert isinstance(src_list, list), "code issue happened!!" | |||||
assert isinstance(dst_dir, str), "code issue happened!!" | |||||
for src in src_list: | |||||
cmd = 'rsync --progress -a -e "ssh -p {}" {} {}@{}:{}'.format( | |||||
self.port, src, self.login_name, self.ip, dst_dir | |||||
) | |||||
logging.debug("ssh run cmd: {}".format(cmd)) | |||||
subprocess.check_call(cmd, shell=True) | |||||
def cmd(self, cmd): | |||||
output = "" | |||||
assert isinstance(cmd, list), "code issue happened!!" | |||||
for sub_cmd in cmd: | |||||
p_cmd = 'ssh -p {} {}@{} "{}" '.format( | |||||
self.port, self.login_name, self.ip, sub_cmd | |||||
) | |||||
logging.debug("ssh run cmd: {}".format(p_cmd)) | |||||
output = output + subprocess.check_output(p_cmd, shell=True).decode("utf-8") | |||||
return output | |||||
def get_finally_bench_resulut_from_log(raw_log) -> float: | |||||
# raw_log --> avg_time=23.331ms -->23.331ms | |||||
h = re.findall(r"avg_time=.*ms ", raw_log)[-1][9:] | |||||
# to 23.331 | |||||
h = h[: h.find("ms")] | |||||
# to float | |||||
h = float(h) | |||||
return h | |||||
def main(): | |||||
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) | |||||
parser.add_argument("--models_dir", help="models dir", required=True) | |||||
parser.add_argument( | |||||
"--load_and_run_file", help="path for load_and_run", required=True | |||||
) | |||||
args = parser.parse_args() | |||||
assert os.path.isdir( | |||||
args.models_dir | |||||
), "invalid args for models_dir, need a dir for models" | |||||
assert os.path.isfile(args.load_and_run_file), "invalid args for load_and_run_file" | |||||
for m in test_cpu_models: | |||||
assert os.path.isfile( | |||||
os.path.join(args.models_dir, m) | |||||
), "invalid args for models_dir, need put model: {} to args.models_dir".format( | |||||
test_cpu_models | |||||
) | |||||
# init device | |||||
ssh = SshConnector() | |||||
ssh.setup(device["login_name"], device["ip"], device["port"]) | |||||
# create test dir | |||||
workspace = "cpu_evaluation_workspace" | |||||
ssh.cmd(["mkdir -p {}".format(workspace)]) | |||||
# copy load_and_run_file | |||||
ssh.copy([args.load_and_run_file], workspace) | |||||
# call test | |||||
result = [] | |||||
for m in test_cpu_models: | |||||
m_path = os.path.join(args.models_dir, m) | |||||
# copy model file | |||||
ssh.copy([m_path], workspace) | |||||
# run single thread | |||||
sub_b = ["-cpu", "-multithread {}".format(device["thread_number"])] | |||||
for b in sub_b: | |||||
cmd = [] | |||||
cmd0 = "cd {} && rm -rf fastrun.cache".format(workspace) | |||||
cmd1 = "cd {} && ./load_and_run {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format( | |||||
workspace, m, b | |||||
) | |||||
cmd2 = "cd {} && ./load_and_run {} {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess --record-comp-seq".format( | |||||
workspace, m, b | |||||
) | |||||
cmd.append(cmd0) | |||||
cmd.append(cmd1) | |||||
cmd.append(cmd2) | |||||
raw_log = ssh.cmd(cmd) | |||||
# logging.debug(raw_log) | |||||
ret = get_finally_bench_resulut_from_log(raw_log) | |||||
logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret)) | |||||
result.append(ret) | |||||
total_time = 0.0 | |||||
for r in result: | |||||
total_time += r | |||||
logging.debug("total time is: {}".format(total_time)) | |||||
score = 100000.0 / total_time * 1000 | |||||
logging.debug("device: {} score is: {}".format(device["name"], score)) | |||||
if __name__ == "__main__": | |||||
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" | |||||
DATE_FORMAT = "%Y/%m/%d %H:%M:%S" | |||||
logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT) | |||||
main() |