Browse Source

feat(mgb): add tensorrt plugin support

GitOrigin-RevId: 5428b4f665
release-1.7
Megvii Engine Team 3 years ago
parent
commit
84baf3df1b
7 changed files with 190 additions and 3 deletions
  1. +3
    -3
      CMakeLists.txt
  2. +29
    -0
      cmake/tensorrt.cmake
  3. +2
    -0
      scripts/whl/windows/windows_build_whl.sh
  4. +2
    -0
      src/tensorrt/impl/tensorrt_runtime_opr.cpp
  5. +79
    -0
      src/tensorrt/test/make_trt_net.cpp
  6. +12
    -0
      src/tensorrt/test/make_trt_net.h
  7. +63
    -0
      src/tensorrt/test/tensorrt_runtime.cpp

+ 3
- 3
CMakeLists.txt View File

@@ -659,9 +659,9 @@ if(MGE_WITH_CUDA)
if(MGE_WITH_TRT)
if(MSVC OR WIN32)
message(STATUS "windows TRT_LIBRARY: ${TRT_LIBRARY}")
list(APPEND MGE_CUDA_LIBS ${TRT_LIBRARY})
list(APPEND MGE_CUDA_LIBS ${TRT_LIBRARY} ${TRT_PLUGIN_LIBRARY})
else()
list(APPEND MGE_CUDA_LIBS -Wl,--whole-archive libnvinfer -Wl,--no-whole-archive)
list(APPEND MGE_CUDA_LIBS -Wl,--whole-archive libnvinfer libnvinfer_plugin -Wl,--no-whole-archive)
endif()
if(TensorRT_VERSION_MAJOR GREATER_EQUAL 7)
message(STATUS "handle trt myelin lib after trt7")
@@ -738,7 +738,7 @@ if(MGE_WITH_CUDA)
endif()
else()
if(MGE_WITH_TRT)
list(APPEND MGE_CUDA_LIBS libnvinfer)
list(APPEND MGE_CUDA_LIBS libnvinfer libnvinfer_plugin)
if(TensorRT_VERSION_MAJOR GREATER_EQUAL 7)
message(STATUS "handle trt myelin lib after trt7")
list(APPEND MGE_CUDA_LIBS libmyelin)


+ 29
- 0
cmake/tensorrt.cmake View File

@@ -9,6 +9,12 @@ if(MGE_CUDA_USE_STATIC)
HINTS ${ALTER_LIBRARY_PATHS}
PATH_SUFFIXES lib lib64
DOC "TRT library." )
find_library(TRT_PLUGIN_LIBRARY
NAMES libnvinfer_plugin_static.a nvinfer_plugin.lib
PATHS ${ALTER_LD_LIBRARY_PATHS} ${TRT_ROOT_DIR} ${CMAKE_INSTALL_PREFIX}
HINTS ${ALTER_LIBRARY_PATHS}
PATH_SUFFIXES lib lib64
DOC "TRT plugin library." )
else()
find_library(TRT_LIBRARY
NAMES libnvinfer.so libnvinfer.dylib nvinfer.dll
@@ -16,11 +22,20 @@ else()
HINTS ${ALTER_LIBRARY_PATHS}
PATH_SUFFIXES lib lib64
DOC "TRT library." )
find_library(TRT_PLUGIN_LIBRARY
NAMES libnvinfer_plugin.so libnvinfer_plugin.dylib nvinfer_plugin.dll
PATHS ${ALTER_LD_LIBRARY_PATHS} ${TRT_ROOT_DIR} ${CMAKE_INSTALL_PREFIX}
HINTS ${ALTER_LIBRARY_PATHS}
PATH_SUFFIXES lib lib64
DOC "TRT plugin library." )
endif()

if(TRT_LIBRARY STREQUAL "TRT_LIBRARY-NOTFOUND")
message(FATAL_ERROR "Can not find TensorRT Library, please refer to scripts/cmake-build/BUILD_README.md to init TRT env")
endif()
if(TRT_PLUGIN_LIBRARY STREQUAL "TRT_PLUGIN_LIBRARY-NOTFOUND")
message(FATAL_ERROR "Can not find TensorRT Plugin Library, please refer to scripts/cmake-build/BUILD_README.md to init TRT env")
endif()

