diff --git a/src/tensorrt/impl/tensorrt_runtime_opr.cpp b/src/tensorrt/impl/tensorrt_runtime_opr.cpp index f0849eb0..822116bb 100644 --- a/src/tensorrt/impl/tensorrt_runtime_opr.cpp +++ b/src/tensorrt/impl/tensorrt_runtime_opr.cpp @@ -107,6 +107,7 @@ TensorRTRuntimeOpr::TensorRTRuntimeOpr( void TensorRTRuntimeOpr::get_output_var_shape( const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { auto batch = inp_shape.at(0)[0]; + m_manager.clear_trt_context(); m_manager.create_trt_context(this->comp_node(), inp_shape, m_engine.get()); auto get_mgb_shape = [&](int binding_idx) -> TensorShape { auto dims = m_engine->getBindingDimensions(binding_idx);