GitOrigin-RevId: a28ed2f27a
tags/v1.9.0
@@ -0,0 +1,15 @@ | |||
# load and run example scripts | |||
this is some typical examples for using of load and run, you can use | |||
```bash | |||
./example.sh | |||
``` | |||
to run those examples, more details is in the script | |||
the usage of script can be see by run this scirpt with: | |||
```bash | |||
./example.sh -h | |||
``` |
@@ -0,0 +1,437 @@ | |||
#!/bin/bash | |||
OLD_IFS="$IFS" | |||
IFS=$'\n' | |||
TYPICAL_MODEL_DATD="model_source/resnet50_b1_float32_without_data.mge --input \"data:input_data/resnet50_input.npy\"" | |||
DEVICE_DESC="" | |||
WORK_DIR_PATH="." | |||
RUN_ARM_DEVICE="false" | |||
RUN_TARGET="diff_model" | |||
ONLY_PREPARE_MODEL="false" | |||
MODEL_PREAPRED="false" | |||
ONLY_BUILD="false" | |||
LAR_BUILT="false" | |||
CLEAN_ALL="false" | |||
RUN_TARGETS=("diff_model") | |||
RUN_TARGETS+=("diff_device") | |||
RUN_TARGETS+=("fast_run") | |||
RUN_TARGETS+=("io") | |||
RUN_TARGETS+=("layout") | |||
RUN_TARGETS+=("optimize") | |||
RUN_TARGETS+=("plugin") | |||
RUN_TARGETS+=("all") | |||
function usage() { | |||
echo "$0 args1 args2 .." | |||
echo "available args detail:" | |||
echo "-p : prepare example model " | |||
echo "-b : build load_and_run for x86/armv7/arm64 cpu and CUDA" | |||
echo "-t : set the ssh arm device " | |||
echo "-w : set the arm device workspace dir" | |||
echo "-c : clean all" | |||
echo "-a : run all test" | |||
echo "-e : set the running target for test (details use \"-e\" to see)" | |||
echo "-h : show usage" | |||
exit -1 | |||
} | |||
while getopts "pbcahe:w:t:" arg | |||
do | |||
case $arg in | |||
t) | |||
DEVICE_DESC=$OPTARG | |||
RUN_ARM_DEVICE="true" | |||
echo "config arm device DEVICE_DESC to ${DEVICE_DESC}" | |||
;; | |||
w) | |||
WORK_DIR_PATH=$OPTARG | |||
echo "config arm device WORK_DIR_PATH to ${WORK_DIR_PATH}" | |||
;; | |||
e) | |||
tmp_target=null | |||
for target in ${RUN_TARGETS[@]}; do | |||
if [ "$target" = "$OPTARG" ]; then | |||
echo "CONFIG BUILD RUN_TARGET to : $OPTARG" | |||
tmp_target=$OPTARG | |||
RUN_TARGET=$OPTARG | |||
break | |||
fi | |||
done | |||
if [ "$tmp_target" = "null" ]; then | |||
echo "ERR args for target (-e)" | |||
echo "available target usage :" | |||
for target in ${RUN_TARGETS[@]}; do | |||
echo " -e $target" | |||
done | |||
exit -1 | |||
fi | |||
;; | |||
h) | |||
echo "show usage" | |||
usage | |||
;; | |||
a) | |||
echo "config RUN_TARGET=all" | |||
RUN_TARGET="all" | |||
;; | |||
c) | |||
echo "clean all directory generated by script" | |||
CLEAN_ALL="true" | |||
;; | |||
b) | |||
echo "run build" | |||
ONLY_BUILD="true" | |||
;; | |||
p) | |||
echo "prepare model and input" | |||
ONLY_PREPARE_MODEL="true" | |||
;; | |||
?) | |||
echo "unkonw argument" | |||
usage | |||
;; | |||
esac | |||
done | |||
function prepare_model_and_data(){ | |||
rm -rf model_source && mkdir model_source | |||
# dump mgb model | |||
python3 script/resnet50_mgb.py -o model_source/resnet50.pkl | |||
../dump_with_testcase.py model_source/resnet50.pkl -o model_source/resnet50_with_data.mgb -d "#rand(0, 255)" --no-assert | |||
# prepare simple add model | |||
python3 script/add_demo.py --dir model_source | |||
python3 script/conv_demo.py --dir model_source | |||
#generate trt model | |||
script/gen_trt_model.sh | |||
#prepare mge model | |||
python3 script/resnet50_mge.py --dir model_source | |||
python3 script/resnet50_mge.py --dir model_source -d int8 | |||
python3 script/resnet50_mge.py --dir model_source --inputs "#rand(0,255)" | |||
#make input_data | |||
rm -rf input_data && mkdir input_data | |||
python3 script/mge_input_data.py | |||
rm -rf tmpdir && mkdir tmpdir | |||
} | |||
function build_lar(){ | |||
# build cpu and cuda version | |||
../../../scripts/cmake-build/host_build.sh -r -t -e load_and_run | |||
#WARNING:config the cuda environment before compile | |||
../../../scripts/cmake-build/host_build.sh -c -t -e load_and_run | |||
# # build arm version | |||
../../../scripts/cmake-build/cross_build_android_arm_inference.sh -r -a arm64-v8a -e load_and_run | |||
../../../scripts/cmake-build/cross_build_android_arm_inference.sh -r -a armeabi-v7a -e load_and_run | |||
# link or for test | |||
ln -s ../../../build_dir/host/MGE_WITH_CUDA_OFF/MGE_INFERENCE_ONLY_OFF/Release//build/lite/load_and_run/load_and_run lar_cpu | |||
ln -s ../../../build_dir/host/MGE_WITH_CUDA_ON/MGE_INFERENCE_ONLY_OFF/Release//build/lite/load_and_run/load_and_run lar_cuda | |||
cp ../../../build_dir/android/arm64-v8a/Release/build/lite/load_and_run/load_and_run ./lar_arm64 | |||
cp ../../../build_dir/android/armeabi-v7a/Release/build/lite/load_and_run/load_and_run ./lar_armv7 | |||
} | |||
function set_arm_device_and_upload(){ | |||
DEVICE_DESC="${1}" | |||
WORK_DIR_PATH="${2}" | |||
RUN_ARM_DEVICE="${3}" | |||
cmd="rsync -aP -zz ./lar_arm64 ./lar_armv7 model_source/resnet50_b1_float32_without_data.mge input_data/resnet50_input.npy $DEVICE_DESC:$WORK_DIR_PATH/" | |||
echo $cmd | |||
bash -c "$cmd" | |||
} | |||
function test_different_model(){ | |||
CmdArray=("./lar_cpu model_source/resnet50_with_data.mgb") | |||
CmdArray+=("./lar_cpu model_source/resnet50_b1_float32_with_data.mge") | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD") | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --lite") | |||
for cmd in ${CmdArray[@]}; do | |||
echo "$cmd" | |||
bash -c "$cmd" | |||
done | |||
} | |||
function test_different_device(){ | |||
#dispatch时,计算任务会加入一个工作队列,由队列统一管理执行 均值 131.278 ms 标准差 15.197ms m_asyc_exec异步执行 | |||
CmdArray=("./lar_cpu $TYPICAL_MODEL_DATD --cpu") | |||
#dispatch时,计算任务直接执行 均值 131.875 ms 标准差 7.758ms m_asyc_exec同步执行 | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --cpu-default") | |||
#多线程运行 1~8平均运行时间(ms):129.611, 84.266, 76.963, 55.212, 69.283, 77.338, 58.386, 64.585 | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --multithread 4") | |||
#主线程锁核,其他任务在线程池中的线程上运行 132.614, 83.095, 69.792, 54.452, 48.890, 48.206, 46.386, 53.908 | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --multithread-default 4") | |||
#cpu多线程绑核(x86上绑核影响不大) | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --multithread 2 --multi-thread-core-ids 1,5") | |||
#xpu 设置为cpu上运行 132.740 ms comp_node:cpu | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cpu") | |||
#xpu 设置为cuda上运行 6.495 ms comp_node:gpu | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cuda") | |||
for cmd in ${CmdArray[@]}; do | |||
echo $cmd | |||
bash -c "$cmd" | |||
done | |||
function test_fast_run(){ | |||
CmdArray=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --full-run") | |||
#fast run 搜参 | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --fast-run-algo-policy tmpdir/algo_cache_file") | |||
#fast run 带参执行 | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run-algo-policy tmpdir/algo_cache_file") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --reproducible") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --fast-run-shared-batch-size 1") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --fast-run --binary-equal-between-batch") | |||
for cmd in ${CmdArray[@]}; do | |||
echo $cmd | |||
bash -c "$cmd" | |||
done | |||
} | |||
function test_io(){ | |||
rm -rf tmpdir/bin_io_info tmpdir/bin_out_info tmpdir/bin_out_info_cuda tmpdir/io_info.txt | |||
mkdir tmpdir/bin_io_info tmpdir/bin_out_info tmpdir/bin_io_info_cuda | |||
CmdArray=("./lar_cpu $TYPICAL_MODEL_DATD") | |||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"input_data/add_demo_input.json\"") | |||
#! the model must support input with nhwc shape | |||
CmdArray+=("./lar_cpu model_source/resnet50_b1_int8_without_data.mge --input \"data:input_data/cat.ppm\"") | |||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\"") | |||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --io-dump tmpdir/io_info.txt --iter 1 --warmup-iter 0") | |||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --io-dump-stdout --iter 1 --warmup-iter 0") | |||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --io-dump-stderr --iter 1 --warmup-iter 0") | |||
#different data in the given directory the name is the var id which is the same with txt-dump information | |||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --bin-io-dump tmpdir/bin_io_info --iter 1 --warmup-iter 0") | |||
CmdArray+=("./lar_cuda model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --bin-io-dump tmpdir/bin_io_info_cuda --iter 1 --warmup-iter 0") | |||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --bin-out-dump tmpdir/bin_out_info --iter 1 --warmup-iter 0") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --copy-to-host") | |||
for cmd in ${CmdArray[@]}; do | |||
echo $cmd | |||
bash -c "$cmd" | |||
done | |||
#compare the binary io information | |||
python3 ../../../imperative/python/megengine/tools/compare_binary_iodump.py tmpdir/bin_io_info tmpdir/bin_io_info_cuda | |||
} | |||
function test_layout_related(){ | |||
# very little speed up | |||
CmdArray=("./lar_cuda $TYPICAL_MODEL_DATD --enable-nchw4") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-chwn4") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-nchw32") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-nchw64") | |||
#speed up | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --enable-nchw88") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cuda --layout-transform cuda") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --cuda --layout-transform cuda --layout-transform-dump model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cuda.mge") | |||
CmdArray+=("./lar_cuda model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cuda.mge") | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --cpu --layout-transform cpu") | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --cpu --layout-transform cpu --layout-transform-dump model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cpu.mge") | |||
CmdArray+=("./lar_cpu model_source/resnet50_b1_float32_without_data_with_global_layout_trans_cpu.mge") | |||
for cmd in ${CmdArray[@]}; do | |||
echo $cmd | |||
bash -c "$cmd" | |||
done | |||
if [ ${RUN_ARM_DEVICE} == "true" ] ; then | |||
#speed up | |||
CmdArray=("./lar_arm64 resnet50_b1_float32_without_data.mge --input \"data:resnet50_input.npy\" --cpu --enable-nchw44") | |||
#speed up | |||
CmdArray+=("./lar_arm64 resnet50_b1_float32_without_data.mge --input \"data:resnet50_input.npy\" --cpu --enable-nchw44-dot") | |||
for cmd in ${CmdArray[@]}; do | |||
echo $cmd | |||
ssh -t $DEVICE_DESC "unset LD_PRELOAD && cd $WORK_DIR_PATH && LD_LIBRARY_PATH=./ $cmd" | |||
done | |||
else | |||
echo "SET arm device ON : $RUN_ARM_DEVICE" | |||
fi | |||
} | |||
function test_optimize(){ | |||
CmdArray=("./lar_cpu $TYPICAL_MODEL_DATD --enable-fuse-preprocess") | |||
#warm up speed up | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-fuse-conv-bias-nonlinearity") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable-fuse-conv-bias-with-z") | |||
CmdArray+=("./lar_cuda model_source/trt_conv_demo_with_data.mgb --tensorrt") | |||
CmdArray+=("./lar_cuda model_source/trt_conv_demo_with_data.mgb --tensorrt --tensorrt-cache tmpdir/TRT_cache") | |||
CmdArray+=("./lar_cuda model_source/trt_conv_demo_with_data.mgb --tensorrt-cache tmpdir/TRT_cache") | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --no-sanity-check --record-comp-seq2") | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --disable_mem_opt") | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --workspace_limit 10000") | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --fake-first") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --enable_jit ") | |||
for cmd in ${CmdArray[@]}; do | |||
echo $cmd | |||
bash -c "$cmd" | |||
done | |||
} | |||
function test_plugin(){ | |||
rm -rf tmpdir/staticMemInfoDir tmpdir/staticMemInfoDirLogs | |||
mkdir tmpdir/staticMemInfoDir | |||
CmdArray=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --check-dispatch") | |||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --check-var-value 5:0") | |||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --range 2") | |||
CmdArray+=("./lar_cpu model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --profile tmpdir/opr_profile.json | |||
") | |||
CmdArray+=("./lar_cuda model_source/add_demo_f32_without_data.mge --input \"data:[2,3,4]\" --profile-host tmpdir/opr_profile_host.json") | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --model-info") | |||
CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --verbose") | |||
CmdArray+=("./lar_cpu model_source/resnet50_with_data.mgb --disable-assert-throw") | |||
# wait gdb attach to given PID | |||
# CmdArray+=("./lar_cpu $TYPICAL_MODEL_DATD --wait-gdb") | |||
CmdArray+=("./lar_cuda $TYPICAL_MODEL_DATD --get-static-mem-info tmpdir/staticMemInfoDir") | |||
for cmd in ${CmdArray[@]}; do | |||
echo $cmd | |||
bash -c "$cmd" | |||
done | |||
# view the graph with given url (usally: http://localhost:6006/) | |||
# mkdir tmpdir/staticMemInfoDirLogs && python3 ../../../imperative/python/megengine/tools/graph_info_analyze.py -i tmpdir/staticMemInfoDir -o tmpdir/staticMemInfoDirLogs | |||
# pip3 install tensorboard && tensorboard --logdir tmpdir/staticMemInfoDirLogs | |||
} | |||
function clean(){ | |||
rm -rf tmpdir model_source input_data lar_cpu lar_cuda lar_arm64 lar_armv7 | |||
} | |||
function main(){ | |||
if [ ${CLEAN_ALL} == "true" ] ; then | |||
clean | |||
exit 0 | |||
fi | |||
if [ ${ONLY_PREPARE_MODEL} == "true" ] ; then | |||
prepare_model_and_data | |||
MODEL_PREAPRED="true" | |||
exit 0 | |||
fi | |||
if [ ${ONLY_BUILD} == "true" ] ; then | |||
build_lar | |||
LAR_BUILT="true" | |||
exit 0 | |||
fi | |||
if [ ${RUN_ARM_DEVICE} == "true" ] ; then | |||
set_arm_device_and_upload $DEVICE_DESC "$WORK_DIR_PATH" "true" | |||
fi | |||
if [ ${MODEL_PREAPRED} != "true" ] ; then | |||
CHECK_MODEL=$(find . -name add_demo_input.json) | |||
if [ ${CHECK_MODEL} == "" ] ; then | |||
prepare_model_and_data | |||
MODEL_PREAPRED="true" | |||
fi | |||
fi | |||
if [ ${LAR_BUILT} != "true" ] ; then | |||
CHECK_LAR=$(find . -name lar_armv7) | |||
if [ ${CHECK_LAR} == "" ] ; then | |||
build_lar | |||
LAR_BUILT="true" | |||
fi | |||
fi | |||
if [ ${RUN_TARGET} == "diff_model" -o ${RUN_TARGET} == "all" ] ; then | |||
test_different_model | |||
fi | |||
if [ ${RUN_TARGET} == "diff_device" -o ${RUN_TARGET} == "all" ] ; then | |||
test_different_device | |||
fi | |||
if [ ${RUN_TARGET} == "fast_run" -o ${RUN_TARGET} == "all" ] ; then | |||
test_fast_run | |||
fi | |||
if [ ${RUN_TARGET} == "io" -o ${RUN_TARGET} == "all" ] ; then | |||
test_io | |||
fi | |||
if [ ${RUN_TARGET} == "layout" -o ${RUN_TARGET} == "all" ] ; then | |||
test_layout_related | |||
fi | |||
if [ ${RUN_TARGET} == "optimize" -o ${RUN_TARGET} == "all" ] ; then | |||
test_optimize | |||
fi | |||
if [ ${RUN_TARGET} == "plugin" -o ${RUN_TARGET} == "all" ] ; then | |||
test_plugin | |||
fi | |||
} | |||
main | |||
IFS=$OLD_IFS |
@@ -0,0 +1,57 @@ | |||
#!/usr/bin/env python3 | |||
import argparse | |||
import math | |||
import megengine.functional as F | |||
import megengine.hub as hub | |||
import megengine.module as M | |||
import numpy as np | |||
from megengine import jit, tensor, Parameter | |||
class Simple(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.a = Parameter([0,1,2], dtype=np.float32) | |||
def forward(self, x): | |||
x = x + self.a | |||
return x | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser( | |||
description="dump mge model for add_demo", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument( | |||
"--inputs", | |||
help="set the inputs data to get a model with testcase", | |||
default="", | |||
type=str, | |||
) | |||
parser.add_argument( | |||
"--dir", | |||
help="set the dir where the model to dump", | |||
default=".", | |||
type=str, | |||
) | |||
args = parser.parse_args() | |||
net = Simple() | |||
net.eval() | |||
@jit.trace(symbolic=True, capture_as_const=True) | |||
def fun(data): | |||
return net(data) | |||
data = tensor([3,4,5]) | |||
fun(data) | |||
if args.inputs == "": | |||
fun.dump( | |||
args.dir + "/add_demo_f32_without_data.mge", arg_names=["data"], | |||
no_assert=True, | |||
) | |||
else: | |||
fun.dump( | |||
args.dir + "/add_demo_f32_with_data.mge", arg_names=["data"], | |||
input_data=[args.inputs], no_assert=True, | |||
) |
@@ -0,0 +1,42 @@ | |||
#!/usr/bin/env python3 | |||
import argparse | |||
import math | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import numpy as np | |||
from megengine import jit, tensor | |||
class ConvNet(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv1 = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3, bias=False) | |||
def forward(self, input): | |||
x = self.conv1(input) | |||
return x | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser( | |||
description="dump mge model for add_demo", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument( | |||
"--dir", | |||
help="set the dir where the model to dump", | |||
default=".", | |||
type=str, | |||
) | |||
args = parser.parse_args() | |||
net = ConvNet() | |||
net.eval() | |||
@jit.trace(symbolic=True, capture_as_const=True) | |||
def fun(data): | |||
return net(data) | |||
inp = tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4)) | |||
out = fun(inp) | |||
fun.dump(args.dir + "/conv_demo_f32_without_data.mge", arg_names=["data"], | |||
no_assert=True) |
@@ -0,0 +1,37 @@ | |||
#!/usr/bin/env python3 | |||
from megskull.network import RawNetworkBuilder | |||
import megskull.opr.all as O | |||
from megskull.opr.external import TensorRTRuntimeOpr | |||
from meghair.utils.io import dump | |||
import argparse | |||
def str2tuple(x): | |||
x = x.split(',') | |||
x = [int(a) for a in x] | |||
x = tuple(x) | |||
return x | |||
def make_network(model, isize): | |||
data = [O.DataProvider('input{}'.format(i), shape=isizes[i]) | |||
for i in range(len(isizes))] | |||
f = open(model, 'rb') | |||
engine = f.read() | |||
opr = TensorRTRuntimeOpr(data, engine, 1) | |||
net = RawNetworkBuilder(inputs=[data], outputs=opr.outputs) | |||
return net | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument(dest = 'model') | |||
parser.add_argument(dest = 'output') | |||
parser.add_argument('--isize', help='input sizes. ' | |||
'e.g. for models with two (1,3,224,224) inputs, ' | |||
'the option --isize="1,3,224,224;1,3,224,224" should be used') | |||
args = parser.parse_args() | |||
isizes = [str2tuple(x) for x in args.isize.split(';')] | |||
net = make_network(args.model, isizes) | |||
dump(net, args.output) |
@@ -0,0 +1,33 @@ | |||
#!/bin/bash | |||
CUR_DIR="$( cd "$(dirname $0)" >/dev/null 2>&1 ; pwd -P )" | |||
# find correct trtexec version, only works for internal ci and brainpp env setups | |||
CUDA_VERSION=$(nvcc --version | grep -o "[0-9].\.[0-9]*" | head -n 1) | |||
SEARCH_PATH=$(echo `which nvcc | xargs dirname`/../../) | |||
TRT_CANDIDATE=$(find `cd $SEARCH_PATH; pwd` -name "trtexec" | grep "bin/trtexec" | grep $CUDA_VERSION) | |||
TRT_CANDIDATE=${TRT_CANDIDATE%$'\n'*} | |||
TRT_LIB_PATH=$(readlink -f "`dirname $TRT_CANDIDATE`/../lib") | |||
MODELS_PATH=$(readlink -f "${CUR_DIR}/../model_source") | |||
# generate mge model | |||
rm -rf $MODELS_PATH/conv_demo_f32_without_data.mge | |||
python3 ${CUR_DIR}/conv_demo.py --dir $MODELS_PATH | |||
# generate caffe model with mge convert | |||
# INSTALL mgeconvert: | |||
# python3 -m pip install git+https://github.com/MegEngine/mgeconvert.git --user --install-option="--targets=caffe" | |||
rm -rf $MODELS_PATH/conv_demo.prototxt $MODELS_PATH/conv_demo.caffemodel | |||
convert mge_to_caffe -i $MODELS_PATH/conv_demo_f32_without_data.mge -c $MODELS_PATH/conv_demo.prototxt -b $MODELS_PATH/conv_demo.caffemodel | |||
# generate trt model | |||
rm -rf $MODELS_PATH/conv_demo.trt | |||
echo "WARNING: config cudnn and cublas path before run trtexec" | |||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_LIB_PATH | |||
echo $LD_LIBRARY_PATH | |||
$TRT_CANDIDATE --deploy="$MODELS_PATH/conv_demo.prototxt" --model="$MODELS_PATH/conv_demo.caffemodel" --output="conv1_Convolution" --batch=1 --saveEngine="$MODELS_PATH/conv_demo.trt" | |||
# redump trt model into mgb model | |||
rm -rf $MODELS_PATH/trt_conv_demo.pkl $MODELS_PATH/trt_conv_demo_with_data.mgb | |||
python3 $CUR_DIR/dump_trt.py $MODELS_PATH/conv_demo.trt $MODELS_PATH/trt_conv_demo.pkl --isize="1,3,4,4" | |||
$CUR_DIR/../../dump_with_testcase.py $MODELS_PATH/trt_conv_demo.pkl -o $MODELS_PATH/trt_conv_demo_with_data.mgb -d "#rand(0, 255)" --no-assert |
@@ -0,0 +1,33 @@ | |||
#!/usr/bin/env python3 | |||
import numpy as np | |||
import cv2 | |||
import megengine.data.transform as T | |||
import megengine.functional as F | |||
import json | |||
import urllib | |||
url, filename = ("https://data.megengine.org.cn/images/cat.jpg", "input_data/cat.jpg") | |||
try: urllib.URLopener().retrieve(url, filename) | |||
except: urllib.request.urlretrieve(url, filename) | |||
# numpy data | |||
data=np.random.rand(1,3,224,224) | |||
np.save("input_data/resnet50_input_uint8.npy",data.astype(np.uint8)) | |||
np.save("input_data/resnet50_input.npy",data.astype(np.float32)) | |||
#ppm data | |||
image = cv2.imread("input_data/cat.jpg") | |||
transform = T.Compose([ | |||
T.Resize(256), | |||
T.CenterCrop(224), | |||
]) | |||
processed_img = transform.apply(image) | |||
cv2.imwrite("input_data/cat.ppm",processed_img) | |||
#json | |||
data_obj = { | |||
"shape": [1,3], | |||
"type": "int32", | |||
"raw": [2,3,4] | |||
} | |||
with open("input_data/add_demo_input.json", "w") as f: | |||
json.dump({"data":data_obj},f) |
@@ -0,0 +1,153 @@ | |||
#!/usr/bin/env python3 | |||
import megbrain as mgb | |||
from megskull.graph import FpropEnv | |||
import megskull as mgsk | |||
from megskull.opr.compatible.caffepool import CaffePooling2D | |||
from megskull.opr.arith import ReLU | |||
from megskull.opr.all import ( | |||
DataProvider, Conv2D, Pooling2D, FullyConnected, | |||
Softmax, Dropout, BatchNormalization, CrossEntropyLoss, | |||
ElementwiseAffine, WarpPerspective, WarpPerspectiveWeightProducer, | |||
WeightDecay, ParamProvider, ConvBiasActivation, ElemwiseMultiType) | |||
from megskull.network import RawNetworkBuilder | |||
from megskull.utils.debug import CallbackInjector | |||
import megskull.opr.helper.param_init as pinit | |||
from megskull.opr.helper.elemwise_trans import Identity | |||
from megskull.opr.netsrc import DataProvider | |||
from megskull.opr.cnn import Conv2D, Pooling2D, FullyConnected, Softmax, Conv2DImplHelper | |||
from megskull.opr.loss import CrossEntropyLoss | |||
from megskull.opr.regularizer import Dropout, BatchNormalization | |||
from megskull.opr.arith import Add, ReLU | |||
from megskull.opr.netsrc import ConstProvider | |||
from megskull.network import RawNetworkBuilder | |||
import numpy as np | |||
from megskull.network import RawNetworkBuilder, NetworkVisitor | |||
from megskull.graph import iter_dep_opr | |||
from megskull.utils.misc import get_2dshape | |||
import functools | |||
import re | |||
import fnmatch | |||
import argparse | |||
import sys | |||
def create_bn_relu_float(conv_name, f_in, ksize, stride, pad, num_outputs, | |||
has_relu, args): | |||
f = Conv2D(conv_name, f_in, kernel_shape=ksize, stride=stride, | |||
padding=pad, output_nr_channel=num_outputs, | |||
nonlinearity=mgsk.opr.helper.elemwise_trans.Identity()) | |||
if has_relu: | |||
f = ReLU(f) | |||
return f | |||
def get_num_inputs(feature, format): | |||
if format == 'NCHW': | |||
return feature.partial_shape[1] | |||
else: | |||
assert format == 'NCHW4' | |||
return feature.partial_shape[1] * 4 | |||
def create_bn_relu(prefix, f_in, ksize, stride, pad, num_outputs, | |||
has_relu, conv_name_fun, args): | |||
if conv_name_fun: | |||
conv_name = conv_name_fun(prefix) | |||
else: | |||
conv_name = prefix | |||
return create_bn_relu_float(conv_name, f_in, ksize, stride, pad, | |||
num_outputs, has_relu, args) | |||
def create_bottleneck(prefix, f_in, stride, num_outputs1, num_outputs2, args, | |||
has_proj=False): | |||
proj = f_in | |||
if has_proj: | |||
proj = create_bn_relu(prefix, f_in, ksize=1, stride=stride, pad=0, | |||
num_outputs=num_outputs2, has_relu=False, | |||
conv_name_fun=lambda p: "interstellar{}_branch1".format( | |||
p), args=args) | |||
f = create_bn_relu(prefix, f_in, ksize=1, stride=1, pad=0, | |||
num_outputs=num_outputs1, has_relu=True, | |||
conv_name_fun=lambda p: "interstellar{}_branch2a".format( | |||
p), args=args) | |||
f = create_bn_relu(prefix, f, ksize=3, stride=stride, pad=1, | |||
num_outputs=num_outputs1, has_relu=True, | |||
conv_name_fun=lambda p: "interstellar{}_branch2b".format( | |||
p), args=args) | |||
f = create_bn_relu(prefix, f, ksize=1, stride=1, pad=0, | |||
num_outputs=num_outputs2, has_relu=False, | |||
conv_name_fun=lambda p: "interstellar{}_branch2c".format( | |||
p), args=args) | |||
f = ReLU(f + proj) | |||
return f | |||
def get(args): | |||
img_size = 224 | |||
num_inputs = 3 | |||
data = DataProvider('data', shape=(args.batch_size, num_inputs, | |||
img_size, img_size)) | |||
inp = data | |||
f = create_bn_relu("conv1", inp, ksize=7, stride=2, pad=3, num_outputs=64, | |||
has_relu=True, conv_name_fun=None, | |||
args=args) | |||
f = Pooling2D("pool1", f, window=3, stride=2, padding=1, mode="MAX", | |||
format=args.format) | |||
pre = [2, 3, 4, 5] | |||
stages = [3, 4, 6, 3] | |||
mid_outputs = [64, 128, 256, 512] | |||
enable_stride = [False, True, True, True] | |||
for p, s, o, es in zip(pre, stages, mid_outputs, enable_stride): | |||
for i in range(s): | |||
has_proj = False if i > 0 else True | |||
stride = 1 if not es or i > 0 else 2 | |||
prefix = "{}{}".format(p, chr(ord("a") + i)) | |||
f = create_bottleneck(prefix, f, stride, o, o * 4, args, has_proj) | |||
print("{}\t{}".format(prefix, f.partial_shape)) | |||
f = Pooling2D("pool5", f, window=7, stride=7, padding=0, mode="AVERAGE", | |||
format=args.format) | |||
f = FullyConnected("fc1000", f, output_dim=1000, | |||
nonlinearity=mgsk.opr.helper.elemwise_trans.Identity()) | |||
f = Softmax("cls_softmax", f) | |||
f.init_weights() | |||
net = RawNetworkBuilder(inputs=[data], outputs=[f]) | |||
return net | |||
if __name__ == '__main__': | |||
parser = argparse.ArgumentParser( | |||
description='dump pkl model for resnet50', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('-b', '--batch-size', | |||
help='batch size of the model', default=1) | |||
parser.add_argument('-f', '--format', choices=['NCHW', 'NCHW4'], | |||
help='format of conv', | |||
default='NCHW') | |||
parser.add_argument('-o', '--output', | |||
help='output pkl path', required=True) | |||
args = parser.parse_args() | |||
if args.format != 'NCHW': | |||
print('Only suppprt NCHW for float model') | |||
parser.print_help() | |||
sys.exit(1) | |||
from meghair.utils import io | |||
io.dump(get(args), args.output) | |||
@@ -0,0 +1,351 @@ | |||
#!/usr/bin/env python3 | |||
import argparse | |||
import math | |||
import megengine.functional as F | |||
import megengine.hub as hub | |||
import megengine.module as M | |||
import numpy as np | |||
from megengine import jit, tensor | |||
class BasicBlock(M.Module): | |||
expansion = 1 | |||
def __init__( | |||
self, | |||
in_channels, | |||
channels, | |||
stride=1, | |||
groups=1, | |||
base_width=64, | |||
dilation=1, | |||
norm=M.BatchNorm2d, | |||
): | |||
super().__init__() | |||
if groups != 1 or base_width != 64: | |||
raise ValueError("BasicBlock only supports groups=1 and base_width=64") | |||
if dilation > 1: | |||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |||
self.conv1 = M.Conv2d( | |||
in_channels, channels, 3, stride, padding=dilation, bias=False | |||
) | |||
self.bn1 = norm(channels) | |||
self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False) | |||
self.bn2 = norm(channels) | |||
self.downsample = ( | |||
M.Identity() | |||
if in_channels == channels and stride == 1 | |||
else M.Sequential( | |||
M.Conv2d(in_channels, channels, 1, stride, bias=False), norm(channels), | |||
) | |||
) | |||
def forward(self, x): | |||
identity = x | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = F.relu(x) | |||
x = self.conv2(x) | |||
x = self.bn2(x) | |||
identity = self.downsample(identity) | |||
x += identity | |||
x = F.relu(x) | |||
return x | |||
class Bottleneck(M.Module): | |||
expansion = 4 | |||
def __init__( | |||
self, | |||
in_channels, | |||
channels, | |||
stride=1, | |||
groups=1, | |||
base_width=64, | |||
dilation=1, | |||
norm=M.BatchNorm2d, | |||
): | |||
super().__init__() | |||
width = int(channels * (base_width / 64.0)) * groups | |||
self.conv1 = M.Conv2d(in_channels, width, 1, 1, bias=False) | |||
self.bn1 = norm(width) | |||
self.conv2 = M.Conv2d( | |||
width, | |||
width, | |||
3, | |||
stride, | |||
padding=dilation, | |||
groups=groups, | |||
dilation=dilation, | |||
bias=False, | |||
) | |||
self.bn2 = norm(width) | |||
self.conv3 = M.Conv2d(width, channels * self.expansion, 1, 1, bias=False) | |||
self.bn3 = norm(channels * self.expansion) | |||
self.downsample = ( | |||
M.Identity() | |||
if in_channels == channels * self.expansion and stride == 1 | |||
else M.Sequential( | |||
M.Conv2d(in_channels, channels * self.expansion, 1, stride, bias=False), | |||
norm(channels * self.expansion), | |||
) | |||
) | |||
def forward(self, x): | |||
identity = x | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = F.relu(x) | |||
x = self.conv2(x) | |||
x = self.bn2(x) | |||
x = F.relu(x) | |||
x = self.conv3(x) | |||
x = self.bn3(x) | |||
identity = self.downsample(identity) | |||
x += identity | |||
x = F.relu(x) | |||
return x | |||
class ResNet(M.Module): | |||
def __init__( | |||
self, | |||
block, | |||
layers, | |||
num_classes=1000, | |||
zero_init_residual=False, | |||
groups=1, | |||
width_per_group=64, | |||
replace_stride_with_dilation=None, | |||
norm=M.BatchNorm2d, | |||
): | |||
super().__init__() | |||
self.in_channels = 64 | |||
self.dilation = 1 | |||
if replace_stride_with_dilation is None: | |||
# each element in the tuple indicates if we should replace | |||
# the 2x2 stride with a dilated convolution instead | |||
replace_stride_with_dilation = [False, False, False] | |||
if len(replace_stride_with_dilation) != 3: | |||
raise ValueError( | |||
"replace_stride_with_dilation should be None " | |||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation) | |||
) | |||
self.groups = groups | |||
self.base_width = width_per_group | |||
self.conv1 = M.Conv2d( | |||
3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False | |||
) | |||
self.bn1 = norm(self.in_channels) | |||
self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=0) | |||
self.layer1 = self._make_layer(block, 64, layers[0], norm=norm) | |||
self.layer2 = self._make_layer( | |||
block, | |||
128, | |||
layers[1], | |||
stride=2, | |||
dilate=replace_stride_with_dilation[0], | |||
norm=norm, | |||
) | |||
self.layer3 = self._make_layer( | |||
block, | |||
256, | |||
layers[2], | |||
stride=2, | |||
dilate=replace_stride_with_dilation[1], | |||
norm=norm, | |||
) | |||
self.layer4 = self._make_layer( | |||
block, | |||
512, | |||
layers[3], | |||
stride=2, | |||
dilate=replace_stride_with_dilation[2], | |||
norm=norm, | |||
) | |||
self.fc = M.Linear(512 * block.expansion, num_classes) | |||
for m in self.modules(): | |||
if isinstance(m, M.Conv2d): | |||
M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |||
if m.bias is not None: | |||
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) | |||
bound = 1 / math.sqrt(fan_in) | |||
M.init.uniform_(m.bias, -bound, bound) | |||
elif isinstance(m, M.BatchNorm2d): | |||
M.init.ones_(m.weight) | |||
M.init.zeros_(m.bias) | |||
elif isinstance(m, M.Linear): | |||
M.init.msra_uniform_(m.weight, a=math.sqrt(5)) | |||
if m.bias is not None: | |||
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) | |||
bound = 1 / math.sqrt(fan_in) | |||
M.init.uniform_(m.bias, -bound, bound) | |||
# Zero-initialize the last BN in each residual branch, | |||
# so that the residual branch starts with zeros, and each residual block | |||
# behaves like an identity. According to https://arxiv.org/abs/1706.02677 | |||
# This improves the model by 0.2~0.3%. | |||
if zero_init_residual: | |||
for m in self.modules(): | |||
if isinstance(m, Bottleneck): | |||
M.init.zeros_(m.bn3.weight) | |||
elif isinstance(m, BasicBlock): | |||
M.init.zeros_(m.bn2.weight) | |||
def _make_layer( | |||
self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d | |||
): | |||
previous_dilation = self.dilation | |||
if dilate: | |||
self.dilation *= stride | |||
stride = 1 | |||
layers = [] | |||
layers.append( | |||
block( | |||
self.in_channels, | |||
channels, | |||
stride, | |||
groups=self.groups, | |||
base_width=self.base_width, | |||
dilation=previous_dilation, | |||
norm=norm, | |||
) | |||
) | |||
self.in_channels = channels * block.expansion | |||
for _ in range(1, blocks): | |||
layers.append( | |||
block( | |||
self.in_channels, | |||
channels, | |||
groups=self.groups, | |||
base_width=self.base_width, | |||
dilation=self.dilation, | |||
norm=norm, | |||
) | |||
) | |||
return M.Sequential(*layers) | |||
def extract_features(self, x): | |||
outputs = {} | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = F.relu(x) | |||
x = self.maxpool(x) | |||
outputs["stem"] = x | |||
x = self.layer1(x) | |||
outputs["res2"] = x | |||
x = self.layer2(x) | |||
outputs["res3"] = x | |||
x = self.layer3(x) | |||
outputs["res4"] = x | |||
x = self.layer4(x) | |||
outputs["res5"] = x | |||
return outputs | |||
def forward(self, x): | |||
x = F.reshape(x, (1,3,224,224)) | |||
x = self.extract_features(x)["res5"] | |||
x = F.avg_pool2d(x, 7) | |||
x = F.flatten(x, 1) | |||
x = self.fc(x) | |||
return x | |||
@hub.pretrained( | |||
"https://data.megengine.org.cn/models/weights/resnet50_fbaug_76254_4e14b7d1.pkl" | |||
) | |||
def resnet50(**kwargs): | |||
r"""ResNet-50 model from | |||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | |||
""" | |||
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser( | |||
description="dump mge model for resnet50", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument( | |||
"-b", "--batch-size", help="batch size of the model", default=1, type=int | |||
) | |||
parser.add_argument( | |||
"-d", | |||
"--dtype", | |||
help="the dtype of the model,which includes float32 and uint8", | |||
default="float32", | |||
type=str, | |||
) | |||
parser.add_argument( | |||
"--inputs", | |||
help="set the inputs data to get a model with testcase", | |||
default="", | |||
type=str, | |||
) | |||
parser.add_argument( | |||
"--dir", | |||
help="set the dir where the model to dump", | |||
default=".", | |||
type=str, | |||
) | |||
parser.add_argument( | |||
"--enable-nchw4", | |||
help="enable-nchw4 for NVIDIA CUDNN", | |||
action='store_true' | |||
) | |||
parser.add_argument( | |||
"--enable-chwn4", | |||
help="enable-chwn4 for NVIDIA CUDNN", | |||
action='store_true' | |||
) | |||
args = parser.parse_args() | |||
net = resnet50() | |||
net.eval() | |||
@jit.trace(symbolic=True, capture_as_const=True) | |||
def fun(data): | |||
return net(data) | |||
if args.dtype == "float32" or args.dtype == "uint8": | |||
# dump float32 | |||
data_type=np.float32 | |||
if args.dtype == "uint8": | |||
data_type =np.uint8 | |||
data = tensor( | |||
(np.random.random([args.batch_size, 3, 224, 224])*255).astype(data_type) | |||
) | |||
fun(data) | |||
if args.inputs == "": | |||
fun.dump( | |||
args.dir + "/resnet50_b" + str(args.batch_size) + "_"+ args.dtype +"_without_data.mge", arg_names=["data"], | |||
no_assert=True, enable_nchw4=args.enable_nchw4, enable_chwn4=args.enable_chwn4, | |||
) | |||
else: | |||
fun.dump( | |||
args.dir + "/resnet50_b" + str(args.batch_size) + "_"+ args.dtype +"_with_data.mge", arg_names=["data"], | |||
input_data=[args.inputs], no_assert=True, enable_nchw4=args.enable_nchw4, | |||
) | |||
else: | |||
raise TypeError("dtype should be float32") |
@@ -148,7 +148,9 @@ XPUDeviceOption::XPUDeviceOption() { | |||
} | |||
if (!FLAGS_multi_thread_core_ids.empty()) { | |||
mgb_assert(enable_multithread, "core ids should be set after --multithread"); | |||
mgb_assert( | |||
enable_multithread || enable_multithread_default, | |||
"core ids should be set after --multithread or --multithread-default"); | |||
std::stringstream id_stream(FLAGS_multi_thread_core_ids); | |||
std::string id; | |||
size_t thread_cnt = 0; | |||
@@ -126,7 +126,9 @@ void COprLibOption::init_extern_param(std::shared_ptr<ModelBase> model_ptr) { | |||
void COprLibOption::load_lib() { | |||
auto handle = dlopen(lib_path.c_str(), RTLD_LAZY); | |||
mgb_assert(handle, "failed to open c opr lib %s: %s", lib_path.c_str(), dlerror()); | |||
mgb_assert( | |||
handle, "failed to open c opr lib %s:\n errmsg: %s", lib_path.c_str(), | |||
dlerror()); | |||
const char* entry = MGB_C_OPR_INIT_FUNC_STR; | |||
auto func = dlsym(handle, entry); | |||
@@ -286,8 +286,13 @@ DEFINE_string( | |||
DEFINE_string(io_dump, "", "set the io dump file path in text format"); | |||
DEFINE_bool(io_dump_stdout, false, "dump io opr to stdout in text format"); | |||
DEFINE_bool(io_dump_stderr, false, "dump io opr to stderr in text format"); | |||
DEFINE_string(bin_io_dump, "", "set the io dump file path in binary format"); | |||
DEFINE_string(bin_out_dump, "", "set the out dump file path in binary format"); | |||
DEFINE_string( | |||
bin_io_dump, "", | |||
"set the io dump directory path where variable in binary format located"); | |||
DEFINE_string( | |||
bin_out_dump, "", | |||
"set the out dump directory path where output variable in binary format " | |||
"located"); | |||
DEFINE_bool(copy_to_host, false, "copy device data to host"); | |||
REGIST_OPTION_CREATOR(input, lar::InputOption::create_option); | |||
@@ -77,12 +77,6 @@ void WeightPreprocessOption::config_model_internel<ModelLite>( | |||
if (weight_preprocess) { | |||
LITE_WARN("enable weight-preprocess optimization"); | |||
model->get_config().options.weight_preprocess = true; | |||
//! FIXME: algo searcher enable weight preprocess for opencl( | |||
//! implement below has some problem); | |||
// #if MGB_OPENCL | |||
// megdnn::opencl::algo_searcher::AlgoSearcherBase:: | |||
// enable_weight_preprocess(); | |||
// #endif | |||
} | |||
} | |||
} | |||
@@ -95,11 +89,6 @@ void WeightPreprocessOption::config_model_internel<ModelMdl>( | |||
if (weight_preprocess) { | |||
mgb_log_warn("enable weight-preprocess optimization"); | |||
graph_option.graph_opt.enable_weight_preprocess(); | |||
//! FIXME: this implemment is not right | |||
// #if MGB_OPENCL | |||
// megdnn::opencl::algo_searcher::AlgoSearcherBase:: | |||
// enable_weight_preprocess(); | |||
// #endif | |||
} | |||
} | |||
} | |||
@@ -570,12 +559,6 @@ DEFINE_bool( | |||
enable_jit, false, | |||
" Execute supported operators with JIT(now only support NVRTC). " | |||
"Can only be used on Nvidia GPUs"); | |||
#if MGB_ENABLE_ANDROID_NN | |||
DEFINE_bool( | |||
android_nn, false, | |||
"Execute supported operators with Android NN. Can only be used " | |||
"with --cpu."); | |||
#endif | |||
#if MGB_ENABLE_TENSOR_RT | |||
DEFINE_bool( | |||
tensorrt, false, | |||
@@ -8,10 +8,10 @@ | |||
*/ | |||
#include "plugin_options.h" | |||
#include <map> | |||
#include "misc.h" | |||
#include "models/model_lite.h" | |||
#include "models/model_mdl.h" | |||
///////////////////// Plugin options/////////////////////////// | |||
namespace lar { | |||
@@ -153,7 +153,12 @@ void DebugOption::format_and_print( | |||
auto table = mgb::TextTable(tablename); | |||
auto&& network = model->get_lite_network(); | |||
table.padding(1); | |||
table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); | |||
table.align(mgb::TextTable::Align::Mid) | |||
.add("type") | |||
.add("name") | |||
.add("shape") | |||
.add("dtype") | |||
.eor(); | |||
auto to_string = [&](lite::Layout& layout) { | |||
std::string shape("{"); | |||
@@ -165,6 +170,19 @@ void DebugOption::format_and_print( | |||
shape.append("}"); | |||
return shape; | |||
}; | |||
auto get_dtype = [&](lite::Layout& layout) { | |||
std::map<LiteDataType, std::string> type_map = { | |||
{LiteDataType::LITE_FLOAT, "float32"}, | |||
{LiteDataType::LITE_HALF, "float16"}, | |||
{LiteDataType::LITE_INT64, "int64"}, | |||
{LiteDataType::LITE_INT, "int32"}, | |||
{LiteDataType::LITE_UINT, "uint32"}, | |||
{LiteDataType::LITE_INT16, "int16"}, | |||
{LiteDataType::LITE_UINT16, "uint16"}, | |||
{LiteDataType::LITE_INT8, "int8"}, | |||
{LiteDataType::LITE_UINT8, "uint8"}}; | |||
return type_map[layout.data_type]; | |||
}; | |||
auto input_name = network->get_all_input_name(); | |||
for (auto& i : input_name) { | |||
@@ -173,6 +191,7 @@ void DebugOption::format_and_print( | |||
.add("INPUT") | |||
.add(i) | |||
.add(to_string(layout)) | |||
.add(get_dtype(layout)) | |||
.eor(); | |||
} | |||
@@ -183,6 +202,7 @@ void DebugOption::format_and_print( | |||
.add("OUTPUT") | |||
.add(i) | |||
.add(to_string(layout)) | |||
.add(get_dtype(layout)) | |||
.eor(); | |||
} | |||
@@ -196,13 +216,28 @@ void DebugOption::format_and_print( | |||
const std::string& tablename, std::shared_ptr<ModelMdl> model) { | |||
auto table = mgb::TextTable(tablename); | |||
table.padding(1); | |||
table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); | |||
table.align(mgb::TextTable::Align::Mid) | |||
.add("type") | |||
.add("name") | |||
.add("shape") | |||
.add("dtype") | |||
.eor(); | |||
auto get_dtype = [&](megdnn::DType data_type) { | |||
std::map<megdnn::DTypeEnum, std::string> type_map = { | |||
{mgb::dtype::Float32().enumv(), "float32"}, | |||
{mgb::dtype::Int32().enumv(), "int32"}, | |||
{mgb::dtype::Int16().enumv(), "int16"}, | |||
{mgb::dtype::Uint16().enumv(), "uint16"}, | |||
{mgb::dtype::Int8().enumv(), "int8"}, | |||
{mgb::dtype::Uint8().enumv(), "uint8"}}; | |||
return type_map[data_type.enumv()]; | |||
}; | |||
for (auto&& i : model->get_mdl_load_result().tensor_map) { | |||
table.align(mgb::TextTable::Align::Mid) | |||
.add("INPUT") | |||
.add(i.first) | |||
.add(i.second->shape().to_string()) | |||
.add(get_dtype(i.second->dtype())) | |||
.eor(); | |||
} | |||
@@ -211,6 +246,7 @@ void DebugOption::format_and_print( | |||
.add("OUTPUT") | |||
.add(i.node()->name()) | |||
.add(i.shape().to_string()) | |||
.add(get_dtype(i.dtype())) | |||
.eor(); | |||
} | |||
@@ -358,18 +394,22 @@ DEFINE_double( | |||
DEFINE_bool( | |||
check_dispatch, false, | |||
"check whether an operator call dispatch on cpu comp nodes"); | |||
"check whether an operator call dispatch on cpu comp nodes This is used to " | |||
"find potential bugs in MegDNN"); | |||
DEFINE_string( | |||
check_var_value, "", | |||
"--check-var-value [interval]|[interval:init_idx], Enable " | |||
"VarValueChecker plugin. Refer to its doc for more details"); | |||
"VarValueChecker plugin. check values of all vars in a graph from given var " | |||
"ID(init_idx) with step interval"); | |||
#if MGB_ENABLE_JSON | |||
DEFINE_string( | |||
profile, "", | |||
"Write profiling result to given file. The output file is in " | |||
"JSON format"); | |||
DEFINE_string(profile_host, "", "focus on host time profiling For some backends"); | |||
DEFINE_string( | |||
profile_host, "", | |||
"focus on host time profiling For some backends(such as openCL)"); | |||
#endif | |||
///////////////////// Debug gflags/////////////////////////// | |||
@@ -94,15 +94,16 @@ void NormalStrategy::run_subline() { | |||
m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | |||
stage_config_model(); | |||
auto cur = timer.get_msecs(); | |||
printf("iter %lu/%lu: %.3fms (exec=%.3fms)\n", i, run_num, cur, exec_time); | |||
printf("iter %lu/%lu: %.3f ms (exec=%.3f ms)\n", i, run_num, cur, | |||
exec_time); | |||
time_sum += cur; | |||
time_sqrsum += cur * cur; | |||
fflush(stdout); | |||
min_time = std::min(min_time, cur); | |||
max_time = std::max(max_time, cur); | |||
} | |||
printf("\n=== finished test #%u: time=%.3fms avg_time=%.3fms " | |||
"sexec=%.3fms min=%.3fms max=%.3fms\n\n", | |||
printf("\n=== finished test #%u: time=%.3f ms avg_time=%.3f ms " | |||
"sexec=%.3f ms min=%.3f ms max=%.3f ms\n\n", | |||
idx, time_sum, time_sum / run_num, | |||
std::sqrt( | |||
(time_sqrsum * run_num - time_sum * time_sum) / | |||
@@ -121,7 +122,7 @@ void NormalStrategy::run_subline() { | |||
m_runtime_param.stage = RunStage::MODEL_RUNNING; | |||
stage_config_model(); | |||
if (!idx) { | |||
if (idx == 0) { | |||
warm_up(); | |||
} | |||
tot_time += run_iter(idx); | |||