Browse Source

fix(src/tensorrt): trt7 manage all workspace

GitOrigin-RevId: b78c80c8f1
release-1.11.1
Megvii Engine Team 2 years ago
parent
commit
fb9d73b279
2 changed files with 25 additions and 1 deletions
  1. +8
    -1
      src/tensorrt/impl/tensorrt_opr.cpp
  2. +17
    -0
      src/tensorrt/include/megbrain/tensorrt/tensorrt_opr.h

+ 8
- 1
src/tensorrt/impl/tensorrt_opr.cpp View File

@@ -151,7 +151,11 @@ void TensorRTManager::create_trt_context(
nvinfer1::ICudaEngine* engine) { nvinfer1::ICudaEngine* engine) {
bool has_no_context = (!m_context); bool has_no_context = (!m_context);
if (has_no_context) { if (has_no_context) {
#if TENSOR_RT_MANAGE_ALL_WORKSPACE
m_context = {engine->createExecutionContext(), {}};
#else
m_context = {engine->createExecutionContextWithoutDeviceMemory(), {}}; m_context = {engine->createExecutionContextWithoutDeviceMemory(), {}};
#endif
} }
MGB_MARK_USED_VAR(cn); MGB_MARK_USED_VAR(cn);
#if NV_TENSOR_RT_VERSION >= 6001 #if NV_TENSOR_RT_VERSION >= 6001
@@ -333,6 +337,9 @@ void TensorRTManager::exec(
} }
} }
MGB_MARK_USED_VAR(is_trt_opr); MGB_MARK_USED_VAR(is_trt_opr);
#if TENSOR_RT_MANAGE_ALL_WORKSPACE
MGB_MARK_USED_VAR(should_reinit_device_memory);
#else
if (should_reinit_device_memory) { if (should_reinit_device_memory) {
mgb_assert( mgb_assert(
opr->output().back()->shape()[0] == intl::workspace_size(engine) && opr->output().back()->shape()[0] == intl::workspace_size(engine) &&
@@ -340,7 +347,7 @@ void TensorRTManager::exec(
m_context->setDeviceMemory(workspace_ptr); m_context->setDeviceMemory(workspace_ptr);
m_device_workspace_memory_ptr = workspace_ptr; m_device_workspace_memory_ptr = workspace_ptr;
} }
#endif
auto&& env = mgb::CompNodeEnv::from_comp_node(comp_node); auto&& env = mgb::CompNodeEnv::from_comp_node(comp_node);


bool exec_success = false; bool exec_success = false;


+ 17
- 0
src/tensorrt/include/megbrain/tensorrt/tensorrt_opr.h View File

@@ -28,6 +28,18 @@ enum class Empty : int32_t {};
#define TENSORRT_NO_EXCEPT(api) #define TENSORRT_NO_EXCEPT(api)
#endif #endif


#if (NV_TENSOR_RT_VERSION >= 7000)
//! FIXME: trt7.2.2.3 leak memory in setDeviceMemory API, now trt malloc workspace
//! self, megengine do not alloc any workspace
#define TENSOR_RT_MANAGE_ALL_WORKSPACE 1
#else
#define TENSOR_RT_MANAGE_ALL_WORKSPACE 0
#endif

#if NV_TENSOR_RT_VERSION >= 8000
#error "if trt8 fix https://github.com/NVIDIA/TensorRT/issues/2290, try TENSOR_RT_MANAGE_ALL_WORKSPACE=0"
#endif

namespace mgb { namespace mgb {
namespace opr { namespace opr {


@@ -73,7 +85,12 @@ public:
}; };


static inline size_t workspace_size(nvinfer1::ICudaEngine* engine) { static inline size_t workspace_size(nvinfer1::ICudaEngine* engine) {
#if TENSOR_RT_MANAGE_ALL_WORKSPACE
MGB_MARK_USED_VAR(engine);
return 0;
#else
return engine->getDeviceMemorySize(); return engine->getDeviceMemorySize();
#endif
} }
} // namespace intl } // namespace intl




Loading…
Cancel
Save