|
|
@@ -107,6 +107,7 @@ TensorRTRuntimeOpr::TensorRTRuntimeOpr( |
|
|
|
void TensorRTRuntimeOpr::get_output_var_shape( |
|
|
|
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { |
|
|
|
auto batch = inp_shape.at(0)[0]; |
|
|
|
m_manager.clear_trt_context(); |
|
|
|
m_manager.create_trt_context(this->comp_node(), inp_shape, m_engine.get()); |
|
|
|
auto get_mgb_shape = [&](int binding_idx) -> TensorShape { |
|
|
|
auto dims = m_engine->getBindingDimensions(binding_idx); |
|
|
|