From 6ce4a34403b9ba457a4b0f97145e0867c2ed6a97 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 14 Dec 2021 10:57:53 +0800 Subject: [PATCH] feat(dnn): add fallback postprocess GitOrigin-RevId: 4201a0f158b1439cfa9d900c498aff1e242992ed --- dnn/src/common/postprocess_helper.h | 73 ++- dnn/src/common/relayout_helper.h | 3 +- imperative/python/megengine/tools/README.md | 156 +----- imperative/python/megengine/tools/benchmark_op.py | 151 ------ .../megengine/tools/dump_with_testcase_mge.py | 535 --------------------- tools/cpu_evaluation_tools.py | 152 ++++++ 6 files changed, 212 insertions(+), 858 deletions(-) delete mode 100644 imperative/python/megengine/tools/benchmark_op.py delete mode 100755 imperative/python/megengine/tools/dump_with_testcase_mge.py create mode 100644 tools/cpu_evaluation_tools.py diff --git a/dnn/src/common/postprocess_helper.h b/dnn/src/common/postprocess_helper.h index 05de2368..5deccd2f 100644 --- a/dnn/src/common/postprocess_helper.h +++ b/dnn/src/common/postprocess_helper.h @@ -14,6 +14,8 @@ #include "megdnn/basic_types.h" #include "midout.h" +#include "src/common/conv_bias.h" +#include "src/common/opr_delegate.h" #include "src/common/postprocess.h" namespace { @@ -31,17 +33,41 @@ namespace { MEGDNN_MARK_USED_VAR(OW); \ MEGDNN_MARK_USED_VAR(pack_oc_size) +void to_handle_bias_and_nonlinear( + void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, + megdnn::ConvBiasForward::BiasMode bias_mode, + megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, + megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW) { + auto handle = megdnn::inplace_cpu_handle(); + auto conv_dst_tensor_layout = megdnn::TensorLayout({N, OC, OH, OW}, dst_type); + auto conv_dst_tensor = megdnn::TensorND{conv_dst_ptr, conv_dst_tensor_layout}; + auto dst_tensor = megdnn::TensorND{dst_ptr, conv_dst_tensor_layout}; + auto bias_tensor_layout = conv_dst_tensor_layout; + if (megdnn::ConvBiasForward::BiasMode::BROADCAST_CHANNEL_BIAS == bias_mode) { + bias_tensor_layout = megdnn::TensorLayout({1, OC, 1, 1}, bias_type); + } else if (megdnn::ConvBiasForward::BiasMode::NO_BIAS == bias_mode) { + bias_tensor_layout = megdnn::TensorLayout({}, bias_type); + } + auto bias_tensor = + megdnn::TensorND{const_cast(bias_ptr), bias_tensor_layout}; + handle_bias_and_nonlinear( + handle.get(), nonlineMode, &conv_dst_tensor, &dst_tensor, &bias_tensor); +} + template < typename ctype, typename dtype = ctype, megdnn::PostprocessMode postprocess_mode = megdnn::PostprocessMode::FLOAT> struct PostProcess { static void run( void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, - size_t OH, size_t OW, size_t pack_oc_size = 1) { - POST_PROCESS_UNUSED_VAR(); - megdnn_throw("not impl PostProcess"); + megdnn::ConvBiasForward::BiasMode bias_mode, + megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, + megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, + size_t pack_oc_size = 1) { + MEGDNN_MARK_USED_VAR(pack_oc_size); + to_handle_bias_and_nonlinear( + conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type, + dst_type, N, OC, OH, OW); } }; @@ -49,11 +75,11 @@ template struct PostProcess { static void run( void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, - size_t OH, size_t OW, size_t pack_oc_size = 1) { + megdnn::ConvBiasForward::BiasMode bias_mode, + megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, + megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, + size_t pack_oc_size = 1) { POST_PROCESS_UNUSED_VAR(); - megdnn_throw("not impl PostProcess"); } }; @@ -61,11 +87,14 @@ template struct PostProcess { static void run( void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, - size_t OH, size_t OW, size_t pack_oc_size = 1) { - POST_PROCESS_UNUSED_VAR(); - megdnn_throw("not impl PostProcess"); + megdnn::ConvBiasForward::BiasMode bias_mode, + megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, + megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, + size_t pack_oc_size = 1) { + MEGDNN_MARK_USED_VAR(pack_oc_size); + to_handle_bias_and_nonlinear( + conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type, + dst_type, N, OC, OH, OW); } }; @@ -73,11 +102,17 @@ template struct PostProcess { static void run( void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, - size_t OH, size_t OW, size_t pack_oc_size = 1) { - POST_PROCESS_UNUSED_VAR(); - megdnn_throw("not impl PostProcess"); + megdnn::ConvBiasForward::BiasMode bias_mode, + megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, + megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, + size_t pack_oc_size = 1) { + MEGDNN_MARK_USED_VAR(pack_oc_size); + if (bias_mode == megdnn::ConvBiasForward::BiasMode::NO_BIAS) { + return; + } + to_handle_bias_and_nonlinear( + conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type, + dst_type, N, OC, OH, OW); } }; diff --git a/dnn/src/common/relayout_helper.h b/dnn/src/common/relayout_helper.h index c03f9e53..0c52f725 100644 --- a/dnn/src/common/relayout_helper.h +++ b/dnn/src/common/relayout_helper.h @@ -50,7 +50,8 @@ constexpr size_t BLOCK_LINE_SIZE_BYTES = 32; //! ref U54-MC arch constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; #else -#error "unknown megdnn arch" +//! for fallback, need keep same with MEGDNN_NAIVE +constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; #endif /** diff --git a/imperative/python/megengine/tools/README.md b/imperative/python/megengine/tools/README.md index 56d443f9..a2ce3de9 100644 --- a/imperative/python/megengine/tools/README.md +++ b/imperative/python/megengine/tools/README.md @@ -1,156 +1,8 @@ # MegEngine Tools -MegEngine 相关的工具汇总。使用方法如下(可将 `xxx` 替换成任一脚本文件,如 `network_visualize`): +This directory contains executable python files. +Use these files in the following way (replace `xxx` to specific file name, like `network_visualize`): -```bash -python -m megengine.tools.xxx -``` - -工具列表: - -### accuracy_shake_var_tree - -将精度抖动分析结果构造成树结构,方便锁定引起抖动的根节点,以及查找依赖关系。 - -输入: compare_binary_iodump 的输出存入到的一个文件 - -输出: 第一个出现结果不一致的输出结点 - -执行命令: accuracy_shake_var_tree 中定义了一些函数组件,可按需集成到实际代码中。下面有一个测试代码: - -```python -import megengine.tools.accuracy_shake_var_tree as st - -r = st.parse('diff.txt') -for key, value in r.items(): - n = st.varNode.get_varNode(key) - n.show_src_info() - print("reference nodes:") - for i in n.get_reference_list(): - print(i.id) -``` - -### benchmark_op - -逐个运行 functional op(并不是所有的 functional op),对比 MegEngine 与 PyTorch 的性能,通过量化结果来指导如何进行下一步的优化。 - -输入: 无 - -输出: 打印一个列表,对比在小输入和大输入的情况下 MegEngine 和 Pytorch 执行一些 functional op 的速度对比 - -执行命令: `python3 -m megengine.tools.benchmark_op` - -### compare_binary_iodump - -分析同一模型在不同平台下给定相同输入之后的输出是否完全一致。 - -输入: 两个目录(假设分别为 expect/ 和 actual/),分别存有不同平台下运行的 tensor 结果 - -输出: 打印所有的输出 tensor 信息,如果某个 tensor 在两个平台上的值不一致,那么会打印出第一个不一致的值 - -执行命令: `python3 -m megengine.tools.compare_binary_iodump expect/ actual/` - -### draw_graph - -用来查看静态图的 op 序列,有助于理解 MegEngine 的静态图在动态图的基础上做了哪些优化。 - -输入: `megengine.core.tensor.megbrain_graph.Graph._to_json` 得出的静态图描述文件,为 json 格式 - -输出: 一个 dot 文件,可通过 dot 命令绘制出图片 - -执行命令: - -```bash -python3 -m megengine.tools.draw_graph -i dump.json -o dump.dot -dot -Tpng dump.dot -o dump.png -``` - -### dump_with_testcase_mge - -将待测数据提前注入模型文件,并在本地运行得到期望结果,可与实际运行的结果进行比对以检查是否出错。 - -输入: 一个 MegEngine 模型文件,可选一些 npy 文件作为模型输入(也可以随机生成输入,如下面的命令示例) - -输出: 一个带输入的 MegEngine 模型文件 - -执行命令: `python3 -m megengine.tools.dump_with_testcase_mge model.mge -d "#rand(0,255,14,2)"` - -### graph_info_analyze - -将图和内存信息的 json 文件的文件夹 logs 转换为 TensorBoard 的输入文件夹 logs_p。以便 TensorBoard 对图结构以及内存信息进行可视化。 - -输入: 图和内存信息的 json 文件的文件夹 - -输出: TensorBoard 的输入文件夹 - -执行命令: `python3 -m megengine.tools.graph_info_analyze -i logs -o logs_p` - -### load_network_and_run - -python 版本的 load_and_run。 - -输入: MegEngine 的模型文件,可选一些 npy 文件作为模型输入 - -输出: 模型执行并打印一些测速信息 - -执行命令: `python3 -m megengine.tools.load_network_and_run model.mge --iter 10` - -### network_visualize - -1. 分析给定的 MegEngine 模型中参数量信息,包括 shape、dtype、mean、std 以及 size 占比等。 -2. 分析给定的 MegEngine 模型中算子 FLOPs 计算量以及占比,还有算子的 inputs\outputs shape、感受野、stride 等。 - -输入: MegEngine 的模型文件 - -输出: 模型中的参数量信息或计算量信息 - -执行命令: - -```bash -# 分析参数量 -python3 -m megengine.tools.network_visualize model.mge --cal_params --logging_to_stdout - -# 分析计算量 -python3 -m megengine.tools.network_visualize model.mge --cal_flops --logging_to_stdout ``` - -### profile_analyze - -对于 load_and_run --profile 运行模型生成的 profile.json 文件或者 trace 模式下开启 profiling 功能并通过 trace.get_profile() 得到的 json 文件进行分析,得到静态图中算子的时间和显存占比等信息,以表格形式呈现。 - -输入: load_and_run 生成的 profile 文件 - -输出: 一个按照参数在输入文件中筛选得出的数据表格 - -执行命令: - -```bash -# 生成供分析的 json 文件 -python3 -m megengine.tools.load_network_and_run model.mge --warm-up --iter 10 --profile profile.json - -#分析耗时前 3 的单个算子 -python3 -m megengine.tools.profile_analyze profile.json -t 3 - -#筛选用时超过 10us 的 conv 按 flops 排序 -python3 -m megengine.tools.profile_analyze profile.json -t 3 --order-by +flops --min-time 1e-5 --type ConvolutionForward -``` - -### profiler - -对给定的训练程序,记录训练过程并以通用格式存储,可在浏览器上可视化。 - -输入: 需要一个 MegEngine 的训练程序(称之为 train.py,其中包含一个典型的 MegEngine 训练过程) - -输出: 一些记录 profile 过程的 json 文件,默认在 profile 子目录下,可用 https://ui.perfetto.dev/ 进行加载并且可视化 - -执行命令: `python3 -m megengine.tools.profiler train.py` - -### svg_viewer - -查看 MegEngine 生成的显存占用图,可以帮助用户了解显存使用情况. - -输入: 显存占用的 svg 图片 - -输出: 网页展示的可视化 - -执行命令: `python3 -m megengine.tools.svg_viewer` +python -m megengine.tools.xxx +``` \ No newline at end of file diff --git a/imperative/python/megengine/tools/benchmark_op.py b/imperative/python/megengine/tools/benchmark_op.py deleted file mode 100644 index a296198c..00000000 --- a/imperative/python/megengine/tools/benchmark_op.py +++ /dev/null @@ -1,151 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import time -import numpy as np -import megengine as mge -import megengine.module as MM -import megengine.functional as MF -import torch -import torch.nn as nn -import torch.nn.functional as TF - -from tabulate import tabulate - -module_cache = { - "conv2d": (MM.Conv2d(32, 32, 3, 1, 0), nn.Conv2d(32, 32, 3, 1, 0).cuda()), - "dw_conv2d": (MM.Conv2d(32, 32, 3, 1, 0, groups=32), nn.Conv2d(32, 32, 3, 1, 0, groups=32).cuda()), - "conv3d": (MM.Conv3d(32, 32, 3, 1, 0), nn.Conv3d(32, 32, 3, 1, 0).cuda()), - "ConvTranspose2d": (MM.ConvTranspose2d(32, 32, 3, 1, 0), nn.ConvTranspose2d(32, 32, 3, 1, 0).cuda()), - "BatchNorm2d": (MM.BatchNorm2d(64), nn.BatchNorm2d(64).cuda()), - "Linear": (MM.Linear(1000, 1000), nn.Linear(1000, 1000).cuda()), -} - -test_cases = [ - # (mge op, torch op, small inps, large inps, unpack_inps, rep) - ("adaptive_avg_pool2d", lambda x: MF.adaptive_avg_pool2d(x, (7,7)), lambda x: TF.adaptive_avg_pool2d(x, (7,7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), - ("adaptive_max_pool2d", lambda x: MF.adaptive_max_pool2d(x, (7,7)), lambda x: TF.adaptive_max_pool2d(x, (7,7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), - ("argsort", MF.argsort, torch.argsort, [(1000,)], [(1000, 1000),], True, 1000), - ("avg_pool2d", lambda x: MF.avg_pool2d(x, 2), lambda x: TF.avg_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), - ("broadcast", lambda x: MF.broadcast_to(x, (5,) + x.shape), lambda x: torch.broadcast_to(x, (5,)+x.shape), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("batchedmatmul", MF.matmul, torch.matmul, [(8, 64, 32), (8, 32, 64)], [(8, 2048, 512), (8, 512, 2048)], True, 1000), - ("batchnrom2d", lambda x: module_cache["BatchNorm2d"][0](x), lambda x: module_cache["BatchNorm2d"][1](x), [(2, 64, 16, 16)], [(64, 64, 128, 128)], True, 1000), - ("concat", MF.concat, torch.cat, [(20, 100), (50, 100), (30, 100)], [(64, 512, 16, 16), (64, 512, 16, 16), (64, 512, 16, 16)], False, 1000), - ("conv2d", lambda x: module_cache["conv2d"][0](x), lambda x: module_cache["conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), - ("conv3d", lambda x: module_cache["conv3d"][0](x), lambda x: module_cache["conv3d"][1](x), [(2, 32, 8, 8, 8)], [(32, 32, 16, 16, 16)], True, 1000), - ("convTranspose2d", lambda x: module_cache["ConvTranspose2d"][0](x), lambda x: module_cache["ConvTranspose2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), - ("dropout", lambda x: MF.dropout(x, 0.5), TF.dropout, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("dw_conv2d", lambda x: module_cache["dw_conv2d"][0](x), lambda x: module_cache["dw_conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), - ("elemwise.unary", MF.log, torch.log, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("elemwise.binary", MF.add, torch.add, [(100,100), (100,100)], [(64, 512, 16, 16), (64, 512, 16, 16)], True, 1000), - ("expand_dims", lambda x: MF.expand_dims(x, 0), lambda x: torch.unsqueeze(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("gelu", MF.gelu, TF.gelu, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("hswish", MF.hswish, TF.hardswish, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("hsigmoid", MF.hsigmoid, TF.hardsigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("isinf", MF.isinf, torch.isinf, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("indeixngMultiAxisVec", lambda x: x[[1,3,5], [1,3,5], [1,3,5], [1,3,5]], lambda x: x[[1,3,5], [1,3,5], [1,3,5], [1,3,5]], [(10,10,10,10)], [(64, 512, 16, 16)], True, 1000), - ("logsigmoid", MF.logsigmoid, TF.logsigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("leaky_relu", lambda x: MF.leaky_relu(x, 0.5), lambda x: TF.leaky_relu(x, 0.5), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("linear", lambda x: module_cache["Linear"][0](x), lambda x: module_cache["Linear"][1](x), [(10, 1000)], [(64, 128, 1000)], True, 1000), - ("matinv", MF.matinv, torch.inverse, [(10,10)], [(30, 30)], True, 1000), - ("matmul", MF.matmul, torch.matmul, [(64,32), (32, 64)], [(2048, 1024), (1024, 2048)], True, 1000), - ("max_pool2d", lambda x: MF.max_pool2d(x, 2), lambda x: TF.max_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), - ("normal", lambda x: mge.random.normal(0,1, x.shape), lambda x: torch.randn(x.shape, device="cuda"), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("prelu", MF.prelu, TF.prelu, [(100,100), (1,)], [(64, 512, 16, 16), (1,)], True, 1000), - ("reduce.max", lambda x: MF.max(x, 0), lambda x: torch.max(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("relu", MF.relu, TF.relu, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("relu6", MF.relu6, TF.relu6, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("repeat", lambda x: MF.repeat(x, 5), lambda x: torch.repeat_interleave(x, 5), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("silu", MF.silu, TF.silu, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("split", lambda x: MF.split(x, 5), lambda x: torch.split(x, 5), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("sigmoid", MF.sigmoid, TF.sigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("softmax", lambda x: MF.softmax(x, axis=1), lambda x: TF.softmax(x, dim=1), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("softplus", MF.softplus, TF.softplus, [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("squeeze", lambda x: MF.squeeze(x, 0), lambda x: torch.squeeze(x, 0), [(1, 100,100)], [(1, 64, 512, 16, 16)], True, 1000), - ("stack", MF.stack, torch.stack, [(100,100), (100,100)], [(64, 512, 16, 16), (64, 512, 16, 16)], False, 10000), - ("subtensor", lambda x: x[0:20, 10:60], lambda x: x[0:20, 10:60], [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("topk", lambda x: MF.topk(x, 10), lambda x: torch.topk(x, 10), [(100,100)], [(1000, 1000)], True, 1000), - ("tile", lambda x: MF.tile(x, (2,)*len(x.shape)), lambda x: torch.tile(x, (2,)*len(x.shape)), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("transpose", lambda x: MF.transpose(x, list(range(len(x.shape)))[::-1]), lambda x: torch.permute(x, list(range(len(x.shape)))[::-1]), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("where", lambda x: MF.where(x > 0.5, x, x), lambda x: torch.where(x > 0.5, x, x), [(100,100)], [(64, 512, 16, 16)], True, 1000), - ("uniform", lambda x: mge.random.uniform(0,1, x.shape), lambda x: torch.rand(x.shape, device="cuda"), [(100,100)], [(64, 512, 16, 16)], True, 1000), -] - - -def perf_func(func, inps, reps, unpack_inps, is_mge): - if is_mge: - mge._full_sync() - tik = time.time() - for _ in range(reps): - if unpack_inps: - out = func(*inps) - else: - out = func(inps) - mge._full_sync() - else: - torch.cuda.synchronize() - with torch.no_grad(): - tik = time.time() - for _ in range(reps): - if unpack_inps: - out = func(*inps) - else: - out = func(inps) - torch.cuda.synchronize() - return time.time() - tik - - -def get_avg_time(func, inps, reps, unpack_inps, is_mge): - # warm up - for _ in range(2): - t = perf_func(func, inps, reps, unpack_inps, is_mge) - - times = [] - for _ in range(5): - t = perf_func(func, inps, reps, unpack_inps, is_mge) - times.append(t) - return np.mean(times) - - -def get_perf_results(mge_func, torch_func, shapes, unpack_inps, reps): - inps = [ - np.random.randn(*shape) for shape in shapes - ] - - inps_mge = [mge.tensor(inp, dtype="float32") for inp in inps] - avg_time_mge = get_avg_time(mge_func, inps_mge, reps, unpack_inps, True) - - inps_torch = [torch.Tensor(inp).type(torch.float).cuda() for inp in inps] - avg_time_torch = get_avg_time(torch_func, inps_torch, reps, unpack_inps, False) - - return avg_time_mge, avg_time_torch - - -if __name__ == "__main__": - header = ["opr_name", "time(mge/pytorch; small input)", "time(mge/pytorch; large input)"] - table = [] - for case in test_cases: - assert len(case) == 7 - name, mge_func, torch_func, small_shapes, large_shapes, unpack_inps, reps = case - data = [] - data.append(name) - print("========== op: {}".format(name)) - - avg_time_mge, avg_time_torch = get_perf_results(mge_func, torch_func, small_shapes, unpack_inps, reps) - print("mge time: {}".format(avg_time_mge)) - print("torch time: {}".format(avg_time_torch)) - data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) - - avg_time_mge, avg_time_torch = get_perf_results(mge_func, torch_func, large_shapes, unpack_inps, reps) - print("mge time: {}".format(avg_time_mge)) - print("torch time: {}".format(avg_time_torch)) - data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) - table.append(data) - print(tabulate(table, header, tablefmt="github")) \ No newline at end of file diff --git a/imperative/python/megengine/tools/dump_with_testcase_mge.py b/imperative/python/megengine/tools/dump_with_testcase_mge.py deleted file mode 100755 index 6f325edb..00000000 --- a/imperative/python/megengine/tools/dump_with_testcase_mge.py +++ /dev/null @@ -1,535 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import argparse -import os -import re -import struct - -import cv2 -import numpy as np - -import megengine as mge -import megengine.core._imperative_rt as rt -import megengine.core.tensor.megbrain_graph as G -from megengine import tensor -from megengine.core.ops import builtin -from megengine.utils import comp_graph_tools as cgtools - -logger = mge.get_logger(__name__) - - -def auto_reformat_image(args, path, data, dst_shape): - """reformat image to target shape - - :param data: image data as numpy array - :param dst_shape: target shape - """ - dim3_format = False # required input format does not contain batch - hwc_format = False # required input format is NHWC - - if not dst_shape: # input tensor shape is not predefined - if len(data.shape) == 2: - chl = 1 - h = data.shape[0] - w = data.shape[1] - else: - assert len(data.shape) == 3, "Input image must be of dimension 2 or 3" - h, w, chl = data.shape - dst_shape = (1, chl, h, w) - - if len(dst_shape) == 3: - dst_shape = (1,) + dst_shape - dim3_format = True - - assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) - chl = dst_shape[1] - if chl in [1, 3]: - n, c, h, w = dst_shape - dst_shape = (n, h, w, c) - else: - chl = dst_shape[3] - assert chl in [1, 3], "can not infer input format from shape: {}".format( - dst_shape - ) - hwc_format = True - - # dst_shape has now been normalized to NHWC format - - if args.resize_input: - h, w = dst_shape[1:3] - data = cv2.resize(data, (w, h)) - logger.info("input {} resized to {}".format(path, data.shape)) - - if chl == 1: - data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) - data = data[:, :, np.newaxis] - - assert data.ndim == 3 - data = data[np.newaxis] - # data normalized to NHWC format - - if not hwc_format: - data = np.transpose(data, (0, 3, 1, 2)) - - if dim3_format: - data = np.squeeze(data, 0) - - return data - - -def read_input_data(args, dst_shape, dtype, path, repeat): - def check_shape_equal(dst_shape, data_shape): - if len(dst_shape): - assert len(data_shape) == len( - dst_shape - ), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape) - - if data_shape[1:] != dst_shape[1:]: - logger.warning( - "dst_shape is {}; data_shape is {}".format(dst_shape, data_shape) - ) - - if path.startswith("#"): - assert not args.resize_input - assert not args.input_transform - spec = path - m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec) - assert m, "bad spec {}".format(spec) - - rng_min = float(m.group(1)) - rng_max = float(m.group(2)) - if m.group(3): - shape_str = m.group(3) - try: - shape = shape_str[1:].split(",") - if shape[-1].strip() == "...": - shape = shape[:-1] - shape.extend(list(dst_shape[len(shape) :])) - data_shape = tuple(map(int, shape)) - except ValueError as e: - raise ValueError("bad spec {}: {}".format(spec, e.args)) - else: - data_shape = dst_shape - - check_shape_equal(dst_shape, data_shape) - return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) - - # try to load image - data = cv2.imread(path, cv2.IMREAD_COLOR) - if data is None: - assert not args.resize_input - data = np.load(path) - assert isinstance(data, np.ndarray) - else: - # load image succeeds, so we expect input format is image format - data = auto_reformat_image(args, path, data, dst_shape) - - data = np.repeat(data, repeat, axis=0) - if repeat > 1: - logger.info( - "repeat input for {} times, data shape is {}".format(repeat, data.shape) - ) - - check_shape_equal(dst_shape, data.shape) - - if args.input_transform: - data = eval(args.input_transform, {"data": data, "np": np}) - - return data - - -def gen_one_testcase(args, inputs, spec): - paths = spec.split(";") - if len(paths) != len(inputs): - if len(paths) == 1 and paths[0].startswith("#"): - paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] - assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format( - inputs.keys(), paths - ) - if len(paths) == 1 and ":" not in paths[0]: - paths[0] = next(iter(inputs.keys())) + ":" + paths[0] - - ret = {} - for path in paths: - var, path = path.split(":") - if args.repeat: - repeat = args.repeat - else: - repeat = 1 - ret[var] = read_input_data( - args, inputs[var].shape, inputs[var].dtype, path, repeat - ) - return ret - - -def make_feeds(args): - ret = G.load_graph(args.input) - cg_rt, outputs = ret.graph, ret.output_vars_list - inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") - - inputs = {i.name: i for i in inputs} - if not args.no_assert: - - replace_varmap = {} - inp_map = {} - # replace var use InputNode - for name, var in inputs.items(): - inp = G.InputNode( - device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt - ) - replace_varmap[var] = inp.outputs[0] - inp_map[name] = inp - - new = cgtools.replace_vars(outputs, replace_varmap) - if isinstance(new, rt.VarNode): - new = list(new) - - output_nodes = [G.OutputNode(var) for var in new] - func = cg_rt.compile([node.outputs[0] for node in output_nodes]) - - def make_dev_tensor(value, dtype=None, device=None): - return tensor(value, dtype=dtype, device=device)._dev_tensor() - - def calculate(*args, **kwargs): - output_val = [] - # set inputs value - for name, var in inputs.items(): - val = kwargs.pop(name, None) - assert val is not None, "miss input name{}".format(name) - dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") - inp_map[name].set_value(dev_tensor) - - func.execute() - - for res in output_nodes: - output_val.append(res.get_value().numpy()) - return output_val - - def expect_name(var): - return "{}:expect".format(var.name) - - testcases = [] - - np.set_printoptions(precision=2, threshold=4, suppress=True) - - data_list = [] - for item in args.data: - if item.startswith("@"): - with open(item[1:], "r") as f: - data_list.extend([line.rstrip() for line in f if line.rstrip() != ""]) - else: - data_list.append(item) - - for inp_spec in data_list: - cur_testcase = gen_one_testcase(args, inputs, inp_spec) - assert len(cur_testcase) == len( - inputs - ), "required inputs: {}; given data: {}".format( - inputs.keys(), cur_testcase.keys() - ) - - if not args.no_assert: - outputs_get = calculate(**cur_testcase) - for var, val in zip(outputs, outputs_get): - cur_testcase[expect_name(var)] = val - logger.info( - "generate test groundtruth: var={} shape={} range=({}, {})" - " mean={} var={}".format( - var, val.shape, val.min(), val.max(), np.mean(val), np.var(val) - ) - ) - testcases.append(cur_testcase) - logger.info( - "add testcase: \n {}".format( - "\n ".join( - "{}: shape={} dtype={} range=({:.2f},{:.2f}) " - "mean={:.2f} sd={:.2f}".format( - k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) - ) - for k, v in sorted(cur_testcase.items()) - ) - ) - ) - - if not args.no_assert: - - def expect_shp(var): - ret = var.shape - if ret: - return ret - return testcases[0][expect_name(var)].shape - - def assert_equal(expect, real, **kwargs): - op = builtin.AssertEqual(**kwargs) - (res,) = G.apply_normal_varnode(op, expect, real) - return res - - verbose = not args.silent - - outputs_new = [] - for i in outputs: - device = rt.CompNode("xpux") - dtype = i.dtype - name = expect_name(i) - shape = expect_shp(i) - # make expect output as one input of model. - expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name) - # insert assert opr to check expect and real. - outputs_new.append( - assert_equal( - expect_get, - i, - verbose=verbose, - maxerr=args.maxerr, - ) - ) - inputs[expect_name(i)] = expect_get - outputs = outputs_new - - return {"outputs": outputs, "testcases": testcases} - - -def optimize_for_inference(args, outputs): - args_list = [ - "enable_io16xc32", - "enable_ioc16", - "enable_hwcd4", - "enable_nchw4", - "enable_nchw88", - "enable_nchw44", - "enable_nchw44_dot", - "enable_nchw32", - "enable_chwn4", - "enable_fuse_conv_bias_nonlinearity", - "enable_fuse_conv_bias_with_z", - "enable_fuse_preprocess", - ] - kwargs = {} - for k in args_list: - if getattr(args, k): - kwargs[k] = True - - if args.optimize_for_inference: - outputs = G.optimize_for_inference(outputs, **kwargs) - - return outputs - - -def main(): - parser = argparse.ArgumentParser( - description="Pack computing graph, input values and expected output " - "values into one file for checking correctness. README.md gives more " - "details on the usage", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("input", help="MegEngine dumped model file") - parser.add_argument("-o", "--output", help="output file", required=True) - parser.add_argument( - "-d", - "--data", - default=[], - action="append", - required=True, - help="Given input test data when input file is a network, " - "and current network output would be used as groundtruth. " - "The format is var0:file0;var1:file1... to specify data files for " - "input vars. It can also be #rand(min,max,shape...) for generating " - "random input data, for example, #rand(0,255), " - "#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means " - "the remaining part of the original shape. " - "If the shape is not specified, the shape of " - "corresponding input tensors in the network will be used. " - "If there is only one input var, its name can be omitted. " - "Each data file can either be an image which can be loaded by opencv, " - "or a pickled numpy.ndarray. " - "This option can be given multiple times to add multiple testcases. " - " *NOTE* " - "If you start the data with the letter @, the rest should be a " - "filename, and each line in the file should be a single datum in " - "the format described above. ", - ) - parser.add_argument( - "--repeat", - type=int, - default=1, - help="Specify how many times the input image is repeated. " - "Useful when running benchmark for batch size other than one. " - "Have no effect on randomly generated input data.", - ) - parser.add_argument( - "--silent", - action="store_true", - help="set verbose to False in asserti_equal opr", - ) - parser.add_argument( - "--optimize-for-inference", - action="store_true", - help="enable optimization for inference", - ) - parser.add_argument( - "--no-assert", - action="store_true", - help="do not insert assert_equal opr to check result; " - "this option is useful for benchmarking", - ) - parser.add_argument( - "--maxerr", - type=float, - default=1e-4, - help="max error for assert_equal check during runtime", - ) - parser.add_argument( - "--resize-input", - action="store_true", - help="resize input image to fit input var shape", - ) - parser.add_argument( - "--input-transform", - help="a python expression to transform the input data. " - "Example: data / np.std(data)", - ) - parser.add_argument( - "--discard-var-name", - action="store_true", - help="discard variable and param names in the " "generated output", - ) - parser.add_argument( - "--output-strip-info", action="store_true", help="output code strip information" - ) - parser.add_argument( - "--enable-io16xc32", - action="store_true", - help="transform the mode to float16 io float32 compute", - ) - parser.add_argument( - "--enable-ioc16", - action="store_true", - help="transform the dtype of the model to float16 io " "and compute", - ) - parser.add_argument( - "--enable-fuse-conv-bias-nonlinearity", - action="store_true", - help="fuse convolution bias and nonlinearity opr to a " - "conv_bias opr and compute", - ) - parser.add_argument( - "--enable-hwcd4", - action="store_true", - help="transform the model format from NCHW to NHWCD4 " - "for inference; you may need to disable CUDA and set " - "MGB_USE_MEGDNN_DBG=2", - ) - parser.add_argument( - "--enable-nchw4", - action="store_true", - help="transform the model format from NCHW to NCHW4 " "for inference", - ) - parser.add_argument( - "--enable-nchw88", - action="store_true", - help="transform the model format from NCHW to NCHW88 " "for inference", - ) - parser.add_argument( - "--enable-nchw44", - action="store_true", - help="transform the model format from NCHW to NCHW44 " "for inference", - ) - parser.add_argument( - "--enable-nchw44-dot", - action="store_true", - help="transform the model format from NCHW to NCHW44_DOT " - "for optimizing armv8.2 dot in inference", - ) - parser.add_argument( - "--enable-nchw32", - action="store_true", - help="transform the model format from NCHW4 to NCHW32 " - "for inference on nvidia TensoCore", - ) - parser.add_argument( - "--enable-chwn4", - action="store_true", - help="transform the model format to CHWN4 " - "for inference, mainly used for nvidia tensorcore", - ) - parser.add_argument( - "--enable-fuse-conv-bias-with-z", - action="store_true", - help="fuse conv_bias with z input for inference on " - "nvidia GPU (this optimization pass will result in mismatch " - "of the precision of output of training and inference)", - ) - parser.add_argument( - "--enable-fuse-preprocess", - action="store_true", - help="fuse astype\pad_channel\dimshuffle and etc opr " - "from h2d opr", - ) - args = parser.parse_args() - - feeds = make_feeds(args) - - assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty" - - output_mgbvars = feeds["outputs"] - output_mgbvars = optimize_for_inference(args, output_mgbvars) - - inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") - inputs = sorted((i.name, i.dtype) for i in inputs) - - if args.discard_var_name: - sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) - else: - sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) - - strip_info_file = args.output + ".json" if args.output_strip_info else None - - with open(args.output, "wb") as fout: - fout.write(b"mgbtest0") - fout.write(struct.pack("I", len(feeds["testcases"]))) - dump_content, stat = G.dump_graph( - output_mgbvars, - append_json=True, - strip_info_file=strip_info_file, - **sereg_kwargs, - ) - fout.write(dump_content) - - logger.info( - "graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format( - stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 - ) - ) - - def make_dev_tensor(value, dtype=None, device=None): - return tensor(value, dtype=dtype, device=device)._dev_tensor() - - for testcase in feeds["testcases"]: - assert isinstance(testcase, dict) - cg = G.Graph() - output_mgbvars = [] - for name, dtype in inputs: - output_mgbvars.append( - cg.make_const( - make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux") - ) - ) - assert not testcase, "extra inputs provided in testcase: {}".format( - testcase.keys() - ) - with open(args.output, "ab") as fout: - dump_content, _ = G.dump_graph( - output_mgbvars, strip_info_file=strip_info_file, append_json=True - ) - fout.write(dump_content) - - -if __name__ == "__main__": - main() - diff --git a/tools/cpu_evaluation_tools.py b/tools/cpu_evaluation_tools.py new file mode 100644 index 00000000..bd5277d7 --- /dev/null +++ b/tools/cpu_evaluation_tools.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +""" +purpose: Used to simply measure CPU performance by running several basic models. +how to use: python3 cpu_evaluation_tools.py --help for more details, now need to args: + --load_and_run_file: path of load_and_run binary, please refs to ../scripts/cmake-build/BUILD_README.md to build it. + --models_dir: path of model directory. +how to config test device info: config device[name/login_name/ip/port/thread_number]. +""" +import argparse +import logging +import os +import re +import subprocess + +# test device +device = { + "name": "hwmt40p", + "login_name": "hwmt40p-K9000-maliG78", + "ip": "box86.br.megvii-inc.com", + "port": 2200, + "thread_number": 3, +} +# test models +test_cpu_models = [ + "inceptionv2", + "mobilenetv1", + "mobilenetv2", + "resnet18", + "resnet50", + "shufflenetv2", + "vgg16", +] + + +class SshConnector: + """imp ssh control master connector""" + + ip = None + port = None + login_name = None + + def setup(self, login_name, ip, port): + self.ip = ip + self.login_name = login_name + self.port = port + + def copy(self, src_list, dst_dir): + assert isinstance(src_list, list), "code issue happened!!" + assert isinstance(dst_dir, str), "code issue happened!!" + for src in src_list: + cmd = 'rsync --progress -a -e "ssh -p {}" {} {}@{}:{}'.format( + self.port, src, self.login_name, self.ip, dst_dir + ) + logging.debug("ssh run cmd: {}".format(cmd)) + subprocess.check_call(cmd, shell=True) + + def cmd(self, cmd): + output = "" + assert isinstance(cmd, list), "code issue happened!!" + for sub_cmd in cmd: + p_cmd = 'ssh -p {} {}@{} "{}" '.format( + self.port, self.login_name, self.ip, sub_cmd + ) + logging.debug("ssh run cmd: {}".format(p_cmd)) + output = output + subprocess.check_output(p_cmd, shell=True).decode("utf-8") + return output + + +def get_finally_bench_resulut_from_log(raw_log) -> float: + # raw_log --> avg_time=23.331ms -->23.331ms + h = re.findall(r"avg_time=.*ms ", raw_log)[-1][9:] + # to 23.331 + h = h[: h.find("ms")] + # to float + h = float(h) + return h + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("--models_dir", help="models dir", required=True) + parser.add_argument( + "--load_and_run_file", help="path for load_and_run", required=True + ) + args = parser.parse_args() + assert os.path.isdir( + args.models_dir + ), "invalid args for models_dir, need a dir for models" + assert os.path.isfile(args.load_and_run_file), "invalid args for load_and_run_file" + for m in test_cpu_models: + assert os.path.isfile( + os.path.join(args.models_dir, m) + ), "invalid args for models_dir, need put model: {} to args.models_dir".format( + test_cpu_models + ) + + # init device + ssh = SshConnector() + ssh.setup(device["login_name"], device["ip"], device["port"]) + # create test dir + workspace = "cpu_evaluation_workspace" + ssh.cmd(["mkdir -p {}".format(workspace)]) + # copy load_and_run_file + ssh.copy([args.load_and_run_file], workspace) + # call test + result = [] + for m in test_cpu_models: + m_path = os.path.join(args.models_dir, m) + # copy model file + ssh.copy([m_path], workspace) + # run single thread + sub_b = ["-cpu", "-multithread {}".format(device["thread_number"])] + for b in sub_b: + cmd = [] + cmd0 = "cd {} && rm -rf fastrun.cache".format(workspace) + cmd1 = "cd {} && ./load_and_run {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format( + workspace, m, b + ) + cmd2 = "cd {} && ./load_and_run {} {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess --record-comp-seq".format( + workspace, m, b + ) + cmd.append(cmd0) + cmd.append(cmd1) + cmd.append(cmd2) + raw_log = ssh.cmd(cmd) + # logging.debug(raw_log) + ret = get_finally_bench_resulut_from_log(raw_log) + logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret)) + result.append(ret) + + total_time = 0.0 + for r in result: + total_time += r + logging.debug("total time is: {}".format(total_time)) + score = 100000.0 / total_time * 1000 + logging.debug("device: {} score is: {}".format(device["name"], score)) + + +if __name__ == "__main__": + LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" + DATE_FORMAT = "%Y/%m/%d %H:%M:%S" + logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT) + + main()