You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

specializations.cpp 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. /**
  2. * \file imperative/src/impl/ops/specialzations.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. // FIXME: split this file into separate files for each specialized op
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/opr/basic_arith.h"
  15. #include "megbrain/opr/blas.h"
  16. #include "megbrain/opr/dnn/adaptive_pooling.h"
  17. #include "megbrain/opr/dnn/convolution.h"
  18. #include "megbrain/opr/dnn/correlation.h"
  19. #include "megbrain/opr/dnn/fake_quant.h"
  20. #include "megbrain/opr/dnn/images2neibs.h"
  21. #include "megbrain/opr/dnn/local.h"
  22. #include "megbrain/opr/dnn/lsq.h"
  23. #include "megbrain/opr/dnn/pooling.h"
  24. #include "megbrain/opr/dnn/roi_align.h"
  25. #include "megbrain/opr/dnn/roi_pooling.h"
  26. #include "megbrain/opr/dnn/tqt.h"
  27. #include "megbrain/opr/imgproc.h"
  28. #include "megbrain/opr/indexing.h"
  29. #include "megbrain/opr/io.h"
  30. #include "megbrain/opr/misc.h"
  31. #include "megbrain/opr/nn_int.h"
  32. #include "megbrain/opr/rand.h"
  33. #include "megbrain/opr/tensor_gen.h"
  34. #include "megbrain/opr/tensor_manip.h"
  35. #include "megbrain/opr/utility.h"
  36. #include "megbrain/opr/dnn/images2neibs.h"
  37. #include "megbrain/opr/dnn/sliding_window_transpose.h"
  38. #include "../op_trait.h"
  39. namespace mgb::imperative {
  40. namespace {
  41. namespace dimshuffle {
  42. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  43. auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
  44. std::vector<int> pattern(node->param().pattern_len);
  45. for (size_t i = 0; i < node->param().pattern_len; ++i) {
  46. pattern[i] = node->param().pattern[i];
  47. }
  48. return Dimshuffle::make(pattern);
  49. }
  50. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  51. auto&& ds = static_cast<const Dimshuffle&>(def);
  52. OperatorNodeConfig config{ds.make_name()};
  53. return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
  54. }
  55. OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
  56. .make_from_op_node(make_from_op_node)
  57. .apply_on_var_node(apply_on_var_node)
  58. .fallback();
  59. } // namespace dimshuffle
  60. } // namespace
  61. namespace {
  62. namespace add_axis {
  63. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  64. auto&& add_axis = static_cast<const AddAxis&>(def);
  65. using Desc = opr::AxisAddRemove::AxisDesc;
  66. std::vector<Desc> param;
  67. for (auto&& i : add_axis.axis) {
  68. param.push_back(Desc::make_add(i));
  69. }
  70. OperatorNodeConfig config{add_axis.make_name()};
  71. return opr::AxisAddRemove::make(inputs[0], param, config);
  72. }
  73. OP_TRAIT_REG(AddAxis, AddAxis).apply_on_var_node(apply_on_var_node).fallback();
  74. } // namespace add_axis
  75. } // namespace
  76. namespace {
  77. namespace remove_axis {
  78. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  79. auto&& remove_axis = static_cast<const RemoveAxis&>(def);
  80. using Desc = opr::AxisAddRemove::AxisDesc;
  81. std::vector<Desc> param;
  82. for (auto&& i : remove_axis.axis) {
  83. param.push_back(Desc::make_remove(i));
  84. }
  85. OperatorNodeConfig config{remove_axis.make_name()};
  86. return opr::AxisAddRemove::make(inputs[0], param, config);
  87. }
  88. OP_TRAIT_REG(RemoveAxis, RemoveAxis)
  89. .apply_on_var_node(apply_on_var_node)
  90. .fallback();
  91. } // namespace remove_axis
  92. } // namespace
  93. namespace {
  94. namespace top_k {
  95. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  96. auto&& topk = static_cast<const TopK&>(def);
  97. OperatorNodeConfig config{topk.make_name()};
  98. return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
  99. .node()
  100. ->owner_opr();
  101. }
  102. OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback();
  103. } // namespace top_k
  104. } // namespace
  105. namespace {
  106. namespace adaptive_pooling {
  107. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  108. auto&& pool = static_cast<const AdaptivePooling&>(def);
  109. OperatorNodeConfig config{pool.make_name()};
  110. return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(),
  111. config);
  112. }
  113. OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
  114. .apply_on_var_node(apply_on_var_node)
  115. .fallback();
  116. } // namespace adaptive_pooling
  117. } // namespace
  118. namespace {
  119. namespace conv_bias {
  120. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  121. auto&& conv = static_cast<const ConvBias&>(def);
  122. cg::OperatorNodeConfig config{conv.dtype};
  123. config.name(conv.make_name());
  124. if (inputs.size() == 2) {
  125. return opr::ConvBias::make(inputs[0], inputs[1], conv.param(),
  126. conv.policy(), config);
  127. } else if (inputs.size() == 3) {
  128. return opr::ConvBias::make(inputs[0], inputs[1], inputs[2],
  129. conv.param(), conv.policy(), config);
  130. } else if (inputs.size() == 4) {
  131. return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3],
  132. conv.param(), conv.policy(), config);
  133. }
  134. mgb_assert(0);
  135. }
  136. OP_TRAIT_REG(ConvBias, ConvBias)
  137. .apply_on_var_node(apply_on_var_node)
  138. .fallback();
  139. } // namespace conv_bias
  140. } // namespace
  141. namespace {
  142. namespace batch_conv_bias {
  143. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  144. auto&& conv = static_cast<const BatchConvBias&>(def);
  145. cg::OperatorNodeConfig config{conv.dtype};
  146. config.name(conv.make_name());
  147. if (inputs.size() == 2) {
  148. return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(),
  149. conv.policy(), config);
  150. } else if (inputs.size() == 3) {
  151. return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2],
  152. conv.param(), conv.policy(), config);
  153. } else if (inputs.size() == 4) {
  154. return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2],
  155. inputs[3], conv.param(), conv.policy(),
  156. config);
  157. }
  158. mgb_assert(0);
  159. }
  160. OP_TRAIT_REG(BatchConvBias, BatchConvBias)
  161. .apply_on_var_node(apply_on_var_node)
  162. .fallback();
  163. } // namespace batch_conv_bias
  164. } // namespace
  165. namespace {
  166. namespace pooling {
  167. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  168. auto&& pool = static_cast<const Pooling&>(def);
  169. OperatorNodeConfig config{pool.make_name()};
  170. return opr::Pooling::make(inputs[0], pool.param(), config);
  171. }
  172. OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback();
  173. } // namespace pooling
  174. } // namespace
  175. namespace {
  176. namespace matrix_mul {
  177. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  178. auto&& matmul = static_cast<const MatrixMul&>(def);
  179. mgb_assert(inputs.size() == 2);
  180. OperatorNodeConfig config{matmul.make_name()};
  181. return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(),
  182. matmul.policy(), config);
  183. }
  184. OP_TRAIT_REG(MatrixMul, MatrixMul)
  185. .apply_on_var_node(apply_on_var_node)
  186. .fallback();
  187. } // namespace matrix_mul
  188. } // namespace
  189. namespace {
  190. namespace batched_matrix_mul {
  191. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  192. auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
  193. mgb_assert(inputs.size() == 2);
  194. OperatorNodeConfig config{matmul.make_name()};
  195. return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(),
  196. matmul.policy(), config);
  197. }
  198. OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
  199. .apply_on_var_node(apply_on_var_node)
  200. .fallback();
  201. } // namespace batched_matrix_mul
  202. } // namespace
  203. namespace {
  204. namespace dot {
  205. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  206. auto&& op = def.cast_final_safe<Dot>();
  207. mgb_assert(inputs.size() == 2);
  208. OperatorNodeConfig config{op.make_name()};
  209. return opr::Dot::make(inputs[0], inputs[1], config);
  210. }
  211. OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback();
  212. } // namespace dot
  213. } // namespace
  214. namespace {
  215. namespace argsort {
  216. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  217. auto&& argsort = static_cast<const Argsort&>(def);
  218. OperatorNodeConfig config{argsort.make_name()};
  219. return opr::Argsort::make(inputs[0], argsort.param(), config);
  220. }
  221. OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback();
  222. } // namespace argsort
  223. } // namespace
  224. namespace {
  225. namespace argmax {
  226. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  227. auto&& argmax = static_cast<const Argmax&>(def);
  228. OperatorNodeConfig config{argmax.make_name()};
  229. return opr::Argmax::make(inputs[0], argmax.param(), config);
  230. }
  231. OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback();
  232. } // namespace argmax
  233. } // namespace
  234. namespace {
  235. namespace argmin {
  236. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  237. auto&& argmin = static_cast<const Argmin&>(def);
  238. OperatorNodeConfig config{argmin.make_name()};
  239. return opr::Argmin::make(inputs[0], argmin.param(), config);
  240. }
  241. OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback();
  242. } // namespace argmin
  243. } // namespace
  244. namespace {
  245. namespace warp_perspective {
  246. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  247. auto&& warp = static_cast<const WarpPerspective&>(def);
  248. OperatorNodeConfig config{warp.make_name()};
  249. if (inputs.size() == 3) {
  250. return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2],
  251. warp.param(), config);
  252. } else {
  253. mgb_assert(inputs.size() == 4);
  254. return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2],
  255. inputs[3], warp.param(), config);
  256. }
  257. }
  258. OP_TRAIT_REG(WarpPerspective, WarpPerspective)
  259. .apply_on_var_node(apply_on_var_node)
  260. .fallback();
  261. } // namespace warp_perspective
  262. } // namespace
  263. namespace {
  264. namespace group_local {
  265. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  266. auto&& local = static_cast<const GroupLocal&>(def);
  267. mgb_assert(inputs.size() == 2);
  268. OperatorNodeConfig config{local.make_name()};
  269. return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config);
  270. }
  271. OP_TRAIT_REG(GroupLocal, GroupLocal)
  272. .apply_on_var_node(apply_on_var_node)
  273. .fallback();
  274. } // namespace group_local
  275. } // namespace
  276. namespace {
  277. namespace indexing_one_hot {
  278. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  279. auto&& op = static_cast<const IndexingOneHot&>(def);
  280. mgb_assert(inputs.size() == 2);
  281. OperatorNodeConfig config{op.make_name()};
  282. return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
  283. }
  284. OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
  285. .apply_on_var_node(apply_on_var_node)
  286. .fallback();
  287. } // namespace indexing_one_hot
  288. } // namespace
  289. namespace {
  290. namespace indexing_set_one_hot {
  291. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  292. auto&& op = static_cast<const IndexingSetOneHot&>(def);
  293. mgb_assert(inputs.size() == 3);
  294. OperatorNodeConfig config{op.make_name()};
  295. return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2],
  296. op.param(), config);
  297. }
  298. OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
  299. .apply_on_var_node(apply_on_var_node)
  300. .fallback();
  301. } // namespace indexing_set_one_hot
  302. } // namespace
  303. namespace {
  304. namespace typecvt {
  305. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  306. auto&& op = static_cast<const TypeCvt&>(def);
  307. mgb_assert(inputs.size() == 1);
  308. OperatorNodeConfig config{op.make_name()};
  309. return opr::TypeCvt::make(inputs[0], op.dtype, config);
  310. }
  311. OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
  312. } // namespace typecvt
  313. } // namespace
  314. namespace {
  315. namespace concat {
  316. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  317. auto&& op = static_cast<const Concat&>(def);
  318. cg::OperatorNodeConfig config{op.comp_node};
  319. config.name(op.make_name());
  320. return opr::Concat::make(inputs, op.axis, config);
  321. }
  322. OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback();
  323. } // namespace concat
  324. } // namespace
  325. namespace {
  326. namespace copy {
  327. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  328. auto&& op = static_cast<const Copy&>(def);
  329. mgb_assert(inputs.size() == 1);
  330. cg::OperatorNodeConfig config{op.comp_node};
  331. config.name(op.make_name());
  332. return opr::Copy::make(inputs[0], config);
  333. }
  334. OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback();
  335. } // namespace copy
  336. } // namespace
  337. namespace { namespace assert_equal {
  338. auto apply_on_var_node(
  339. const OpDef& def,
  340. const VarNodeArray& inputs) {
  341. auto&& op = def.cast_final<AssertEqual>();
  342. if (inputs.size() == 2) {
  343. return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
  344. } else {
  345. // workaround for MiniGraph, which only allow one opr in the graph
  346. mgb_assert(inputs.size() == 3);
  347. return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2],
  348. op.param(), {});
  349. }
  350. }
  351. OP_TRAIT_REG(AssertEqual, AssertEqual)
  352. .apply_on_var_node(apply_on_var_node)
  353. .fallback();
  354. } // namespace assert_equal
  355. } // namespace
  356. namespace {
  357. namespace roi_align {
  358. VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  359. auto&& op = static_cast<const ROIAlign&>(def);
  360. mgb_assert(inputs.size() == 2);
  361. OperatorNodeConfig config{op.make_name()};
  362. auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config)
  363. .node()
  364. ->owner_opr();
  365. return {opr->output(0), opr->output(1)};
  366. }
  367. OP_TRAIT_REG(ROIAlign, ROIAlign)
  368. .apply_on_var_node(apply_on_var_node)
  369. .fallback();
  370. } // namespace roi_align
  371. } // namespace
  372. namespace {
  373. namespace correlation {
  374. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  375. auto&& op = static_cast<const Correlation&>(def);
  376. mgb_assert(inputs.size() == 2);
  377. OperatorNodeConfig config{op.make_name()};
  378. return opr::Correlation::make(inputs[0], inputs[1], op.param(), config);
  379. }
  380. OP_TRAIT_REG(Correlation, Correlation)
  381. .apply_on_var_node(apply_on_var_node)
  382. .fallback();
  383. } // namespace correlation
  384. } // namespace
  385. #if MGB_CUDA
  386. namespace {
  387. namespace nvof {
  388. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  389. auto&& op = static_cast<const NvOf&>(def);
  390. mgb_assert(inputs.size() == 1);
  391. OperatorNodeConfig config{op.make_name()};
  392. return opr::NvOf::make(inputs[0], op.param(), config);
  393. }
  394. OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback();
  395. } // namespace nvof
  396. } // namespace
  397. #endif
  398. namespace {
  399. namespace linspace {
  400. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  401. auto&& op = static_cast<const Linspace&>(def);
  402. mgb_assert(inputs.size() == 3);
  403. cg::OperatorNodeConfig config{op.comp_node};
  404. config.name(op.make_name());
  405. return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(),
  406. config);
  407. }
  408. OP_TRAIT_REG(Linspace, Linspace)
  409. .apply_on_var_node(apply_on_var_node)
  410. .fallback();
  411. } // namespace linspace
  412. } // namespace
  413. namespace {
  414. namespace eye {
  415. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  416. auto&& op = static_cast<const Eye&>(def);
  417. mgb_assert(inputs.size() == 1);
  418. cg::OperatorNodeConfig config{op.comp_node};
  419. config.name(op.make_name());
  420. opr::Eye::Param param{op.k, op.dtype.enumv()};
  421. return opr::Eye::make(inputs[0], param, config);
  422. }
  423. OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback();
  424. } // namespace eye
  425. } // namespace
  426. namespace {
  427. namespace roi_pooling {
  428. VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  429. auto&& op = static_cast<const ROIPooling&>(def);
  430. mgb_assert(inputs.size() == 3);
  431. OperatorNodeConfig config{op.make_name()};
  432. auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2],
  433. op.param(), config)
  434. .node()
  435. ->owner_opr();
  436. return {opr->output(0), opr->output(1)};
  437. }
  438. OP_TRAIT_REG(ROIPooling, ROIPooling)
  439. .apply_on_var_node(apply_on_var_node)
  440. .fallback();
  441. } // namespace roi_pooling
  442. } // namespace
  443. namespace {
  444. namespace remap {
  445. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  446. auto&& op = static_cast<const Remap&>(def);
  447. mgb_assert(inputs.size() == 2);
  448. OperatorNodeConfig config{op.make_name()};
  449. return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
  450. }
  451. OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback();
  452. } // namespace remap
  453. } // namespace
  454. namespace {
  455. auto get_index(
  456. const VarNodeArray& inputs, size_t vidx,
  457. const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
  458. size_t length = mask.size();
  459. opr::Subtensor::IndexDesc ret(length);
  460. for (size_t i = 0; i < length; ++i) {
  461. auto&& [axis, begin, end, step, idx] = mask[i];
  462. ret[i].axis = axis;
  463. if (idx) {
  464. ret[i].idx = inputs[vidx++];
  465. } else {
  466. mgb_assert(begin || end || step);
  467. if (begin)
  468. ret[i].begin = inputs[vidx++];
  469. if (end)
  470. ret[i].end = inputs[vidx++];
  471. if (step)
  472. ret[i].step = inputs[vidx++];
  473. }
  474. }
  475. mgb_assert(vidx == inputs.size());
  476. return ret;
  477. }
  478. #define IN1 inputs[0]
  479. #define IN2 inputs[0], inputs[1]
  480. #define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \
  481. namespace NAME##_impl { \
  482. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { \
  483. auto&& op = static_cast<const NAME&>(def); \
  484. OperatorNodeConfig config{op.make_name()}; \
  485. return opr::NAME::make(IN##NR_INPUT, \
  486. get_index(inputs, NR_INPUT, op.items), \
  487. config); \
  488. } \
  489. OP_TRAIT_REG(NAME, NAME) \
  490. .apply_on_var_node(apply_on_var_node) \
  491. .fallback(); \
  492. }
  493. FANCY_INDEXING_IMPL(Subtensor, 1)
  494. FANCY_INDEXING_IMPL(SetSubtensor, 2)
  495. FANCY_INDEXING_IMPL(IncrSubtensor, 2)
  496. FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
  497. FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
  498. FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
  499. FANCY_INDEXING_IMPL(MeshIndexing, 1)
  500. FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
  501. FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
  502. FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
  503. FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
  504. FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)
  505. #undef FANCY_INDEXING_IMPL
  506. #undef IN1
  507. #undef IN2
  508. } // anonymous namespace
  509. namespace {
  510. namespace fake_quant {
  511. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  512. auto&& op = static_cast<const FakeQuant&>(def);
  513. mgb_assert(inputs.size() == 3);
  514. OperatorNodeConfig config{op.make_name()};
  515. return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(),
  516. config);
  517. }
  518. OP_TRAIT_REG(FakeQuant, FakeQuant)
  519. .apply_on_var_node(apply_on_var_node)
  520. .fallback();
  521. } // namespace fake_quant
  522. } // namespace
  523. namespace {
  524. namespace tqt {
  525. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  526. auto&& op = static_cast<const TQT&>(def);
  527. mgb_assert(inputs.size() == 2);
  528. OperatorNodeConfig config{op.make_name()};
  529. return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
  530. }
  531. OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback();
  532. } // namespace tqt
  533. } // namespace
  534. namespace {
  535. namespace elemwise_multi_type {
  536. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  537. auto&& op = static_cast<const ElemwiseMultiType&>(def);
  538. OperatorNodeConfig config{op.dtype};
  539. config.name(op.make_name());
  540. return opr::ElemwiseMultiType::make(inputs, op.param(), config);
  541. }
  542. OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
  543. .apply_on_var_node(apply_on_var_node)
  544. .fallback();
  545. } // namespace elemwise_multi_type
  546. } // namespace
  547. namespace {
  548. namespace svd {
  549. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  550. auto&& op = static_cast<const SVD&>(def);
  551. mgb_assert(inputs.size() == 1);
  552. OperatorNodeConfig config{op.make_name()};
  553. return opr::SVD::make(inputs[0], op.param(), config)[0]
  554. .node()
  555. ->owner_opr()
  556. ->usable_output();
  557. }
  558. OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback();
  559. } // namespace svd
  560. } // namespace
  561. namespace {
  562. namespace images2neibs {
  563. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  564. auto&& op = static_cast<const Images2Neibs&>(def);
  565. OperatorNodeConfig config{op.make_name()};
  566. return opr::Images2Neibs::make(inputs[0], op.param(), config);
  567. }
  568. OP_TRAIT_REG(Images2Neibs, Images2Neibs)
  569. .apply_on_var_node(apply_on_var_node)
  570. .fallback();
  571. } // namespace images2neibs
  572. } // namespace
  573. namespace {
  574. namespace lsq {
  575. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  576. auto&& op = static_cast<const LSQ&>(def);
  577. mgb_assert(inputs.size() == 4);
  578. OperatorNodeConfig config{op.make_name()};
  579. return opr::LSQ::make(inputs[0], inputs[1], inputs[2], inputs[3],
  580. op.param(), config);
  581. }
  582. OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback();
  583. } // namespace lsq
  584. } // namespace
  585. namespace { namespace sliding_window_transpose {
  586. auto apply_on_var_node(
  587. const OpDef& def,
  588. const VarNodeArray& inputs) {
  589. auto&& op = static_cast<const SlidingWindowTranspose&>(def);
  590. OperatorNodeConfig config{op.make_name()};
  591. return opr::SlidingWindowTranspose::make(inputs[0], op.param(), config);
  592. }
  593. OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose)
  594. .apply_on_var_node(apply_on_var_node)
  595. .fallback();
  596. }} // sliding_window_transpose
  597. namespace {
  598. namespace cumsum {
  599. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  600. auto&& op = static_cast<const Cumsum&>(def);
  601. OperatorNodeConfig config{op.make_name()};
  602. return opr::Cumsum::make(inputs[0], op.param(), config);
  603. }
  604. OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback();
  605. } // namespace cumsum
  606. } // namespace
  607. } // namespace mgb::imperative

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台