Browse Source

fix(mgb): fix "TRT_ERROR: INVALID_ARGUMENT: Get binding data type failed."

GitOrigin-RevId: d9601cb15b
release-1.10
Megvii Engine Team 3 years ago
parent
commit
b3f79966fd
1 changed files with 8 additions and 4 deletions
  1. +8
    -4
      src/tensorrt/impl/tensorrt_runtime_opr.cpp

+ 8
- 4
src/tensorrt/impl/tensorrt_runtime_opr.cpp View File

@@ -194,21 +194,25 @@ void TensorRTRuntimeOpr::init_output_dtype() {
idx++;
}

for (size_t i = 0; i < output().size(); ++i) {
size_t out = 0;
for (; out < output().size() - 1; ++out) {
dt_trt = get_dtype_from_trt(m_engine->getBindingDataType(idx));
mgb_assert(
dt_trt.valid(),
"output dtype checking failed: invalid dtype returned.");
if (dt_trt.enumv() == DTypeEnum::QuantizedS8) {
mgb_assert(
output(i)->dtype().valid(),
output(out)->dtype().valid(),
"user should specify scale of output tensor of "
"TensorRTRuntimeOpr.");
}
if (!output(i)->dtype().valid())
output(i)->dtype(dt_trt);
if (!output(out)->dtype().valid())
output(out)->dtype(dt_trt);
idx++;
}
//! workspace
if (!output(out)->dtype().valid())
output(out)->dtype(dtype::Byte());
}

SymbolVarArray TensorRTRuntimeOpr::make(


Loading…
Cancel
Save