|
|
@@ -194,21 +194,25 @@ void TensorRTRuntimeOpr::init_output_dtype() { |
|
|
|
idx++; |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t i = 0; i < output().size(); ++i) { |
|
|
|
size_t out = 0; |
|
|
|
for (; out < output().size() - 1; ++out) { |
|
|
|
dt_trt = get_dtype_from_trt(m_engine->getBindingDataType(idx)); |
|
|
|
mgb_assert( |
|
|
|
dt_trt.valid(), |
|
|
|
"output dtype checking failed: invalid dtype returned."); |
|
|
|
if (dt_trt.enumv() == DTypeEnum::QuantizedS8) { |
|
|
|
mgb_assert( |
|
|
|
output(i)->dtype().valid(), |
|
|
|
output(out)->dtype().valid(), |
|
|
|
"user should specify scale of output tensor of " |
|
|
|
"TensorRTRuntimeOpr."); |
|
|
|
} |
|
|
|
if (!output(i)->dtype().valid()) |
|
|
|
output(i)->dtype(dt_trt); |
|
|
|
if (!output(out)->dtype().valid()) |
|
|
|
output(out)->dtype(dt_trt); |
|
|
|
idx++; |
|
|
|
} |
|
|
|
//! workspace |
|
|
|
if (!output(out)->dtype().valid()) |
|
|
|
output(out)->dtype(dtype::Byte()); |
|
|
|
} |
|
|
|
|
|
|
|
SymbolVarArray TensorRTRuntimeOpr::make( |
|
|
|