Browse Source

test(mgb): fix tensorrt tests missing cudaSetDevice

GitOrigin-RevId: faeb6ae070
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
ab9dfbcefc
1 changed files with 3 additions and 0 deletions
  1. +3
    -0
      src/tensorrt/test/make_trt_net.cpp

+ 3
- 0
src/tensorrt/test/make_trt_net.cpp View File

@@ -46,6 +46,7 @@ intl::SimpleTensorRTNetwork::SimpleTensorRTNetwork() {

std::pair<nvinfer1::IBuilder*, INetworkDefinition*>
intl::SimpleTensorRTNetwork::create_trt_network(bool has_batch_dim) {
CompNode::load("xpu0").activate();
Weights wt_filter{DataType::kFLOAT, nullptr, 0},
wt_bias{DataType::kFLOAT, nullptr, 0};
wt_filter.type = DataType::kFLOAT;
@@ -205,6 +206,7 @@ intl::SimpleQuantizedTensorRTNetwork::SimpleQuantizedTensorRTNetwork() {
std::pair<nvinfer1::IBuilder*, INetworkDefinition*>
intl::SimpleQuantizedTensorRTNetwork::create_trt_network(
bool has_batch_dim) {
CompNode::load("xpu0").activate();
Weights wt_filter{DataType::kFLOAT, nullptr, 0},
wt_bias{DataType::kFLOAT, nullptr, 0};
wt_filter.type = DataType::kFLOAT;
@@ -290,6 +292,7 @@ intl::ConcatConvTensorRTNetwork::ConcatConvTensorRTNetwork() {

std::pair<nvinfer1::IBuilder*, INetworkDefinition*>
intl::ConcatConvTensorRTNetwork::create_trt_network(bool has_batch_dim) {
CompNode::load("xpu0").activate();
auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
#if NV_TENSOR_RT_VERSION >= 6001
nvinfer1::NetworkDefinitionCreationFlags flags;


Loading…
Cancel
Save