GitOrigin-RevId: a28ed2f27a
tags/v1.9.0
@@ -0,0 +1,15 @@ | |||||
# load and run example scripts | |||||
this is some typical examples for using of load and run, you can use | |||||
```bash | |||||
./example.sh | |||||
``` | |||||
to run those examples, more details is in the script | |||||
the usage of script can be see by run this scirpt with: | |||||
```bash | |||||
./example.sh -h | |||||
``` |
@@ -0,0 +1,437 @@ | |||||
#!/bin/bash | |||||
OLD_IFS="$IFS" | |||||
IFS=$'\n' | |||||
TYPICAL_MODEL_DATD="model_source/resnet50_b1_float32_without_data.mge --input \"data:input_data/resnet50_input.npy\"" | |||||
DEVICE_DESC="" | |||||
WORK_DIR_PATH="." | |||||
RUN_ARM_DEVICE="false" | |||||
RUN_TARGET="diff_model" | |||||
ONLY_PREPARE_MODEL="false" | |||||
MODEL_PREAPRED="false" | |||||
ONLY_BUILD="false" | |||||
LAR_BUILT="false" | |||||
CLEAN_ALL="false" | |||||
RUN_TARGETS=("diff_model") | |||||
RUN_TARGETS+=("diff_device") | |||||
RUN_TARGETS+=("fast_run") | |||||
RUN_TARGETS+=("io") | |||||
RUN_TARGETS+=("layout") | |||||
RUN_TARGETS+=("optimize") | |||||
RUN_TARGETS+=("plugin") | |||||
RUN_TARGETS+=("all") | |||||
function usage() { | |||||
echo "$0 args1 args2 .." | |||||
echo "available args detail:" | |||||
echo "-p : prepare example model " | |||||
echo "-b : build load_and_run for x86/armv7/arm64 cpu and CUDA" | |||||
echo "-t : set the ssh arm device " | |||||
echo "-w : set the arm device workspace dir" | |||||
echo "-c : clean all" | |||||
echo "-a : run all test" | |||||
echo "-e : set the running target for test (details use \"-e\" to see)" | |||||
echo "-h : show usage" | |||||
exit -1 | |||||
} | |||||
while getopts "pbcahe:w:t:" arg | |||||
do | |||||
case $arg in | |||||
t) | |||||
DEVICE_DESC=$OPTARG | |||||
RUN_ARM_DEVICE="true" | |||||
echo "config arm device DEVICE_DESC to ${DEVICE_DESC}" | |||||
;; | |||||
w) | |||||
WORK_DIR_PATH=$OPTARG | |||||
echo "config arm device WORK_DIR_PATH to ${WORK_DIR_PATH}" | |||||
;; | |||||
e) | |||||
tmp_target=null | |||||
for target in ${RUN_TARGETS[@]}; do | |||||
if [ "$target" = "$OPTARG" ]; then | |||||
echo "CONFIG BUILD RUN_TARGET to : $OPTARG" | |||||
tmp_target=$OPTARG | |||||
RUN_TARGET=$OPTARG | |||||
break | |||||
fi | |||||
done | |||||
if [ "$tmp_target" = "null" ]; then | |||||
echo "ERR args for target (-e)" | |||||
echo "available target usage :" | |||||
for target in ${RUN_TARGETS[@]}; do | |||||
echo " -e $target" | |||||
done | |||||
exit -1 | |||||
fi | |||||
;; | |||||
h) | |||||
echo "show usage" | |||||
usage | |||||
;; | |||||
a) | |||||
echo "config RUN_TARGET=all" | |||||
RUN_TARGET="all" | |||||
;; | |||||
c) | |||||
echo "clean all directory generated by script" | |||||
CLEAN_ALL="true" | |||||
;; | |||||
b) | |||||
echo "run build" | |||||
ONLY_BUILD="true" | |||||
;; | |||||
p) | |||||
echo "prepare model and input" | |||||
ONLY_PREPARE_MODEL="true" | |||||
;; | |||||
?) | |||||
echo "unkonw argument" | |||||
usage | |||||
;; | |||||
esac | |||||
done | |||||
function prepare_model_and_data(){ | |||||
rm -rf model_source && mkdir model_source | |||||
# dump mgb model | |||||
python3 script/resnet50_mgb.py -o model_source/resnet50.pkl | |||||
../dump_with_testcase.py model_source/resnet50.pkl -o model_source/resnet50_with_data.mgb -d "#rand(0, 255)" --no-assert | |||||
# prepare simple add model | |||||
python3 script/add_demo.py --dir model_source | |||||
python3 script/conv_demo.py --dir model_source | |||||
#generate trt model | |||||
script/gen_trt_model.sh | |||||
#prepare mge model | |||||
python3 script/resnet50_mge.py --dir model_source | |||||
python3 script/resnet50_mge.py --dir model_source -d int8 | |||||
python3 script/resnet50_mge.py --dir model_source --inputs "#rand(0,255)" | |||||
#make input_data | |||||
rm -rf input_data && mkdir input_data | |||||
python3 script/mge_input_data.py | |||||
rm -rf tmpdir && mkdir tmpdir | |||||
} | |||||
function build_lar(){ | |||||
# build cpu and cuda version | |||||
../../../scripts/cmake-build/host_build.sh -r -t -e load_and_run | |||||
#WARNING:config the cuda environment before compile | |||||
../../../scripts/cmake-build/host_build.sh -c -t -e load_and_run | |||||
# # build arm version | |||||
../../../scripts/cmake-build/cross_build_android_arm_inference.sh -r -a arm64-v8a -e load_and_run | |||||
../../../scripts/cmake-build/cross_build_android_arm_inference.sh -r -a armeabi-v7a -e load_and_run | |||||
# link or for test | |||||
ln -s ../../../build_dir/host/MGE_WITH_CUDA_OFF/MGE_INFERENCE_ONLY_OFF/Release//build/lite/load_and_run/load_and_run lar_cpu | |||||
ln -s ../../../build_dir/host/MGE_WITH_CUDA_ON/MGE_INFERENCE_ONLY_OFF/Release//build/lite/load_and_run/load_and_run lar_cuda | |||||
cp ../../../build_dir/android/arm64-v8a/Release/build/lite/load_and_run/load_and_run ./lar_arm64 | |||||
cp ../../../build_dir/android/armeabi-v7a/Release/build/lite/load_and_run/load_and_run ./lar_armv7 | |||||
} | |||||
function set_arm_device_and_upload(){ | |||||
DEVICE_DESC="${1}" | |||||
WORK_DIR_PATH="${2}" | |||||
RUN_ARM_DEVICE="${3}" | |||||
cmd="rsync -aP -zz ./lar_arm64 ./lar_armv7 model_source/resnet50_b1_float32_without_data.mge input_data/resnet50_input.npy $DEVICE_DESC:$WORK_DIR_PATH/" | |||||
echo $cmd | |||||
bash -c "$cmd" | |||||
} | |||||
function test_different_model(){ | |||||
CmdArray=("./lar_cpu model_source/resnet50_with_data.mgb") | |||||
CmdArray+=("./lar_cpu model_source/resnet50_b1_float32_with_data.mge") | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD") | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --lite") | |||||
for cmd in ${CmdArray[@]}; do | |||||
echo "$cmd" | |||||
bash -c "$cmd" | |||||
done | |||||
} | |||||
function test_different_device(){ | |||||
#dispatch时,计算任务会加入一个工作队列,由队列统一管理执行 均值 131.278 ms 标准差 15.197ms m_asyc_exec异步执行 | |||||
CmdArray=("./lar_cpu $TYPICAL_MODEL_DATD --cpu") | |||||
#dispatch时,计算任务直接执行 均值 131.875 ms 标准差 7.758ms m_asyc_exec同步执行 | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --cpu-default") | |||||
#多线程运行 1~8平均运行时间(ms):129.611, 84.266, 76.963, 55.212, 69.283, 77.338, 58.386, 64.585 | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --multithread 4") | |||||
#主线程锁核,其他任务在线程池中的线程上运行 132.614, 83.095, 69.792, 54.452, 48.890, 48.206, 46.386, 53.908 | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --multithread-default 4") | |||||
#cpu多线程绑核(x86上绑核影响不大) | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --multithread 2 --multi-thread-core-ids 1,5") | |||||
#xpu 设置为cpu上运行 132.740 ms comp_node:cpu | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cpu") | |||||
#xpu 设置为cuda上运行 6.495 ms comp_node:gpu | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cuda") | |||||
for cmd in ${CmdArray[@]}; do | |||||
echo $cmd | |||||
bash -c "$cmd" | |||||
done | |||||
function test_fast_run(){ | |||||
CmdArray=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --full-run") | |||||
#fast run 搜参 | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --fast-run-algo-policy tmpdir/algo_cache_file") | |||||
#fast run 带参执行 | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run-algo-policy tmpdir/algo_cache_file") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --reproducible") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --fast-run-shared-batch-size 1") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --binary-equal-between-batch") | |||||
for cmd in ${CmdArray[@]}; do | |||||
echo $cmd | |||||
bash -c "$cmd" | |||||
done | |||||
} | |||||
function test_io(){ | |||||
rm -rf tmpdir/bin_io_info tmpdir/bin_out_info tmpdir/bin_out_info_cuda tmpdir/io_info.txt | |||||
mkdir tmpdir/bin_io_info tmpdir/bin_out_info tmpdir/bin_io_info_cuda | |||||
CmdArray=("./lar_cpu $TYPICAL_MODEL_DATD") | |||||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"input_data/add_demo_input.json\"") | |||||
#! the model must support input with nhwc shape | |||||
CmdArray+=("./lar_cpu model_source/resnet50_b1_int8_without_data.mge --input \"data:input_data/cat.ppm\"") | |||||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\"") | |||||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --io-dump tmpdir/io_info.txt --iter 1 --warmup-iter 0") | |||||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --io-dump-stdout --iter 1 --warmup-iter 0") | |||||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --io-dump-stderr --iter 1 --warmup-iter 0") | |||||
#different data in the given directory the name is the var id which is the same with txt-dump information | |||||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --bin-io-dump tmpdir/bin_io_info --iter 1 --warmup-iter 0") | |||||
CmdArray+=("./lar_cuda model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --bin-io-dump tmpdir/bin_io_info_cuda --iter 1 --warmup-iter 0") | |||||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --bin-out-dump tmpdir/bin_out_info --iter 1 --warmup-iter 0") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --copy-to-host") | |||||
for cmd in ${CmdArray[@]}; do | |||||
echo $cmd | |||||
bash -c "$cmd" | |||||
done | |||||
#compare the binary io information | |||||
python3 ../../../imperative/python/megengine/tools/compare_binary_iodump.py tmpdir/bin_io_info tmpdir/bin_io_info_cuda | |||||
} | |||||
function test_layout_related(){ | |||||
# very little speed up | |||||
CmdArray=("./lar_cuda $TYPICAL_MODEL_DATD --enable-nchw4") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-chwn4") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-nchw32") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-nchw64") | |||||
#speed up | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --enable-nchw88") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cuda --layout-transform cuda") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cuda --layout-transform cuda --layout-transform-dump model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cuda.mge") | |||||
CmdArray+=("./lar_cuda model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cuda.mge") | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --cpu --layout-transform cpu") | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --cpu --layout-transform cpu --layout-transform-dump model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cpu.mge") | |||||
CmdArray+=("./lar_cpu model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cpu.mge") | |||||
for cmd in ${CmdArray[@]}; do | |||||
echo $cmd | |||||
bash -c "$cmd" | |||||
done | |||||
if [ ${RUN_ARM_DEVICE} == "true" ] ; then | |||||
#speed up | |||||
CmdArray=("./lar_arm64 resnet50_b1_float32_without_data.mge --input \"data:resnet50_input.npy\" --cpu --enable-nchw44") | |||||
#speed up | |||||
CmdArray+=("./lar_arm64 resnet50_b1_float32_without_data.mge --input \"data:resnet50_input.npy\" --cpu --enable-nchw44-dot") | |||||
for cmd in ${CmdArray[@]}; do | |||||
echo $cmd | |||||
ssh -t $DEVICE_DESC "unset LD_PRELOAD && cd $WORK_DIR_PATH && LD_LIBRARY_PATH=./ $cmd" | |||||
done | |||||
else | |||||
echo "SET arm device ON : $RUN_ARM_DEVICE" | |||||
fi | |||||
} | |||||
function test_optimize(){ | |||||
CmdArray=("./lar_cpu $TYPICAL_MODEL_DATD --enable-fuse-preprocess") | |||||
#warm up speed up | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-fuse-conv-bias-nonlinearity") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-fuse-conv-bias-with-z") | |||||
CmdArray+=("./lar_cuda model_source/trt_conv_demo_with_data.mgb --tensorrt") | |||||
CmdArray+=("./lar_cuda model_source/trt_conv_demo_with_data.mgb --tensorrt --tensorrt-cache tmpdir/TRT_cache") | |||||
CmdArray+=("./lar_cuda model_source/trt_conv_demo_with_data.mgb --tensorrt-cache tmpdir/TRT_cache") | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --no-sanity-check --record-comp-seq2") | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --disable_mem_opt") | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --workspace_limit 10000") | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --fake-first") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable_jit ") | |||||
for cmd in ${CmdArray[@]}; do | |||||
echo $cmd | |||||
bash -c "$cmd" | |||||
done | |||||
} | |||||
function test_plugin(){ | |||||
rm -rf tmpdir/staticMemInfoDir tmpdir/staticMemInfoDirLogs | |||||
mkdir tmpdir/staticMemInfoDir | |||||
CmdArray=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --check-dispatch") | |||||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --check-var-value 5:0") | |||||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --range 2") | |||||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --profile tmpdir/opr_profile.json | |||||
") | |||||
CmdArray+=("./lar_cuda model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --profile-host tmpdir/opr_profile_host.json") | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --model-info") | |||||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --verbose") | |||||
CmdArray+=("./lar_cpu model_source/resnet50_with_data.mgb --disable-assert-throw") | |||||
# wait gdb attach to given PID | |||||
# CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --wait-gdb") | |||||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --get-static-mem-info tmpdir/staticMemInfoDir") | |||||
for cmd in ${CmdArray[@]}; do | |||||
echo $cmd | |||||
bash -c "$cmd" | |||||
done | |||||
# view the graph with given url (usally: http://localhost:6006/) | |||||
# mkdir tmpdir/staticMemInfoDirLogs && python3 ../../../imperative/python/megengine/tools/graph_info_analyze.py -i tmpdir/staticMemInfoDir -o tmpdir/staticMemInfoDirLogs | |||||
# pip3 install tensorboard && tensorboard --logdir tmpdir/staticMemInfoDirLogs | |||||
} | |||||
function clean(){ | |||||
rm -rf tmpdir model_source input_data lar_cpu lar_cuda lar_arm64 lar_armv7 | |||||
} | |||||
function main(){ | |||||
if [ ${CLEAN_ALL} == "true" ] ; then | |||||
clean | |||||
exit 0 | |||||
fi | |||||
if [ ${ONLY_PREPARE_MODEL} == "true" ] ; then | |||||
prepare_model_and_data | |||||
MODEL_PREAPRED="true" | |||||
exit 0 | |||||
fi | |||||
if [ ${ONLY_BUILD} == "true" ] ; then | |||||
build_lar | |||||
LAR_BUILT="true" | |||||
exit 0 | |||||
fi | |||||
if [ ${RUN_ARM_DEVICE} == "true" ] ; then | |||||
set_arm_device_and_upload $DEVICE_DESC "$WORK_DIR_PATH" "true" | |||||
fi | |||||
if [ ${MODEL_PREAPRED} != "true" ] ; then | |||||
CHECK_MODEL=$(find . -name add_demo_input.json) | |||||
if [ ${CHECK_MODEL} == "" ] ; then | |||||
prepare_model_and_data | |||||
MODEL_PREAPRED="true" | |||||
fi | |||||
fi | |||||
if [ ${LAR_BUILT} != "true" ] ; then | |||||
CHECK_LAR=$(find . -name lar_armv7) | |||||
if [ ${CHECK_LAR} == "" ] ; then | |||||
build_lar | |||||
LAR_BUILT="true" | |||||
fi | |||||
fi | |||||
if [ ${RUN_TARGET} == "diff_model" -o ${RUN_TARGET} == "all" ] ; then | |||||
test_different_model | |||||
fi | |||||
if [ ${RUN_TARGET} == "diff_device" -o ${RUN_TARGET} == "all" ] ; then | |||||
test_different_device | |||||
fi | |||||
if [ ${RUN_TARGET} == "fast_run" -o ${RUN_TARGET} == "all" ] ; then | |||||
test_fast_run | |||||
fi | |||||
if [ ${RUN_TARGET} == "io" -o ${RUN_TARGET} == "all" ] ; then | |||||
test_io | |||||
fi | |||||
if [ ${RUN_TARGET} == "layout" -o ${RUN_TARGET} == "all" ] ; then | |||||
test_layout_related | |||||
fi | |||||
if [ ${RUN_TARGET} == "optimize" -o ${RUN_TARGET} == "all" ] ; then | |||||
test_optimize | |||||
fi | |||||
if [ ${RUN_TARGET} == "plugin" -o ${RUN_TARGET} == "all" ] ; then | |||||
test_plugin | |||||
fi | |||||
} | |||||
main | |||||
IFS=$OLD_IFS |
@@ -0,0 +1,57 @@ | |||||
#!/usr/bin/env python3 | |||||
import argparse | |||||
import math | |||||
import megengine.functional as F | |||||
import megengine.hub as hub | |||||
import megengine.module as M | |||||
import numpy as np | |||||
from megengine import jit, tensor, Parameter | |||||
class Simple(M.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.a = Parameter([0,1,2], dtype=np.float32) | |||||
def forward(self, x): | |||||
x = x + self.a | |||||
return x | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser( | |||||
description="dump mge model for add_demo", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument( | |||||
"--inputs", | |||||
help="set the inputs data to get a model with testcase", | |||||
default="", | |||||
type=str, | |||||
) | |||||
parser.add_argument( | |||||
"--dir", | |||||
help="set the dir where the model to dump", | |||||
default=".", | |||||
type=str, | |||||
) | |||||
args = parser.parse_args() | |||||
net = Simple() | |||||
net.eval() | |||||
@jit.trace(symbolic=True, capture_as_const=True) | |||||
def fun(data): | |||||
return net(data) | |||||
data = tensor([3,4,5]) | |||||
fun(data) | |||||
if args.inputs == "": | |||||
fun.dump( | |||||
args.dir + "/add_demo_f32_without_data.mge", arg_names=["data"], | |||||
no_assert=True, | |||||
) | |||||
else: | |||||
fun.dump( | |||||
args.dir + "/add_demo_f32_with_data.mge", arg_names=["data"], | |||||
input_data=[args.inputs], no_assert=True, | |||||
) |
@@ -0,0 +1,42 @@ | |||||
#!/usr/bin/env python3 | |||||
import argparse | |||||
import math | |||||
import megengine.functional as F | |||||
import megengine.module as M | |||||
import numpy as np | |||||
from megengine import jit, tensor | |||||
class ConvNet(M.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.conv1 = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3, bias=False) | |||||
def forward(self, input): | |||||
x = self.conv1(input) | |||||
return x | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser( | |||||
description="dump mge model for add_demo", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument( | |||||
"--dir", | |||||
help="set the dir where the model to dump", | |||||
default=".", | |||||
type=str, | |||||
) | |||||
args = parser.parse_args() | |||||
net = ConvNet() | |||||
net.eval() | |||||
@jit.trace(symbolic=True, capture_as_const=True) | |||||
def fun(data): | |||||
return net(data) | |||||
inp = tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4)) | |||||
out = fun(inp) | |||||
fun.dump(args.dir + "/conv_demo_f32_without_data.mge", arg_names=["data"], | |||||
no_assert=True) |
@@ -0,0 +1,37 @@ | |||||
#!/usr/bin/env python3 | |||||
from megskull.network import RawNetworkBuilder | |||||
import megskull.opr.all as O | |||||
from megskull.opr.external import TensorRTRuntimeOpr | |||||
from meghair.utils.io import dump | |||||
import argparse | |||||
def str2tuple(x): | |||||
x = x.split(',') | |||||
x = [int(a) for a in x] | |||||
x = tuple(x) | |||||
return x | |||||
def make_network(model, isize): | |||||
data = [O.DataProvider('input{}'.format(i), shape=isizes[i]) | |||||
for i in range(len(isizes))] | |||||
f = open(model, 'rb') | |||||
engine = f.read() | |||||
opr = TensorRTRuntimeOpr(data, engine, 1) | |||||
net = RawNetworkBuilder(inputs=[data], outputs=opr.outputs) | |||||
return net | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument(dest = 'model') | |||||
parser.add_argument(dest = 'output') | |||||
parser.add_argument('--isize', help='input sizes. ' | |||||
'e.g. for models with two (1,3,224,224) inputs, ' | |||||
'the option --isize="1,3,224,224;1,3,224,224" should be used') | |||||
args = parser.parse_args() | |||||
isizes = [str2tuple(x) for x in args.isize.split(';')] | |||||
net = make_network(args.model, isizes) | |||||
dump(net, args.output) |
@@ -0,0 +1,33 @@ | |||||
#!/bin/bash | |||||
CUR_DIR="$( cd "$(dirname $0)" >/dev/null 2>&1 ; pwd -P )" | |||||
# find correct trtexec version, only works for internal ci and brainpp env setups | |||||
CUDA_VERSION=$(nvcc --version | grep -o "[0-9].\.[0-9]*" | head -n 1) | |||||
SEARCH_PATH=$(echo `which nvcc | xargs dirname`/../../) | |||||
TRT_CANDIDATE=$(find `cd $SEARCH_PATH; pwd` -name "trtexec" | grep "bin/trtexec" | grep $CUDA_VERSION) | |||||
TRT_CANDIDATE=${TRT_CANDIDATE%$'\n'*} | |||||
TRT_LIB_PATH=$(readlink -f "`dirname $TRT_CANDIDATE`/../lib") | |||||
MODELS_PATH=$(readlink -f "${CUR_DIR}/../model_source") | |||||
# generate mge model | |||||
rm -rf $MODELS_PATH/conv_demo_f32_without_data.mge | |||||
python3 ${CUR_DIR}/conv_demo.py --dir $MODELS_PATH | |||||
# generate caffe model with mge convert | |||||
# INSTALL mgeconvert: | |||||
# python3 -m pip install git+https://github.com/MegEngine/mgeconvert.git --user --install-option="--targets=caffe" | |||||
rm -rf $MODELS_PATH/conv_demo.prototxt $MODELS_PATH/conv_demo.caffemodel | |||||
convert mge_to_caffe -i $MODELS_PATH/conv_demo_f32_without_data.mge -c $MODELS_PATH/conv_demo.prototxt -b $MODELS_PATH/conv_demo.caffemodel | |||||
# generate trt model | |||||
rm -rf $MODELS_PATH/conv_demo.trt | |||||
echo "WARNING: config cudnn and cublas path before run trtexec" | |||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_LIB_PATH | |||||
echo $LD_LIBRARY_PATH | |||||
$TRT_CANDIDATE --deploy="$MODELS_PATH/conv_demo.prototxt" --model="$MODELS_PATH/conv_demo.caffemodel" --output="conv1_Convolution" --batch=1 --saveEngine="$MODELS_PATH/conv_demo.trt" | |||||
# redump trt model into mgb model | |||||
rm -rf $MODELS_PATH/trt_conv_demo.pkl $MODELS_PATH/trt_conv_demo_with_data.mgb | |||||
python3 $CUR_DIR/dump_trt.py $MODELS_PATH/conv_demo.trt $MODELS_PATH/trt_conv_demo.pkl --isize="1,3,4,4" | |||||
$CUR_DIR/../../dump_with_testcase.py $MODELS_PATH/trt_conv_demo.pkl -o $MODELS_PATH/trt_conv_demo_with_data.mgb -d "#rand(0, 255)" --no-assert |
@@ -0,0 +1,33 @@ | |||||
#!/usr/bin/env python3 | |||||
import numpy as np | |||||
import cv2 | |||||
import megengine.data.transform as T | |||||
import megengine.functional as F | |||||
import json | |||||
import urllib | |||||
url, filename = ("https://data.megengine.org.cn/images/cat.jpg", "input_data/cat.jpg") | |||||
try: urllib.URLopener().retrieve(url, filename) | |||||
except: urllib.request.urlretrieve(url, filename) | |||||
# numpy data | |||||
data=np.random.rand(1,3,224,224) | |||||
np.save("input_data/resnet50_input_uint8.npy",data.astype(np.uint8)) | |||||
np.save("input_data/resnet50_input.npy",data.astype(np.float32)) | |||||
#ppm data | |||||
image = cv2.imread("input_data/cat.jpg") | |||||
transform = T.Compose([ | |||||
T.Resize(256), | |||||
T.CenterCrop(224), | |||||
]) | |||||
processed_img = transform.apply(image) | |||||
cv2.imwrite("input_data/cat.ppm",processed_img) | |||||
#json | |||||
data_obj = { | |||||
"shape": [1,3], | |||||
"type": "int32", | |||||
"raw": [2,3,4] | |||||
} | |||||
with open("input_data/add_demo_input.json", "w") as f: | |||||
json.dump({"data":data_obj},f) |
@@ -0,0 +1,153 @@ | |||||
#!/usr/bin/env python3 | |||||
import megbrain as mgb | |||||
from megskull.graph import FpropEnv | |||||
import megskull as mgsk | |||||
from megskull.opr.compatible.caffepool import CaffePooling2D | |||||
from megskull.opr.arith import ReLU | |||||
from megskull.opr.all import ( | |||||
DataProvider, Conv2D, Pooling2D, FullyConnected, | |||||
Softmax, Dropout, BatchNormalization, CrossEntropyLoss, | |||||
ElementwiseAffine, WarpPerspective, WarpPerspectiveWeightProducer, | |||||
WeightDecay, ParamProvider, ConvBiasActivation, ElemwiseMultiType) | |||||
from megskull.network import RawNetworkBuilder | |||||
from megskull.utils.debug import CallbackInjector | |||||
import megskull.opr.helper.param_init as pinit | |||||
from megskull.opr.helper.elemwise_trans import Identity | |||||
from megskull.opr.netsrc import DataProvider | |||||
from megskull.opr.cnn import Conv2D, Pooling2D, FullyConnected, Softmax, Conv2DImplHelper | |||||
from megskull.opr.loss import CrossEntropyLoss | |||||
from megskull.opr.regularizer import Dropout, BatchNormalization | |||||
from megskull.opr.arith import Add, ReLU | |||||
from megskull.opr.netsrc import ConstProvider | |||||
from megskull.network import RawNetworkBuilder | |||||
import numpy as np | |||||
from megskull.network import RawNetworkBuilder, NetworkVisitor | |||||
from megskull.graph import iter_dep_opr | |||||
from megskull.utils.misc import get_2dshape | |||||
import functools | |||||
import re | |||||
import fnmatch | |||||
import argparse | |||||
import sys | |||||
def create_bn_relu_float(conv_name, f_in, ksize, stride, pad, num_outputs, | |||||
has_relu, args): | |||||
f = Conv2D(conv_name, f_in, kernel_shape=ksize, stride=stride, | |||||
padding=pad, output_nr_channel=num_outputs, | |||||
nonlinearity=mgsk.opr.helper.elemwise_trans.Identity()) | |||||
if has_relu: | |||||
f = ReLU(f) | |||||
return f | |||||
def get_num_inputs(feature, format): | |||||
if format == 'NCHW': | |||||
return feature.partial_shape[1] | |||||
else: | |||||
assert format == 'NCHW4' | |||||
return feature.partial_shape[1] * 4 | |||||
def create_bn_relu(prefix, f_in, ksize, stride, pad, num_outputs, | |||||
has_relu, conv_name_fun, args): | |||||
if conv_name_fun: | |||||
conv_name = conv_name_fun(prefix) | |||||
else: | |||||
conv_name = prefix | |||||
return create_bn_relu_float(conv_name, f_in, ksize, stride, pad, | |||||
num_outputs, has_relu, args) | |||||
def create_bottleneck(prefix, f_in, stride, num_outputs1, num_outputs2, args, | |||||
has_proj=False): | |||||
proj = f_in | |||||
if has_proj: | |||||
proj = create_bn_relu(prefix, f_in, ksize=1, stride=stride, pad=0, | |||||
num_outputs=num_outputs2, has_relu=False, | |||||
conv_name_fun=lambda p: "interstellar{}_branch1".format( | |||||
p), args=args) | |||||
f = create_bn_relu(prefix, f_in, ksize=1, stride=1, pad=0, | |||||
num_outputs=num_outputs1, has_relu=True, | |||||
conv_name_fun=lambda p: "interstellar{}_branch2a".format( | |||||
p), args=args) | |||||
f = create_bn_relu(prefix, f, ksize=3, stride=stride, pad=1, | |||||
num_outputs=num_outputs1, has_relu=True, | |||||
conv_name_fun=lambda p: "interstellar{}_branch2b".format( | |||||
p), args=args) | |||||
f = create_bn_relu(prefix, f, ksize=1, stride=1, pad=0, | |||||
num_outputs=num_outputs2, has_relu=False, | |||||
conv_name_fun=lambda p: "interstellar{}_branch2c".format( | |||||
p), args=args) | |||||
f = ReLU(f + proj) | |||||
return f | |||||
def get(args): | |||||
img_size = 224 | |||||
num_inputs = 3 | |||||
data = DataProvider('data', shape=(args.batch_size, num_inputs, | |||||
img_size, img_size)) | |||||
inp = data | |||||
f = create_bn_relu("conv1", inp, ksize=7, stride=2, pad=3, num_outputs=64, | |||||
has_relu=True, conv_name_fun=None, | |||||
args=args) | |||||
f = Pooling2D("pool1", f, window=3, stride=2, padding=1, mode="MAX", | |||||
format=args.format) | |||||
pre = [2, 3, 4, 5] | |||||
stages = [3, 4, 6, 3] | |||||
mid_outputs = [64, 128, 256, 512] | |||||
enable_stride = [False, True, True, True] | |||||
for p, s, o, es in zip(pre, stages, mid_outputs, enable_stride): | |||||
for i in range(s): | |||||
has_proj = False if i > 0 else True | |||||
stride = 1 if not es or i > 0 else 2 | |||||
prefix = "{}{}".format(p, chr(ord("a") + i)) | |||||
f = create_bottleneck(prefix, f, stride, o, o * 4, args, has_proj) | |||||
print("{}\t{}".format(prefix, f.partial_shape)) | |||||
f = Pooling2D("pool5", f, window=7, stride=7, padding=0, mode="AVERAGE", | |||||
format=args.format) | |||||
f = FullyConnected("fc1000", f, output_dim=1000, | |||||
nonlinearity=mgsk.opr.helper.elemwise_trans.Identity()) | |||||
f = Softmax("cls_softmax", f) | |||||
f.init_weights() | |||||
net = RawNetworkBuilder(inputs=[data], outputs=[f]) | |||||
return net | |||||
if __name__ == '__main__': | |||||
parser = argparse.ArgumentParser( | |||||
description='dump pkl model for resnet50', | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
parser.add_argument('-b', '--batch-size', | |||||
help='batch size of the model', default=1) | |||||
parser.add_argument('-f', '--format', choices=['NCHW', 'NCHW4'], | |||||
help='format of conv', | |||||
default='NCHW') | |||||
parser.add_argument('-o', '--output', | |||||
help='output pkl path', required=True) | |||||
args = parser.parse_args() | |||||
if args.format != 'NCHW': | |||||
print('Only suppprt NCHW for float model') | |||||
parser.print_help() | |||||
sys.exit(1) | |||||
from meghair.utils import io | |||||
io.dump(get(args), args.output) | |||||
@@ -0,0 +1,351 @@ | |||||
#!/usr/bin/env python3 | |||||
import argparse | |||||
import math | |||||
import megengine.functional as F | |||||
import megengine.hub as hub | |||||
import megengine.module as M | |||||
import numpy as np | |||||
from megengine import jit, tensor | |||||
class BasicBlock(M.Module): | |||||
expansion = 1 | |||||
def __init__( | |||||
self, | |||||
in_channels, | |||||
channels, | |||||
stride=1, | |||||
groups=1, | |||||
base_width=64, | |||||
dilation=1, | |||||
norm=M.BatchNorm2d, | |||||
): | |||||
super().__init__() | |||||
if groups != 1 or base_width != 64: | |||||
raise ValueError("BasicBlock only supports groups=1 and base_width=64") | |||||
if dilation > 1: | |||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |||||
self.conv1 = M.Conv2d( | |||||
in_channels, channels, 3, stride, padding=dilation, bias=False | |||||
) | |||||
self.bn1 = norm(channels) | |||||
self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False) | |||||
self.bn2 = norm(channels) | |||||
self.downsample = ( | |||||
M.Identity() | |||||
if in_channels == channels and stride == 1 | |||||
else M.Sequential( | |||||
M.Conv2d(in_channels, channels, 1, stride, bias=False), norm(channels), | |||||
) | |||||
) | |||||
def forward(self, x): | |||||
identity = x | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = F.relu(x) | |||||
x = self.conv2(x) | |||||
x = self.bn2(x) | |||||
identity = self.downsample(identity) | |||||
x += identity | |||||
x = F.relu(x) | |||||
return x | |||||
class Bottleneck(M.Module): | |||||
expansion = 4 | |||||
def __init__( | |||||
self, | |||||
in_channels, | |||||
channels, | |||||
stride=1, | |||||
groups=1, | |||||
base_width=64, | |||||
dilation=1, | |||||
norm=M.BatchNorm2d, | |||||
): | |||||
super().__init__() | |||||
width = int(channels * (base_width / 64.0)) * groups | |||||
self.conv1 = M.Conv2d(in_channels, width, 1, 1, bias=False) | |||||
self.bn1 = norm(width) | |||||
self.conv2 = M.Conv2d( | |||||
width, | |||||
width, | |||||
3, | |||||
stride, | |||||
padding=dilation, | |||||
groups=groups, | |||||
dilation=dilation, | |||||
bias=False, | |||||
) | |||||
self.bn2 = norm(width) | |||||
self.conv3 = M.Conv2d(width, channels * self.expansion, 1, 1, bias=False) | |||||
self.bn3 = norm(channels * self.expansion) | |||||
self.downsample = ( | |||||
M.Identity() | |||||
if in_channels == channels * self.expansion and stride == 1 | |||||
else M.Sequential( | |||||
M.Conv2d(in_channels, channels * self.expansion, 1, stride, bias=False), | |||||
norm(channels * self.expansion), | |||||
) | |||||
) | |||||
def forward(self, x): | |||||
identity = x | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = F.relu(x) | |||||
x = self.conv2(x) | |||||
x = self.bn2(x) | |||||
x = F.relu(x) | |||||
x = self.conv3(x) | |||||
x = self.bn3(x) | |||||
identity = self.downsample(identity) | |||||
x += identity | |||||
x = F.relu(x) | |||||
return x | |||||
class ResNet(M.Module): | |||||
def __init__( | |||||
self, | |||||
block, | |||||
layers, | |||||
num_classes=1000, | |||||
zero_init_residual=False, | |||||
groups=1, | |||||
width_per_group=64, | |||||
replace_stride_with_dilation=None, | |||||
norm=M.BatchNorm2d, | |||||
): | |||||
super().__init__() | |||||
self.in_channels = 64 | |||||
self.dilation = 1 | |||||
if replace_stride_with_dilation is None: | |||||
# each element in the tuple indicates if we should replace | |||||
# the 2x2 stride with a dilated convolution instead | |||||
replace_stride_with_dilation = [False, False, False] | |||||
if len(replace_stride_with_dilation) != 3: | |||||
raise ValueError( | |||||
"replace_stride_with_dilation should be None " | |||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation) | |||||
) | |||||
self.groups = groups | |||||
self.base_width = width_per_group | |||||
self.conv1 = M.Conv2d( | |||||
3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False | |||||
) | |||||
self.bn1 = norm(self.in_channels) | |||||
self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=0) | |||||
self.layer1 = self._make_layer(block, 64, layers[0], norm=norm) | |||||
self.layer2 = self._make_layer( | |||||
block, | |||||
128, | |||||
layers[1], | |||||
stride=2, | |||||
dilate=replace_stride_with_dilation[0], | |||||
norm=norm, | |||||
) | |||||
self.layer3 = self._make_layer( | |||||
block, | |||||
256, | |||||
layers[2], | |||||
stride=2, | |||||
dilate=replace_stride_with_dilation[1], | |||||
norm=norm, | |||||
) | |||||
self.layer4 = self._make_layer( | |||||
block, | |||||
512, | |||||
layers[3], | |||||
stride=2, | |||||
dilate=replace_stride_with_dilation[2], | |||||
norm=norm, | |||||
) | |||||
self.fc = M.Linear(512 * block.expansion, num_classes) | |||||
for m in self.modules(): | |||||
if isinstance(m, M.Conv2d): | |||||
M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |||||
if m.bias is not None: | |||||
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) | |||||
bound = 1 / math.sqrt(fan_in) | |||||
M.init.uniform_(m.bias, -bound, bound) | |||||
elif isinstance(m, M.BatchNorm2d): | |||||
M.init.ones_(m.weight) | |||||
M.init.zeros_(m.bias) | |||||
elif isinstance(m, M.Linear): | |||||
M.init.msra_uniform_(m.weight, a=math.sqrt(5)) | |||||
if m.bias is not None: | |||||
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) | |||||
bound = 1 / math.sqrt(fan_in) | |||||
M.init.uniform_(m.bias, -bound, bound) | |||||
# Zero-initialize the last BN in each residual branch, | |||||
# so that the residual branch starts with zeros, and each residual block | |||||
# behaves like an identity. According to https://arxiv.org/abs/1706.02677 | |||||
# This improves the model by 0.2~0.3%. | |||||
if zero_init_residual: | |||||
for m in self.modules(): | |||||
if isinstance(m, Bottleneck): | |||||
M.init.zeros_(m.bn3.weight) | |||||
elif isinstance(m, BasicBlock): | |||||
M.init.zeros_(m.bn2.weight) | |||||
def _make_layer( | |||||
self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d | |||||
): | |||||
previous_dilation = self.dilation | |||||
if dilate: | |||||
self.dilation *= stride | |||||
stride = 1 | |||||
layers = [] | |||||
layers.append( | |||||
block( | |||||
self.in_channels, | |||||
channels, | |||||
stride, | |||||
groups=self.groups, | |||||
base_width=self.base_width, | |||||
dilation=previous_dilation, | |||||
norm=norm, | |||||
) | |||||
) | |||||
self.in_channels = channels * block.expansion | |||||
for _ in range(1, blocks): | |||||
layers.append( | |||||
block( | |||||
self.in_channels, | |||||
channels, | |||||
groups=self.groups, | |||||
base_width=self.base_width, | |||||
dilation=self.dilation, | |||||
norm=norm, | |||||
) | |||||
) | |||||
return M.Sequential(*layers) | |||||
def extract_features(self, x): | |||||
outputs = {} | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = F.relu(x) | |||||
x = self.maxpool(x) | |||||
outputs["stem"] = x | |||||
x = self.layer1(x) | |||||
outputs["res2"] = x | |||||
x = self.layer2(x) | |||||
outputs["res3"] = x | |||||
x = self.layer3(x) | |||||
outputs["res4"] = x | |||||
x = self.layer4(x) | |||||
outputs["res5"] = x | |||||
return outputs | |||||
def forward(self, x): | |||||
x = F.reshape(x, (1,3,224,224)) | |||||
x = self.extract_features(x)["res5"] | |||||
x = F.avg_pool2d(x, 7) | |||||
x = F.flatten(x, 1) | |||||
x = self.fc(x) | |||||
return x | |||||
@hub.pretrained( | |||||
"https://data.megengine.org.cn/models/weights/resnet50_fbaug_76254_4e14b7d1.pkl" | |||||
) | |||||
def resnet50(**kwargs): | |||||
r"""ResNet-50 model from | |||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | |||||
""" | |||||
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser( | |||||
description="dump mge model for resnet50", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument( | |||||
"-b", "--batch-size", help="batch size of the model", default=1, type=int | |||||
) | |||||
parser.add_argument( | |||||
"-d", | |||||
"--dtype", | |||||
help="the dtype of the model,which includes float32 and uint8", | |||||
default="float32", | |||||
type=str, | |||||
) | |||||
parser.add_argument( | |||||
"--inputs", | |||||
help="set the inputs data to get a model with testcase", | |||||
default="", | |||||
type=str, | |||||
) | |||||
parser.add_argument( | |||||
"--dir", | |||||
help="set the dir where the model to dump", | |||||
default=".", | |||||
type=str, | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw4", | |||||
help="enable-nchw4 for NVIDIA CUDNN", | |||||
action='store_true' | |||||
) | |||||
parser.add_argument( | |||||
"--enable-chwn4", | |||||
help="enable-chwn4 for NVIDIA CUDNN", | |||||
action='store_true' | |||||
) | |||||
args = parser.parse_args() | |||||
net = resnet50() | |||||
net.eval() | |||||
@jit.trace(symbolic=True, capture_as_const=True) | |||||
def fun(data): | |||||
return net(data) | |||||
if args.dtype == "float32" or args.dtype == "uint8": | |||||
# dump float32 | |||||
data_type=np.float32 | |||||
if args.dtype == "uint8": | |||||
data_type =np.uint8 | |||||
data = tensor( | |||||
(np.random.random([args.batch_size, 3, 224, 224])*255).astype(data_type) | |||||
) | |||||
fun(data) | |||||
if args.inputs == "": | |||||
fun.dump( | |||||
args.dir + "/resnet50_b" + str(args.batch_size) + "_"+ args.dtype +"_without_data.mge", arg_names=["data"], | |||||
no_assert=True, enable_nchw4=args.enable_nchw4, enable_chwn4=args.enable_chwn4, | |||||
) | |||||
else: | |||||
fun.dump( | |||||
args.dir + "/resnet50_b" + str(args.batch_size) + "_"+ args.dtype +"_with_data.mge", arg_names=["data"], | |||||
input_data=[args.inputs], no_assert=True, enable_nchw4=args.enable_nchw4, | |||||
) | |||||
else: | |||||
raise TypeError("dtype should be float32") |
@@ -148,7 +148,9 @@ XPUDeviceOption::XPUDeviceOption() { | |||||
} | } | ||||
if (!FLAGS_multi_thread_core_ids.empty()) { | if (!FLAGS_multi_thread_core_ids.empty()) { | ||||
mgb_assert(enable_multithread, "core ids should be set after --multithread"); | |||||
mgb_assert( | |||||
enable_multithread || enable_multithread_default, | |||||
"core ids should be set after --multithread or --multithread-default"); | |||||
std::stringstream id_stream(FLAGS_multi_thread_core_ids); | std::stringstream id_stream(FLAGS_multi_thread_core_ids); | ||||
std::string id; | std::string id; | ||||
size_t thread_cnt = 0; | size_t thread_cnt = 0; | ||||
@@ -126,7 +126,9 @@ void COprLibOption::init_extern_param(std::shared_ptr<ModelBase> model_ptr) { | |||||
void COprLibOption::load_lib() { | void COprLibOption::load_lib() { | ||||
auto handle = dlopen(lib_path.c_str(), RTLD_LAZY); | auto handle = dlopen(lib_path.c_str(), RTLD_LAZY); | ||||
mgb_assert(handle, "failed to open c opr lib %s: %s", lib_path.c_str(), dlerror()); | |||||
mgb_assert( | |||||
handle, "failed to open c opr lib %s:\n errmsg: %s", lib_path.c_str(), | |||||
dlerror()); | |||||
const char* entry = MGB_C_OPR_INIT_FUNC_STR; | const char* entry = MGB_C_OPR_INIT_FUNC_STR; | ||||
auto func = dlsym(handle, entry); | auto func = dlsym(handle, entry); | ||||
@@ -286,8 +286,13 @@ DEFINE_string( | |||||
DEFINE_string(io_dump, "", "set the io dump file path in text format"); | DEFINE_string(io_dump, "", "set the io dump file path in text format"); | ||||
DEFINE_bool(io_dump_stdout, false, "dump io opr to stdout in text format"); | DEFINE_bool(io_dump_stdout, false, "dump io opr to stdout in text format"); | ||||
DEFINE_bool(io_dump_stderr, false, "dump io opr to stderr in text format"); | DEFINE_bool(io_dump_stderr, false, "dump io opr to stderr in text format"); | ||||
DEFINE_string(bin_io_dump, "", "set the io dump file path in binary format"); | |||||
DEFINE_string(bin_out_dump, "", "set the out dump file path in binary format"); | |||||
DEFINE_string( | |||||
bin_io_dump, "", | |||||
"set the io dump directory path where variable in binary format located"); | |||||
DEFINE_string( | |||||
bin_out_dump, "", | |||||
"set the out dump directory path where output variable in binary format " | |||||
"located"); | |||||
DEFINE_bool(copy_to_host, false, "copy device data to host"); | DEFINE_bool(copy_to_host, false, "copy device data to host"); | ||||
REGIST_OPTION_CREATOR(input, lar::InputOption::create_option); | REGIST_OPTION_CREATOR(input, lar::InputOption::create_option); | ||||
@@ -77,12 +77,6 @@ void WeightPreprocessOption::config_model_internel<ModelLite>( | |||||
if (weight_preprocess) { | if (weight_preprocess) { | ||||
LITE_WARN("enable weight-preprocess optimization"); | LITE_WARN("enable weight-preprocess optimization"); | ||||
model->get_config().options.weight_preprocess = true; | model->get_config().options.weight_preprocess = true; | ||||
//! FIXME: algo searcher enable weight preprocess for opencl( | |||||
//! implement below has some problem); | |||||
// #if MGB_OPENCL | |||||
// megdnn::opencl::algo_searcher::AlgoSearcherBase:: | |||||
// enable_weight_preprocess(); | |||||
// #endif | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -95,11 +89,6 @@ void WeightPreprocessOption::config_model_internel<ModelMdl>( | |||||
if (weight_preprocess) { | if (weight_preprocess) { | ||||
mgb_log_warn("enable weight-preprocess optimization"); | mgb_log_warn("enable weight-preprocess optimization"); | ||||
graph_option.graph_opt.enable_weight_preprocess(); | graph_option.graph_opt.enable_weight_preprocess(); | ||||
//! FIXME: this implemment is not right | |||||
// #if MGB_OPENCL | |||||
// megdnn::opencl::algo_searcher::AlgoSearcherBase:: | |||||
// enable_weight_preprocess(); | |||||
// #endif | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -570,12 +559,6 @@ DEFINE_bool( | |||||
enable_jit, false, | enable_jit, false, | ||||
" Execute supported operators with JIT(now only support NVRTC). " | " Execute supported operators with JIT(now only support NVRTC). " | ||||
"Can only be used on Nvidia GPUs"); | "Can only be used on Nvidia GPUs"); | ||||
#if MGB_ENABLE_ANDROID_NN | |||||
DEFINE_bool( | |||||
android_nn, false, | |||||
"Execute supported operators with Android NN. Can only be used " | |||||
"with --cpu."); | |||||
#endif | |||||
#if MGB_ENABLE_TENSOR_RT | #if MGB_ENABLE_TENSOR_RT | ||||
DEFINE_bool( | DEFINE_bool( | ||||
tensorrt, false, | tensorrt, false, | ||||
@@ -8,10 +8,10 @@ | |||||
*/ | */ | ||||
#include "plugin_options.h" | #include "plugin_options.h" | ||||
#include <map> | |||||
#include "misc.h" | #include "misc.h" | ||||
#include "models/model_lite.h" | #include "models/model_lite.h" | ||||
#include "models/model_mdl.h" | #include "models/model_mdl.h" | ||||
///////////////////// Plugin options/////////////////////////// | ///////////////////// Plugin options/////////////////////////// | ||||
namespace lar { | namespace lar { | ||||
@@ -153,7 +153,12 @@ void DebugOption::format_and_print( | |||||
auto table = mgb::TextTable(tablename); | auto table = mgb::TextTable(tablename); | ||||
auto&& network = model->get_lite_network(); | auto&& network = model->get_lite_network(); | ||||
table.padding(1); | table.padding(1); | ||||
table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); | |||||
table.align(mgb::TextTable::Align::Mid) | |||||
.add("type") | |||||
.add("name") | |||||
.add("shape") | |||||
.add("dtype") | |||||
.eor(); | |||||
auto to_string = [&](lite::Layout& layout) { | auto to_string = [&](lite::Layout& layout) { | ||||
std::string shape("{"); | std::string shape("{"); | ||||
@@ -165,6 +170,19 @@ void DebugOption::format_and_print( | |||||
shape.append("}"); | shape.append("}"); | ||||
return shape; | return shape; | ||||
}; | }; | ||||
auto get_dtype = [&](lite::Layout& layout) { | |||||
std::map<LiteDataType, std::string> type_map = { | |||||
{LiteDataType::LITE_FLOAT, "float32"}, | |||||
{LiteDataType::LITE_HALF, "float16"}, | |||||
{LiteDataType::LITE_INT64, "int64"}, | |||||
{LiteDataType::LITE_INT, "int32"}, | |||||
{LiteDataType::LITE_UINT, "uint32"}, | |||||
{LiteDataType::LITE_INT16, "int16"}, | |||||
{LiteDataType::LITE_UINT16, "uint16"}, | |||||
{LiteDataType::LITE_INT8, "int8"}, | |||||
{LiteDataType::LITE_UINT8, "uint8"}}; | |||||
return type_map[layout.data_type]; | |||||
}; | |||||
auto input_name = network->get_all_input_name(); | auto input_name = network->get_all_input_name(); | ||||
for (auto& i : input_name) { | for (auto& i : input_name) { | ||||
@@ -173,6 +191,7 @@ void DebugOption::format_and_print( | |||||
.add("INPUT") | .add("INPUT") | ||||
.add(i) | .add(i) | ||||
.add(to_string(layout)) | .add(to_string(layout)) | ||||
.add(get_dtype(layout)) | |||||
.eor(); | .eor(); | ||||
} | } | ||||
@@ -183,6 +202,7 @@ void DebugOption::format_and_print( | |||||
.add("OUTPUT") | .add("OUTPUT") | ||||
.add(i) | .add(i) | ||||
.add(to_string(layout)) | .add(to_string(layout)) | ||||
.add(get_dtype(layout)) | |||||
.eor(); | .eor(); | ||||
} | } | ||||
@@ -196,13 +216,28 @@ void DebugOption::format_and_print( | |||||
const std::string& tablename, std::shared_ptr<ModelMdl> model) { | const std::string& tablename, std::shared_ptr<ModelMdl> model) { | ||||
auto table = mgb::TextTable(tablename); | auto table = mgb::TextTable(tablename); | ||||
table.padding(1); | table.padding(1); | ||||
table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); | |||||
table.align(mgb::TextTable::Align::Mid) | |||||
.add("type") | |||||
.add("name") | |||||
.add("shape") | |||||
.add("dtype") | |||||
.eor(); | |||||
auto get_dtype = [&](megdnn::DType data_type) { | |||||
std::map<megdnn::DTypeEnum, std::string> type_map = { | |||||
{mgb::dtype::Float32().enumv(), "float32"}, | |||||
{mgb::dtype::Int32().enumv(), "int32"}, | |||||
{mgb::dtype::Int16().enumv(), "int16"}, | |||||
{mgb::dtype::Uint16().enumv(), "uint16"}, | |||||
{mgb::dtype::Int8().enumv(), "int8"}, | |||||
{mgb::dtype::Uint8().enumv(), "uint8"}}; | |||||
return type_map[data_type.enumv()]; | |||||
}; | |||||
for (auto&& i : model->get_mdl_load_result().tensor_map) { | for (auto&& i : model->get_mdl_load_result().tensor_map) { | ||||
table.align(mgb::TextTable::Align::Mid) | table.align(mgb::TextTable::Align::Mid) | ||||
.add("INPUT") | .add("INPUT") | ||||
.add(i.first) | .add(i.first) | ||||
.add(i.second->shape().to_string()) | .add(i.second->shape().to_string()) | ||||
.add(get_dtype(i.second->dtype())) | |||||
.eor(); | .eor(); | ||||
} | } | ||||
@@ -211,6 +246,7 @@ void DebugOption::format_and_print( | |||||
.add("OUTPUT") | .add("OUTPUT") | ||||
.add(i.node()->name()) | .add(i.node()->name()) | ||||
.add(i.shape().to_string()) | .add(i.shape().to_string()) | ||||
.add(get_dtype(i.dtype())) | |||||
.eor(); | .eor(); | ||||
} | } | ||||
@@ -358,18 +394,22 @@ DEFINE_double( | |||||
DEFINE_bool( | DEFINE_bool( | ||||
check_dispatch, false, | check_dispatch, false, | ||||
"check whether an operator call dispatch on cpu comp nodes"); | |||||
"check whether an operator call dispatch on cpu comp nodes This is used to " | |||||
"find potential bugs in MegDNN"); | |||||
DEFINE_string( | DEFINE_string( | ||||
check_var_value, "", | check_var_value, "", | ||||
"--check-var-value [interval]|[interval:init_idx], Enable " | "--check-var-value [interval]|[interval:init_idx], Enable " | ||||
"VarValueChecker plugin. Refer to its doc for more details"); | |||||
"VarValueChecker plugin. check values of all vars in a graph from given var " | |||||
"ID(init_idx) with step interval"); | |||||
#if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||
DEFINE_string( | DEFINE_string( | ||||
profile, "", | profile, "", | ||||
"Write profiling result to given file. The output file is in " | "Write profiling result to given file. The output file is in " | ||||
"JSON format"); | "JSON format"); | ||||
DEFINE_string(profile_host, "", "focus on host time profiling For some backends"); | |||||
DEFINE_string( | |||||
profile_host, "", | |||||
"focus on host time profiling For some backends(such as openCL)"); | |||||
#endif | #endif | ||||
///////////////////// Debug gflags/////////////////////////// | ///////////////////// Debug gflags/////////////////////////// | ||||
@@ -94,15 +94,16 @@ void NormalStrategy::run_subline() { | |||||
m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | ||||
stage_config_model(); | stage_config_model(); | ||||
auto cur = timer.get_msecs(); | auto cur = timer.get_msecs(); | ||||
printf("iter %lu/%lu: %.3fms (exec=%.3fms)\n", i, run_num, cur, exec_time); | |||||
printf("iter %lu/%lu: %.3f ms (exec=%.3f ms)\n", i, run_num, cur, | |||||
exec_time); | |||||
time_sum += cur; | time_sum += cur; | ||||
time_sqrsum += cur * cur; | time_sqrsum += cur * cur; | ||||
fflush(stdout); | fflush(stdout); | ||||
min_time = std::min(min_time, cur); | min_time = std::min(min_time, cur); | ||||
max_time = std::max(max_time, cur); | max_time = std::max(max_time, cur); | ||||
} | } | ||||
printf("\n=== finished test #%u: time=%.3fms avg_time=%.3fms " | |||||
"sexec=%.3fms min=%.3fms max=%.3fms\n\n", | |||||
printf("\n=== finished test #%u: time=%.3f ms avg_time=%.3f ms " | |||||
"sexec=%.3f ms min=%.3f ms max=%.3f ms\n\n", | |||||
idx, time_sum, time_sum / run_num, | idx, time_sum, time_sum / run_num, | ||||
std::sqrt( | std::sqrt( | ||||
(time_sqrsum * run_num - time_sum * time_sum) / | (time_sqrsum * run_num - time_sum * time_sum) / | ||||
@@ -121,7 +122,7 @@ void NormalStrategy::run_subline() { | |||||
m_runtime_param.stage = RunStage::MODEL_RUNNING; | m_runtime_param.stage = RunStage::MODEL_RUNNING; | ||||
stage_config_model(); | stage_config_model(); | ||||
if (!idx) { | |||||
if (idx == 0) { | |||||
warm_up(); | warm_up(); | ||||
} | } | ||||
tot_time += run_iter(idx); | tot_time += run_iter(idx); | ||||