diff --git a/src/tensorrt/test/make_trt_net.cpp b/src/tensorrt/test/make_trt_net.cpp index 2fa3c44b..0d33314f 100644 --- a/src/tensorrt/test/make_trt_net.cpp +++ b/src/tensorrt/test/make_trt_net.cpp @@ -46,6 +46,7 @@ intl::SimpleTensorRTNetwork::SimpleTensorRTNetwork() { std::pair intl::SimpleTensorRTNetwork::create_trt_network(bool has_batch_dim) { + CompNode::load("xpu0").activate(); Weights wt_filter{DataType::kFLOAT, nullptr, 0}, wt_bias{DataType::kFLOAT, nullptr, 0}; wt_filter.type = DataType::kFLOAT; @@ -205,6 +206,7 @@ intl::SimpleQuantizedTensorRTNetwork::SimpleQuantizedTensorRTNetwork() { std::pair intl::SimpleQuantizedTensorRTNetwork::create_trt_network( bool has_batch_dim) { + CompNode::load("xpu0").activate(); Weights wt_filter{DataType::kFLOAT, nullptr, 0}, wt_bias{DataType::kFLOAT, nullptr, 0}; wt_filter.type = DataType::kFLOAT; @@ -290,6 +292,7 @@ intl::ConcatConvTensorRTNetwork::ConcatConvTensorRTNetwork() { std::pair intl::ConcatConvTensorRTNetwork::create_trt_network(bool has_batch_dim) { + CompNode::load("xpu0").activate(); auto builder = createInferBuilder(TensorRTOpr::Logger::instance()); #if NV_TENSOR_RT_VERSION >= 6001 nvinfer1::NetworkDefinitionCreationFlags flags;