Browse Source

fix(mgb): fix TensorRT missing cudaSetDevice

GitOrigin-RevId: 40eb119e48
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
39bd66fc63
2 changed files with 3 additions and 0 deletions
  1. +1
    -0
      src/tensorrt/impl/opr_replace.cpp
  2. +2
    -0
      src/tensorrt/impl/tensorrt_opr.cpp

+ 1
- 0
src/tensorrt/impl/opr_replace.cpp View File

@@ -1404,6 +1404,7 @@ void TensorRTReplacePass::Impl::detect_replace() {


m_graph_map[opr] = max; m_graph_map[opr] = max;
if (max > m_tensorrt_graphs.size()) { if (max > m_tensorrt_graphs.size()) {
opr->output(0)->comp_node().activate();
m_tensorrt_graphs.push_back( m_tensorrt_graphs.push_back(
std::make_shared<TensorRTGraph>(feature_bits)); std::make_shared<TensorRTGraph>(feature_bits));
} }


+ 2
- 0
src/tensorrt/impl/tensorrt_opr.cpp View File

@@ -533,6 +533,7 @@ void TensorRTOpr::get_output_var_shape(const TensorShapeArray& inp_shape,
} }


if (!engine_valid) { if (!engine_valid) {
comp_node().activate();
// If a context created by a cuda engine, the context must be destroyed // If a context created by a cuda engine, the context must be destroyed
// before the corresponding cuda engine. Otherwise, a segmentfault will // before the corresponding cuda engine. Otherwise, a segmentfault will
// occur. // occur.
@@ -576,6 +577,7 @@ void TensorRTOpr::build_engine_from_cache() {
TensorRTEngineCache::make_key_from_trt_opr(this)); TensorRTEngineCache::make_key_from_trt_opr(this));
if (!ret.valid()) if (!ret.valid())
return; return;
comp_node().activate();
auto engine = runtime->deserializeCudaEngine( auto engine = runtime->deserializeCudaEngine(
reinterpret_cast<const void*>(ret->ptr), ret->size, nullptr); reinterpret_cast<const void*>(ret->ptr), ret->size, nullptr);
mgb_assert(engine, "failed to deserialize ICudaEngine"); mgb_assert(engine, "failed to deserialize ICudaEngine");


Loading…
Cancel
Save