get_filename_component(__found_trt_root ${TRT_LIBRARY}/../.. REALPATH)
find_path(TRT_INCLUDE_DIR
@@ -28,10 +43,18 @@ find_path(TRT_INCLUDE_DIR
HINTS ${TRT_ROOT_DIR} ${CUDA_TOOLKIT_INCLUDE} ${__found_trt_root}
PATH_SUFFIXES include
DOC "Path to TRT include directory." )
find_path(TRT_PLUGIN_INCLUDE_DIR
NAMES NvInferPlugin.h
HINTS ${TRT_ROOT_DIR} ${CUDA_TOOLKIT_INCLUDE} ${__found_trt_root}
PATH_SUFFIXES include
DOC "Path to TRT plugin include directory." )

if(TRT_INCLUDE_DIR STREQUAL "TRT_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find TensorRT INCLUDE, please refer to scripts/cmake-build/BUILD_README.md to init TRT env")
endif()
if(TRT_PLUGIN_INCLUDE_DIR STREQUAL "TRT_PLUGIN_INCLUDE_DIR-NOTFOUND")
message(FATAL_ERROR "Can not find TensorRT Plugin INCLUDE, please refer to scripts/cmake-build/BUILD_README.md to init TRT env")
endif()

file(STRINGS "${TRT_INCLUDE_DIR}/NvInfer.h" TensorRT_MAJOR REGEX "^#define NV_TENSORRT_MAJOR [0-9]+.*$")
file(STRINGS "${TRT_INCLUDE_DIR}/NvInfer.h" TensorRT_MINOR REGEX "^#define NV_TENSORRT_MINOR [0-9]+.*$")
@@ -50,14 +73,20 @@ set(TRT_VERSION_STRING "${TensorRT_VERSION_MAJOR}.${TensorRT_VERSION_MINOR}.${Te

if(MGE_CUDA_USE_STATIC)
add_library(libnvinfer STATIC IMPORTED)
add_library(libnvinfer_plugin STATIC IMPORTED)
else()
add_library(libnvinfer SHARED IMPORTED)
add_library(libnvinfer_plugin SHARED IMPORTED)
endif()

set_target_properties(libnvinfer PROPERTIES
IMPORTED_LOCATION ${TRT_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${TRT_INCLUDE_DIR}
)
set_target_properties(libnvinfer_plugin PROPERTIES
IMPORTED_LOCATION ${TRT_PLUGIN_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${TRT_PLUGIN_INCLUDE_DIR}
)

message(STATUS "Found TensorRT: ${__found_trt_root} (found version: ${TRT_VERSION_STRING})")



+ 2
- 0
scripts/whl/windows/windows_build_whl.sh View File

@@ -70,6 +70,7 @@ fi

# config NVIDIA libs
TRT_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/TensorRT-6.0.1.5/lib/nvinfer.dll"
TRT_PLUGIN_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/TensorRT-6.0.1.5/lib/nvinfer_plugin.dll"
CUDNN_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/cudnn-10.1-windows10-x64-v7.6.5.32/cuda/bin/cudnn64_7.dll"
CUSOLVER_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cusolver64_10.dll"
CUBLAS_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cublas64_10.dll"
@@ -86,6 +87,7 @@ function depend_real_copy() {
if [ ${BUILD_WHL_CPU_ONLY} = "OFF" ]; then
echo "copy nvidia lib...."
cp "${TRT_LIB}" ${REAL_DST}
cp "${TRT_PLUGIN_LIB}" ${REAL_DST}
cp "${CUDNN_LIB}" ${REAL_DST}
cp "${CUSOLVER_LIB}" ${REAL_DST}
cp "${CUBLAS_LIB}" ${REAL_DST}


+ 2
- 0
src/tensorrt/impl/tensorrt_runtime_opr.cpp View File

@@ -19,6 +19,7 @@
#include <cinttypes>

#if MGB_ENABLE_TENSOR_RT
#include <NvInferPlugin.h>

using namespace mgb;
using namespace opr;
@@ -208,6 +209,7 @@ SymbolVarArray TensorRTRuntimeOpr::make(
!CompNode::get_device_count(CompNode::DeviceType::CUDA), SystemError,
"can not create TensorRTRuntimeOpr when CUDA is not available");
mgb_assert(!src.empty(), "no inputs provided");
initLibNvInferPlugins(&TensorRTOpr::Logger::instance(), "");
TensorRTUniquePtr<nvinfer1::IRuntime> runtime{
nvinfer1::createInferRuntime(TensorRTOpr::Logger::instance()), {}};
auto gpu_allocator = std::make_shared<GpuAllocator>(src[0].node()->comp_node());


+ 79
- 0
src/tensorrt/test/make_trt_net.cpp View File

@@ -25,6 +25,7 @@
#include "make_trt_net.h"
#include "megbrain/tensorrt/tensorrt_opr.h"

#include <NvInferPlugin.h>
#include <random>

using namespace mgb;
@@ -404,6 +405,84 @@ std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::ConcatConvTensorRTNetw
return std::make_pair(builder, network);
}

intl::ReshapeConcatTensorRTNetwork::ReshapeConcatTensorRTNetwork() {
host_x0 = gen({2, 2, 2, 2});
host_y0 = gen({2, 3, 2, 2});

graph = ComputingGraph::make();
x0 = Host2DeviceCopy::make(*graph, host_x0);
y0 = Host2DeviceCopy::make(*graph, host_y0);
auto x1 = opr::Reshape::make(x0, {2, 8, 1, 1}),
y1 = opr::Reshape::make(y0, {2, 12, 1, 1});
z = opr::Concat::make({x1, y1}, 1);
}

std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::ReshapeConcatTensorRTNetwork::
create_trt_network(bool has_batch_dim) {
initLibNvInferPlugins(&TensorRTOpr::Logger::instance(), "");

CompNode::load("xpu0").activate();
auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
#if NV_TENSOR_RT_VERSION >= 6001
nvinfer1::NetworkDefinitionCreationFlags flags;
::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
if (has_batch_dim)
flags = 1 << static_cast<int>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto network = builder->createNetworkV2(flags);
#else
auto network = builder->createNetwork();
#endif
nvinfer1::ITensor *data0, *data1;
#if NV_TENSOR_RT_VERSION >= 6001
if (has_batch_dim) {
data0 = network->addInput("x0", DataType::kFLOAT, Dims4{2, 2, 2, 2});
data1 = network->addInput("y0", DataType::kFLOAT, Dims4{2, 3, 2, 2});
} else {
data0 = network->addInput("x0", DataType::kFLOAT, Dims3{2, 2, 2});
data1 = network->addInput("y0", DataType::kFLOAT, Dims3{3, 2, 2});
}
{
nvinfer1::TensorFormats formats =
1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
data0->setAllowedFormats(formats);
data1->setAllowedFormats(formats);
}
#else
if (has_batch_dim) {
data0 = network->addInput("x0", DataType::kFLOAT, DimsNCHW{2, 2, 2, 2});
data1 = network->addInput("y0", DataType::kFLOAT, DimsNCHW{2, 3, 2, 2});
} else {
data0 = network->addInput("x0", DataType::kFLOAT, DimsCHW{2, 2, 2});
data1 = network->addInput("y0", DataType::kFLOAT, DimsCHW{3, 2, 2});
}
#endif
int axis = 1;
bool ignoreBatch = false;
nvinfer1::PluginField fields[2] = {
nvinfer1::PluginField{"axis", &axis, nvinfer1::PluginFieldType::kINT32, 1},
nvinfer1::PluginField{
"ignoreBatch", &ignoreBatch, nvinfer1::PluginFieldType::kINT32, 1},
};
nvinfer1::PluginFieldCollection fc{2, fields};

auto creator = getPluginRegistry()->getPluginCreator("FlattenConcat_TRT", "1", "");
TensorRTUniquePtr<nvinfer1::IPluginV2> plugin(
creator->createPlugin("FlattenConcat_TRT", &fc));
ITensor* inputTensors[] = {data0, data1};
auto flt_cct = network->addPluginV2(inputTensors, 2, *plugin);
mgb_assert(flt_cct != nullptr, "FlattenConcat_TRT is invalid");
network->markOutput(*flt_cct->getOutput(0));
#if NV_TENSOR_RT_VERSION >= 6001
{
nvinfer1::TensorFormats formats =
1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
flt_cct->getOutput(0)->setAllowedFormats(formats);
}
#endif
return std::make_pair(builder, network);
}

#pragma GCC diagnostic pop
#endif // MGB_ENABLE_TENSOR_RT



+ 12
- 0
src/tensorrt/test/make_trt_net.h View File

@@ -92,6 +92,18 @@ struct ConcatConvTensorRTNetwork {
bool has_batch_dim);
};

struct ReshapeConcatTensorRTNetwork {
HostTensorGenerator<> gen;
std::shared_ptr<HostTensorND> host_x0, host_y0;
std::shared_ptr<ComputingGraph> graph;
SymbolVar x0, y0, z;

ReshapeConcatTensorRTNetwork();

std::pair<nvinfer1::IBuilder*, INetworkDefinition*> create_trt_network(
bool has_batch_dim);
};

} // namespace intl
} // namespace opr
} // namespace mgb


