From 4adba378675f1db5e359713c8ca5ecec67450d74 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 3 Mar 2022 15:27:37 +0800 Subject: [PATCH] feat(lite): add example script and some small change for lar GitOrigin-RevId: a28ed2f27a8d9fe8fb1ed2960477aeffd35f1857 --- lite/load_and_run/examples/README.md | 15 + lite/load_and_run/examples/example.sh | 437 +++++++++++++++++++++ lite/load_and_run/examples/script/add_demo.py | 57 +++ lite/load_and_run/examples/script/conv_demo.py | 42 ++ lite/load_and_run/examples/script/dump_trt.py | 37 ++ lite/load_and_run/examples/script/gen_trt_model.sh | 33 ++ .../load_and_run/examples/script/mge_input_data.py | 33 ++ lite/load_and_run/examples/script/resnet50_mgb.py | 153 ++++++++ lite/load_and_run/examples/script/resnet50_mge.py | 351 +++++++++++++++++ lite/load_and_run/src/options/device_options.cpp | 4 +- .../src/options/extern_c_opr_options.cpp | 4 +- lite/load_and_run/src/options/io_options.cpp | 9 +- lite/load_and_run/src/options/optimize_options.cpp | 17 - lite/load_and_run/src/options/plugin_options.cpp | 54 ++- .../load_and_run/src/strategys/strategy_normal.cpp | 9 +- 15 files changed, 1223 insertions(+), 32 deletions(-) create mode 100644 lite/load_and_run/examples/README.md create mode 100755 lite/load_and_run/examples/example.sh create mode 100755 lite/load_and_run/examples/script/add_demo.py create mode 100755 lite/load_and_run/examples/script/conv_demo.py create mode 100644 lite/load_and_run/examples/script/dump_trt.py create mode 100755 lite/load_and_run/examples/script/gen_trt_model.sh create mode 100755 lite/load_and_run/examples/script/mge_input_data.py create mode 100755 lite/load_and_run/examples/script/resnet50_mgb.py create mode 100755 lite/load_and_run/examples/script/resnet50_mge.py diff --git a/lite/load_and_run/examples/README.md b/lite/load_and_run/examples/README.md new file mode 100644 index 00000000..080769c5 --- /dev/null +++ b/lite/load_and_run/examples/README.md @@ -0,0 +1,15 @@ +# load and run example scripts + +this is some typical examples for using of load and run, you can use + +```bash +./example.sh +``` + +to run those examples, more details is in the script + +the usage of script can be see by run this scirpt with: + +```bash +./example.sh -h +``` diff --git a/lite/load_and_run/examples/example.sh b/lite/load_and_run/examples/example.sh new file mode 100755 index 00000000..ba4a3c3d --- /dev/null +++ b/lite/load_and_run/examples/example.sh @@ -0,0 +1,437 @@ +#!/bin/bash +OLD_IFS="$IFS" +IFS=$'\n' +TYPICAL_MODEL_DATD="model_source/resnet50_b1_float32_without_data.mge --input \"data:input_data/resnet50_input.npy\"" +DEVICE_DESC="" +WORK_DIR_PATH="." +RUN_ARM_DEVICE="false" +RUN_TARGET="diff_model" +ONLY_PREPARE_MODEL="false" +MODEL_PREAPRED="false" +ONLY_BUILD="false" +LAR_BUILT="false" +CLEAN_ALL="false" + +RUN_TARGETS=("diff_model") +RUN_TARGETS+=("diff_device") +RUN_TARGETS+=("fast_run") +RUN_TARGETS+=("io") +RUN_TARGETS+=("layout") +RUN_TARGETS+=("optimize") +RUN_TARGETS+=("plugin") +RUN_TARGETS+=("all") + + +function usage() { + echo "$0 args1 args2 .." + echo "available args detail:" + echo "-p : prepare example model " + echo "-b : build load_and_run for x86/armv7/arm64 cpu and CUDA" + echo "-t : set the ssh arm device " + echo "-w : set the arm device workspace dir" + echo "-c : clean all" + echo "-a : run all test" + echo "-e : set the running target for test (details use \"-e\" to see)" + echo "-h : show usage" + exit -1 +} + + +while getopts "pbcahe:w:t:" arg +do + case $arg in + t) + DEVICE_DESC=$OPTARG + RUN_ARM_DEVICE="true" + echo "config arm device DEVICE_DESC to ${DEVICE_DESC}" + ;; + w) + WORK_DIR_PATH=$OPTARG + echo "config arm device WORK_DIR_PATH to ${WORK_DIR_PATH}" + ;; + e) + tmp_target=null + for target in ${RUN_TARGETS[@]}; do + if [ "$target" = "$OPTARG" ]; then + echo "CONFIG BUILD RUN_TARGET to : $OPTARG" + tmp_target=$OPTARG + RUN_TARGET=$OPTARG + break + fi + done + if [ "$tmp_target" = "null" ]; then + echo "ERR args for target (-e)" + echo "available target usage :" + for target in ${RUN_TARGETS[@]}; do + echo " -e $target" + done + exit -1 + fi + ;; + h) + echo "show usage" + usage + ;; + a) + echo "config RUN_TARGET=all" + RUN_TARGET="all" + ;; + c) + echo "clean all directory generated by script" + CLEAN_ALL="true" + + ;; + b) + echo "run build" + ONLY_BUILD="true" + ;; + p) + echo "prepare model and input" + ONLY_PREPARE_MODEL="true" + ;; + ?) + echo "unkonw argument" + usage + ;; + esac +done + + +function prepare_model_and_data(){ + rm -rf model_source && mkdir model_source + # dump mgb model + python3 script/resnet50_mgb.py -o model_source/resnet50.pkl + ../dump_with_testcase.py model_source/resnet50.pkl -o model_source/resnet50_with_data.mgb -d "#rand(0, 255)" --no-assert + + # prepare simple add model + python3 script/add_demo.py --dir model_source + python3 script/conv_demo.py --dir model_source + + #generate trt model + script/gen_trt_model.sh + + #prepare mge model + python3 script/resnet50_mge.py --dir model_source + python3 script/resnet50_mge.py --dir model_source -d int8 + python3 script/resnet50_mge.py --dir model_source --inputs "#rand(0,255)" + + #make input_data + rm -rf input_data && mkdir input_data + python3 script/mge_input_data.py + + rm -rf tmpdir && mkdir tmpdir +} + +function build_lar(){ + # build cpu and cuda version + ../../../scripts/cmake-build/host_build.sh -r -t -e load_and_run + #WARNING:config the cuda environment before compile + ../../../scripts/cmake-build/host_build.sh -c -t -e load_and_run + + # # build arm version + + ../../../scripts/cmake-build/cross_build_android_arm_inference.sh -r -a arm64-v8a -e load_and_run + ../../../scripts/cmake-build/cross_build_android_arm_inference.sh -r -a armeabi-v7a -e load_and_run + + + # link or for test + ln -s ../../../build_dir/host/MGE_WITH_CUDA_OFF/MGE_INFERENCE_ONLY_OFF/Release//build/lite/load_and_run/load_and_run lar_cpu + ln -s ../../../build_dir/host/MGE_WITH_CUDA_ON/MGE_INFERENCE_ONLY_OFF/Release//build/lite/load_and_run/load_and_run lar_cuda + cp ../../../build_dir/android/arm64-v8a/Release/build/lite/load_and_run/load_and_run ./lar_arm64 + cp ../../../build_dir/android/armeabi-v7a/Release/build/lite/load_and_run/load_and_run ./lar_armv7 + + + +} + +function set_arm_device_and_upload(){ + DEVICE_DESC="${1}" + WORK_DIR_PATH="${2}" + RUN_ARM_DEVICE="${3}" + cmd="rsync -aP -zz ./lar_arm64 ./lar_armv7 model_source/resnet50_b1_float32_without_data.mge input_data/resnet50_input.npy $DEVICE_DESC:$WORK_DIR_PATH/" + echo $cmd + bash -c "$cmd" +} + +function test_different_model(){ + CmdArray=("./lar_cpu model_source/resnet50_with_data.mgb") + CmdArray+=("./lar_cpu model_source/resnet50_b1_float32_with_data.mge") + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD") + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --lite") + for cmd in ${CmdArray[@]}; do + echo "$cmd" + bash -c "$cmd" + done +} + +function test_different_device(){ + #dispatch时,计算任务会加入一个工作队列,由队列统一管理执行 均值 131.278 ms 标准差 15.197ms m_asyc_exec异步执行 + CmdArray=("./lar_cpu $TYPICAL_MODEL_DATD --cpu") + #dispatch时,计算任务直接执行 均值 131.875 ms 标准差 7.758ms m_asyc_exec同步执行 + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --cpu-default") + #多线程运行 1~8平均运行时间(ms):129.611, 84.266, 76.963, 55.212, 69.283, 77.338, 58.386, 64.585 + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --multithread 4") + #主线程锁核,其他任务在线程池中的线程上运行 132.614, 83.095, 69.792, 54.452, 48.890, 48.206, 46.386, 53.908 + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --multithread-default 4") + #cpu多线程绑核(x86上绑核影响不大) + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --multithread 2 --multi-thread-core-ids 1,5") + + #xpu 设置为cpu上运行 132.740 ms comp_node:cpu + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cpu") + #xpu 设置为cuda上运行 6.495 ms comp_node:gpu + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cuda") + + for cmd in ${CmdArray[@]}; do + echo $cmd + bash -c "$cmd" + done + +function test_fast_run(){ + + CmdArray=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run") + + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --full-run") + #fast run 搜参 + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --fast-run-algo-policy tmpdir/algo_cache_file") + + #fast run 带参执行 + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run-algo-policy tmpdir/algo_cache_file") + + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --reproducible") + + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --fast-run-shared-batch-size 1") + + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --binary-equal-between-batch") + + for cmd in ${CmdArray[@]}; do + echo $cmd + bash -c "$cmd" + done +} + +function test_io(){ + rm -rf tmpdir/bin_io_info tmpdir/bin_out_info tmpdir/bin_out_info_cuda tmpdir/io_info.txt + mkdir tmpdir/bin_io_info tmpdir/bin_out_info tmpdir/bin_io_info_cuda + + CmdArray=("./lar_cpu $TYPICAL_MODEL_DATD") + + CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"input_data/add_demo_input.json\"") + #! the model must support input with nhwc shape + CmdArray+=("./lar_cpu model_source/resnet50_b1_int8_without_data.mge --input \"data:input_data/cat.ppm\"") + + CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\"") + + CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --io-dump tmpdir/io_info.txt --iter 1 --warmup-iter 0") + + CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --io-dump-stdout --iter 1 --warmup-iter 0") + + + + CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --io-dump-stderr --iter 1 --warmup-iter 0") + + #different data in the given directory the name is the var id which is the same with txt-dump information + CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --bin-io-dump tmpdir/bin_io_info --iter 1 --warmup-iter 0") + + CmdArray+=("./lar_cuda model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --bin-io-dump tmpdir/bin_io_info_cuda --iter 1 --warmup-iter 0") + + + CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --bin-out-dump tmpdir/bin_out_info --iter 1 --warmup-iter 0") + + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --copy-to-host") + + for cmd in ${CmdArray[@]}; do + echo $cmd + bash -c "$cmd" + done + + #compare the binary io information + python3 ../../../imperative/python/megengine/tools/compare_binary_iodump.py tmpdir/bin_io_info tmpdir/bin_io_info_cuda + +} + +function test_layout_related(){ + # very little speed up + CmdArray=("./lar_cuda $TYPICAL_MODEL_DATD --enable-nchw4") + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-chwn4") + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-nchw32") + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-nchw64") + + #speed up + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --enable-nchw88") + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cuda --layout-transform cuda") + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cuda --layout-transform cuda --layout-transform-dump model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cuda.mge") + CmdArray+=("./lar_cuda model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cuda.mge") + + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --cpu --layout-transform cpu") + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --cpu --layout-transform cpu --layout-transform-dump model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cpu.mge") + CmdArray+=("./lar_cpu model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cpu.mge") + + for cmd in ${CmdArray[@]}; do + echo $cmd + bash -c "$cmd" + done + +if [ ${RUN_ARM_DEVICE} == "true" ] ; then + #speed up + CmdArray=("./lar_arm64 resnet50_b1_float32_without_data.mge --input \"data:resnet50_input.npy\" --cpu --enable-nchw44") + + #speed up + CmdArray+=("./lar_arm64 resnet50_b1_float32_without_data.mge --input \"data:resnet50_input.npy\" --cpu --enable-nchw44-dot") + + for cmd in ${CmdArray[@]}; do + echo $cmd + ssh -t $DEVICE_DESC "unset LD_PRELOAD && cd $WORK_DIR_PATH && LD_LIBRARY_PATH=./ $cmd" + done +else + echo "SET arm device ON : $RUN_ARM_DEVICE" +fi + +} + +function test_optimize(){ + CmdArray=("./lar_cpu $TYPICAL_MODEL_DATD --enable-fuse-preprocess") + #warm up speed up + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-fuse-conv-bias-nonlinearity") + + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-fuse-conv-bias-with-z") + + CmdArray+=("./lar_cuda model_source/trt_conv_demo_with_data.mgb --tensorrt") + + CmdArray+=("./lar_cuda model_source/trt_conv_demo_with_data.mgb --tensorrt --tensorrt-cache tmpdir/TRT_cache") + + CmdArray+=("./lar_cuda model_source/trt_conv_demo_with_data.mgb --tensorrt-cache tmpdir/TRT_cache") + + + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --no-sanity-check --record-comp-seq2") + + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --disable_mem_opt") + + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --workspace_limit 10000") + + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --fake-first") + + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable_jit ") + + for cmd in ${CmdArray[@]}; do + echo $cmd + bash -c "$cmd" + done + +} + +function test_plugin(){ + + rm -rf tmpdir/staticMemInfoDir tmpdir/staticMemInfoDirLogs + mkdir tmpdir/staticMemInfoDir + + CmdArray=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --check-dispatch") + + CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --check-var-value 5:0") + + CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --range 2") + + CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --profile tmpdir/opr_profile.json + + ") + CmdArray+=("./lar_cuda model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --profile-host tmpdir/opr_profile_host.json") + + + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --model-info") + + CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --verbose") + + CmdArray+=("./lar_cpu model_source/resnet50_with_data.mgb --disable-assert-throw") + + # wait gdb attach to given PID + # CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --wait-gdb") + + CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --get-static-mem-info tmpdir/staticMemInfoDir") + + for cmd in ${CmdArray[@]}; do + echo $cmd + bash -c "$cmd" + done + + # view the graph with given url (usally: http://localhost:6006/) + # mkdir tmpdir/staticMemInfoDirLogs && python3 ../../../imperative/python/megengine/tools/graph_info_analyze.py -i tmpdir/staticMemInfoDir -o tmpdir/staticMemInfoDirLogs + # pip3 install tensorboard && tensorboard --logdir tmpdir/staticMemInfoDirLogs +} + +function clean(){ + rm -rf tmpdir model_source input_data lar_cpu lar_cuda lar_arm64 lar_armv7 + +} + +function main(){ + +if [ ${CLEAN_ALL} == "true" ] ; then + clean + exit 0 +fi + +if [ ${ONLY_PREPARE_MODEL} == "true" ] ; then + prepare_model_and_data + MODEL_PREAPRED="true" + exit 0 +fi + +if [ ${ONLY_BUILD} == "true" ] ; then + build_lar + LAR_BUILT="true" + exit 0 +fi + +if [ ${RUN_ARM_DEVICE} == "true" ] ; then + set_arm_device_and_upload $DEVICE_DESC "$WORK_DIR_PATH" "true" +fi + +if [ ${MODEL_PREAPRED} != "true" ] ; then + CHECK_MODEL=$(find . -name add_demo_input.json) + if [ ${CHECK_MODEL} == "" ] ; then + prepare_model_and_data + MODEL_PREAPRED="true" + fi +fi + +if [ ${LAR_BUILT} != "true" ] ; then + CHECK_LAR=$(find . -name lar_armv7) + if [ ${CHECK_LAR} == "" ] ; then + build_lar + LAR_BUILT="true" + fi +fi + +if [ ${RUN_TARGET} == "diff_model" -o ${RUN_TARGET} == "all" ] ; then + test_different_model +fi + +if [ ${RUN_TARGET} == "diff_device" -o ${RUN_TARGET} == "all" ] ; then + test_different_device +fi + +if [ ${RUN_TARGET} == "fast_run" -o ${RUN_TARGET} == "all" ] ; then + test_fast_run +fi + +if [ ${RUN_TARGET} == "io" -o ${RUN_TARGET} == "all" ] ; then + test_io +fi + +if [ ${RUN_TARGET} == "layout" -o ${RUN_TARGET} == "all" ] ; then + test_layout_related +fi + +if [ ${RUN_TARGET} == "optimize" -o ${RUN_TARGET} == "all" ] ; then + test_optimize +fi + +if [ ${RUN_TARGET} == "plugin" -o ${RUN_TARGET} == "all" ] ; then + test_plugin +fi + + + +} + +main +IFS=$OLD_IFS \ No newline at end of file diff --git a/lite/load_and_run/examples/script/add_demo.py b/lite/load_and_run/examples/script/add_demo.py new file mode 100755 index 00000000..d0afe5ce --- /dev/null +++ b/lite/load_and_run/examples/script/add_demo.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +import argparse +import math + +import megengine.functional as F +import megengine.hub as hub +import megengine.module as M +import numpy as np +from megengine import jit, tensor, Parameter + +class Simple(M.Module): + def __init__(self): + super().__init__() + self.a = Parameter([0,1,2], dtype=np.float32) + + def forward(self, x): + x = x + self.a + return x + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="dump mge model for add_demo", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--inputs", + help="set the inputs data to get a model with testcase", + default="", + type=str, + ) + + parser.add_argument( + "--dir", + help="set the dir where the model to dump", + default=".", + type=str, + ) + + args = parser.parse_args() + net = Simple() + net.eval() + + @jit.trace(symbolic=True, capture_as_const=True) + def fun(data): + return net(data) + data = tensor([3,4,5]) + fun(data) + if args.inputs == "": + fun.dump( + args.dir + "/add_demo_f32_without_data.mge", arg_names=["data"], + no_assert=True, + ) + else: + fun.dump( + args.dir + "/add_demo_f32_with_data.mge", arg_names=["data"], + input_data=[args.inputs], no_assert=True, + ) \ No newline at end of file diff --git a/lite/load_and_run/examples/script/conv_demo.py b/lite/load_and_run/examples/script/conv_demo.py new file mode 100755 index 00000000..65898cd3 --- /dev/null +++ b/lite/load_and_run/examples/script/conv_demo.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +import argparse +import math + +import megengine.functional as F +import megengine.module as M +import numpy as np +from megengine import jit, tensor +class ConvNet(M.Module): + def __init__(self): + + super().__init__() + self.conv1 = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3, bias=False) + + def forward(self, input): + x = self.conv1(input) + return x + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="dump mge model for add_demo", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--dir", + help="set the dir where the model to dump", + default=".", + type=str, + ) + + args = parser.parse_args() + net = ConvNet() + net.eval() + + @jit.trace(symbolic=True, capture_as_const=True) + def fun(data): + return net(data) + inp = tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4)) + out = fun(inp) + fun.dump(args.dir + "/conv_demo_f32_without_data.mge", arg_names=["data"], + no_assert=True) \ No newline at end of file diff --git a/lite/load_and_run/examples/script/dump_trt.py b/lite/load_and_run/examples/script/dump_trt.py new file mode 100644 index 00000000..bdfcd643 --- /dev/null +++ b/lite/load_and_run/examples/script/dump_trt.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +from megskull.network import RawNetworkBuilder +import megskull.opr.all as O +from megskull.opr.external import TensorRTRuntimeOpr +from meghair.utils.io import dump +import argparse + +def str2tuple(x): + x = x.split(',') + x = [int(a) for a in x] + x = tuple(x) + return x + +def make_network(model, isize): + data = [O.DataProvider('input{}'.format(i), shape=isizes[i]) + for i in range(len(isizes))] + f = open(model, 'rb') + engine = f.read() + + opr = TensorRTRuntimeOpr(data, engine, 1) + + net = RawNetworkBuilder(inputs=[data], outputs=opr.outputs) + + return net + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument(dest = 'model') + parser.add_argument(dest = 'output') + parser.add_argument('--isize', help='input sizes. ' + 'e.g. for models with two (1,3,224,224) inputs, ' + 'the option --isize="1,3,224,224;1,3,224,224" should be used') + + args = parser.parse_args() + isizes = [str2tuple(x) for x in args.isize.split(';')] + net = make_network(args.model, isizes) + dump(net, args.output) \ No newline at end of file diff --git a/lite/load_and_run/examples/script/gen_trt_model.sh b/lite/load_and_run/examples/script/gen_trt_model.sh new file mode 100755 index 00000000..f6533bd4 --- /dev/null +++ b/lite/load_and_run/examples/script/gen_trt_model.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +CUR_DIR="$( cd "$(dirname $0)" >/dev/null 2>&1 ; pwd -P )" + +# find correct trtexec version, only works for internal ci and brainpp env setups +CUDA_VERSION=$(nvcc --version | grep -o "[0-9].\.[0-9]*" | head -n 1) +SEARCH_PATH=$(echo `which nvcc | xargs dirname`/../../) +TRT_CANDIDATE=$(find `cd $SEARCH_PATH; pwd` -name "trtexec" | grep "bin/trtexec" | grep $CUDA_VERSION) +TRT_CANDIDATE=${TRT_CANDIDATE%$'\n'*} +TRT_LIB_PATH=$(readlink -f "`dirname $TRT_CANDIDATE`/../lib") +MODELS_PATH=$(readlink -f "${CUR_DIR}/../model_source") + +# generate mge model +rm -rf $MODELS_PATH/conv_demo_f32_without_data.mge +python3 ${CUR_DIR}/conv_demo.py --dir $MODELS_PATH + +# generate caffe model with mge convert +# INSTALL mgeconvert: +# python3 -m pip install git+https://github.com/MegEngine/mgeconvert.git --user --install-option="--targets=caffe" +rm -rf $MODELS_PATH/conv_demo.prototxt $MODELS_PATH/conv_demo.caffemodel +convert mge_to_caffe -i $MODELS_PATH/conv_demo_f32_without_data.mge -c $MODELS_PATH/conv_demo.prototxt -b $MODELS_PATH/conv_demo.caffemodel + +# generate trt model +rm -rf $MODELS_PATH/conv_demo.trt +echo "WARNING: config cudnn and cublas path before run trtexec" +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_LIB_PATH +echo $LD_LIBRARY_PATH +$TRT_CANDIDATE --deploy="$MODELS_PATH/conv_demo.prototxt" --model="$MODELS_PATH/conv_demo.caffemodel" --output="conv1_Convolution" --batch=1 --saveEngine="$MODELS_PATH/conv_demo.trt" + +# redump trt model into mgb model +rm -rf $MODELS_PATH/trt_conv_demo.pkl $MODELS_PATH/trt_conv_demo_with_data.mgb +python3 $CUR_DIR/dump_trt.py $MODELS_PATH/conv_demo.trt $MODELS_PATH/trt_conv_demo.pkl --isize="1,3,4,4" +$CUR_DIR/../../dump_with_testcase.py $MODELS_PATH/trt_conv_demo.pkl -o $MODELS_PATH/trt_conv_demo_with_data.mgb -d "#rand(0, 255)" --no-assert \ No newline at end of file diff --git a/lite/load_and_run/examples/script/mge_input_data.py b/lite/load_and_run/examples/script/mge_input_data.py new file mode 100755 index 00000000..10a38817 --- /dev/null +++ b/lite/load_and_run/examples/script/mge_input_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +import numpy as np +import cv2 +import megengine.data.transform as T +import megengine.functional as F +import json +import urllib +url, filename = ("https://data.megengine.org.cn/images/cat.jpg", "input_data/cat.jpg") +try: urllib.URLopener().retrieve(url, filename) +except: urllib.request.urlretrieve(url, filename) + +# numpy data +data=np.random.rand(1,3,224,224) +np.save("input_data/resnet50_input_uint8.npy",data.astype(np.uint8)) +np.save("input_data/resnet50_input.npy",data.astype(np.float32)) + +#ppm data +image = cv2.imread("input_data/cat.jpg") +transform = T.Compose([ + T.Resize(256), + T.CenterCrop(224), +]) +processed_img = transform.apply(image) +cv2.imwrite("input_data/cat.ppm",processed_img) + +#json +data_obj = { + "shape": [1,3], + "type": "int32", + "raw": [2,3,4] +} +with open("input_data/add_demo_input.json", "w") as f: + json.dump({"data":data_obj},f) \ No newline at end of file diff --git a/lite/load_and_run/examples/script/resnet50_mgb.py b/lite/load_and_run/examples/script/resnet50_mgb.py new file mode 100755 index 00000000..be05610b --- /dev/null +++ b/lite/load_and_run/examples/script/resnet50_mgb.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 + +import megbrain as mgb +from megskull.graph import FpropEnv +import megskull as mgsk +from megskull.opr.compatible.caffepool import CaffePooling2D +from megskull.opr.arith import ReLU +from megskull.opr.all import ( + DataProvider, Conv2D, Pooling2D, FullyConnected, + Softmax, Dropout, BatchNormalization, CrossEntropyLoss, + ElementwiseAffine, WarpPerspective, WarpPerspectiveWeightProducer, + WeightDecay, ParamProvider, ConvBiasActivation, ElemwiseMultiType) +from megskull.network import RawNetworkBuilder +from megskull.utils.debug import CallbackInjector +import megskull.opr.helper.param_init as pinit +from megskull.opr.helper.elemwise_trans import Identity +from megskull.opr.netsrc import DataProvider +from megskull.opr.cnn import Conv2D, Pooling2D, FullyConnected, Softmax, Conv2DImplHelper +from megskull.opr.loss import CrossEntropyLoss +from megskull.opr.regularizer import Dropout, BatchNormalization + +from megskull.opr.arith import Add, ReLU +from megskull.opr.netsrc import ConstProvider +from megskull.network import RawNetworkBuilder +import numpy as np +from megskull.network import RawNetworkBuilder, NetworkVisitor +from megskull.graph import iter_dep_opr +from megskull.utils.misc import get_2dshape +import functools +import re +import fnmatch +import argparse +import sys + +def create_bn_relu_float(conv_name, f_in, ksize, stride, pad, num_outputs, + has_relu, args): + f = Conv2D(conv_name, f_in, kernel_shape=ksize, stride=stride, + padding=pad, output_nr_channel=num_outputs, + nonlinearity=mgsk.opr.helper.elemwise_trans.Identity()) + if has_relu: + f = ReLU(f) + return f + + +def get_num_inputs(feature, format): + if format == 'NCHW': + return feature.partial_shape[1] + else: + assert format == 'NCHW4' + return feature.partial_shape[1] * 4 + + +def create_bn_relu(prefix, f_in, ksize, stride, pad, num_outputs, + has_relu, conv_name_fun, args): + if conv_name_fun: + conv_name = conv_name_fun(prefix) + else: + conv_name = prefix + return create_bn_relu_float(conv_name, f_in, ksize, stride, pad, + num_outputs, has_relu, args) + + + + +def create_bottleneck(prefix, f_in, stride, num_outputs1, num_outputs2, args, + has_proj=False): + proj = f_in + if has_proj: + proj = create_bn_relu(prefix, f_in, ksize=1, stride=stride, pad=0, + num_outputs=num_outputs2, has_relu=False, + conv_name_fun=lambda p: "interstellar{}_branch1".format( + p), args=args) + + f = create_bn_relu(prefix, f_in, ksize=1, stride=1, pad=0, + num_outputs=num_outputs1, has_relu=True, + conv_name_fun=lambda p: "interstellar{}_branch2a".format( + p), args=args) + + f = create_bn_relu(prefix, f, ksize=3, stride=stride, pad=1, + num_outputs=num_outputs1, has_relu=True, + conv_name_fun=lambda p: "interstellar{}_branch2b".format( + p), args=args) + + f = create_bn_relu(prefix, f, ksize=1, stride=1, pad=0, + num_outputs=num_outputs2, has_relu=False, + conv_name_fun=lambda p: "interstellar{}_branch2c".format( + p), args=args) + + f = ReLU(f + proj) + + return f + + +def get(args): + img_size = 224 + num_inputs = 3 + data = DataProvider('data', shape=(args.batch_size, num_inputs, + img_size, img_size)) + + inp = data + f = create_bn_relu("conv1", inp, ksize=7, stride=2, pad=3, num_outputs=64, + has_relu=True, conv_name_fun=None, + args=args) + f = Pooling2D("pool1", f, window=3, stride=2, padding=1, mode="MAX", + format=args.format) + + pre = [2, 3, 4, 5] + stages = [3, 4, 6, 3] + mid_outputs = [64, 128, 256, 512] + enable_stride = [False, True, True, True] + + for p, s, o, es in zip(pre, stages, mid_outputs, enable_stride): + for i in range(s): + has_proj = False if i > 0 else True + stride = 1 if not es or i > 0 else 2 + prefix = "{}{}".format(p, chr(ord("a") + i)) + f = create_bottleneck(prefix, f, stride, o, o * 4, args, has_proj) + print("{}\t{}".format(prefix, f.partial_shape)) + + f = Pooling2D("pool5", f, window=7, stride=7, padding=0, mode="AVERAGE", + format=args.format) + + f = FullyConnected("fc1000", f, output_dim=1000, + nonlinearity=mgsk.opr.helper.elemwise_trans.Identity()) + + f = Softmax("cls_softmax", f) + f.init_weights() + + net = RawNetworkBuilder(inputs=[data], outputs=[f]) + + return net + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='dump pkl model for resnet50', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('-b', '--batch-size', + help='batch size of the model', default=1) + parser.add_argument('-f', '--format', choices=['NCHW', 'NCHW4'], + help='format of conv', + default='NCHW') + parser.add_argument('-o', '--output', + help='output pkl path', required=True) + args = parser.parse_args() + if args.format != 'NCHW': + print('Only suppprt NCHW for float model') + parser.print_help() + sys.exit(1) + + from meghair.utils import io + io.dump(get(args), args.output) + diff --git a/lite/load_and_run/examples/script/resnet50_mge.py b/lite/load_and_run/examples/script/resnet50_mge.py new file mode 100755 index 00000000..b9555cdb --- /dev/null +++ b/lite/load_and_run/examples/script/resnet50_mge.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python3 +import argparse +import math + +import megengine.functional as F +import megengine.hub as hub +import megengine.module as M +import numpy as np +from megengine import jit, tensor + + +class BasicBlock(M.Module): + expansion = 1 + + def __init__( + self, + in_channels, + channels, + stride=1, + groups=1, + base_width=64, + dilation=1, + norm=M.BatchNorm2d, + ): + super().__init__() + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.conv1 = M.Conv2d( + in_channels, channels, 3, stride, padding=dilation, bias=False + ) + self.bn1 = norm(channels) + self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False) + self.bn2 = norm(channels) + self.downsample = ( + M.Identity() + if in_channels == channels and stride == 1 + else M.Sequential( + M.Conv2d(in_channels, channels, 1, stride, bias=False), norm(channels), + ) + ) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.bn2(x) + identity = self.downsample(identity) + x += identity + x = F.relu(x) + return x + + +class Bottleneck(M.Module): + expansion = 4 + + def __init__( + self, + in_channels, + channels, + stride=1, + groups=1, + base_width=64, + dilation=1, + norm=M.BatchNorm2d, + ): + super().__init__() + width = int(channels * (base_width / 64.0)) * groups + self.conv1 = M.Conv2d(in_channels, width, 1, 1, bias=False) + self.bn1 = norm(width) + self.conv2 = M.Conv2d( + width, + width, + 3, + stride, + padding=dilation, + groups=groups, + dilation=dilation, + bias=False, + ) + self.bn2 = norm(width) + self.conv3 = M.Conv2d(width, channels * self.expansion, 1, 1, bias=False) + self.bn3 = norm(channels * self.expansion) + self.downsample = ( + M.Identity() + if in_channels == channels * self.expansion and stride == 1 + else M.Sequential( + M.Conv2d(in_channels, channels * self.expansion, 1, stride, bias=False), + norm(channels * self.expansion), + ) + ) + + def forward(self, x): + identity = x + + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + + x = self.conv3(x) + x = self.bn3(x) + + identity = self.downsample(identity) + + x += identity + x = F.relu(x) + + return x + + +class ResNet(M.Module): + def __init__( + self, + block, + layers, + num_classes=1000, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm=M.BatchNorm2d, + ): + super().__init__() + self.in_channels = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = M.Conv2d( + 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm(self.in_channels) + self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=0) + self.layer1 = self._make_layer(block, 64, layers[0], norm=norm) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0], + norm=norm, + ) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1], + norm=norm, + ) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2], + norm=norm, + ) + self.fc = M.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, M.Conv2d): + M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + M.init.uniform_(m.bias, -bound, bound) + elif isinstance(m, M.BatchNorm2d): + M.init.ones_(m.weight) + M.init.zeros_(m.bias) + elif isinstance(m, M.Linear): + M.init.msra_uniform_(m.weight, a=math.sqrt(5)) + if m.bias is not None: + fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + M.init.uniform_(m.bias, -bound, bound) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block + # behaves like an identity. According to https://arxiv.org/abs/1706.02677 + # This improves the model by 0.2~0.3%. + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + M.init.zeros_(m.bn3.weight) + elif isinstance(m, BasicBlock): + M.init.zeros_(m.bn2.weight) + + def _make_layer( + self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d + ): + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + + layers = [] + layers.append( + block( + self.in_channels, + channels, + stride, + groups=self.groups, + base_width=self.base_width, + dilation=previous_dilation, + norm=norm, + ) + ) + self.in_channels = channels * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.in_channels, + channels, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm=norm, + ) + ) + + return M.Sequential(*layers) + + def extract_features(self, x): + outputs = {} + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.maxpool(x) + outputs["stem"] = x + + x = self.layer1(x) + outputs["res2"] = x + x = self.layer2(x) + outputs["res3"] = x + x = self.layer3(x) + outputs["res4"] = x + x = self.layer4(x) + outputs["res5"] = x + return outputs + + def forward(self, x): + x = F.reshape(x, (1,3,224,224)) + x = self.extract_features(x)["res5"] + + x = F.avg_pool2d(x, 7) + x = F.flatten(x, 1) + x = self.fc(x) + + return x + + + +@hub.pretrained( + "https://data.megengine.org.cn/models/weights/resnet50_fbaug_76254_4e14b7d1.pkl" +) +def resnet50(**kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + """ + return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="dump mge model for resnet50", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-b", "--batch-size", help="batch size of the model", default=1, type=int + ) + parser.add_argument( + "-d", + "--dtype", + help="the dtype of the model,which includes float32 and uint8", + default="float32", + type=str, + ) + parser.add_argument( + "--inputs", + help="set the inputs data to get a model with testcase", + default="", + type=str, + ) + + parser.add_argument( + "--dir", + help="set the dir where the model to dump", + default=".", + type=str, + ) + + parser.add_argument( + "--enable-nchw4", + help="enable-nchw4 for NVIDIA CUDNN", + action='store_true' + ) + + parser.add_argument( + "--enable-chwn4", + help="enable-chwn4 for NVIDIA CUDNN", + action='store_true' + ) + + args = parser.parse_args() + net = resnet50() + net.eval() + + @jit.trace(symbolic=True, capture_as_const=True) + def fun(data): + return net(data) + + if args.dtype == "float32" or args.dtype == "uint8": + # dump float32 + data_type=np.float32 + if args.dtype == "uint8": + data_type =np.uint8 + data = tensor( + (np.random.random([args.batch_size, 3, 224, 224])*255).astype(data_type) + ) + fun(data) + if args.inputs == "": + fun.dump( + + args.dir + "/resnet50_b" + str(args.batch_size) + "_"+ args.dtype +"_without_data.mge", arg_names=["data"], + no_assert=True, enable_nchw4=args.enable_nchw4, enable_chwn4=args.enable_chwn4, + + ) + else: + fun.dump( + args.dir + "/resnet50_b" + str(args.batch_size) + "_"+ args.dtype +"_with_data.mge", arg_names=["data"], + input_data=[args.inputs], no_assert=True, enable_nchw4=args.enable_nchw4, + ) + + else: + raise TypeError("dtype should be float32") \ No newline at end of file diff --git a/lite/load_and_run/src/options/device_options.cpp b/lite/load_and_run/src/options/device_options.cpp index c0832b81..71107c88 100644 --- a/lite/load_and_run/src/options/device_options.cpp +++ b/lite/load_and_run/src/options/device_options.cpp @@ -148,7 +148,9 @@ XPUDeviceOption::XPUDeviceOption() { } if (!FLAGS_multi_thread_core_ids.empty()) { - mgb_assert(enable_multithread, "core ids should be set after --multithread"); + mgb_assert( + enable_multithread || enable_multithread_default, + "core ids should be set after --multithread or --multithread-default"); std::stringstream id_stream(FLAGS_multi_thread_core_ids); std::string id; size_t thread_cnt = 0; diff --git a/lite/load_and_run/src/options/extern_c_opr_options.cpp b/lite/load_and_run/src/options/extern_c_opr_options.cpp index d7131cdc..66915af4 100644 --- a/lite/load_and_run/src/options/extern_c_opr_options.cpp +++ b/lite/load_and_run/src/options/extern_c_opr_options.cpp @@ -126,7 +126,9 @@ void COprLibOption::init_extern_param(std::shared_ptr model_ptr) { void COprLibOption::load_lib() { auto handle = dlopen(lib_path.c_str(), RTLD_LAZY); - mgb_assert(handle, "failed to open c opr lib %s: %s", lib_path.c_str(), dlerror()); + mgb_assert( + handle, "failed to open c opr lib %s:\n errmsg: %s", lib_path.c_str(), + dlerror()); const char* entry = MGB_C_OPR_INIT_FUNC_STR; auto func = dlsym(handle, entry); diff --git a/lite/load_and_run/src/options/io_options.cpp b/lite/load_and_run/src/options/io_options.cpp index 72085eff..855f9138 100644 --- a/lite/load_and_run/src/options/io_options.cpp +++ b/lite/load_and_run/src/options/io_options.cpp @@ -286,8 +286,13 @@ DEFINE_string( DEFINE_string(io_dump, "", "set the io dump file path in text format"); DEFINE_bool(io_dump_stdout, false, "dump io opr to stdout in text format"); DEFINE_bool(io_dump_stderr, false, "dump io opr to stderr in text format"); -DEFINE_string(bin_io_dump, "", "set the io dump file path in binary format"); -DEFINE_string(bin_out_dump, "", "set the out dump file path in binary format"); +DEFINE_string( + bin_io_dump, "", + "set the io dump directory path where variable in binary format located"); +DEFINE_string( + bin_out_dump, "", + "set the out dump directory path where output variable in binary format " + "located"); DEFINE_bool(copy_to_host, false, "copy device data to host"); REGIST_OPTION_CREATOR(input, lar::InputOption::create_option); diff --git a/lite/load_and_run/src/options/optimize_options.cpp b/lite/load_and_run/src/options/optimize_options.cpp index f600aa13..ce2f2a92 100644 --- a/lite/load_and_run/src/options/optimize_options.cpp +++ b/lite/load_and_run/src/options/optimize_options.cpp @@ -77,12 +77,6 @@ void WeightPreprocessOption::config_model_internel( if (weight_preprocess) { LITE_WARN("enable weight-preprocess optimization"); model->get_config().options.weight_preprocess = true; - //! FIXME: algo searcher enable weight preprocess for opencl( - //! implement below has some problem); - // #if MGB_OPENCL - // megdnn::opencl::algo_searcher::AlgoSearcherBase:: - // enable_weight_preprocess(); - // #endif } } } @@ -95,11 +89,6 @@ void WeightPreprocessOption::config_model_internel( if (weight_preprocess) { mgb_log_warn("enable weight-preprocess optimization"); graph_option.graph_opt.enable_weight_preprocess(); - //! FIXME: this implemment is not right - // #if MGB_OPENCL - // megdnn::opencl::algo_searcher::AlgoSearcherBase:: - // enable_weight_preprocess(); - // #endif } } } @@ -570,12 +559,6 @@ DEFINE_bool( enable_jit, false, " Execute supported operators with JIT(now only support NVRTC). " "Can only be used on Nvidia GPUs"); -#if MGB_ENABLE_ANDROID_NN -DEFINE_bool( - android_nn, false, - "Execute supported operators with Android NN. Can only be used " - "with --cpu."); -#endif #if MGB_ENABLE_TENSOR_RT DEFINE_bool( tensorrt, false, diff --git a/lite/load_and_run/src/options/plugin_options.cpp b/lite/load_and_run/src/options/plugin_options.cpp index a3d622a6..2c244b92 100644 --- a/lite/load_and_run/src/options/plugin_options.cpp +++ b/lite/load_and_run/src/options/plugin_options.cpp @@ -8,10 +8,10 @@ */ #include "plugin_options.h" +#include #include "misc.h" #include "models/model_lite.h" #include "models/model_mdl.h" - ///////////////////// Plugin options/////////////////////////// namespace lar { @@ -153,7 +153,12 @@ void DebugOption::format_and_print( auto table = mgb::TextTable(tablename); auto&& network = model->get_lite_network(); table.padding(1); - table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); + table.align(mgb::TextTable::Align::Mid) + .add("type") + .add("name") + .add("shape") + .add("dtype") + .eor(); auto to_string = [&](lite::Layout& layout) { std::string shape("{"); @@ -165,6 +170,19 @@ void DebugOption::format_and_print( shape.append("}"); return shape; }; + auto get_dtype = [&](lite::Layout& layout) { + std::map type_map = { + {LiteDataType::LITE_FLOAT, "float32"}, + {LiteDataType::LITE_HALF, "float16"}, + {LiteDataType::LITE_INT64, "int64"}, + {LiteDataType::LITE_INT, "int32"}, + {LiteDataType::LITE_UINT, "uint32"}, + {LiteDataType::LITE_INT16, "int16"}, + {LiteDataType::LITE_UINT16, "uint16"}, + {LiteDataType::LITE_INT8, "int8"}, + {LiteDataType::LITE_UINT8, "uint8"}}; + return type_map[layout.data_type]; + }; auto input_name = network->get_all_input_name(); for (auto& i : input_name) { @@ -173,6 +191,7 @@ void DebugOption::format_and_print( .add("INPUT") .add(i) .add(to_string(layout)) + .add(get_dtype(layout)) .eor(); } @@ -183,6 +202,7 @@ void DebugOption::format_and_print( .add("OUTPUT") .add(i) .add(to_string(layout)) + .add(get_dtype(layout)) .eor(); } @@ -196,13 +216,28 @@ void DebugOption::format_and_print( const std::string& tablename, std::shared_ptr model) { auto table = mgb::TextTable(tablename); table.padding(1); - table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); - + table.align(mgb::TextTable::Align::Mid) + .add("type") + .add("name") + .add("shape") + .add("dtype") + .eor(); + auto get_dtype = [&](megdnn::DType data_type) { + std::map type_map = { + {mgb::dtype::Float32().enumv(), "float32"}, + {mgb::dtype::Int32().enumv(), "int32"}, + {mgb::dtype::Int16().enumv(), "int16"}, + {mgb::dtype::Uint16().enumv(), "uint16"}, + {mgb::dtype::Int8().enumv(), "int8"}, + {mgb::dtype::Uint8().enumv(), "uint8"}}; + return type_map[data_type.enumv()]; + }; for (auto&& i : model->get_mdl_load_result().tensor_map) { table.align(mgb::TextTable::Align::Mid) .add("INPUT") .add(i.first) .add(i.second->shape().to_string()) + .add(get_dtype(i.second->dtype())) .eor(); } @@ -211,6 +246,7 @@ void DebugOption::format_and_print( .add("OUTPUT") .add(i.node()->name()) .add(i.shape().to_string()) + .add(get_dtype(i.dtype())) .eor(); } @@ -358,18 +394,22 @@ DEFINE_double( DEFINE_bool( check_dispatch, false, - "check whether an operator call dispatch on cpu comp nodes"); + "check whether an operator call dispatch on cpu comp nodes This is used to " + "find potential bugs in MegDNN"); DEFINE_string( check_var_value, "", "--check-var-value [interval]|[interval:init_idx], Enable " - "VarValueChecker plugin. Refer to its doc for more details"); + "VarValueChecker plugin. check values of all vars in a graph from given var " + "ID(init_idx) with step interval"); #if MGB_ENABLE_JSON DEFINE_string( profile, "", "Write profiling result to given file. The output file is in " "JSON format"); -DEFINE_string(profile_host, "", "focus on host time profiling For some backends"); +DEFINE_string( + profile_host, "", + "focus on host time profiling For some backends(such as openCL)"); #endif ///////////////////// Debug gflags/////////////////////////// diff --git a/lite/load_and_run/src/strategys/strategy_normal.cpp b/lite/load_and_run/src/strategys/strategy_normal.cpp index 591be03d..e5f5170f 100644 --- a/lite/load_and_run/src/strategys/strategy_normal.cpp +++ b/lite/load_and_run/src/strategys/strategy_normal.cpp @@ -94,15 +94,16 @@ void NormalStrategy::run_subline() { m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; stage_config_model(); auto cur = timer.get_msecs(); - printf("iter %lu/%lu: %.3fms (exec=%.3fms)\n", i, run_num, cur, exec_time); + printf("iter %lu/%lu: %.3f ms (exec=%.3f ms)\n", i, run_num, cur, + exec_time); time_sum += cur; time_sqrsum += cur * cur; fflush(stdout); min_time = std::min(min_time, cur); max_time = std::max(max_time, cur); } - printf("\n=== finished test #%u: time=%.3fms avg_time=%.3fms " - "sexec=%.3fms min=%.3fms max=%.3fms\n\n", + printf("\n=== finished test #%u: time=%.3f ms avg_time=%.3f ms " + "sexec=%.3f ms min=%.3f ms max=%.3f ms\n\n", idx, time_sum, time_sum / run_num, std::sqrt( (time_sqrsum * run_num - time_sum * time_sum) / @@ -121,7 +122,7 @@ void NormalStrategy::run_subline() { m_runtime_param.stage = RunStage::MODEL_RUNNING; stage_config_model(); - if (!idx) { + if (idx == 0) { warm_up(); } tot_time += run_iter(idx);