You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_replace.cpp 83 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758
  1. /**
  2. * \file src/tensorrt/impl/opr_replace.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <cstring>
  12. #include "megbrain/opr/basic_arith.h"
  13. #include "megbrain/opr/blas.h"
  14. #include "megbrain/opr/dnn/convolution.h"
  15. #include "megbrain/opr/dnn/pooling.h"
  16. #include "megbrain/opr/tensor_manip.h"
  17. #include "megbrain/utils/arith_helper.h"
  18. #include "megbrain/opr/nn_int.h"
  19. #include "megbrain/dtype.h"
  20. #if MGB_ENABLE_TENSOR_RT
  21. #include "megbrain/tensorrt/opr_replace.h"
  22. #include "megbrain/tensorrt/tensorrt_opr.h"
  23. #include "megbrain/tensorrt/tensorrt_engine_cache.h"
  24. #include "megbrain/gopt/basic_arith.h"
  25. #include "megbrain/gopt/inference.h"
  26. #include "megbrain/gopt/misc.h"
  27. #pragma GCC diagnostic push
  28. #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
  29. using namespace mgb;
  30. using namespace gopt;
  31. using namespace cg;
  32. template <typename T>
  33. using TensorRTUniquePtr = opr::intl::TensorRTUniquePtr<T>;
  34. namespace {
  35. nvinfer1::DataType mgb_dtype_to_trt_dtype(DType dtype) {
  36. switch (dtype.enumv()) {
  37. case DTypeEnum::Float32:
  38. return nvinfer1::DataType::kFLOAT;
  39. case DTypeEnum::Float16:
  40. return nvinfer1::DataType::kHALF;
  41. case DTypeEnum::QuantizedS8:
  42. return nvinfer1::DataType::kINT8;
  43. case DTypeEnum::Int32:
  44. return nvinfer1::DataType::kINT32;
  45. default:
  46. mgb_throw(
  47. InternalError,
  48. "invalid data type which is not supported in TensorRT: %s",
  49. dtype.name());
  50. }
  51. }
  52. }
  53. class TensorRTReplacePass::Impl final {
  54. static constexpr size_t OPR_FAIL_LOG_NUM = 10;
  55. static constexpr float i8_max = std::numeric_limits<int8_t>::max();
  56. using TensorRTGraphFeatureBits = opr::intl::TensorRTGraphFeatureBits;
  57. using ConvFormat = opr::Convolution::Param::Format;
  58. using ExtraDep = ThinHashMap<OperatorNodeBase*, VarNodeArray>;
  59. const Pass& m_pass;
  60. OptState& m_opt_state;
  61. SubGraph::Rewriter m_rewriter;
  62. struct TensorRTGraph {
  63. using Callback = cg::DepOprIter::Callback;
  64. nvinfer1::IBuilder* builder;
  65. nvinfer1::INetworkDefinition* network;
  66. ThinHashSet<VarNode*> inputs;
  67. ThinHashSet<VarNode*> outputs;
  68. // is used for mapping output varnode in original computing graph to
  69. // output varnode of TensorRTOpr
  70. ThinHashMap<VarNode*, size_t> output2idx;
  71. // mark input and output tensor as nchw4 format, we should insert
  72. // dimshuffle and typecvt to make the TensorRTOpr's inputs and outputs
  73. // match with those of non fused operators.
  74. ThinHashSet<VarNode*> mark_input_varnode_nchw4;
  75. ThinHashSet<VarNode*> mark_output_varnode_nchw4;
  76. VarNodeArray trt_inputs;
  77. VarNodeArray trt_outputs;
  78. // Every tensor rt graph should own a map from var node to infer tensor.
  79. // Because a var node can belong to two different tensor rt subgraph
  80. ThinHashMap<VarNode*, nvinfer1::ITensor*> varnode2itensor;
  81. TensorRTGraphFeatureBits feature_bits;
  82. TensorRTGraph(TensorRTGraphFeatureBits feature_bits =
  83. TensorRTGraphFeatureBits::NCHW_FLOAT)
  84. : builder{nvinfer1::createInferBuilder(
  85. opr::TensorRTOpr::Logger::instance())},
  86. network{nullptr},
  87. feature_bits{feature_bits} {}
  88. void mark_varnode_format_nchw4();
  89. };
  90. struct FailInfo {
  91. OperatorNodeBase* opr;
  92. std::string fail_msg;
  93. };
  94. class HostTensorKeeper : public UserDataContainer::UserData {
  95. MGB_TYPEINFO_OBJ_DECL;
  96. public:
  97. std::vector<HostTensorND> htr;
  98. };
  99. std::unique_ptr<ConstVarPropogate> m_const_var_propogate;
  100. std::vector<std::shared_ptr<TensorRTGraph>> m_tensorrt_graphs;
  101. // use ThinHashMap instead of std::unordered_map
  102. ThinHashMap<OperatorNodeBase*, size_t> m_graph_map;
  103. ThinHashMap<OperatorNodeBase*, nvinfer1::IConvolutionLayer*>
  104. m_opr2convlayer;
  105. ThinHashMap<OperatorNodeBase*, nvinfer1::IDeconvolutionLayer*>
  106. m_opr2deconvlayer;
  107. size_t m_opr_num;
  108. size_t m_opr_fail_num;
  109. std::vector<FailInfo> m_opr_fail;
  110. struct OprTrait {
  111. // judge if supported, not exist means not support
  112. thin_function<Maybe<std::string>(OperatorNodeBase*)>
  113. get_replace_fail_msg;
  114. // replace opr by trt opr, ditto
  115. thin_function<void(nvinfer1::INetworkDefinition*, OperatorNodeBase*)>
  116. add_to_nvinfer;
  117. };
  118. ThinHashMap<Typeinfo*, OprTrait> m_opr_trait;
  119. // Find parent conv of elemwise ADD opr.
  120. VarNodeArray find_parent_conv(OperatorNodeBase* opr);
  121. // Make a trt tensor for Varnode var and add it as input of trt buffer.
  122. // Return false if a tensor of var is previously made and added.
  123. // True if var is encountered for the first time.
  124. bool check_input(VarNode* var, OperatorNodeBase* opr,
  125. mgb::SmallVector<nvinfer1::DimensionType> dimtypes = {});
  126. HostTensorND get_value(VarNode* var, ConvFormat format = ConvFormat::NCHW);
  127. void set_itensor_dynamic_range(VarNode* var, OperatorNodeBase* opr);
  128. float get_scale(DType data_type);
  129. // Check whether an operator is a quantized operator. If an operator is a
  130. // quantized operator, this operator can be fused into a quantized TensorRT
  131. // subgraph
  132. bool is_quantized_int8_operator(OperatorNodeBase* opr);
  133. Maybe<std::string> has_fail_msg(OperatorNodeBase* opr);
  134. static nvinfer1::ITensor& replace(nvinfer1::INetworkDefinition* newtwork,
  135. nvinfer1::ITensor& pre_output,
  136. OperatorNodeBase* opr);
  137. void update_graph();
  138. void mark_varnode_format_nchw4();
  139. void detect_replace();
  140. public:
  141. Impl(const Pass& pass, OptState& opt_state)
  142. : m_pass{pass},
  143. m_opt_state{opt_state},
  144. m_rewriter{opt_state.graph().make_rewriter()},
  145. m_const_var_propogate{std::make_unique<ConstVarPropogate>(
  146. ConstVarType::IMMUTABLE_AND_PARAM)} {
  147. #define REPLACE_FAIL_MSG_EPILOGUE \
  148. { \
  149. auto&& mgr = opr->owner_graph()->static_infer_manager(); \
  150. auto&& shp = mgr.infer_shape_fallible(opr->output(0)); \
  151. if (!shp) \
  152. return "Unsupported opr, because operator shape cannot be " \
  153. "inferred at compile time."; \
  154. else \
  155. return None; \
  156. }
  157. m_opr_trait[opr::Elemwise::typeinfo()].get_replace_fail_msg =
  158. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  159. bool has_scalar = false;
  160. for (auto&& inp : opr->input()) {
  161. if (inp->shape().is_scalar()) {
  162. has_scalar = true;
  163. break;
  164. }
  165. }
  166. if (has_scalar)
  167. return "Elemwise with scalar input is not supported.";
  168. if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 &&
  169. opr->input(0)->dtype() != dtype::Float32()) {
  170. return "Unsupported data type.";
  171. }
  172. using Mode = opr::Elemwise::Mode;
  173. static const ThinHashSet<Mode> supported_modes {
  174. #if NV_TENSOR_RT_VERSION >= 5105
  175. Mode::SIN, Mode::COS, Mode::ASIN, Mode::ACOS, Mode::CEIL,
  176. Mode::FLOOR,
  177. #endif
  178. Mode::EXP, Mode::LOG, Mode::ABS,
  179. Mode::RELU, Mode::SIGMOID, Mode::TANH, Mode::ADD,
  180. Mode::MUL, Mode::MIN, Mode::MAX, Mode::SUB,
  181. Mode::TRUE_DIV, Mode::POW, Mode::FUSE_ADD_RELU,
  182. Mode::FUSE_ADD_TANH, Mode::FUSE_ADD_SIGMOID
  183. };
  184. auto mode = opr->cast_final_safe<opr::Elemwise>().param().mode;
  185. if (!supported_modes.count(mode)) {
  186. return "Unsupported Elemwise mode.";
  187. }
  188. #if NV_TENSOR_RT_VERSION >= 6001
  189. if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
  190. TensorShapeArray inps;
  191. for (auto&& inp : opr->input()) {
  192. inps.push_back(inp->shape());
  193. }
  194. TensorShape brdcast;
  195. megdnn::Elemwise::deduce_shape(inps, brdcast);
  196. if (brdcast.ndim < 4) {
  197. return "Elemwise with QuantizedS8 data type must have more "
  198. "than 4 dimensions. Less than 3 dimensions is not "
  199. "supported since trt6.0.";
  200. }
  201. }
  202. #endif
  203. REPLACE_FAIL_MSG_EPILOGUE;
  204. };
  205. m_opr_trait[opr::ElemwiseMultiType::typeinfo()].get_replace_fail_msg =
  206. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  207. bool has_scalar = false;
  208. for (auto&& inp : opr->input()) {
  209. if (inp->shape().is_scalar()) {
  210. has_scalar = true;
  211. break;
  212. }
  213. }
  214. if (has_scalar)
  215. return "ElemwiseMultiType with scalar input is not supported.";
  216. for (auto&& inp : opr->input()) {
  217. if (inp->dtype().enumv() != DTypeEnum::QuantizedS8)
  218. return "Unsupported data type.";
  219. }
  220. if (opr->output(0)->dtype().enumv() != DTypeEnum::QuantizedS8)
  221. return "Unsupported data type.";
  222. using Mode = opr::ElemwiseMultiType::Mode;
  223. auto mode =
  224. opr->cast_final_safe<opr::ElemwiseMultiType>().param().mode;
  225. if (mode != Mode::QFUSE_ADD_RELU && mode != Mode::QADD) {
  226. return "Unsupported ElemwiseMultiType mode.";
  227. }
  228. REPLACE_FAIL_MSG_EPILOGUE;
  229. };
  230. m_opr_trait[opr::Convolution::typeinfo()].get_replace_fail_msg =
  231. [this](OperatorNodeBase* opr) -> Maybe<std::string> {
  232. if (opr->input(0)->dtype() != dtype::Float32())
  233. return "Non-Float32 convolution is not supported.";
  234. if (!m_const_var_propogate->is_const(opr->input(1)))
  235. return "Weights not constant. Not replaceable in TRT.";
  236. auto&& param = opr->cast_final_safe<opr::Convolution>().param();
  237. if (param.format != ConvFormat::NCHW)
  238. return "TensorRT replace pass only support NCHW format "
  239. "convolution.";
  240. if (param.mode == opr::Convolution::Param::Mode::CONVOLUTION)
  241. return "TensorRT does not support non cross correlation "
  242. "convolution.";
  243. REPLACE_FAIL_MSG_EPILOGUE;
  244. };
  245. m_opr_trait[opr::ConvBias::typeinfo()].get_replace_fail_msg =
  246. [this](OperatorNodeBase* opr) -> Maybe<std::string> {
  247. if (opr->input(0)->dtype() != dtype::Float32() &&
  248. opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8)
  249. return "Convolution is only supported for float32 or qint8.";
  250. if (!m_const_var_propogate->is_const(opr->input(1)))
  251. return "Weights not constant. Not replaceable in TRT.";
  252. if (opr->input().size() >= 3) {
  253. if (!m_const_var_propogate->is_const(opr->input(2)))
  254. return "Bias not constant. Not replaceable in TRT.";
  255. }
  256. auto&& param = opr->cast_final_safe<opr::ConvBias>().param();
  257. if (param.format != ConvFormat::NCHW &&
  258. param.format != ConvFormat::NCHW4)
  259. return "TensorRT replace pass only support NCHW format "
  260. "convolution.";
  261. if (param.mode == opr::ConvBias::Param::Mode::CONVOLUTION)
  262. return "TensorRT does not support non cross correlation "
  263. "convolution.";
  264. REPLACE_FAIL_MSG_EPILOGUE;
  265. };
  266. m_opr_trait[opr::ConvolutionBackwardData::typeinfo()]
  267. .get_replace_fail_msg =
  268. [this](OperatorNodeBase* opr) -> Maybe<std::string> {
  269. if (opr->input(0)->dtype() != dtype::Float32())
  270. return "Non-Float32 Deconvolution is not supported.";
  271. if (!m_const_var_propogate->is_const(opr->input(0)))
  272. return "Weights not constant. Not replaceable in TRT.";
  273. auto&& param = opr->cast_final_safe<opr::ConvolutionBackwardData>().param();
  274. if (param.dilate_h != 1 || param.dilate_w != 1)
  275. return "TensorRT does not support dilation deconvolution.";
  276. if (param.format != ConvFormat::NCHW)
  277. return "TensorRT replace pass only support NCHW format deconv.";
  278. if (param.mode == opr::ConvBias::Param::Mode::CONVOLUTION)
  279. return "TensorRT does not support non cross correlation "
  280. "deconvolution.";
  281. REPLACE_FAIL_MSG_EPILOGUE;
  282. };
  283. m_opr_trait[opr::Pooling::typeinfo()].get_replace_fail_msg =
  284. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  285. auto pool = opr->try_cast_final<opr::Pooling>();
  286. auto&& param = pool->param();
  287. if (param.format != opr::Pooling::Param::Format::NCHW &&
  288. param.format != opr::Pooling::Param::Format::NCHW4)
  289. return "Pooling is only supported for NCHW and NCHW4";
  290. REPLACE_FAIL_MSG_EPILOGUE;
  291. };
  292. m_opr_trait[opr::Concat::typeinfo()].get_replace_fail_msg =
  293. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  294. if (opr->input(0)->dtype() != dtype::Float32() &&
  295. opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8) {
  296. return "Concat only support float32 and quantized int8.";
  297. }
  298. // TODO: TensorRT only supports concat on channel dimension,
  299. // we can set nvinfer1::DimensionType to kCHANNEL to support
  300. // concat on other dimension
  301. if (!(opr->input(0)->shape().ndim == 4 &&
  302. opr->cast_final_safe<opr::Concat>().param().axis == 1)) {
  303. return "Concat only support input is NCHW and axis is 1.";
  304. }
  305. REPLACE_FAIL_MSG_EPILOGUE;
  306. };
  307. m_opr_trait[opr::MatrixMul::typeinfo()].get_replace_fail_msg =
  308. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  309. if (opr->input(0)->dtype() != dtype::Float32())
  310. return "Non-Float32 MatrixMul is not supported.";
  311. REPLACE_FAIL_MSG_EPILOGUE;
  312. };
  313. m_opr_trait[opr::BatchedMatrixMul::typeinfo()].get_replace_fail_msg =
  314. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  315. if (opr->input(0)->dtype() != dtype::Float32())
  316. return "Non-Float32 MatrixMul is not supported.";
  317. REPLACE_FAIL_MSG_EPILOGUE;
  318. };
  319. m_opr_trait[opr::PowC::typeinfo()].get_replace_fail_msg =
  320. [](OperatorNodeBase* opr) -> Maybe<std::string> {
  321. if (opr->input(0)->dtype() != dtype::Float32())
  322. return "Non-Float32 PowC is not supported.";
  323. if (opr->input(0)->shape().ndim < 3)
  324. return "Dimensions of input should be greater than or equal to "
  325. "3.";
  326. REPLACE_FAIL_MSG_EPILOGUE;
  327. };
  328. #undef REPLACE_FAIL_MSG_EPILOGUE
  329. // megdnn convolution opr on cuda backend does not support quantized
  330. // dtype, so we assume that megbrain int8 network for converting to fine
  331. // grained TensorRT subgraph does not include convolution operator with
  332. // quantized int8 data type
  333. m_opr_trait[opr::Convolution::typeinfo()]
  334. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  335. OperatorNodeBase* opr) {
  336. auto&& varnode2itensor =
  337. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  338. VarNode* input = opr->input(0);
  339. VarNode* kernel = opr->input(1);
  340. check_input(input, opr);
  341. nvinfer1::Weights wt_kernel{
  342. nvinfer1::DataType::kFLOAT, get_value(kernel).raw_ptr(),
  343. static_cast<int64_t>(kernel->shape().total_nr_elems())};
  344. nvinfer1::Weights wt_bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
  345. auto&& param = opr->cast_final_safe<opr::Convolution>().param();
  346. mgb_assert(
  347. param.format == megdnn::param::Convolution::Format::NCHW &&
  348. param.mode == megdnn::param::Convolution::Mode::
  349. CROSS_CORRELATION,
  350. "conv param is not supported by TensorRT");
  351. size_t group_offset = 0;
  352. if (param.sparse == megdnn::param::Convolution::Sparse::GROUP) {
  353. group_offset = 1;
  354. } else {
  355. mgb_assert(param.sparse ==
  356. megdnn::param::Convolution::Sparse::DENSE,
  357. "param.sparse should be GROUP or DENSE");
  358. }
  359. auto conv = net->addConvolution(
  360. *varnode2itensor[input], opr->output(0)->shape()[1],
  361. nvinfer1::DimsHW{
  362. static_cast<int>(kernel->shape()[group_offset + 2]),
  363. static_cast<int>(
  364. kernel->shape()[group_offset + 3])},
  365. wt_kernel, wt_bias);
  366. mgb_assert(conv, "construct network failed");
  367. std::string layer_name = "TRT_CONV:" + opr->name();
  368. conv->setName(layer_name.c_str());
  369. conv->setStride(nvinfer1::DimsHW{static_cast<int>(param.stride_h),
  370. static_cast<int>(param.stride_w)});
  371. conv->setPadding(nvinfer1::DimsHW{static_cast<int>(param.pad_h),
  372. static_cast<int>(param.pad_w)});
  373. conv->setDilation(
  374. nvinfer1::DimsHW{static_cast<int>(param.dilate_h),
  375. static_cast<int>(param.dilate_w)});
  376. if (group_offset > 0)
  377. conv->setNbGroups(static_cast<int>(kernel->shape()[0]));
  378. m_opr2convlayer[opr] = conv;
  379. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  380. conv->getOutput(0)->setName(output_name.c_str());
  381. varnode2itensor[opr->output(0)] = conv->getOutput(0);
  382. };
  383. // support floating point data type and quantized data type
  384. m_opr_trait[opr::ConvBiasForward::typeinfo()]
  385. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  386. OperatorNodeBase* opr) {
  387. auto&& varnode2itensor =
  388. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  389. using Param = opr::ConvBias::Param;
  390. using NonlineMode = Param::NonlineMode;
  391. using Sparse = Param::Sparse;
  392. using Format = Param::Format;
  393. auto conv_bias = try_cast_as_op<opr::ConvBias>(opr);
  394. auto&& param = conv_bias->param();
  395. mgb_assert(param.mode == Param::Mode::CROSS_CORRELATION,
  396. "Trt only support CROSS_CORRELATION convolution.");
  397. bool is_format_nchw4 = param.format == Format::NCHW4;
  398. bool is_qint8 = is_quantized_int8_operator(opr);
  399. if (is_format_nchw4)
  400. mgb_assert(is_qint8);
  401. // set kernel and bias
  402. VarNode* input = conv_bias->input(0);
  403. VarNode* kernel = conv_bias->input(1);
  404. check_input(input, opr);
  405. nvinfer1::Weights wt_kernel{
  406. nvinfer1::DataType::kFLOAT,
  407. get_value(kernel, param.format).raw_ptr(),
  408. static_cast<int64_t>(kernel->shape().total_nr_elems())};
  409. nvinfer1::Weights wt_bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
  410. if (conv_bias->input().size() >= 3) {
  411. VarNode* bias = conv_bias->input(2);
  412. wt_bias.values = get_value(bias, param.format).raw_ptr();
  413. wt_bias.count =
  414. static_cast<int64_t>(bias->shape().total_nr_elems());
  415. }
  416. // determine conv shape
  417. int co = 0;
  418. int sh = param.stride_h, sw = param.stride_w, ph = param.pad_h,
  419. pw = param.pad_w, dh = param.dilate_h, dw = param.dilate_w;
  420. size_t group_offset = 0;
  421. int groups = 1;
  422. if (param.sparse == Sparse::GROUP) {
  423. groups = kernel->shape()[0];
  424. group_offset = 1;
  425. } else {
  426. mgb_assert(param.sparse == Sparse::DENSE,
  427. "sparse should be GROUP or DENSE");
  428. }
  429. int fh = kernel->shape()[group_offset + 2],
  430. fw = kernel->shape()[group_offset + 3];
  431. if (param.format == Format::NCHW) {
  432. mgb_assert(conv_bias->input(0)->dtype() == dtype::Float32(),
  433. "conv bias only support Float32 with NCHW format");
  434. co = conv_bias->output(0)->shape()[1];
  435. } else if (param.format == Format::NCHW4) {
  436. mgb_assert(
  437. conv_bias->input(0)->dtype().enumv() ==
  438. DTypeEnum::QuantizedS8 &&
  439. conv_bias->output(0)->dtype().enumv() ==
  440. DTypeEnum::QuantizedS8,
  441. "conv bias only support QuantizedS8 with NCHW4 format");
  442. co = conv_bias->output(0)->shape()[1] * 4;
  443. }
  444. mgb_assert(co > 0);
  445. // process conv
  446. auto conv = net->addConvolution(*varnode2itensor[input], co,
  447. nvinfer1::DimsHW{fh, fw}, wt_kernel,
  448. wt_bias);
  449. mgb_assert(conv, "construct network failed");
  450. std::string layer_name = "TRT_CONV:" + conv_bias->name();
  451. conv->setName(layer_name.c_str());
  452. conv->setStride(nvinfer1::DimsHW{sh, sw});
  453. conv->setPadding(nvinfer1::DimsHW{ph, pw});
  454. conv->setDilation(nvinfer1::DimsHW{dh, dw});
  455. if (group_offset > 0)
  456. conv->setNbGroups(groups);
  457. std::string output_name = "TRT_O:" + conv_bias->output(0)->name();
  458. conv->getOutput(0)->setName(output_name.c_str());
  459. varnode2itensor[conv_bias->output(0)] = conv->getOutput(0);
  460. if (is_qint8)
  461. set_itensor_dynamic_range(conv_bias->output(0), conv_bias);
  462. // process short cut add
  463. if (conv_bias->input().size() >= 4) {
  464. check_input(conv_bias->input(3), opr);
  465. auto add = net->addElementWise(
  466. *varnode2itensor[conv_bias->output(0)],
  467. *varnode2itensor[conv_bias->input(3)],
  468. nvinfer1::ElementWiseOperation::kSUM);
  469. mgb_assert(add, "construct network failed");
  470. std::string layer_name = "TRT_ELEM:" + conv_bias->name();
  471. add->setName(layer_name.c_str());
  472. std::string output_name =
  473. "TRT_O:" + conv_bias->output(0)->name() +
  474. "_shortcut_add";
  475. add->getOutput(0)->setName(output_name.c_str());
  476. varnode2itensor[conv_bias->output(0)] = add->getOutput(0);
  477. if (is_qint8)
  478. set_itensor_dynamic_range(conv_bias->output(0), conv_bias);
  479. }
  480. // process activation
  481. if (param.nonlineMode != Param::NonlineMode::IDENTITY) {
  482. nvinfer1::ActivationType act_type =
  483. param.nonlineMode == NonlineMode::RELU
  484. ? nvinfer1::ActivationType::kRELU
  485. : nvinfer1::ActivationType::kSIGMOID;
  486. auto act = net->addActivation(
  487. *varnode2itensor[conv_bias->output(0)], act_type);
  488. mgb_assert(act, "construct network failed");
  489. std::string layer_name =
  490. "TRT_ACTV:" + conv_bias->name();
  491. act->setName(layer_name.c_str());
  492. std::string output_name =
  493. "TRT_O:" + conv_bias->output(0)->name() + "_act";
  494. act->getOutput(0)->setName(output_name.c_str());
  495. varnode2itensor[conv_bias->output(0)] = act->getOutput(0);
  496. if (is_qint8)
  497. set_itensor_dynamic_range(conv_bias->output(0), conv_bias);
  498. }
  499. };
  500. // megbrain deconvolution operator does not support quantized data type
  501. m_opr_trait[opr::ConvolutionBackwardData::typeinfo()]
  502. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  503. OperatorNodeBase* opr) {
  504. auto&& varnode2itensor =
  505. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  506. VarNode* kernel = opr->input(0);
  507. VarNode* input = opr->input(1);
  508. check_input(input, opr);
  509. nvinfer1::Weights wt_kernel{
  510. nvinfer1::DataType::kFLOAT, get_value(kernel).raw_ptr(),
  511. static_cast<int64_t>(kernel->shape().total_nr_elems())};
  512. nvinfer1::Weights wt_bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
  513. auto&& param = opr->cast_final_safe<opr::ConvolutionBackwardData>()
  514. .param();
  515. mgb_assert(
  516. param.format == megdnn::param::Convolution::Format::NCHW &&
  517. param.mode == megdnn::param::Convolution::Mode::
  518. CROSS_CORRELATION &&
  519. param.dilate_h == 1 && param.dilate_w == 1,
  520. "conv param is not supported by TensorRT");
  521. size_t group_offset = 0;
  522. if (param.sparse == megdnn::param::Convolution::Sparse::GROUP) {
  523. group_offset = 1;
  524. } else {
  525. mgb_assert(param.sparse ==
  526. megdnn::param::Convolution::Sparse::DENSE,
  527. "param.sparse should be GROUP or DENSE");
  528. }
  529. auto deconv = net->addDeconvolution(
  530. *varnode2itensor[input], opr->output(0)->shape()[1],
  531. nvinfer1::DimsHW{
  532. static_cast<int>(kernel->shape()[group_offset + 2]),
  533. static_cast<int>(
  534. kernel->shape()[group_offset + 3])},
  535. wt_kernel, wt_bias);
  536. mgb_assert(deconv, "construct network failed");
  537. std::string layer_name = "TRT_DCON:" + opr->name();
  538. deconv->setName(layer_name.c_str());
  539. deconv->setStride(
  540. nvinfer1::DimsHW{static_cast<int>(param.stride_h),
  541. static_cast<int>(param.stride_w)});
  542. deconv->setPadding(nvinfer1::DimsHW{static_cast<int>(param.pad_h),
  543. static_cast<int>(param.pad_w)});
  544. if (group_offset > 0)
  545. deconv->setNbGroups(static_cast<int>(kernel->shape()[0]));
  546. m_opr2deconvlayer[opr] = deconv;
  547. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  548. deconv->getOutput(0)->setName(output_name.c_str());
  549. varnode2itensor[opr->output(0)] = deconv->getOutput(0);
  550. };
  551. // support floating point data type and quantized data type
  552. m_opr_trait[opr::Pooling::typeinfo()]
  553. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  554. OperatorNodeBase* opr) {
  555. auto&& varnode2itensor =
  556. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  557. using Param = opr::Pooling::Param;
  558. using Mode = Param::Mode;
  559. using Format = Param::Format;
  560. static ThinHashMap<Mode, nvinfer1::PoolingType> pooling_type_map = {
  561. {Mode::MAX, nvinfer1::PoolingType::kMAX},
  562. {Mode::AVERAGE, nvinfer1::PoolingType::kAVERAGE},
  563. {Mode::AVERAGE_COUNT_EXCLUDE_PADDING,
  564. nvinfer1::PoolingType::kAVERAGE}};
  565. auto&& param = opr->cast_final_safe<opr::Pooling>().param();
  566. check_input(opr->input(0), opr);
  567. auto pool = net->addPooling(
  568. *varnode2itensor[opr->input(0)],
  569. pooling_type_map.at(param.mode),
  570. nvinfer1::DimsHW{static_cast<int>(param.window_h),
  571. static_cast<int>(param.window_w)});
  572. mgb_assert(pool, "construct network failed");
  573. std::string layer_name = "TRT_POOL:" + opr->name();
  574. pool->setName(layer_name.c_str());
  575. pool->setPadding(nvinfer1::DimsHW{static_cast<int>(param.pad_h),
  576. static_cast<int>(param.pad_w)});
  577. pool->setStride(nvinfer1::DimsHW{static_cast<int>(param.stride_h),
  578. static_cast<int>(param.stride_w)});
  579. //! According to the documentation of TensorRT, the default value of exclusive is true.
  580. //! So we need to set exclusive to false when pooling mode is average
  581. if (param.mode == Mode::AVERAGE_COUNT_EXCLUDE_PADDING)
  582. pool->setAverageCountExcludesPadding(true);
  583. else if (param.mode == Mode::AVERAGE)
  584. pool->setAverageCountExcludesPadding(false);
  585. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  586. pool->getOutput(0)->setName(output_name.c_str());
  587. varnode2itensor[opr->output(0)] = pool->getOutput(0);
  588. if (param.format == Format::NCHW4) {
  589. mgb_assert(opr->input(0)->dtype().enumv() ==
  590. DTypeEnum::QuantizedS8,
  591. "Pooling with NCHW4 format should use quantized "
  592. "int8 data type");
  593. set_itensor_dynamic_range(opr->output(0), opr);
  594. }
  595. };
  596. m_opr_trait[opr::Concat::typeinfo()].add_to_nvinfer =
  597. [this](nvinfer1::INetworkDefinition* net,
  598. OperatorNodeBase* opr) {
  599. auto&& varnode2itensor =
  600. m_tensorrt_graphs[m_graph_map[opr] - 1]
  601. ->varnode2itensor;
  602. size_t input_size = opr->input().size();
  603. std::unique_ptr<nvinfer1::ITensor* []> input_tensors(
  604. new nvinfer1::ITensor*[input_size]);
  605. for (size_t i = 0; i < input_size; ++i) {
  606. check_input(opr->input(i), opr);
  607. input_tensors[i] = varnode2itensor[opr->input(i)];
  608. }
  609. auto concat = net->addConcatenation(
  610. input_tensors.get(), static_cast<int>(input_size));
  611. mgb_assert(concat, "construct Concatenation layer failed!");
  612. std::string layer_name = "TRT_CCAT:" + opr->name();
  613. concat->setName(layer_name.c_str());
  614. int axis = opr->cast_final_safe<opr::Concat>().param().axis;
  615. concat->setAxis(axis);
  616. std::string output_name =
  617. "TRT_O:" + opr->output()[0]->name();
  618. concat->getOutput(0)->setName(output_name.c_str());
  619. varnode2itensor[opr->output(0)] = concat->getOutput(0);
  620. if (is_quantized_int8_operator(opr)) {
  621. set_itensor_dynamic_range(opr->output(0), opr);
  622. }
  623. };
  624. // support floating point data type and quantized data type
  625. m_opr_trait[opr::Elemwise::typeinfo()]
  626. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  627. OperatorNodeBase* opr) {
  628. auto&& varnode2itensor =
  629. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  630. using Mode = opr::Elemwise::Mode;
  631. auto mode = opr->cast_final_safe<opr::Elemwise>().param().mode;
  632. auto get_dimtype = [&](int ndim) {
  633. SmallVector<nvinfer1::DimensionType> dimtypes(ndim);
  634. for (int i = 0; i < ndim; i++) {
  635. dimtypes[i] = nvinfer1::DimensionType::kSPATIAL;
  636. }
  637. return dimtypes;
  638. };
  639. auto on_elemwise_arity_unary =
  640. [this, &varnode2itensor, &net, &opr,
  641. &get_dimtype](nvinfer1::UnaryOperation unary_op) {
  642. size_t tensor_ndim = opr->input(0)->shape().ndim;
  643. check_input(opr->input(0), opr,
  644. get_dimtype(tensor_ndim));
  645. auto unary = net->addUnary(
  646. *varnode2itensor[opr->input(0)], unary_op);
  647. mgb_assert(unary, "construct network failed");
  648. std::string layer_name = "TRT_UNARY:" + opr->name();
  649. unary->setName(layer_name.c_str());
  650. std::string output_name =
  651. "TRT_O:" + opr->output()[0]->name();
  652. unary->getOutput(0)->setName(output_name.c_str());
  653. varnode2itensor[opr->output(0)] = unary->getOutput(0);
  654. };
  655. auto on_elemwise_arity_activation =
  656. [this, &varnode2itensor, &net, &opr,
  657. &get_dimtype](nvinfer1::ActivationType act_type) {
  658. size_t tensor_ndim = opr->input(0)->shape().ndim;
  659. check_input(opr->input(0), opr,
  660. get_dimtype(tensor_ndim));
  661. auto act = net->addActivation(
  662. *varnode2itensor[opr->input(0)], act_type);
  663. mgb_assert(act, "construct network failed");
  664. std::string layer_name = "TRT_ACTV:" + opr->name();
  665. act->setName(layer_name.c_str());
  666. std::string output_name =
  667. "TRT_O:" + opr->output()[0]->name();
  668. act->getOutput(0)->setName(output_name.c_str());
  669. varnode2itensor[opr->output(0)] = act->getOutput(0);
  670. };
  671. auto on_elemwise_arity_binary = [this, &varnode2itensor, &net, &opr,
  672. &get_dimtype](
  673. nvinfer1::
  674. ElementWiseOperation
  675. elem_op) {
  676. size_t ndim0 = opr->input(0)->shape().ndim,
  677. ndim1 = opr->input(1)->shape().ndim;
  678. mgb_assert(ndim0 == ndim1);
  679. size_t tensor_ndim = ndim0;
  680. bool inp0_new = check_input(opr->input(0), opr,
  681. get_dimtype(tensor_ndim));
  682. bool inp1_new = check_input(opr->input(1), opr,
  683. get_dimtype(tensor_ndim));
  684. if (inp0_new && inp1_new) {
  685. mgb_log_warn(
  686. "Both operands of Elemwise are newly prepared. "
  687. "This is rare. "
  688. "Please check. opr=%s inputs=%s",
  689. opr->cname(),
  690. cg::dump_var_info(opr->input()).c_str());
  691. }
  692. auto dims0 = varnode2itensor[opr->input(0)]->getDimensions(),
  693. dims1 = varnode2itensor[opr->input(1)]->getDimensions();
  694. mgb_throw_if(dims0.nbDims != dims1.nbDims, AssertionError,
  695. "Input dimensions of two input tensors must be "
  696. "equal (got: %d, %d).",
  697. dims0.nbDims, dims1.nbDims);
  698. auto elem = net->addElementWise(*varnode2itensor[opr->input(0)],
  699. *varnode2itensor[opr->input(1)],
  700. elem_op);
  701. mgb_assert(elem, "construct network failed");
  702. std::string layer_name = "TRT_ELEM:" + opr->name();
  703. elem->setName(layer_name.c_str());
  704. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  705. elem->getOutput(0)->setName(output_name.c_str());
  706. varnode2itensor[opr->output(0)] = elem->getOutput(0);
  707. };
  708. switch (mode) {
  709. #define cb(mode) \
  710. case Mode::mode: \
  711. on_elemwise_arity_unary(nvinfer1::UnaryOperation::k##mode); \
  712. break;
  713. #if NV_TENSOR_RT_VERSION >= 5105
  714. #define MGB_FOREACH_UNARY_OPERATION(cb) \
  715. cb(EXP) cb(LOG) cb(ABS) cb(SIN) cb(COS) cb(ASIN) cb(ACOS) cb(CEIL) cb(FLOOR)
  716. #else
  717. #define MGB_FOREACH_UNARY_OPERATION(cb) cb(EXP) cb(LOG) cb(ABS)
  718. #endif
  719. MGB_FOREACH_UNARY_OPERATION(cb)
  720. #undef cb
  721. #undef MGB_FOREACH_UNARY_OPERATION
  722. #define cb(mode) \
  723. case Mode::mode: \
  724. on_elemwise_arity_activation(nvinfer1::ActivationType::k##mode); \
  725. break;
  726. #define MGB_FOREACH_ACTIVATION_TYPE(cb) cb(RELU) cb(SIGMOID) cb(TANH)
  727. MGB_FOREACH_ACTIVATION_TYPE(cb)
  728. #undef cb
  729. #undef MGB_FOREACH_ACTIVATION_TYPE
  730. case Mode::ADD: {
  731. VarNode *opr_var, *bias_var;
  732. VarNodeArray result = find_parent_conv(opr);
  733. if (result.size() > 0) {
  734. opr_var = result[0];
  735. bias_var = result[1];
  736. nvinfer1::Weights wt_bias{
  737. nvinfer1::DataType::kFLOAT,
  738. get_value(bias_var).raw_ptr(),
  739. static_cast<int64_t>(
  740. bias_var->shape().total_nr_elems())};
  741. if (opr_var->owner_opr()
  742. ->same_type<opr::Convolution>()) {
  743. m_opr2convlayer[opr_var->owner_opr()]
  744. ->setBiasWeights(wt_bias);
  745. } else if (
  746. opr_var->owner_opr()
  747. ->same_type<
  748. opr::ConvolutionBackwardData>()) {
  749. m_opr2deconvlayer[opr_var->owner_opr()]
  750. ->setBiasWeights(wt_bias);
  751. }
  752. varnode2itensor[opr->output(0)] =
  753. varnode2itensor[result[2]];
  754. break;
  755. }
  756. on_elemwise_arity_binary(
  757. nvinfer1::ElementWiseOperation::kSUM);
  758. break;
  759. }
  760. case Mode::MUL:
  761. on_elemwise_arity_binary(
  762. nvinfer1::ElementWiseOperation::kPROD);
  763. break;
  764. case Mode::MIN:
  765. on_elemwise_arity_binary(
  766. nvinfer1::ElementWiseOperation::kMIN);
  767. break;
  768. case Mode::MAX:
  769. on_elemwise_arity_binary(
  770. nvinfer1::ElementWiseOperation::kMAX);
  771. break;
  772. case Mode::SUB:
  773. on_elemwise_arity_binary(
  774. nvinfer1::ElementWiseOperation::kSUB);
  775. break;
  776. case Mode::TRUE_DIV:
  777. on_elemwise_arity_binary(
  778. nvinfer1::ElementWiseOperation::kDIV);
  779. break;
  780. case Mode::POW:
  781. on_elemwise_arity_binary(
  782. nvinfer1::ElementWiseOperation::kPOW);
  783. break;
  784. case Mode::FUSE_ADD_RELU: {
  785. on_elemwise_arity_binary(
  786. nvinfer1::ElementWiseOperation::kSUM);
  787. if (is_quantized_int8_operator(opr))
  788. set_itensor_dynamic_range(opr->output(0), opr);
  789. auto act =
  790. net->addActivation(*varnode2itensor[opr->output(0)],
  791. nvinfer1::ActivationType::kRELU);
  792. mgb_assert(act, "construct network failed");
  793. std::string layer_name = "TRT_ACTV:" + opr->name();
  794. act->setName(layer_name.c_str());
  795. std::string output_name =
  796. "TRT_O:" + opr->output()[0]->name();
  797. act->getOutput(0)->setName(output_name.c_str());
  798. varnode2itensor[opr->output(0)] = act->getOutput(0);
  799. break;
  800. }
  801. case Mode::FUSE_ADD_SIGMOID: {
  802. on_elemwise_arity_binary(
  803. nvinfer1::ElementWiseOperation::kSUM);
  804. if (is_quantized_int8_operator(opr))
  805. set_itensor_dynamic_range(opr->output(0), opr);
  806. auto act = net->addActivation(
  807. *varnode2itensor[opr->output(0)],
  808. nvinfer1::ActivationType::kSIGMOID);
  809. mgb_assert(act, "construct network failed");
  810. std::string layer_name = "TRT_ACTV:" + opr->name();
  811. act->setName(layer_name.c_str());
  812. std::string output_name =
  813. "TRT_O:" + opr->output()[0]->name();
  814. act->getOutput(0)->setName(output_name.c_str());
  815. varnode2itensor[opr->output(0)] = act->getOutput(0);
  816. break;
  817. }
  818. case Mode::FUSE_ADD_TANH: {
  819. on_elemwise_arity_binary(
  820. nvinfer1::ElementWiseOperation::kSUM);
  821. if (is_quantized_int8_operator(opr))
  822. set_itensor_dynamic_range(opr->output(0), opr);
  823. auto act =
  824. net->addActivation(*varnode2itensor[opr->output(0)],
  825. nvinfer1::ActivationType::kTANH);
  826. mgb_assert(act, "construct network failed");
  827. std::string layer_name = "TRT_ACTV:" + opr->name();
  828. act->setName(layer_name.c_str());
  829. std::string output_name =
  830. "TRT_O:" + opr->output()[0]->name();
  831. act->getOutput(0)->setName(output_name.c_str());
  832. varnode2itensor[opr->output(0)] = act->getOutput(0);
  833. break;
  834. }
  835. default:
  836. mgb_assert(false, "Unsupported elemwise mode.");
  837. }
  838. if (is_quantized_int8_operator(opr))
  839. set_itensor_dynamic_range(opr->output(0), opr);
  840. };
  841. m_opr_trait[opr::ElemwiseMultiType::typeinfo()]
  842. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  843. OperatorNodeBase* opr) {
  844. auto&& varnode2itensor =
  845. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  846. size_t ndim0 = opr->input(0)->shape().ndim,
  847. ndim1 = opr->input(1)->shape().ndim;
  848. mgb_assert(ndim0 == ndim1);
  849. size_t tensor_ndim = ndim0;
  850. using Mode = opr::ElemwiseMultiType::Mode;
  851. SmallVector<nvinfer1::DimensionType> dimtypes(tensor_ndim);
  852. for (size_t i = 0; i < tensor_ndim; i++) {
  853. dimtypes[i] = nvinfer1::DimensionType::kSPATIAL;
  854. }
  855. auto mode =
  856. opr->cast_final_safe<opr::ElemwiseMultiType>().param().mode;
  857. mgb_assert(mode == Mode::QADD || mode == Mode::QFUSE_ADD_RELU,
  858. "Only QADD and QFUSE_ADD_RELU are supported on CUDA.");
  859. mgb_assert(
  860. opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8,
  861. "output data type %s is not supported",
  862. opr->output(0)->dtype().name());
  863. check_input(opr->input(0), opr, dimtypes);
  864. check_input(opr->input(1), opr, dimtypes);
  865. auto dims0 = varnode2itensor[opr->input(0)]->getDimensions(),
  866. dims1 = varnode2itensor[opr->input(1)]->getDimensions();
  867. mgb_throw_if(dims0.nbDims != dims1.nbDims, AssertionError,
  868. "Input dimensions of two input tensors must be "
  869. "equal (got: %d, %d).",
  870. dims0.nbDims, dims1.nbDims);
  871. auto elem =
  872. net->addElementWise(*varnode2itensor[opr->input(0)],
  873. *varnode2itensor[opr->input(1)],
  874. nvinfer1::ElementWiseOperation::kSUM);
  875. mgb_assert(elem, "construct network failed");
  876. std::string layer_name = "TRT_ELEM:" + opr->name();
  877. elem->setName(layer_name.c_str());
  878. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  879. elem->getOutput(0)->setName(output_name.c_str());
  880. varnode2itensor[opr->output(0)] = elem->getOutput(0);
  881. set_itensor_dynamic_range(opr->output(0), opr);
  882. if (mode == Mode::QFUSE_ADD_RELU) {
  883. auto act =
  884. net->addActivation(*varnode2itensor[opr->output(0)],
  885. nvinfer1::ActivationType::kRELU);
  886. mgb_assert(act, "construct network failed");
  887. std::string layer_name = "TRT_ACTV:" + opr->name();
  888. act->setName(layer_name.c_str());
  889. std::string output_name = "TRT_O:" + opr->output()[0]->name() + "_act";
  890. act->getOutput(0)->setName(output_name.c_str());
  891. varnode2itensor[opr->output(0)] = act->getOutput(0);
  892. set_itensor_dynamic_range(opr->output(0), opr);
  893. }
  894. };
  895. auto replace_matmul_opr = [this](nvinfer1::INetworkDefinition* net,
  896. OperatorNodeBase* opr) {
  897. auto&& varnode2itensor =
  898. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  899. SmallVector<nvinfer1::DimensionType> dimtypes;
  900. bool transposeA = false, transposeB = false;
  901. if (opr->same_type<opr::MatrixMul>()) {
  902. dimtypes = {nvinfer1::DimensionType::kSPATIAL,
  903. nvinfer1::DimensionType::kSPATIAL};
  904. transposeA = opr->cast_final_safe<opr::MatrixMul>()
  905. .param()
  906. .transposeA;
  907. transposeB = opr->cast_final_safe<opr::MatrixMul>()
  908. .param()
  909. .transposeB;
  910. } else {
  911. mgb_assert(opr->same_type<opr::BatchedMatrixMul>());
  912. dimtypes = {nvinfer1::DimensionType::kINDEX,
  913. nvinfer1::DimensionType::kSPATIAL,
  914. nvinfer1::DimensionType::kSPATIAL};
  915. transposeA = opr->cast_final_safe<opr::BatchedMatrixMul>()
  916. .param()
  917. .transposeA;
  918. transposeB = opr->cast_final_safe<opr::BatchedMatrixMul>()
  919. .param()
  920. .transposeB;
  921. }
  922. check_input(opr->input(0), opr, dimtypes);
  923. check_input(opr->input(1), opr, dimtypes);
  924. #if NV_TENSOR_RT_VERSION >= 6001
  925. nvinfer1::MatrixOperation
  926. opA = transposeA ? nvinfer1::MatrixOperation::kTRANSPOSE
  927. : nvinfer1::MatrixOperation::kNONE,
  928. opB = transposeB ? nvinfer1::MatrixOperation::kTRANSPOSE
  929. : nvinfer1::MatrixOperation::kNONE;
  930. auto matmul = net->addMatrixMultiply(
  931. *varnode2itensor[opr->input(0)], opA,
  932. *varnode2itensor[opr->input(1)], opB);
  933. #else
  934. auto matmul = net->addMatrixMultiply(
  935. *varnode2itensor[opr->input(0)], transposeA,
  936. *varnode2itensor[opr->input(1)], transposeB);
  937. #endif
  938. std::string layer_name = "TRT_MATMUL:" + opr->name();
  939. matmul->setName(layer_name.c_str());
  940. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  941. matmul->getOutput(0)->setName(output_name.c_str());
  942. varnode2itensor[opr->output(0)] = matmul->getOutput(0);
  943. };
  944. // megdnn matrix mul operator on cuda backend does not support quantized
  945. // data type
  946. m_opr_trait[opr::MatrixMul::typeinfo()].add_to_nvinfer = replace_matmul_opr;
  947. m_opr_trait[opr::BatchedMatrixMul::typeinfo()].add_to_nvinfer = replace_matmul_opr;
  948. // powc only support float32
  949. m_opr_trait[opr::PowC::typeinfo()]
  950. .add_to_nvinfer = [this](nvinfer1::INetworkDefinition* net,
  951. OperatorNodeBase* opr) {
  952. auto&& varnode2itensor =
  953. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  954. size_t tensor_ndim = opr->input(0)->shape().ndim;
  955. SmallVector<nvinfer1::DimensionType> dimtypes(tensor_ndim);
  956. for (size_t i = 0; i < tensor_ndim; i++) {
  957. dimtypes[i] = nvinfer1::DimensionType::kSPATIAL;
  958. }
  959. check_input(opr->input(0), opr, dimtypes);
  960. auto host_one = HostTensorND(opr->output(0)->comp_node(), {1},
  961. dtype::Float32()),
  962. host_zero = HostTensorND(opr->output(0)->comp_node(), {1},
  963. dtype::Float32()),
  964. host_exp = HostTensorND(opr->output(0)->comp_node(), {1},
  965. dtype::Float32());
  966. *(reinterpret_cast<float*>(host_one.raw_ptr())) = 1;
  967. *(reinterpret_cast<float*>(host_zero.raw_ptr())) = 0;
  968. *(reinterpret_cast<float*>(host_exp.raw_ptr())) =
  969. opr->cast_final_safe<opr::PowC>().param().exp;
  970. auto ptr = opr->owner_graph()
  971. ->options()
  972. .user_data
  973. .get_user_data_or_create<HostTensorKeeper>();
  974. ptr->htr.push_back(host_one);
  975. ptr->htr.push_back(host_zero);
  976. ptr->htr.push_back(host_exp);
  977. auto scale =
  978. net->addScale(*varnode2itensor[opr->input(0)],
  979. nvinfer1::ScaleMode::kUNIFORM,
  980. nvinfer1::Weights{nvinfer1::DataType::kFLOAT,
  981. host_zero.raw_ptr(), 1},
  982. nvinfer1::Weights{nvinfer1::DataType::kFLOAT,
  983. host_one.raw_ptr(), 1},
  984. nvinfer1::Weights{nvinfer1::DataType::kFLOAT,
  985. host_exp.raw_ptr(), 1});
  986. std::string layer_name = "TRT_SCALE:" + opr->name();
  987. scale->setName(layer_name.c_str());
  988. std::string output_name = "TRT_O:" + opr->output()[0]->name();
  989. scale->getOutput(0)->setName(output_name.c_str());
  990. varnode2itensor[opr->output(0)] = scale->getOutput(0);
  991. };
  992. m_opr_num = 0;
  993. m_opr_fail_num = 0;
  994. detect_replace();
  995. mark_varnode_format_nchw4();
  996. update_graph();
  997. if (!m_opr_fail.empty()) {
  998. std::string msg{"TRT replace summary:\n"};
  999. msg += ssprintf(" number of oprs: %zu\n", m_opr_num);
  1000. msg += ssprintf(" number of unsupported oprs: %zu\n",
  1001. m_opr_fail_num);
  1002. msg += ssprintf(" first %zu unsupported oprs:\n",
  1003. m_opr_fail.size());
  1004. for (size_t i = 0; i < m_opr_fail.size(); ++i) {
  1005. msg += ssprintf(" %s {%s}: %s\n", m_opr_fail[i].opr->cname(),
  1006. m_opr_fail[i].opr->dyn_typeinfo()->name,
  1007. m_opr_fail[i].fail_msg.c_str());
  1008. }
  1009. msg.pop_back();
  1010. mgb_log("%s", msg.c_str());
  1011. }
  1012. }
  1013. };
  1014. MGB_TYPEINFO_OBJ_IMPL(TensorRTReplacePass::Impl::HostTensorKeeper);
  1015. Maybe<std::string> TensorRTReplacePass::Impl::has_fail_msg(
  1016. OperatorNodeBase* opr) {
  1017. auto iter = m_opr_trait.find(opr->dyn_typeinfo());
  1018. if (iter != m_opr_trait.end()) {
  1019. if (iter->second.get_replace_fail_msg) {
  1020. return iter->second.get_replace_fail_msg(opr);
  1021. }
  1022. return None;
  1023. }
  1024. return "Opr not supported.";
  1025. }
  1026. VarNodeArray TensorRTReplacePass::Impl::find_parent_conv(
  1027. OperatorNodeBase* inp_opr) {
  1028. OperatorNodeBase* owner_opr;
  1029. VarNodeArray vars_to_check, new_vars, rst;
  1030. bool conv_output_found = false;
  1031. VarNode* conv_output_var = nullptr;
  1032. VarNode* bias_var = nullptr;
  1033. VarNode* new_output_var = nullptr;
  1034. if (m_const_var_propogate->is_const(inp_opr->input(0))) {
  1035. vars_to_check.push_back(inp_opr->input(1));
  1036. new_output_var = inp_opr->input(1);
  1037. bias_var = inp_opr->input(0);
  1038. } else if (m_const_var_propogate->is_const(inp_opr->input(1))) {
  1039. vars_to_check.push_back(inp_opr->input(0));
  1040. new_output_var = inp_opr->input(0);
  1041. bias_var = inp_opr->input(1);
  1042. } else {
  1043. // No const input. return empty rst.
  1044. return rst;
  1045. }
  1046. while (vars_to_check.size() != 0) {
  1047. for (size_t i = 0; i < vars_to_check.size(); ++i) {
  1048. owner_opr = vars_to_check[i]->owner_opr();
  1049. if (owner_opr->same_type<opr::Convolution>() ||
  1050. owner_opr->same_type<opr::ConvolutionBackwardData>()) {
  1051. conv_output_found = true;
  1052. conv_output_var = vars_to_check[i];
  1053. break;
  1054. }
  1055. if (owner_opr->same_type<opr::Elemwise>() &&
  1056. owner_opr->cast_final<opr::Elemwise>().param().mode ==
  1057. opr::Elemwise::Mode::ADD) {
  1058. for (auto var2chk : owner_opr->input()) {
  1059. new_vars.push_back(var2chk);
  1060. }
  1061. }
  1062. }
  1063. vars_to_check.clear();
  1064. if (conv_output_found)
  1065. break;
  1066. if (new_vars.size() != 0) {
  1067. vars_to_check.insert(vars_to_check.end(), new_vars.begin(),
  1068. new_vars.end());
  1069. new_vars.clear();
  1070. }
  1071. }
  1072. if (conv_output_found) {
  1073. conv_output_found &= m_graph_map[inp_opr] ==
  1074. m_graph_map[conv_output_var->owner_opr()];
  1075. auto&& trt_graph = m_tensorrt_graphs[m_graph_map[inp_opr] - 1];
  1076. conv_output_found &= trt_graph->outputs.count(conv_output_var) == 0;
  1077. }
  1078. if (conv_output_found) {
  1079. rst.push_back(conv_output_var);
  1080. rst.push_back(bias_var);
  1081. rst.push_back(new_output_var);
  1082. }
  1083. return rst;
  1084. }
  1085. bool TensorRTReplacePass::Impl::check_input(
  1086. VarNode* var, OperatorNodeBase* opr,
  1087. SmallVector<nvinfer1::DimensionType> dimtypes) {
  1088. auto trt_graph = m_tensorrt_graphs[m_graph_map[opr] - 1];
  1089. auto&& varnode2itensor = trt_graph->varnode2itensor;
  1090. auto iter = trt_graph->inputs.find(var);
  1091. if (iter == trt_graph->inputs.end()) // not a input of trt graph
  1092. return false;
  1093. for (auto i : trt_graph->trt_inputs)
  1094. if (i == var) // already added to input
  1095. return false;
  1096. trt_graph->trt_inputs.push_back(var);
  1097. nvinfer1::ITensor* itensor;
  1098. MGB_MARK_USED_VAR(mgb_dtype_to_trt_dtype);
  1099. if (dimtypes.size() == 0) {
  1100. #if NV_TENSOR_RT_VERSION >= 6001
  1101. mgb_assert(var->shape().ndim == 4 || (var->shape().ndim == 5 && var->shape()[4] == 4));
  1102. nvinfer1::Dims4 dims{static_cast<int>(var->shape()[0]),
  1103. static_cast<int>(var->shape()[1]),
  1104. static_cast<int>(var->shape()[2]),
  1105. static_cast<int>(var->shape()[3])};
  1106. if (var->shape().ndim == 5) {
  1107. mgb_assert(var->shape()[4] == 4);
  1108. dims.d[1] *= 4;
  1109. }
  1110. itensor = trt_graph->network->addInput(
  1111. var->cname(), mgb_dtype_to_trt_dtype(var->dtype()),
  1112. dims);
  1113. if (trt_graph->mark_input_varnode_nchw4.count(var)) {
  1114. itensor->setAllowedFormats(
  1115. 1 << static_cast<int>(nvinfer1::TensorFormat::kCHW4));
  1116. } else {
  1117. itensor->setAllowedFormats(
  1118. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR));
  1119. }
  1120. #else
  1121. if (var->shape().ndim == 4) {
  1122. // the default input tensor is a NCHW tensor
  1123. mgb_assert(var->shape().ndim == 4,
  1124. "Default input tensor should be NCHW or NCHW4 format.");
  1125. itensor = trt_graph->network->addInput(
  1126. var->cname(), nvinfer1::DataType::kFLOAT,
  1127. nvinfer1::DimsNCHW{static_cast<int>(var->shape()[0]),
  1128. static_cast<int>(var->shape()[1]),
  1129. static_cast<int>(var->shape()[2]),
  1130. static_cast<int>(var->shape()[3])});
  1131. } else {
  1132. mgb_assert(var->shape().ndim == 5 && var->shape()[4] == 4,
  1133. "Input tensor format is not NCHW4 (got %s)",
  1134. var->shape().to_string().c_str());
  1135. itensor = trt_graph->network->addInput(
  1136. var->cname(), nvinfer1::DataType::kFLOAT,
  1137. nvinfer1::DimsNCHW{static_cast<int>(var->shape()[0]),
  1138. static_cast<int>(var->shape()[1] * 4),
  1139. static_cast<int>(var->shape()[2]),
  1140. static_cast<int>(var->shape()[3])});
  1141. }
  1142. #endif
  1143. } else {
  1144. nvinfer1::Dims dims;
  1145. // process var node that marked as nchw4 format
  1146. if (trt_graph->mark_input_varnode_nchw4.count(var)) {
  1147. mgb_assert(var->shape().ndim == 5 && var->shape()[4] == 4,
  1148. "Input tensor format is not NCHW4 (got %s)",
  1149. var->shape().to_string().c_str());
  1150. dims.nbDims = var->shape().ndim - 1;
  1151. for (size_t i = 0; i < var->shape().ndim - 1; i++) {
  1152. dims.d[i] = var->shape()[i];
  1153. #if NV_TENSOR_RT_VERSION < 6001
  1154. dims.type[i] = dimtypes[i];
  1155. #endif
  1156. }
  1157. dims.d[1] *= 4;
  1158. // process conventional var node
  1159. } else {
  1160. mgb_assert(var->shape().ndim == dimtypes.size());
  1161. mgb_assert(var->shape().ndim <= nvinfer1::Dims::MAX_DIMS);
  1162. dims.nbDims = var->shape().ndim;
  1163. for (size_t i = 0; i < var->shape().ndim; i++) {
  1164. dims.d[i] = var->shape()[i];
  1165. #if NV_TENSOR_RT_VERSION < 6001
  1166. dims.type[i] = dimtypes[i];
  1167. #endif
  1168. }
  1169. }
  1170. #if NV_TENSOR_RT_VERSION >= 6001
  1171. itensor = trt_graph->network->addInput(
  1172. var->cname(), mgb_dtype_to_trt_dtype(var->dtype()), dims);
  1173. if (trt_graph->mark_input_varnode_nchw4.count(var)) {
  1174. itensor->setAllowedFormats(
  1175. 1 << static_cast<int>(nvinfer1::TensorFormat::kCHW4));
  1176. } else {
  1177. itensor->setAllowedFormats(
  1178. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR));
  1179. }
  1180. #else
  1181. itensor = trt_graph->network->addInput(
  1182. var->cname(), nvinfer1::DataType::kFLOAT, dims);
  1183. #endif
  1184. }
  1185. varnode2itensor[var] = itensor;
  1186. if (trt_graph->feature_bits == TensorRTGraphFeatureBits::NCHW4_QINT8)
  1187. set_itensor_dynamic_range(var, opr);
  1188. return true;
  1189. }
  1190. void TensorRTReplacePass::Impl::set_itensor_dynamic_range(
  1191. VarNode* var, OperatorNodeBase* opr) {
  1192. MGB_MARK_USED_VAR(var);
  1193. MGB_MARK_USED_VAR(opr);
  1194. #if NV_TENSOR_RT_VERSION >= 5020
  1195. auto&& varnode2itensor =
  1196. m_tensorrt_graphs[m_graph_map[opr] - 1]->varnode2itensor;
  1197. auto&& tensor = varnode2itensor[var];
  1198. auto&& data_type = var->dtype();
  1199. mgb_assert(data_type.enumv() == DTypeEnum::QuantizedS8);
  1200. float scale = get_scale(data_type);
  1201. tensor->setDynamicRange(-i8_max * scale, i8_max * scale);
  1202. #endif
  1203. }
  1204. HostTensorND TensorRTReplacePass::Impl::get_value(VarNode* var, ConvFormat format) {
  1205. auto cg = m_opt_state.graph().comp_graph();
  1206. auto inferred_val = HostTensorND(var->comp_node(), dtype::Float32());
  1207. auto cb = [&](DeviceTensorND& val) { inferred_val.copy_from(val); };
  1208. if (format == ConvFormat::NCHW) {
  1209. mgb_assert(var->dtype() == dtype::Float32());
  1210. auto orig_level = cg->options().log_level;
  1211. cg->options().log_level = 0;
  1212. MGB_TRY { cg->compile({{var, cb}})->execute(); }
  1213. MGB_FINALLY(cg->options().log_level = orig_level);
  1214. } else {
  1215. mgb_assert(format == ConvFormat::NCHW4);
  1216. if (var->shape().ndim == 5) {
  1217. // assume nchw4 layout
  1218. mgb_assert(var->shape()[4] == 4);
  1219. auto x = SymbolVar(var);
  1220. auto xshp = opr::GetVarShape::make(x);
  1221. auto cv = [&x](int v) { return x.make_scalar(v); };
  1222. auto sub = [&xshp, &cv](int idx) {
  1223. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1224. };
  1225. auto tshp =
  1226. opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
  1227. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  1228. auto y1 = opr::Reshape::make(y0, tshp);
  1229. if (var->dtype().enumv() == DTypeEnum::QuantizedS8 ||
  1230. var->dtype().enumv() == DTypeEnum::QuantizedS32) {
  1231. y1 = opr::TypeCvt::make(y1, dtype::Float32());
  1232. }
  1233. auto orig_level = cg->options().log_level;
  1234. cg->options().log_level = 0;
  1235. cg->options().graph_opt.tensorrt = false;
  1236. MGB_TRY { cg->compile({{y1.node(), cb}})->execute(); }
  1237. MGB_FINALLY({
  1238. cg->options().log_level = orig_level;
  1239. cg->options().graph_opt.tensorrt = true;
  1240. });
  1241. } else if (var->shape().ndim == 6) {
  1242. // assume nchw4 layout
  1243. mgb_assert(var->shape()[5] == 4);
  1244. mgb_assert(var->dtype().enumv() == DTypeEnum::QuantizedS8 ||
  1245. var->dtype() == dtype::Float32());
  1246. auto x = SymbolVar(var);
  1247. auto xshp = opr::GetVarShape::make(x);
  1248. auto cv = [&x](int v) { return x.make_scalar(v); };
  1249. auto sub = [&xshp, &cv](int idx) {
  1250. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1251. };
  1252. auto tshp = opr::Concat::make(
  1253. {sub(0), sub(1), sub(2) * 4, sub(3), sub(4)}, 0);
  1254. auto y0 = opr::Dimshuffle::make(x, {0, 1, 2, 5, 3, 4});
  1255. auto y1 = opr::Reshape::make(y0, tshp);
  1256. if (var->dtype().enumv() == DTypeEnum::QuantizedS8) {
  1257. y1 = opr::TypeCvt::make(y1, dtype::Float32());
  1258. }
  1259. auto orig_level = cg->options().log_level;
  1260. cg->options().log_level = 0;
  1261. cg->options().graph_opt.tensorrt = false;
  1262. MGB_TRY { cg->compile({{y1.node(), cb}})->execute(); }
  1263. MGB_FINALLY({
  1264. cg->options().log_level = orig_level;
  1265. cg->options().graph_opt.tensorrt = true;
  1266. });
  1267. }
  1268. }
  1269. auto ptr = var->owner_graph()
  1270. ->options()
  1271. .user_data.get_user_data_or_create<HostTensorKeeper>();
  1272. ptr->htr.push_back(inferred_val);
  1273. return inferred_val;
  1274. }
  1275. float TensorRTReplacePass::Impl::get_scale(DType data_type) {
  1276. float scale = 1.f;
  1277. #define cb(_dt) \
  1278. case DTypeTrait<_dt>::enumv: \
  1279. scale = data_type.param<_dt>().scale; \
  1280. break;
  1281. switch (data_type.enumv()) {
  1282. MEGDNN_FOREACH_QUANTIZED_DTYPE(cb);
  1283. default:
  1284. mgb_throw(InternalError, "invalid quantized data type: %s",
  1285. data_type.name());
  1286. }
  1287. return scale;
  1288. #undef cb
  1289. }
  1290. bool TensorRTReplacePass::Impl::is_quantized_int8_operator(
  1291. OperatorNodeBase* opr) {
  1292. bool is_quantized = true;
  1293. if (opr->same_type<opr::ConvBias>()) {
  1294. is_quantized = opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
  1295. mgb_assert(!is_quantized ||
  1296. opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8);
  1297. return is_quantized;
  1298. }
  1299. for (auto&& inp : opr->input()) {
  1300. if (inp->dtype().enumv() != DTypeEnum::QuantizedS8) {
  1301. is_quantized = false;
  1302. break;
  1303. }
  1304. }
  1305. // assume all operator has only one output
  1306. auto&& out = opr->output(0);
  1307. if (out->dtype().enumv() != DTypeEnum::QuantizedS8) {
  1308. is_quantized = false;
  1309. }
  1310. return is_quantized;
  1311. }
  1312. void TensorRTReplacePass::Impl::detect_replace() {
  1313. auto cb = [this](OperatorNodeBase* opr) {
  1314. m_const_var_propogate->add_opr(opr);
  1315. };
  1316. m_opt_state.graph().iter(cb);
  1317. auto on_opr = [this](OperatorNodeBase* opr) {
  1318. ++m_opr_num;
  1319. Maybe<std::string> irreplaceable_msg = has_fail_msg(opr);
  1320. TensorRTGraphFeatureBits feature_bits =
  1321. is_quantized_int8_operator(opr)
  1322. ? TensorRTGraphFeatureBits::NCHW4_QINT8
  1323. : TensorRTGraphFeatureBits::NCHW_FLOAT;
  1324. if (!irreplaceable_msg.valid()) {
  1325. size_t max = 1;
  1326. for (auto i : opr->input()) {
  1327. if (!has_fail_msg(i->owner_opr()).valid())
  1328. update_max(max, m_graph_map[i->owner_opr()]);
  1329. else
  1330. update_max(max, m_graph_map[i->owner_opr()] + 1);
  1331. }
  1332. size_t max_update = max;
  1333. for (; max_update <= m_tensorrt_graphs.size(); max_update++) {
  1334. TensorRTGraphFeatureBits trt_graph_feature_bits =
  1335. m_tensorrt_graphs[max_update - 1]->feature_bits;
  1336. if (trt_graph_feature_bits == feature_bits)
  1337. break;
  1338. }
  1339. max = max_update;
  1340. m_graph_map[opr] = max;
  1341. if (max > m_tensorrt_graphs.size()) {
  1342. opr->output(0)->comp_node().activate();
  1343. m_tensorrt_graphs.push_back(
  1344. std::make_shared<TensorRTGraph>(feature_bits));
  1345. }
  1346. for (auto i : opr->input()) {
  1347. if (m_graph_map[i->owner_opr()] != max) {
  1348. m_tensorrt_graphs[max - 1]->inputs.insert(i);
  1349. if (!has_fail_msg(i->owner_opr()).valid()) {
  1350. //! TODO: check
  1351. m_tensorrt_graphs[m_graph_map[i->owner_opr()] - 1]
  1352. ->outputs.insert(i);
  1353. }
  1354. }
  1355. }
  1356. } else {
  1357. static const ThinHashSet<Typeinfo*> ignore_types{
  1358. opr::SharedDeviceTensor::typeinfo(),
  1359. opr::ImmutableTensor::typeinfo(),
  1360. opr::Host2DeviceCopy::typeinfo(),
  1361. opr::MultipleDeviceTensorHolder::typeinfo()};
  1362. if (!ignore_types.count(opr->dyn_typeinfo())) {
  1363. ++m_opr_fail_num;
  1364. if (m_opr_fail.size() < OPR_FAIL_LOG_NUM) {
  1365. FailInfo fail_info;
  1366. fail_info.opr = opr;
  1367. fail_info.fail_msg = irreplaceable_msg.val();
  1368. m_opr_fail.push_back(fail_info);
  1369. }
  1370. }
  1371. size_t max = 0;
  1372. for (auto i : opr->input()) {
  1373. if (m_graph_map[i->owner_opr()] > max)
  1374. max = m_graph_map[i->owner_opr()];
  1375. if (!has_fail_msg(i->owner_opr()).valid()) {
  1376. //! TODO: check
  1377. m_tensorrt_graphs[m_graph_map[i->owner_opr()] - 1]
  1378. ->outputs.insert(i);
  1379. }
  1380. }
  1381. m_graph_map[opr] = max;
  1382. }
  1383. };
  1384. m_opt_state.graph().iter(on_opr);
  1385. for (auto i : m_opt_state.graph().endpoint_vars()) {
  1386. auto var_node = i.node();
  1387. if (!has_fail_msg(var_node->owner_opr()).valid()) {
  1388. //! TODO: check
  1389. m_tensorrt_graphs[m_graph_map[var_node->owner_opr()] - 1]
  1390. ->outputs.insert(var_node);
  1391. }
  1392. }
  1393. }
  1394. void TensorRTReplacePass::Impl::
  1395. mark_varnode_format_nchw4() {
  1396. for (auto trt_graph : m_tensorrt_graphs) {
  1397. trt_graph->mark_varnode_format_nchw4();
  1398. }
  1399. }
  1400. void TensorRTReplacePass::Impl::update_graph() {
  1401. using GpuAllocator = opr::TensorRTOpr::GpuAllocator;
  1402. using TensorRTOpr = opr::TensorRTOpr;
  1403. std::shared_ptr<GpuAllocator> gpu_allocator;
  1404. std::shared_ptr<ExtraDep> extra_dep = std::make_shared<ExtraDep>();
  1405. // construct trt network
  1406. auto construct_network = [this, &gpu_allocator, &extra_dep](OperatorNodeBase* opr) {
  1407. if (!has_fail_msg(opr).valid()) {
  1408. auto cn = opr->output(0)->comp_node();
  1409. auto trt_graph = m_tensorrt_graphs[m_graph_map[opr] - 1];
  1410. auto b = trt_graph->builder;
  1411. mgb_assert(b != nullptr);
  1412. if (!gpu_allocator) {
  1413. gpu_allocator = std::make_shared<GpuAllocator>(cn);
  1414. b->setGpuAllocator(gpu_allocator.get());
  1415. } else {
  1416. auto cn0 = gpu_allocator->comp_node();
  1417. mgb_assert(cn0 == cn,
  1418. "multiple comp nodes for trt graph are not "
  1419. "supported: %s %s",
  1420. cn0.to_string().c_str(), cn.to_string().c_str());
  1421. }
  1422. if (!trt_graph->network) {
  1423. #if NV_TENSOR_RT_VERSION >= 6001
  1424. nvinfer1::NetworkDefinitionCreationFlags flags;
  1425. flags = 1 << static_cast<int>(
  1426. nvinfer1::NetworkDefinitionCreationFlag::
  1427. kEXPLICIT_BATCH);
  1428. trt_graph->network = b->createNetworkV2(flags);
  1429. #else
  1430. trt_graph->network = b->createNetwork();
  1431. #endif
  1432. }
  1433. // make extra dep
  1434. for (auto&& inp : trt_graph->inputs) {
  1435. extra_dep->operator[](opr).push_back(inp);
  1436. }
  1437. auto iter = m_opr_trait.find(opr->dyn_typeinfo());
  1438. if (iter != m_opr_trait.end()) {
  1439. if (iter->second.add_to_nvinfer) {
  1440. iter->second.add_to_nvinfer(trt_graph->network, opr);
  1441. }
  1442. }
  1443. }
  1444. };
  1445. m_opt_state.graph().iter(construct_network);
  1446. // trt network markOutput
  1447. for (auto trt_graph : m_tensorrt_graphs) {
  1448. // record traverse order
  1449. size_t idx = 0;
  1450. auto&& varnode2itensor = trt_graph->varnode2itensor;
  1451. for (auto output : trt_graph->outputs) {
  1452. trt_graph->output2idx[output] = idx++;
  1453. trt_graph->network->markOutput(*varnode2itensor[output]);
  1454. #if NV_TENSOR_RT_VERSION >= 6001
  1455. if (output->dtype().enumv() == DTypeEnum::QuantizedS8) {
  1456. varnode2itensor[output]->setType(nvinfer1::DataType::kINT8);
  1457. }
  1458. if (trt_graph->mark_output_varnode_nchw4.count(output)) {
  1459. mgb_assert(output->dtype().enumv() == DTypeEnum::QuantizedS8);
  1460. varnode2itensor[output]->setAllowedFormats(
  1461. 1 << static_cast<int>(nvinfer1::TensorFormat::kCHW4));
  1462. } else {
  1463. varnode2itensor[output]->setAllowedFormats(
  1464. 1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR));
  1465. }
  1466. #endif
  1467. }
  1468. }
  1469. ThinHashSet<OperatorNodeBase*> visited;
  1470. // replace opr by trt
  1471. auto update_opr = [this, &gpu_allocator,
  1472. &visited](OperatorNodeBase* opr) {
  1473. if (!has_fail_msg(opr).valid()) {
  1474. mgb_assert(gpu_allocator);
  1475. auto trt_graph = m_tensorrt_graphs[m_graph_map[opr] - 1];
  1476. for (auto&& inp : trt_graph->trt_inputs) {
  1477. mgb_assert(visited.count(inp->owner_opr()));
  1478. }
  1479. if (trt_graph->trt_outputs.empty()) {
  1480. // use updated varnode instead of old one
  1481. auto inps = trt_graph->trt_inputs;
  1482. VarNodeArray new_inps{inps.size()};
  1483. for (size_t i = 0; i < inps.size(); i++) {
  1484. new_inps[i] = m_rewriter.get_var(inps[i]);
  1485. #if NV_TENSOR_RT_VERSION < 6001
  1486. if (trt_graph->mark_input_varnode_nchw4.count(inps[i])) {
  1487. auto x = SymbolVar(new_inps[i]);
  1488. auto xshp = opr::GetVarShape::make(x);
  1489. auto cv = [&x](int v) { return x.make_scalar(v); };
  1490. auto sub = [&xshp, &cv](int idx) {
  1491. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1492. };
  1493. auto tshp = opr::Concat::make(
  1494. {sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
  1495. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  1496. auto y1 = opr::Reshape::make(y0, tshp);
  1497. new_inps[i] = y1.node();
  1498. }
  1499. if (inps[i]->dtype().enumv() == DTypeEnum::QuantizedS8) {
  1500. new_inps[i] = opr::TypeCvt::make(new_inps[i],
  1501. dtype::Float32())
  1502. .node();
  1503. }
  1504. #endif
  1505. }
  1506. // now trt_graph does not own the unique_ptr of infer builder
  1507. m_opt_state.call_with_opr(opr, [&] {
  1508. trt_graph->trt_outputs =
  1509. cg::to_var_node_array(TensorRTOpr::make(
  1510. TensorRTOpr::to_shared_ptr_builder(
  1511. trt_graph->builder),
  1512. TensorRTOpr::to_shared_ptr_network(
  1513. trt_graph->network),
  1514. trt_graph->feature_bits, gpu_allocator,
  1515. cg::to_symbol_var_array(new_inps)));
  1516. });
  1517. mgb_assert(trt_graph->trt_outputs.size() ==
  1518. trt_graph->outputs.size(),
  1519. "mgb outputs number != tensorrt outputs number");
  1520. }
  1521. for (auto&& output : opr->output()) {
  1522. if (trt_graph->outputs.count(output)) {
  1523. size_t output_idx = trt_graph->output2idx[output];
  1524. VarNode* output_var = trt_graph->trt_outputs[output_idx];
  1525. #if NV_TENSOR_RT_VERSION < 6001
  1526. if (trt_graph->mark_output_varnode_nchw4.count(output)) {
  1527. auto x = SymbolVar(output_var);
  1528. auto xshp = opr::GetVarShape::make(x);
  1529. auto cv = [&x](int v) { return x.make_scalar(v); };
  1530. auto sub = [&xshp, &cv](int idx) {
  1531. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1532. };
  1533. auto tshp = opr::Concat::make(
  1534. {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  1535. auto y0 = opr::Reshape::make(x, tshp);
  1536. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  1537. output_var = y1.node();
  1538. }
  1539. if (output->dtype().enumv() == DTypeEnum::QuantizedS8) {
  1540. float scale = get_scale(output->dtype());
  1541. output_var =
  1542. opr::TypeCvt::make(output_var,
  1543. dtype::QuantizedS8{scale})
  1544. .node();
  1545. }
  1546. #endif
  1547. m_rewriter.replace_var(
  1548. output, output_var,
  1549. mgb_ssprintf_log("replace opr: %s",
  1550. output->owner_opr()->cname())
  1551. .c_str());
  1552. }
  1553. }
  1554. visited.insert(opr);
  1555. } else {
  1556. for (auto&& inp : opr->input()) {
  1557. mgb_assert(visited.count(inp->owner_opr()));
  1558. }
  1559. visited.insert(opr);
  1560. m_rewriter.auto_replace_outputs(opr);
  1561. }
  1562. };
  1563. m_opt_state.graph().iter(update_opr, std::move(extra_dep));
  1564. m_rewriter.apply_inplace();
  1565. }
  1566. const char* TensorRTReplacePass::name() const {
  1567. return mgb_cstr_log("tensorrt_replace");
  1568. }
  1569. void TensorRTReplacePass::apply(OptState& opt) const {
  1570. if (CompNode::get_device_count(CompNode::DeviceType::CUDA)) {
  1571. opt.set_var_replace_check_flag(gopt::VarReplaceCheckFlag::CHECK_SHAPE |
  1572. gopt::VarReplaceCheckFlag::CHECK_DTYPE);
  1573. Impl(*this, opt);
  1574. } else {
  1575. mgb_log_debug("cuda is not available; TensorRTReplacePass is ignored");
  1576. }
  1577. }
  1578. // ===================== TensorRTGraph =================
  1579. void TensorRTReplacePass::Impl::TensorRTGraph::mark_varnode_format_nchw4() {
  1580. // consider TensorRT subgraph as a bi-directed graph and divide it into
  1581. // multi connected components, mark the subgraph's inputs or outputs varnode
  1582. // in format nchw4 iff the varnode belong to the connected components which
  1583. // contains at least one NCHW4 operator(e.g. ConvBias, Pooling)
  1584. // p[arrent] array use for Disjoint Set
  1585. ThinHashMap<OperatorNodeBase*, OperatorNodeBase*> p;
  1586. ThinHashSet<OperatorNodeBase*> outsides;
  1587. thin_function<OperatorNodeBase*(OperatorNodeBase*)> get_root;
  1588. get_root = [&](OperatorNodeBase* opr) -> OperatorNodeBase* {
  1589. mgb_assert(p.count(opr));
  1590. return p[opr] == opr ? opr : p[opr] = get_root(p[opr]);
  1591. };
  1592. auto is_format_nchw4 = [&](OperatorNodeBase* opr) {
  1593. if (outsides.count(opr)) {
  1594. return false;
  1595. }
  1596. if (opr->same_type<opr::ConvBias>()) {
  1597. auto&& param = opr->cast_final_safe<opr::ConvBias>().param();
  1598. if (param.format == opr::ConvBias::Param::Format::NCHW4)
  1599. return true;
  1600. }
  1601. if (opr->same_type<opr::Pooling>()) {
  1602. auto&& param = opr->cast_final_safe<opr::Pooling>().param();
  1603. if (param.format == opr::Pooling::Param::Format::NCHW4)
  1604. return true;
  1605. }
  1606. return false;
  1607. };
  1608. auto cb = [&](OperatorNodeBase* opr) {
  1609. mgb_assert(!p.count(opr));
  1610. p[opr] = opr;
  1611. for (auto&& inp: opr->input()) {
  1612. auto root = get_root(inp->owner_opr());
  1613. // ensure that if one of oprs in tree is nchw4
  1614. // the root of the tree must be nchw4
  1615. if (is_format_nchw4(root)) {
  1616. p[get_root(opr)] = root;
  1617. } else {
  1618. p[root] = get_root(opr);
  1619. }
  1620. }
  1621. };
  1622. DepOprIter iter{cb};
  1623. for (auto&& inp : inputs) {
  1624. p[inp->owner_opr()] = inp->owner_opr();
  1625. iter.set_visited(inp->owner_opr());
  1626. outsides.insert(inp->owner_opr());
  1627. }
  1628. for (auto&& out : outputs) {
  1629. iter.add(out->owner_opr());
  1630. }
  1631. for (auto&& inp : inputs) {
  1632. if (is_format_nchw4(get_root(inp->owner_opr()))) {
  1633. mark_input_varnode_nchw4.insert(inp);
  1634. }
  1635. }
  1636. for (auto&& out : outputs) {
  1637. if (is_format_nchw4(get_root(out->owner_opr()))) {
  1638. mark_output_varnode_nchw4.insert(out);
  1639. }
  1640. }
  1641. }
  1642. void mgb::tensorrt::transform_dest_vars_inplace(
  1643. mgb::cg::VarNodeArray& dest_vars,
  1644. cg::GraphCommonOptimizeOptions& options) {
  1645. gopt::GraphOptimizer optimizer;
  1646. //! As in megengine, the layout is NCHW, while tensorrt pass currently
  1647. //! only support NCHW4(int8), so we transform layout to nchw4 firstly.
  1648. if (options.has_set_nchw4()) {
  1649. options.disable_nchw4();
  1650. optimizer.add_pass<FuseConvBiasNonlinPass>();
  1651. optimizer.add_pass(EnableNCHW4Pass::make_nchw4_converter());
  1652. }
  1653. optimizer.add_pass<ExpandFusedArithPass>();
  1654. optimizer.add_pass<gopt::TensorRTReplacePass>();
  1655. optimizer.add_pass<ArithFusePass>();
  1656. #if NV_TENSOR_RT_VERSION < 6001
  1657. optimizer.add_pass<ShuffleShuffleRemovePass>();
  1658. optimizer.add_pass<RemoveRedundantTypeCvtPass>();
  1659. #endif
  1660. optimizer.apply_inplace(dest_vars);
  1661. }
  1662. #pragma GCC diagnostic pop
  1663. #endif
  1664. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台