You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

specializations.cpp 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. /**
  2. * \file imperative/src/impl/ops/specialzations.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. // FIXME: split this file into separate files for each specialized op
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/opr/basic_arith.h"
  15. #include "megbrain/opr/blas.h"
  16. #include "megbrain/opr/dnn/adaptive_pooling.h"
  17. #include "megbrain/opr/dnn/convolution.h"
  18. #include "megbrain/opr/dnn/correlation.h"
  19. #include "megbrain/opr/dnn/fake_quant.h"
  20. #include "megbrain/opr/dnn/images2neibs.h"
  21. #include "megbrain/opr/dnn/local.h"
  22. #include "megbrain/opr/dnn/lsq.h"
  23. #include "megbrain/opr/dnn/pooling.h"
  24. #include "megbrain/opr/dnn/roi_align.h"
  25. #include "megbrain/opr/dnn/roi_pooling.h"
  26. #include "megbrain/opr/dnn/tqt.h"
  27. #include "megbrain/opr/imgproc.h"
  28. #include "megbrain/opr/indexing.h"
  29. #include "megbrain/opr/io.h"
  30. #include "megbrain/opr/misc.h"
  31. #include "megbrain/opr/nn_int.h"
  32. #include "megbrain/opr/rand.h"
  33. #include "megbrain/opr/tensor_gen.h"
  34. #include "megbrain/opr/tensor_manip.h"
  35. #include "megbrain/opr/utility.h"
  36. #include "megbrain/opr/dnn/images2neibs.h"
  37. #include "megbrain/opr/dnn/sliding_window_transpose.h"
  38. #include "../op_trait.h"
  39. namespace mgb::imperative {
  40. namespace {
  41. namespace dimshuffle {
  42. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  43. auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
  44. std::vector<int> pattern(node->param().pattern_len);
  45. for (size_t i = 0; i < node->param().pattern_len; ++i) {
  46. pattern[i] = node->param().pattern[i];
  47. }
  48. return Dimshuffle::make(pattern);
  49. }
  50. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  51. auto&& ds = static_cast<const Dimshuffle&>(def);
  52. OperatorNodeConfig config{ds.make_name()};
  53. return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
  54. }
  55. OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
  56. .make_from_op_node(make_from_op_node)
  57. .apply_on_var_node(apply_on_var_node)
  58. .fallback();
  59. } // namespace dimshuffle
  60. } // namespace
  61. namespace {
  62. namespace add_axis {
  63. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  64. auto&& add_axis = static_cast<const AddAxis&>(def);
  65. using Desc = opr::AxisAddRemove::AxisDesc;
  66. std::vector<Desc> param;
  67. for (auto&& i : add_axis.axis) {
  68. param.push_back(Desc::make_add(i));
  69. }
  70. OperatorNodeConfig config{add_axis.make_name()};
  71. return opr::AxisAddRemove::make(inputs[0], param, config);
  72. }
  73. OP_TRAIT_REG(AddAxis, AddAxis).apply_on_var_node(apply_on_var_node).fallback();
  74. } // namespace add_axis
  75. } // namespace
  76. namespace {
  77. namespace remove_axis {
  78. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  79. auto&& remove_axis = static_cast<const RemoveAxis&>(def);
  80. using Desc = opr::AxisAddRemove::AxisDesc;
  81. std::vector<Desc> param;
  82. for (auto&& i : remove_axis.axis) {
  83. param.push_back(Desc::make_remove(i));
  84. }
  85. OperatorNodeConfig config{remove_axis.make_name()};
  86. return opr::AxisAddRemove::make(inputs[0], param, config);
  87. }
  88. OP_TRAIT_REG(RemoveAxis, RemoveAxis)
  89. .apply_on_var_node(apply_on_var_node)
  90. .fallback();
  91. } // namespace remove_axis
  92. } // namespace
  93. namespace {
  94. namespace top_k {
  95. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  96. auto&& topk = static_cast<const TopK&>(def);
  97. OperatorNodeConfig config{topk.make_name()};
  98. return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
  99. .node()
  100. ->owner_opr();
  101. }
  102. OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback();
  103. } // namespace top_k
  104. } // namespace
  105. namespace {
  106. namespace reduce {
  107. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  108. auto&& reduce = static_cast<const Reduce&>(def);
  109. OperatorNodeConfig config{reduce.make_name()};
  110. if (inputs.size() > 1) {
  111. return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config);
  112. } else {
  113. return opr::Reduce::make(inputs[0], reduce.param(),
  114. (cg::VarNode*)nullptr, config);
  115. }
  116. }
  117. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  118. auto* node = &node_->cast_final_safe<opr::Reduce>();
  119. return Reduce::make(node->param());
  120. }
  121. OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
  122. .make_from_op_node(make_from_op_node)
  123. .apply_on_var_node(apply_on_var_node)
  124. .fallback();
  125. } // namespace reduce
  126. } // namespace
  127. namespace {
  128. namespace adaptive_pooling {
  129. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  130. auto&& pool = static_cast<const AdaptivePooling&>(def);
  131. OperatorNodeConfig config{pool.make_name()};
  132. return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(),
  133. config);
  134. }
  135. OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
  136. .apply_on_var_node(apply_on_var_node)
  137. .fallback();
  138. } // namespace adaptive_pooling
  139. } // namespace
  140. namespace {
  141. namespace conv_bias {
  142. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  143. auto&& conv = static_cast<const ConvBias&>(def);
  144. cg::OperatorNodeConfig config{conv.dtype};
  145. config.name(conv.make_name());
  146. if (inputs.size() == 2) {
  147. return opr::ConvBias::make(inputs[0], inputs[1], conv.param(),
  148. conv.policy(), config);
  149. } else if (inputs.size() == 3) {
  150. return opr::ConvBias::make(inputs[0], inputs[1], inputs[2],
  151. conv.param(), conv.policy(), config);
  152. } else if (inputs.size() == 4) {
  153. return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3],
  154. conv.param(), conv.policy(), config);
  155. }
  156. mgb_assert(0);
  157. }
  158. OP_TRAIT_REG(ConvBias, ConvBias)
  159. .apply_on_var_node(apply_on_var_node)
  160. .fallback();
  161. } // namespace conv_bias
  162. } // namespace
  163. namespace {
  164. namespace batch_conv_bias {
  165. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  166. auto&& conv = static_cast<const BatchConvBias&>(def);
  167. cg::OperatorNodeConfig config{conv.dtype};
  168. config.name(conv.make_name());
  169. if (inputs.size() == 2) {
  170. return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(),
  171. conv.policy(), config);
  172. } else if (inputs.size() == 3) {
  173. return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2],
  174. conv.param(), conv.policy(), config);
  175. } else if (inputs.size() == 4) {
  176. return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2],
  177. inputs[3], conv.param(), conv.policy(),
  178. config);
  179. }
  180. mgb_assert(0);
  181. }
  182. OP_TRAIT_REG(BatchConvBias, BatchConvBias)
  183. .apply_on_var_node(apply_on_var_node)
  184. .fallback();
  185. } // namespace batch_conv_bias
  186. } // namespace
  187. namespace {
  188. namespace pooling {
  189. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  190. auto&& pool = static_cast<const Pooling&>(def);
  191. OperatorNodeConfig config{pool.make_name()};
  192. return opr::Pooling::make(inputs[0], pool.param(), config);
  193. }
  194. OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback();
  195. } // namespace pooling
  196. } // namespace
  197. namespace {
  198. namespace matrix_mul {
  199. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  200. auto&& matmul = static_cast<const MatrixMul&>(def);
  201. mgb_assert(inputs.size() == 2);
  202. OperatorNodeConfig config{matmul.make_name()};
  203. return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(),
  204. matmul.policy(), config);
  205. }
  206. OP_TRAIT_REG(MatrixMul, MatrixMul)
  207. .apply_on_var_node(apply_on_var_node)
  208. .fallback();
  209. } // namespace matrix_mul
  210. } // namespace
  211. namespace {
  212. namespace batched_matrix_mul {
  213. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  214. auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
  215. mgb_assert(inputs.size() == 2);
  216. OperatorNodeConfig config{matmul.make_name()};
  217. return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(),
  218. matmul.policy(), config);
  219. }
  220. OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
  221. .apply_on_var_node(apply_on_var_node)
  222. .fallback();
  223. } // namespace batched_matrix_mul
  224. } // namespace
  225. namespace {
  226. namespace dot {
  227. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  228. auto&& op = def.cast_final_safe<Dot>();
  229. mgb_assert(inputs.size() == 2);
  230. OperatorNodeConfig config{op.make_name()};
  231. return opr::Dot::make(inputs[0], inputs[1], config);
  232. }
  233. OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback();
  234. } // namespace dot
  235. } // namespace
  236. namespace {
  237. namespace argsort {
  238. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  239. auto&& argsort = static_cast<const Argsort&>(def);
  240. OperatorNodeConfig config{argsort.make_name()};
  241. return opr::Argsort::make(inputs[0], argsort.param(), config);
  242. }
  243. OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback();
  244. } // namespace argsort
  245. } // namespace
  246. namespace {
  247. namespace argmax {
  248. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  249. auto&& argmax = static_cast<const Argmax&>(def);
  250. OperatorNodeConfig config{argmax.make_name()};
  251. return opr::Argmax::make(inputs[0], argmax.param(), config);
  252. }
  253. OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback();
  254. } // namespace argmax
  255. } // namespace
  256. namespace {
  257. namespace argmin {
  258. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  259. auto&& argmin = static_cast<const Argmin&>(def);
  260. OperatorNodeConfig config{argmin.make_name()};
  261. return opr::Argmin::make(inputs[0], argmin.param(), config);
  262. }
  263. OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback();
  264. } // namespace argmin
  265. } // namespace
  266. namespace {
  267. namespace warp_perspective {
  268. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  269. auto&& warp = static_cast<const WarpPerspective&>(def);
  270. OperatorNodeConfig config{warp.make_name()};
  271. if (inputs.size() == 3) {
  272. return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2],
  273. warp.param(), config);
  274. } else {
  275. mgb_assert(inputs.size() == 4);
  276. return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2],
  277. inputs[3], warp.param(), config);
  278. }
  279. }
  280. OP_TRAIT_REG(WarpPerspective, WarpPerspective)
  281. .apply_on_var_node(apply_on_var_node)
  282. .fallback();
  283. } // namespace warp_perspective
  284. } // namespace
  285. namespace {
  286. namespace group_local {
  287. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  288. auto&& local = static_cast<const GroupLocal&>(def);
  289. mgb_assert(inputs.size() == 2);
  290. OperatorNodeConfig config{local.make_name()};
  291. return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config);
  292. }
  293. OP_TRAIT_REG(GroupLocal, GroupLocal)
  294. .apply_on_var_node(apply_on_var_node)
  295. .fallback();
  296. } // namespace group_local
  297. } // namespace
  298. namespace {
  299. namespace indexing_one_hot {
  300. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  301. auto&& op = static_cast<const IndexingOneHot&>(def);
  302. mgb_assert(inputs.size() == 2);
  303. OperatorNodeConfig config{op.make_name()};
  304. return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
  305. }
  306. OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
  307. .apply_on_var_node(apply_on_var_node)
  308. .fallback();
  309. } // namespace indexing_one_hot
  310. } // namespace
  311. namespace {
  312. namespace indexing_set_one_hot {
  313. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  314. auto&& op = static_cast<const IndexingSetOneHot&>(def);
  315. mgb_assert(inputs.size() == 3);
  316. OperatorNodeConfig config{op.make_name()};
  317. return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2],
  318. op.param(), config);
  319. }
  320. OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
  321. .apply_on_var_node(apply_on_var_node)
  322. .fallback();
  323. } // namespace indexing_set_one_hot
  324. } // namespace
  325. namespace {
  326. namespace typecvt {
  327. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  328. auto&& op = static_cast<const TypeCvt&>(def);
  329. mgb_assert(inputs.size() == 1);
  330. OperatorNodeConfig config{op.make_name()};
  331. return opr::TypeCvt::make(inputs[0], op.dtype, config);
  332. }
  333. OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
  334. } // namespace typecvt
  335. } // namespace
  336. namespace {
  337. namespace concat {
  338. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  339. auto&& op = static_cast<const Concat&>(def);
  340. cg::OperatorNodeConfig config{op.comp_node};
  341. config.name(op.make_name());
  342. return opr::Concat::make(inputs, op.axis, config);
  343. }
  344. OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback();
  345. } // namespace concat
  346. } // namespace
  347. namespace {
  348. namespace copy {
  349. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  350. auto&& op = static_cast<const Copy&>(def);
  351. mgb_assert(inputs.size() == 1);
  352. cg::OperatorNodeConfig config{op.comp_node};
  353. config.name(op.make_name());
  354. return opr::Copy::make(inputs[0], config);
  355. }
  356. OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback();
  357. } // namespace copy
  358. } // namespace
  359. namespace { namespace assert_equal {
  360. auto apply_on_var_node(
  361. const OpDef& def,
  362. const VarNodeArray& inputs) {
  363. auto&& op = def.cast_final<AssertEqual>();
  364. if (inputs.size() == 2) {
  365. return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
  366. } else {
  367. // workaround for MiniGraph, which only allow one opr in the graph
  368. mgb_assert(inputs.size() == 3);
  369. return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2],
  370. op.param(), {});
  371. }
  372. }
  373. OP_TRAIT_REG(AssertEqual, AssertEqual)
  374. .apply_on_var_node(apply_on_var_node)
  375. .fallback();
  376. } // namespace assert_equal
  377. } // namespace
  378. namespace {
  379. namespace roi_align {
  380. VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  381. auto&& op = static_cast<const ROIAlign&>(def);
  382. mgb_assert(inputs.size() == 2);
  383. OperatorNodeConfig config{op.make_name()};
  384. auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config)
  385. .node()
  386. ->owner_opr();
  387. return {opr->output(0), opr->output(1)};
  388. }
  389. OP_TRAIT_REG(ROIAlign, ROIAlign)
  390. .apply_on_var_node(apply_on_var_node)
  391. .fallback();
  392. } // namespace roi_align
  393. } // namespace
  394. namespace {
  395. namespace correlation {
  396. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  397. auto&& op = static_cast<const Correlation&>(def);
  398. mgb_assert(inputs.size() == 2);
  399. OperatorNodeConfig config{op.make_name()};
  400. return opr::Correlation::make(inputs[0], inputs[1], op.param(), config);
  401. }
  402. OP_TRAIT_REG(Correlation, Correlation)
  403. .apply_on_var_node(apply_on_var_node)
  404. .fallback();
  405. } // namespace correlation
  406. } // namespace
  407. #if MGB_CUDA
  408. namespace {
  409. namespace nvof {
  410. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  411. auto&& op = static_cast<const NvOf&>(def);
  412. mgb_assert(inputs.size() == 1);
  413. OperatorNodeConfig config{op.make_name()};
  414. return opr::NvOf::make(inputs[0], op.param(), config);
  415. }
  416. OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback();
  417. } // namespace nvof
  418. } // namespace
  419. #endif
  420. namespace {
  421. namespace linspace {
  422. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  423. auto&& op = static_cast<const Linspace&>(def);
  424. mgb_assert(inputs.size() == 3);
  425. cg::OperatorNodeConfig config{op.comp_node};
  426. config.name(op.make_name());
  427. return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(),
  428. config);
  429. }
  430. OP_TRAIT_REG(Linspace, Linspace)
  431. .apply_on_var_node(apply_on_var_node)
  432. .fallback();
  433. } // namespace linspace
  434. } // namespace
  435. namespace {
  436. namespace eye {
  437. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  438. auto&& op = static_cast<const Eye&>(def);
  439. mgb_assert(inputs.size() == 1);
  440. cg::OperatorNodeConfig config{op.comp_node};
  441. config.name(op.make_name());
  442. opr::Eye::Param param{op.k, op.dtype.enumv()};
  443. return opr::Eye::make(inputs[0], param, config);
  444. }
  445. OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback();
  446. } // namespace eye
  447. } // namespace
  448. namespace {
  449. namespace roi_pooling {
  450. VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  451. auto&& op = static_cast<const ROIPooling&>(def);
  452. mgb_assert(inputs.size() == 3);
  453. OperatorNodeConfig config{op.make_name()};
  454. auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2],
  455. op.param(), config)
  456. .node()
  457. ->owner_opr();
  458. return {opr->output(0), opr->output(1)};
  459. }
  460. OP_TRAIT_REG(ROIPooling, ROIPooling)
  461. .apply_on_var_node(apply_on_var_node)
  462. .fallback();
  463. } // namespace roi_pooling
  464. } // namespace
  465. namespace {
  466. namespace remap {
  467. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  468. auto&& op = static_cast<const Remap&>(def);
  469. mgb_assert(inputs.size() == 2);
  470. OperatorNodeConfig config{op.make_name()};
  471. return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
  472. }
  473. OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback();
  474. } // namespace remap
  475. } // namespace
  476. namespace {
  477. auto get_index(
  478. const VarNodeArray& inputs, size_t vidx,
  479. const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
  480. size_t length = mask.size();
  481. opr::Subtensor::IndexDesc ret(length);
  482. for (size_t i = 0; i < length; ++i) {
  483. auto&& [axis, begin, end, step, idx] = mask[i];
  484. ret[i].axis = axis;
  485. if (idx) {
  486. ret[i].idx = inputs[vidx++];
  487. } else {
  488. mgb_assert(begin || end || step);
  489. if (begin)
  490. ret[i].begin = inputs[vidx++];
  491. if (end)
  492. ret[i].end = inputs[vidx++];
  493. if (step)
  494. ret[i].step = inputs[vidx++];
  495. }
  496. }
  497. mgb_assert(vidx == inputs.size());
  498. return ret;
  499. }
  500. #define IN1 inputs[0]
  501. #define IN2 inputs[0], inputs[1]
  502. #define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \
  503. namespace NAME##_impl { \
  504. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { \
  505. auto&& op = static_cast<const NAME&>(def); \
  506. OperatorNodeConfig config{op.make_name()}; \
  507. return opr::NAME::make(IN##NR_INPUT, \
  508. get_index(inputs, NR_INPUT, op.items), \
  509. config); \
  510. } \
  511. OP_TRAIT_REG(NAME, NAME) \
  512. .apply_on_var_node(apply_on_var_node) \
  513. .fallback(); \
  514. }
  515. FANCY_INDEXING_IMPL(Subtensor, 1)
  516. FANCY_INDEXING_IMPL(SetSubtensor, 2)
  517. FANCY_INDEXING_IMPL(IncrSubtensor, 2)
  518. FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
  519. FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
  520. FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
  521. FANCY_INDEXING_IMPL(MeshIndexing, 1)
  522. FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
  523. FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
  524. FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
  525. FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
  526. FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)
  527. #undef FANCY_INDEXING_IMPL
  528. #undef IN1
  529. #undef IN2
  530. } // anonymous namespace
  531. namespace {
  532. namespace fake_quant {
  533. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  534. auto&& op = static_cast<const FakeQuant&>(def);
  535. mgb_assert(inputs.size() == 3);
  536. OperatorNodeConfig config{op.make_name()};
  537. return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(),
  538. config);
  539. }
  540. OP_TRAIT_REG(FakeQuant, FakeQuant)
  541. .apply_on_var_node(apply_on_var_node)
  542. .fallback();
  543. } // namespace fake_quant
  544. } // namespace
  545. namespace {
  546. namespace tqt {
  547. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  548. auto&& op = static_cast<const TQT&>(def);
  549. mgb_assert(inputs.size() == 2);
  550. OperatorNodeConfig config{op.make_name()};
  551. return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
  552. }
  553. OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback();
  554. } // namespace tqt
  555. } // namespace
  556. namespace {
  557. namespace elemwise_multi_type {
  558. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  559. auto&& op = static_cast<const ElemwiseMultiType&>(def);
  560. OperatorNodeConfig config{op.dtype};
  561. config.name(op.make_name());
  562. return opr::ElemwiseMultiType::make(inputs, op.param(), config);
  563. }
  564. OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
  565. .apply_on_var_node(apply_on_var_node)
  566. .fallback();
  567. } // namespace elemwise_multi_type
  568. } // namespace
  569. namespace {
  570. namespace svd {
  571. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  572. auto&& op = static_cast<const SVD&>(def);
  573. mgb_assert(inputs.size() == 1);
  574. OperatorNodeConfig config{op.make_name()};
  575. return opr::SVD::make(inputs[0], op.param(), config)[0]
  576. .node()
  577. ->owner_opr()
  578. ->usable_output();
  579. }
  580. OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback();
  581. } // namespace svd
  582. } // namespace
  583. namespace {
  584. namespace images2neibs {
  585. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  586. auto&& op = static_cast<const Images2Neibs&>(def);
  587. OperatorNodeConfig config{op.make_name()};
  588. return opr::Images2Neibs::make(inputs[0], op.param(), config);
  589. }
  590. OP_TRAIT_REG(Images2Neibs, Images2Neibs)
  591. .apply_on_var_node(apply_on_var_node)
  592. .fallback();
  593. } // namespace images2neibs
  594. } // namespace
  595. namespace {
  596. namespace lsq {
  597. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  598. auto&& op = static_cast<const LSQ&>(def);
  599. mgb_assert(inputs.size() == 4);
  600. OperatorNodeConfig config{op.make_name()};
  601. return opr::LSQ::make(inputs[0], inputs[1], inputs[2], inputs[3],
  602. op.param(), config);
  603. }
  604. OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback();
  605. } // namespace lsq
  606. } // namespace
  607. namespace { namespace sliding_window_transpose {
  608. auto apply_on_var_node(
  609. const OpDef& def,
  610. const VarNodeArray& inputs) {
  611. auto&& op = static_cast<const SlidingWindowTranspose&>(def);
  612. OperatorNodeConfig config{op.make_name()};
  613. return opr::SlidingWindowTranspose::make(inputs[0], op.param(), config);
  614. }
  615. OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose)
  616. .apply_on_var_node(apply_on_var_node)
  617. .fallback();
  618. }} // sliding_window_transpose
  619. } // namespace mgb::imperative

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台