+ 63
- 0
src/tensorrt/test/tensorrt_runtime.cpp View File

@@ -23,6 +23,7 @@
#include "megbrain/tensorrt/tensorrt_opr.h"
#include "megbrain/tensorrt/tensorrt_runtime_opr.h"

#include <fstream>
#include <random>

using namespace mgb;
@@ -244,6 +245,68 @@ TEST(TestOprTensorRT, IOFormatFree) {
}
#endif

TEST(TestOprTensorRT, FlattenConcatPlugin) {
REQUIRE_GPU(1);
intl::ReshapeConcatTensorRTNetwork net;
auto make_trt = [&net]() {
auto p = net.create_trt_network(false);
TensorRTUniquePtr<INetworkDefinition> trt_net{p.second, {}};
TensorRTUniquePtr<IBuilder> builder{p.first, {}};
builder->setMaxBatchSize(5);
#if NV_TENSOR_RT_VERSION >= 6001
TensorRTUniquePtr<IBuilderConfig> build_config{builder->createBuilderConfig()};
TensorRTUniquePtr<ICudaEngine> cuda_engine{
builder->buildEngineWithConfig(*trt_net, *build_config)};
#else
TensorRTUniquePtr<ICudaEngine> cuda_engine{builder->buildCudaEngine(*trt_net)};
#endif
TensorRTUniquePtr<IHostMemory> mem{cuda_engine->serialize(), {}};
return TensorRTRuntimeOpr::make(mem->data(), mem->size(), {net.x0, net.y0})[0];
};
auto z2 = make_trt();

HostTensorND host_z1;
HostTensorND host_z2;
auto func = net.graph->compile(
{make_callback_copy(net.z, host_z1), make_callback_copy(z2, host_z2)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_z1, host_z2);
}

TEST(TestOprTensorRT, ICudaEngine) {
REQUIRE_GPU(1);
CompNode::load("xpu0").activate();
std::ifstream engineFile("model.trt", std::ios::binary);
if (!engineFile)
return;

engineFile.seekg(0, engineFile.end);
long int fsize = engineFile.tellg();
engineFile.seekg(0, engineFile.beg);

std::vector<char> engineData(fsize);
engineFile.read(engineData.data(), fsize);
if (!engineFile)
return;

std::shared_ptr<ComputingGraph> graph;
graph = ComputingGraph::make();

HostTensorGenerator<> gen;
std::shared_ptr<HostTensorND> host_x0, host_y0;
host_x0 = gen({2, 3, 375, 500});
host_y0 = gen({2, 1, 1, 3});

SymbolVar x0 = Host2DeviceCopy::make(*graph, host_x0);
SymbolVar y0 = Host2DeviceCopy::make(*graph, host_y0);

auto z = TensorRTRuntimeOpr::make(engineData.data(), fsize, {x0, y0})[0];
HostTensorND host_z;

auto func = graph->compile({make_callback_copy(z, host_z)});
func->execute();
}

#endif // MGB_ENABLE_TENSOR_RT

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save