You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

make_trt_net.cpp 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. /**
  2. * \file src/tensorrt/test/make_trt_net.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/blas.h"
  12. #include "megbrain/opr/dnn/convolution.h"
  13. #include "megbrain/opr/io.h"
  14. #include "megbrain/opr/tensor_manip.h"
  15. #include "megbrain/opr/basic_arith.h"
  16. #include "megbrain/plugin/profiler.h"
  17. #include "megbrain/test/helper.h"
  18. #include "megbrain/utils/debug.h"
  19. #if MGB_ENABLE_TENSOR_RT
  20. #pragma GCC diagnostic push
  21. #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
  22. #include "make_trt_net.h"
  23. #include "megbrain/tensorrt/tensorrt_opr.h"
  24. #include <NvInferPlugin.h>
  25. #include <random>
  26. using namespace mgb;
  27. using namespace opr;
  28. using namespace nvinfer1;
  29. intl::SimpleTensorRTNetwork::SimpleTensorRTNetwork() {
  30. host_x = gen({5, 23, 28, 28});
  31. host_w = gen({32, 23, 3, 3});
  32. host_b = gen({1, 32, 1, 1});
  33. graph = ComputingGraph::make();
  34. x = Host2DeviceCopy::make(*graph, host_x);
  35. auto w = Host2DeviceCopy::make(*graph, host_w),
  36. b = Host2DeviceCopy::make(*graph, host_b), y0 = opr::Convolution::make(x, w);
  37. y = y0 + b;
  38. }
  39. std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::SimpleTensorRTNetwork::
  40. create_trt_network(bool has_batch_dim) {
  41. CompNode::load("xpu0").activate();
  42. Weights wt_filter{DataType::kFLOAT, nullptr, 0},
  43. wt_bias{DataType::kFLOAT, nullptr, 0};
  44. wt_filter.type = DataType::kFLOAT;
  45. wt_bias.type = DataType::kFLOAT;
  46. wt_filter.values = host_w->raw_ptr();
  47. wt_bias.values = host_b->raw_ptr();
  48. wt_filter.count = host_w->shape().total_nr_elems();
  49. wt_bias.count = host_b->shape().total_nr_elems();
  50. auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
  51. #if NV_TENSOR_RT_VERSION >= 6001
  52. nvinfer1::NetworkDefinitionCreationFlags flags;
  53. ::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
  54. if (has_batch_dim)
  55. flags = 1 << static_cast<int>(
  56. nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
  57. auto network = builder->createNetworkV2(flags);
  58. #else
  59. auto network = builder->createNetwork();
  60. #endif
  61. nvinfer1::ITensor* data;
  62. #if NV_TENSOR_RT_VERSION >= 6001
  63. if (has_batch_dim) {
  64. data = network->addInput("data", DataType::kFLOAT, Dims4{5, 23, 28, 28});
  65. } else {
  66. data = network->addInput("data", DataType::kFLOAT, Dims3{23, 28, 28});
  67. }
  68. {
  69. nvinfer1::TensorFormats formats =
  70. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  71. data->setAllowedFormats(formats);
  72. }
  73. #else
  74. if (has_batch_dim) {
  75. data = network->addInput("data", DataType::kFLOAT, DimsNCHW{5, 23, 28, 28});
  76. } else {
  77. data = network->addInput("data", DataType::kFLOAT, DimsCHW{23, 28, 28});
  78. }
  79. #endif
  80. mgb_assert(data != nullptr, "data is invalid");
  81. auto conv1 = network->addConvolution(*data, 32, DimsHW{3, 3}, wt_filter, wt_bias);
  82. mgb_assert(conv1 != nullptr, "conv1 is invalid");
  83. conv1->setStride(DimsHW{1, 1});
  84. conv1->getOutput(0)->setName("prob");
  85. network->markOutput(*conv1->getOutput(0));
  86. #if NV_TENSOR_RT_VERSION >= 6001
  87. {
  88. nvinfer1::TensorFormats formats =
  89. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  90. conv1->getOutput(0)->setAllowedFormats(formats);
  91. }
  92. #endif
  93. return std::make_pair(builder, network);
  94. }
  95. intl::BatchedTensorRTNetwork::BatchedTensorRTNetwork() {
  96. host_x = gen({23, 28, 28});
  97. graph = ComputingGraph::make();
  98. x = Host2DeviceCopy::make(*graph, host_x);
  99. opr::Reduce::Param param1{Reduce::Mode::SUM, 0, Reduce::Param::DataType::DEFAULT};
  100. opr::Reduce::Param param2{Reduce::Mode::SUM, 1, Reduce::Param::DataType::DEFAULT};
  101. auto y0 = opr::Reduce::make(x, param1);
  102. auto y1 = opr::Reduce::make(y0, param2);
  103. TensorShape tshp{1, 28};
  104. y = opr::Reshape::make(y1, tshp);
  105. }
  106. std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::BatchedTensorRTNetwork::
  107. create_trt_network(bool has_batch_dim) {
  108. CompNode::load("xpu0").activate();
  109. auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
  110. #if NV_TENSOR_RT_VERSION >= 6001
  111. nvinfer1::NetworkDefinitionCreationFlags flags;
  112. ::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
  113. if (has_batch_dim)
  114. flags = 1 << static_cast<int>(
  115. nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
  116. auto network = builder->createNetworkV2(flags);
  117. #else
  118. auto network = builder->createNetwork();
  119. #endif
  120. nvinfer1::ITensor* data;
  121. #if NV_TENSOR_RT_VERSION >= 6001
  122. if (has_batch_dim) {
  123. data = network->addInput("data", DataType::kFLOAT, Dims4{1, 23, 28, 28});
  124. } else {
  125. data = network->addInput("data", DataType::kFLOAT, Dims3{23, 28, 28});
  126. }
  127. {
  128. nvinfer1::TensorFormats formats =
  129. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  130. data->setAllowedFormats(formats);
  131. }
  132. #else
  133. if (has_batch_dim) {
  134. data = network->addInput("data", DataType::kFLOAT, DimsNCHW{1, 23, 28, 28});
  135. } else {
  136. data = network->addInput("data", DataType::kFLOAT, DimsCHW{23, 28, 28});
  137. }
  138. #endif
  139. mgb_assert(data != nullptr, "data is invalid");
  140. auto reduce1 = network->addReduce(*data, nvinfer1::ReduceOperation::kSUM, 3, false);
  141. mgb_assert(reduce1 != nullptr, "reduce1 is invalid");
  142. reduce1->getOutput(0)->setName("prob");
  143. network->markOutput(*reduce1->getOutput(0));
  144. #if NV_TENSOR_RT_VERSION >= 6001
  145. {
  146. nvinfer1::TensorFormats formats =
  147. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  148. reduce1->getOutput(0)->setAllowedFormats(formats);
  149. }
  150. #endif
  151. return std::make_pair(builder, network);
  152. }
  153. intl::SimpleQuantizedTensorRTNetwork::SimpleQuantizedTensorRTNetwork() {
  154. host_x = range_gen({32, 8, 28, 28});
  155. host_w = weight_gen({8, 8, 3, 3});
  156. host_b = range_gen({1, 8, 1, 1});
  157. {
  158. void* w_ptr = host_w->raw_ptr();
  159. float* ptr = reinterpret_cast<float*>(w_ptr);
  160. ptr[0] = -127 * 1.1f;
  161. ptr[1] = 127 * 1.1f;
  162. }
  163. graph = ComputingGraph::make();
  164. auto mkvar = [this](const char* name, const std::shared_ptr<HostTensorND>& host_ts,
  165. const DType& dtype) {
  166. return opr::TypeCvt::make(
  167. opr::Host2DeviceCopy::make(*graph, host_ts).rename(name), dtype);
  168. };
  169. auto mkcvar = [this](const char* name, const std::shared_ptr<HostTensorND>& host_ts,
  170. const DType& dtype) {
  171. return opr::TypeCvt::make(
  172. opr::SharedDeviceTensor::make(*graph, *host_ts).rename(name), dtype);
  173. };
  174. x = mkvar("x", host_x, dtype::Float32());
  175. quantized_x = mkvar("quantized_x", host_x, dtype::QuantizedS8(1.2f));
  176. auto float_w = mkcvar("float_w", host_w, dtype::Float32()),
  177. float_b = mkcvar("float_b", host_b, dtype::Float32()),
  178. w = opr::TypeCvt::make(float_w, dtype::QuantizedS8(1.1f)),
  179. b = opr::TypeCvt::make(float_b, dtype::QuantizedS32(1.2f * 1.1f));
  180. {
  181. auto xshp = opr::GetVarShape::make(quantized_x);
  182. auto cv = [this](int v) { return quantized_x.make_scalar(v); };
  183. auto sub = [&xshp, &cv](int idx) {
  184. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  185. };
  186. auto tshp = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  187. quantized_x = opr::Reshape::make(quantized_x, tshp);
  188. quantized_x = opr::Dimshuffle::make(quantized_x, {0, 1, 3, 4, 2});
  189. }
  190. {
  191. auto wshp = opr::GetVarShape::make(w);
  192. auto cv = [&w](int v) { return w.make_scalar(v); };
  193. auto sub = [&wshp, &cv](int idx) {
  194. return opr::IndexAt::make(wshp, {{0, cv(idx)}});
  195. };
  196. auto tshp = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  197. w = opr::Reshape::make(w, tshp);
  198. w = opr::Dimshuffle::make(w, {0, 1, 3, 4, 2});
  199. }
  200. {
  201. auto bshp = opr::GetVarShape::make(b);
  202. auto cv = [&b](int v) { return b.make_scalar(v); };
  203. auto sub = [&bshp, &cv](int idx) {
  204. return opr::IndexAt::make(bshp, {{0, cv(idx)}});
  205. };
  206. auto tshp = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  207. b = opr::Reshape::make(b, tshp);
  208. b = opr::Dimshuffle::make(b, {0, 1, 3, 4, 2});
  209. }
  210. opr::ConvBias::Param param;
  211. param.format = opr::ConvBias::Param::Format::NCHW4;
  212. param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
  213. param.stride_h = param.stride_w = 1;
  214. param.pad_h = param.pad_w = 1;
  215. quantized_y = opr::ConvBias::make(
  216. quantized_x, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8(1.1f)});
  217. param.format = opr::ConvBias::Param::Format::NCHW;
  218. y = opr::ConvBias::make(
  219. x, float_w, float_b, param, {}, OperatorNodeConfig{dtype::Float32()});
  220. auto yshp = opr::GetVarShape::make(quantized_y);
  221. auto cv = [this](int v) { return quantized_y.make_scalar(v); };
  222. auto sub = [&yshp, &cv](int idx) {
  223. return opr::IndexAt::make(yshp, {{0, cv(idx)}});
  224. };
  225. auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
  226. quantized_y = opr::Dimshuffle::make(quantized_y, {0, 1, 4, 2, 3});
  227. quantized_y = opr::Reshape::make(quantized_y, tshp);
  228. quantized_y = TypeCvt::make(quantized_y, dtype::Float32());
  229. }
  230. std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::
  231. SimpleQuantizedTensorRTNetwork::create_trt_network(bool has_batch_dim) {
  232. CompNode::load("xpu0").activate();
  233. Weights wt_filter{DataType::kFLOAT, nullptr, 0},
  234. wt_bias{DataType::kFLOAT, nullptr, 0};
  235. wt_filter.type = DataType::kFLOAT;
  236. wt_bias.type = DataType::kFLOAT;
  237. wt_filter.values = host_w->raw_ptr();
  238. wt_bias.values = host_b->raw_ptr();
  239. wt_filter.count = host_w->shape().total_nr_elems();
  240. wt_bias.count = host_b->shape().total_nr_elems();
  241. auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
  242. #if NV_TENSOR_RT_VERSION >= 6001
  243. nvinfer1::NetworkDefinitionCreationFlags flags;
  244. ::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
  245. if (has_batch_dim)
  246. flags = 1 << static_cast<int>(
  247. nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
  248. auto network = builder->createNetworkV2(flags);
  249. #else
  250. auto network = builder->createNetwork();
  251. #endif
  252. nvinfer1::ITensor* data;
  253. #if NV_TENSOR_RT_VERSION >= 6001
  254. if (has_batch_dim) {
  255. data = network->addInput("data", DataType::kFLOAT, Dims4{32, 8, 28, 28});
  256. } else {
  257. data = network->addInput("data", DataType::kFLOAT, Dims3{8, 28, 28});
  258. }
  259. {
  260. nvinfer1::TensorFormats formats =
  261. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  262. data->setAllowedFormats(formats);
  263. }
  264. #else
  265. if (has_batch_dim) {
  266. data = network->addInput("data", DataType::kFLOAT, DimsNCHW{32, 8, 28, 28});
  267. } else {
  268. data = network->addInput("data", DataType::kFLOAT, DimsCHW{8, 28, 28});
  269. }
  270. #endif
  271. data->setDynamicRange(-127.f * 1.2f, 127.f * 1.2f);
  272. mgb_assert(data != nullptr, "data is invalid");
  273. auto add_conv = [&](const char* name, nvinfer1::ITensor* inp) {
  274. auto conv = network->addConvolution(*inp, 8, DimsHW{3, 3}, wt_filter, wt_bias);
  275. mgb_assert(conv != nullptr, "conv1 is invalid");
  276. conv->setName(name);
  277. conv->setStride(DimsHW{1, 1});
  278. conv->setPadding(DimsHW{1, 1});
  279. conv->getOutput(0)->setDynamicRange(-127.f * 1.1f, 127.f * 1.1f);
  280. // conv->setPrecision(nvinfer1::DataType::kINT8);
  281. return conv->getOutput(0);
  282. };
  283. auto out = add_conv("conv1", data);
  284. out->setName("prob");
  285. #if NV_TENSOR_RT_VERSION >= 6001
  286. {
  287. nvinfer1::TensorFormats formats =
  288. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  289. out->setAllowedFormats(formats);
  290. }
  291. #endif
  292. network->markOutput(*out);
  293. return std::make_pair(builder, network);
  294. }
  295. intl::ConcatConvTensorRTNetwork::ConcatConvTensorRTNetwork() {
  296. host_x0 = gen({5, 23, 14, 28});
  297. host_x1 = gen({5, 23, 14, 28});
  298. host_w = gen({32, 46, 3, 3});
  299. host_b = gen({1, 32, 1, 1});
  300. graph = ComputingGraph::make();
  301. x0 = Host2DeviceCopy::make(*graph, host_x0);
  302. x1 = Host2DeviceCopy::make(*graph, host_x1);
  303. auto y0 = opr::Concat::make({x0, x1}, 1), w = Host2DeviceCopy::make(*graph, host_w),
  304. b = Host2DeviceCopy::make(*graph, host_b), y1 = opr::Convolution::make(y0, w);
  305. y = y1 + b;
  306. }
  307. std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::ConcatConvTensorRTNetwork::
  308. create_trt_network(bool has_batch_dim) {
  309. CompNode::load("xpu0").activate();
  310. auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
  311. #if NV_TENSOR_RT_VERSION >= 6001
  312. nvinfer1::NetworkDefinitionCreationFlags flags;
  313. ::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
  314. if (has_batch_dim)
  315. flags = 1 << static_cast<int>(
  316. nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
  317. auto network = builder->createNetworkV2(flags);
  318. #else
  319. auto network = builder->createNetwork();
  320. #endif
  321. ITensor *data0, *data1;
  322. #if NV_TENSOR_RT_VERSION >= 6001
  323. if (has_batch_dim) {
  324. data0 = network->addInput("x0", DataType::kFLOAT, Dims4{5, 23, 14, 28});
  325. data1 = network->addInput("x1", DataType::kFLOAT, Dims4{5, 23, 14, 28});
  326. } else {
  327. data0 = network->addInput("x0", DataType::kFLOAT, Dims3{23, 14, 28});
  328. data1 = network->addInput("x1", DataType::kFLOAT, Dims3{23, 14, 28});
  329. }
  330. {
  331. nvinfer1::TensorFormats formats =
  332. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  333. data0->setAllowedFormats(formats);
  334. data1->setAllowedFormats(formats);
  335. }
  336. #else
  337. if (has_batch_dim) {
  338. data0 = network->addInput("x0", DataType::kFLOAT, DimsNCHW{5, 23, 14, 28});
  339. data1 = network->addInput("x1", DataType::kFLOAT, DimsNCHW{5, 23, 14, 28});
  340. } else {
  341. data0 = network->addInput("x0", DataType::kFLOAT, DimsCHW{23, 14, 28});
  342. data1 = network->addInput("x1", DataType::kFLOAT, DimsCHW{23, 14, 28});
  343. }
  344. #endif
  345. ITensor* inputTensors[] = {data0, data1};
  346. auto concat = network->addConcatenation(inputTensors, 2);
  347. mgb_assert(concat != nullptr, "concat is null!");
  348. concat->setName("concat0");
  349. if (has_batch_dim) {
  350. concat->setAxis(1);
  351. } else {
  352. concat->setAxis(0);
  353. }
  354. Weights wt_filter{DataType::kFLOAT, host_w->raw_ptr(), 0},
  355. wt_bias{DataType::kFLOAT, host_b->raw_ptr(), 0};
  356. wt_filter.count = host_w->shape().total_nr_elems();
  357. wt_bias.count = host_b->shape().total_nr_elems();
  358. auto conv1 = network->addConvolution(
  359. *concat->getOutput(0), 32, DimsHW{3, 3}, wt_filter, wt_bias);
  360. mgb_assert(conv1 != nullptr, "conv1 is invalid");
  361. conv1->setName("conv1");
  362. conv1->setStride(DimsHW{1, 1});
  363. conv1->getOutput(0)->setName("convOut");
  364. network->markOutput(*conv1->getOutput(0));
  365. #if NV_TENSOR_RT_VERSION >= 6001
  366. {
  367. nvinfer1::TensorFormats formats =
  368. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  369. conv1->getOutput(0)->setAllowedFormats(formats);
  370. }
  371. #endif
  372. return std::make_pair(builder, network);
  373. }
  374. intl::ReshapeConcatTensorRTNetwork::ReshapeConcatTensorRTNetwork() {
  375. host_x0 = gen({2, 2, 2, 2});
  376. host_y0 = gen({2, 3, 2, 2});
  377. graph = ComputingGraph::make();
  378. x0 = Host2DeviceCopy::make(*graph, host_x0);
  379. y0 = Host2DeviceCopy::make(*graph, host_y0);
  380. auto x1 = opr::Reshape::make(x0, {2, 8, 1, 1}),
  381. y1 = opr::Reshape::make(y0, {2, 12, 1, 1});
  382. z = opr::Concat::make({x1, y1}, 1);
  383. }
  384. std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::ReshapeConcatTensorRTNetwork::
  385. create_trt_network(bool has_batch_dim) {
  386. initLibNvInferPlugins(&TensorRTOpr::Logger::instance(), "");
  387. CompNode::load("xpu0").activate();
  388. auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
  389. #if NV_TENSOR_RT_VERSION >= 6001
  390. nvinfer1::NetworkDefinitionCreationFlags flags;
  391. ::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
  392. if (has_batch_dim)
  393. flags = 1 << static_cast<int>(
  394. nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
  395. auto network = builder->createNetworkV2(flags);
  396. #else
  397. auto network = builder->createNetwork();
  398. #endif
  399. nvinfer1::ITensor *data0, *data1;
  400. #if NV_TENSOR_RT_VERSION >= 6001
  401. if (has_batch_dim) {
  402. data0 = network->addInput("x0", DataType::kFLOAT, Dims4{2, 2, 2, 2});
  403. data1 = network->addInput("y0", DataType::kFLOAT, Dims4{2, 3, 2, 2});
  404. } else {
  405. data0 = network->addInput("x0", DataType::kFLOAT, Dims3{2, 2, 2});
  406. data1 = network->addInput("y0", DataType::kFLOAT, Dims3{3, 2, 2});
  407. }
  408. {
  409. nvinfer1::TensorFormats formats =
  410. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  411. data0->setAllowedFormats(formats);
  412. data1->setAllowedFormats(formats);
  413. }
  414. #else
  415. if (has_batch_dim) {
  416. data0 = network->addInput("x0", DataType::kFLOAT, DimsNCHW{2, 2, 2, 2});
  417. data1 = network->addInput("y0", DataType::kFLOAT, DimsNCHW{2, 3, 2, 2});
  418. } else {
  419. data0 = network->addInput("x0", DataType::kFLOAT, DimsCHW{2, 2, 2});
  420. data1 = network->addInput("y0", DataType::kFLOAT, DimsCHW{3, 2, 2});
  421. }
  422. #endif
  423. int axis = 1;
  424. bool ignoreBatch = false;
  425. nvinfer1::PluginField fields[2] = {
  426. nvinfer1::PluginField{"axis", &axis, nvinfer1::PluginFieldType::kINT32, 1},
  427. nvinfer1::PluginField{
  428. "ignoreBatch", &ignoreBatch, nvinfer1::PluginFieldType::kINT32, 1},
  429. };
  430. nvinfer1::PluginFieldCollection fc{2, fields};
  431. auto creator = getPluginRegistry()->getPluginCreator("FlattenConcat_TRT", "1", "");
  432. TensorRTUniquePtr<nvinfer1::IPluginV2> plugin(
  433. creator->createPlugin("FlattenConcat_TRT", &fc));
  434. ITensor* inputTensors[] = {data0, data1};
  435. auto flt_cct = network->addPluginV2(inputTensors, 2, *plugin);
  436. mgb_assert(flt_cct != nullptr, "FlattenConcat_TRT is invalid");
  437. network->markOutput(*flt_cct->getOutput(0));
  438. #if NV_TENSOR_RT_VERSION >= 6001
  439. {
  440. nvinfer1::TensorFormats formats =
  441. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  442. flt_cct->getOutput(0)->setAllowedFormats(formats);
  443. }
  444. #endif
  445. return std::make_pair(builder, network);
  446. }
  447. #if NV_TENSOR_RT_VERSION >= 6001
  448. intl::DynamicShapeTensorRTNetwork::DynamicShapeTensorRTNetwork(
  449. size_t n, size_t c, size_t h, size_t w) {
  450. host_x = gen({n, c, h, w});
  451. host_w1 = gen({32, 23, 3, 3});
  452. host_b1 = gen({1, 32, 1, 1});
  453. graph = ComputingGraph::make();
  454. x = Host2DeviceCopy::make(*graph, host_x);
  455. auto w1 = Host2DeviceCopy::make(*graph, host_w1),
  456. b1 = Host2DeviceCopy::make(*graph, host_b1),
  457. y01 = opr::Convolution::make(x, w1);
  458. y1 = y01 + b1;
  459. }
  460. TensorRTUniquePtr<ICudaEngine> intl::DynamicShapeTensorRTNetwork::create_trt_network() {
  461. CompNode::load("xpu0").activate();
  462. Weights wt_filter_1{DataType::kFLOAT, nullptr, 0},
  463. wt_bias_1{DataType::kFLOAT, nullptr, 0};
  464. wt_filter_1.type = DataType::kFLOAT;
  465. wt_bias_1.type = DataType::kFLOAT;
  466. wt_filter_1.values = host_w1->raw_ptr();
  467. wt_bias_1.values = host_b1->raw_ptr();
  468. wt_filter_1.count = host_w1->shape().total_nr_elems();
  469. wt_bias_1.count = host_b1->shape().total_nr_elems();
  470. auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
  471. auto network = builder->createNetworkV2(
  472. 1 << static_cast<int>(
  473. nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
  474. nvinfer1::ITensor* data;
  475. data = network->addInput("data", DataType::kFLOAT, Dims4{-1, 23, -1, -1});
  476. nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
  477. nvinfer1::IOptimizationProfile* profile1 = builder->createOptimizationProfile();
  478. profile1->setDimensions(
  479. "data", nvinfer1::OptProfileSelector::kMIN, Dims4(1, 23, 10, 10));
  480. profile1->setDimensions(
  481. "data", nvinfer1::OptProfileSelector::kOPT, Dims4(2, 23, 12, 12));
  482. profile1->setDimensions(
  483. "data", nvinfer1::OptProfileSelector::kMAX, Dims4(3, 23, 14, 14));
  484. config->addOptimizationProfile(profile1);
  485. nvinfer1::IOptimizationProfile* profile2 = builder->createOptimizationProfile();
  486. profile2->setDimensions(
  487. "data", nvinfer1::OptProfileSelector::kMIN, Dims4(3, 23, 16, 16));
  488. profile2->setDimensions(
  489. "data", nvinfer1::OptProfileSelector::kOPT, Dims4(4, 23, 24, 24));
  490. profile2->setDimensions(
  491. "data", nvinfer1::OptProfileSelector::kMAX, Dims4(5, 23, 28, 28));
  492. config->addOptimizationProfile(profile2);
  493. {
  494. nvinfer1::TensorFormats formats =
  495. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  496. data->setAllowedFormats(formats);
  497. }
  498. mgb_assert(data != nullptr, "data is invalid");
  499. auto conv1 =
  500. network->addConvolution(*data, 32, DimsHW{3, 3}, wt_filter_1, wt_bias_1);
  501. mgb_assert(conv1 != nullptr, "conv1 is invalid");
  502. conv1->setStride(DimsHW{1, 1});
  503. conv1->getOutput(0)->setName("prob1");
  504. network->markOutput(*conv1->getOutput(0));
  505. {
  506. nvinfer1::TensorFormats formats =
  507. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
  508. conv1->getOutput(0)->setAllowedFormats(formats);
  509. }
  510. TensorRTUniquePtr<ICudaEngine> cuda_engine{
  511. builder->buildEngineWithConfig(*network, *config)};
  512. return cuda_engine;
  513. }
  514. #endif
  515. #pragma GCC diagnostic pop
  516. #endif // MGB_ENABLE_TENSOR_RT
  517. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}