diff --git a/src/tensorrt/impl/tensorrt_runtime_opr.cpp b/src/tensorrt/impl/tensorrt_runtime_opr.cpp index f1864ee9..ac6d7dfa 100644 --- a/src/tensorrt/impl/tensorrt_runtime_opr.cpp +++ b/src/tensorrt/impl/tensorrt_runtime_opr.cpp @@ -194,21 +194,25 @@ void TensorRTRuntimeOpr::init_output_dtype() { idx++; } - for (size_t i = 0; i < output().size(); ++i) { + size_t out = 0; + for (; out < output().size() - 1; ++out) { dt_trt = get_dtype_from_trt(m_engine->getBindingDataType(idx)); mgb_assert( dt_trt.valid(), "output dtype checking failed: invalid dtype returned."); if (dt_trt.enumv() == DTypeEnum::QuantizedS8) { mgb_assert( - output(i)->dtype().valid(), + output(out)->dtype().valid(), "user should specify scale of output tensor of " "TensorRTRuntimeOpr."); } - if (!output(i)->dtype().valid()) - output(i)->dtype(dt_trt); + if (!output(out)->dtype().valid()) + output(out)->dtype(dt_trt); idx++; } + //! workspace + if (!output(out)->dtype().valid()) + output(out)->dtype(dtype::Byte()); } SymbolVarArray TensorRTRuntimeOpr::make(