You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

specializations.cpp 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. /**
  2. * \file imperative/src/impl/ops/specialzations.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. // FIXME: split this file into separate files for each specialized op
  12. #include "megbrain/imperative/ops/autogen.h"
  13. #include "megbrain/opr/dnn/convolution.h"
  14. #include "megbrain/opr/dnn/adaptive_pooling.h"
  15. #include "megbrain/opr/dnn/fake_quant.h"
  16. #include "megbrain/opr/dnn/tqt.h"
  17. #include "megbrain/opr/dnn/pooling.h"
  18. #include "megbrain/opr/dnn/local.h"
  19. #include "megbrain/opr/dnn/roi_align.h"
  20. #include "megbrain/opr/dnn/roi_pooling.h"
  21. #include "megbrain/opr/basic_arith.h"
  22. #include "megbrain/opr/blas.h"
  23. #include "megbrain/opr/imgproc.h"
  24. #include "megbrain/opr/indexing.h"
  25. #include "megbrain/opr/io.h"
  26. #include "megbrain/opr/misc.h"
  27. #include "megbrain/opr/nn_int.h"
  28. #include "megbrain/opr/rand.h"
  29. #include "megbrain/opr/tensor_gen.h"
  30. #include "megbrain/opr/tensor_manip.h"
  31. #include "megbrain/opr/utility.h"
  32. #include "../op_trait.h"
  33. namespace mgb::imperative {
  34. namespace { namespace convolution {
  35. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  36. auto* node = &node_->cast_final_safe<opr::Convolution>();
  37. return Convolution::make(node->param(), node->execution_policy());
  38. }
  39. auto apply_on_var_node(
  40. const OpDef& def,
  41. const VarNodeArray& inputs) {
  42. auto&& conv = static_cast<const Convolution&>(def);
  43. OperatorNodeConfig config{conv.make_name()};
  44. return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
  45. }
  46. OP_TRAIT_REG(Convolution, Convolution, opr::Convolution)
  47. .make_from_op_node(make_from_op_node)
  48. .apply_on_var_node(apply_on_var_node)
  49. .fallback();
  50. }} // convolution
  51. namespace { namespace convolution_backward_data {
  52. auto apply_on_var_node(
  53. const OpDef& def,
  54. const VarNodeArray& inputs) {
  55. auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
  56. OperatorNodeConfig config{conv.make_name()};
  57. if (inputs.size() == 2) {
  58. return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
  59. } else {
  60. mgb_assert(inputs.size() == 3);
  61. return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
  62. }
  63. }
  64. OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData)
  65. .apply_on_var_node(apply_on_var_node)
  66. .fallback();
  67. }} // convolution_backward_data
  68. namespace { namespace dimshuffle {
  69. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  70. auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
  71. std::vector<int> pattern(node->param().pattern_len);
  72. for (size_t i = 0; i < node->param().pattern_len; ++ i) {
  73. pattern[i] = node->param().pattern[i];
  74. }
  75. return Dimshuffle::make(pattern);
  76. }
  77. auto apply_on_var_node(
  78. const OpDef& def,
  79. const VarNodeArray& inputs) {
  80. auto&& ds = static_cast<const Dimshuffle&>(def);
  81. OperatorNodeConfig config{ds.make_name()};
  82. return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
  83. }
  84. OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
  85. .make_from_op_node(make_from_op_node)
  86. .apply_on_var_node(apply_on_var_node)
  87. .fallback();
  88. }} // dimshuffle
  89. namespace { namespace add_axis {
  90. auto apply_on_var_node(
  91. const OpDef& def,
  92. const VarNodeArray& inputs) {
  93. auto&& add_axis = static_cast<const AddAxis&>(def);
  94. using Desc = opr::AxisAddRemove::AxisDesc;
  95. std::vector<Desc> param;
  96. for (auto&& i : add_axis.axis) {
  97. param.push_back(Desc::make_add(i));
  98. }
  99. OperatorNodeConfig config{add_axis.make_name()};
  100. return opr::AxisAddRemove::make(inputs[0], param, config);
  101. }
  102. OP_TRAIT_REG(AddAxis, AddAxis)
  103. .apply_on_var_node(apply_on_var_node)
  104. .fallback();
  105. }} // add_axis
  106. namespace { namespace remove_axis {
  107. auto apply_on_var_node(
  108. const OpDef& def,
  109. const VarNodeArray& inputs) {
  110. auto&& remove_axis = static_cast<const RemoveAxis&>(def);
  111. using Desc = opr::AxisAddRemove::AxisDesc;
  112. std::vector<Desc> param;
  113. for (auto&& i : remove_axis.axis) {
  114. param.push_back(Desc::make_remove(i));
  115. }
  116. OperatorNodeConfig config{remove_axis.make_name()};
  117. return opr::AxisAddRemove::make(inputs[0], param, config);
  118. }
  119. OP_TRAIT_REG(RemoveAxis, RemoveAxis)
  120. .apply_on_var_node(apply_on_var_node)
  121. .fallback();
  122. }} // remove_axis
  123. namespace { namespace top_k {
  124. auto apply_on_var_node(
  125. const OpDef& def,
  126. const VarNodeArray& inputs) {
  127. auto&& topk = static_cast<const TopK&>(def);
  128. OperatorNodeConfig config{topk.make_name()};
  129. return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
  130. .node()->owner_opr();
  131. }
  132. OP_TRAIT_REG(TopK, TopK)
  133. .apply_on_var_node(apply_on_var_node)
  134. .fallback();
  135. }} // top_k
  136. namespace { namespace reduce {
  137. auto apply_on_var_node(
  138. const OpDef& def,
  139. const VarNodeArray& inputs) {
  140. auto&& reduce = static_cast<const Reduce&>(def);
  141. OperatorNodeConfig config{reduce.make_name()};
  142. if (inputs.size() > 1) {
  143. return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config);
  144. } else {
  145. return opr::Reduce::make(
  146. inputs[0], reduce.param(), (cg::VarNode*)nullptr, config);
  147. }
  148. }
  149. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  150. auto* node = &node_->cast_final_safe<opr::Reduce>();
  151. return Reduce::make(node->param());
  152. }
  153. OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
  154. .make_from_op_node(make_from_op_node)
  155. .apply_on_var_node(apply_on_var_node)
  156. .fallback();
  157. }} // reduce
  158. namespace { namespace adaptive_pooling {
  159. auto apply_on_var_node(
  160. const OpDef& def,
  161. const VarNodeArray& inputs) {
  162. auto&& pool = static_cast<const AdaptivePooling&>(def);
  163. OperatorNodeConfig config{pool.make_name()};
  164. return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config);
  165. }
  166. OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
  167. .apply_on_var_node(apply_on_var_node)
  168. .fallback();
  169. }} // adaptive_pooling
  170. namespace { namespace conv_bias {
  171. auto apply_on_var_node(
  172. const OpDef& def,
  173. const VarNodeArray& inputs) {
  174. auto&& conv = static_cast<const ConvBias&>(def);
  175. cg::OperatorNodeConfig config{conv.dtype};
  176. config.name(conv.make_name());
  177. if (inputs.size() == 2) {
  178. return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
  179. } else if (inputs.size() == 3) {
  180. return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
  181. } else if (inputs.size() == 4) {
  182. return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config);
  183. }
  184. mgb_assert(0);
  185. }
  186. OP_TRAIT_REG(ConvBias, ConvBias)
  187. .apply_on_var_node(apply_on_var_node)
  188. .fallback();
  189. }} // conv_bias
  190. namespace { namespace batch_conv_bias {
  191. auto apply_on_var_node(
  192. const OpDef& def,
  193. const VarNodeArray& inputs) {
  194. auto&& conv = static_cast<const BatchConvBias&>(def);
  195. cg::OperatorNodeConfig config{conv.dtype};
  196. config.name(conv.make_name());
  197. if (inputs.size() == 2) {
  198. return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
  199. } else if (inputs.size() == 3) {
  200. return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
  201. } else if (inputs.size() == 4) {
  202. return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config);
  203. }
  204. mgb_assert(0);
  205. }
  206. OP_TRAIT_REG(BatchConvBias, BatchConvBias)
  207. .apply_on_var_node(apply_on_var_node)
  208. .fallback();
  209. }} // batch_conv_bias
  210. namespace { namespace pooling {
  211. auto apply_on_var_node(
  212. const OpDef& def,
  213. const VarNodeArray& inputs) {
  214. auto&& pool = static_cast<const Pooling&>(def);
  215. OperatorNodeConfig config{pool.make_name()};
  216. return opr::Pooling::make(inputs[0], pool.param(), config);
  217. }
  218. OP_TRAIT_REG(Pooling, Pooling)
  219. .apply_on_var_node(apply_on_var_node)
  220. .fallback();
  221. }} // pooling
  222. namespace { namespace matrix_mul {
  223. auto apply_on_var_node(
  224. const OpDef& def,
  225. const VarNodeArray& inputs) {
  226. auto&& matmul = static_cast<const MatrixMul&>(def);
  227. mgb_assert(inputs.size() == 2);
  228. OperatorNodeConfig config{matmul.make_name()};
  229. return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(),
  230. matmul.policy(), config);
  231. }
  232. OP_TRAIT_REG(MatrixMul, MatrixMul)
  233. .apply_on_var_node(apply_on_var_node)
  234. .fallback();
  235. }} // matrix_mul
  236. namespace { namespace batched_matrix_mul {
  237. auto apply_on_var_node(
  238. const OpDef& def,
  239. const VarNodeArray& inputs) {
  240. auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
  241. mgb_assert(inputs.size() == 2);
  242. OperatorNodeConfig config{matmul.make_name()};
  243. return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(),
  244. matmul.policy(), config);
  245. }
  246. OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
  247. .apply_on_var_node(apply_on_var_node)
  248. .fallback();
  249. }} // batched_matrix_mul
  250. namespace { namespace dot {
  251. auto apply_on_var_node(
  252. const OpDef& def,
  253. const VarNodeArray& inputs) {
  254. auto&& op = def.cast_final_safe<Dot>();
  255. mgb_assert(inputs.size() == 2);
  256. OperatorNodeConfig config{op.make_name()};
  257. return opr::Dot::make(inputs[0], inputs[1], config);
  258. }
  259. OP_TRAIT_REG(Dot, Dot)
  260. .apply_on_var_node(apply_on_var_node)
  261. .fallback();
  262. }} // dot
  263. namespace { namespace argsort {
  264. auto apply_on_var_node(
  265. const OpDef& def,
  266. const VarNodeArray& inputs) {
  267. auto&& argsort = static_cast<const Argsort&>(def);
  268. OperatorNodeConfig config{argsort.make_name()};
  269. return opr::Argsort::make(inputs[0], argsort.param(), config);
  270. }
  271. OP_TRAIT_REG(Argsort, Argsort)
  272. .apply_on_var_node(apply_on_var_node)
  273. .fallback();
  274. }} // argsort
  275. namespace { namespace argmax {
  276. auto apply_on_var_node(
  277. const OpDef& def,
  278. const VarNodeArray& inputs) {
  279. auto&& argmax = static_cast<const Argmax&>(def);
  280. OperatorNodeConfig config{argmax.make_name()};
  281. return opr::Argmax::make(inputs[0], argmax.param(), config);
  282. }
  283. OP_TRAIT_REG(Argmax, Argmax)
  284. .apply_on_var_node(apply_on_var_node)
  285. .fallback();
  286. }} // argmax
  287. namespace { namespace argmin {
  288. auto apply_on_var_node(
  289. const OpDef& def,
  290. const VarNodeArray& inputs) {
  291. auto&& argmin = static_cast<const Argmin&>(def);
  292. OperatorNodeConfig config{argmin.make_name()};
  293. return opr::Argmin::make(inputs[0], argmin.param(), config);
  294. }
  295. OP_TRAIT_REG(Argmin, Argmin)
  296. .apply_on_var_node(apply_on_var_node)
  297. .fallback();
  298. }} // argmin
  299. namespace { namespace warp_perspective {
  300. auto apply_on_var_node(
  301. const OpDef& def,
  302. const VarNodeArray& inputs) {
  303. auto&& warp = static_cast<const WarpPerspective&>(def);
  304. OperatorNodeConfig config{warp.make_name()};
  305. if (inputs.size() == 3) {
  306. return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config);
  307. } else {
  308. mgb_assert(inputs.size() == 4);
  309. return opr::WarpPerspective::make(
  310. inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config);
  311. }
  312. }
  313. OP_TRAIT_REG(WarpPerspective, WarpPerspective)
  314. .apply_on_var_node(apply_on_var_node)
  315. .fallback();
  316. }} // warp_perspective
  317. namespace { namespace group_local {
  318. auto apply_on_var_node(
  319. const OpDef& def,
  320. const VarNodeArray& inputs) {
  321. auto&& local = static_cast<const GroupLocal&>(def);
  322. mgb_assert(inputs.size() == 2);
  323. OperatorNodeConfig config{local.make_name()};
  324. return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config);
  325. }
  326. OP_TRAIT_REG(GroupLocal, GroupLocal)
  327. .apply_on_var_node(apply_on_var_node)
  328. .fallback();
  329. }} // group_local
  330. namespace { namespace indexing_one_hot {
  331. auto apply_on_var_node(
  332. const OpDef& def,
  333. const VarNodeArray& inputs) {
  334. auto&& op = static_cast<const IndexingOneHot&>(def);
  335. mgb_assert(inputs.size() == 2);
  336. OperatorNodeConfig config{op.make_name()};
  337. return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
  338. }
  339. OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
  340. .apply_on_var_node(apply_on_var_node)
  341. .fallback();
  342. }} // indexing_one_hot
  343. namespace { namespace indexing_set_one_hot {
  344. auto apply_on_var_node(
  345. const OpDef& def,
  346. const VarNodeArray& inputs) {
  347. auto&& op = static_cast<const IndexingSetOneHot&>(def);
  348. mgb_assert(inputs.size() == 3);
  349. OperatorNodeConfig config{op.make_name()};
  350. return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config);
  351. }
  352. OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
  353. .apply_on_var_node(apply_on_var_node)
  354. .fallback();
  355. }} // indexing_set_one_hot
  356. namespace { namespace typecvt {
  357. auto apply_on_var_node(
  358. const OpDef& def,
  359. const VarNodeArray& inputs) {
  360. auto&& op = static_cast<const TypeCvt&>(def);
  361. mgb_assert(inputs.size() == 1);
  362. OperatorNodeConfig config{op.make_name()};
  363. return opr::TypeCvt::make(inputs[0], op.dtype, config);
  364. }
  365. OP_TRAIT_REG(TypeCvt, TypeCvt)
  366. .apply_on_var_node(apply_on_var_node)
  367. .fallback();
  368. }} // typecvt
  369. namespace { namespace concat {
  370. auto apply_on_var_node(
  371. const OpDef& def,
  372. const VarNodeArray& inputs) {
  373. auto&& op = static_cast<const Concat&>(def);
  374. cg::OperatorNodeConfig config{op.comp_node};
  375. config.name(op.make_name());
  376. return opr::Concat::make(inputs, op.axis, config);
  377. }
  378. OP_TRAIT_REG(Concat, Concat)
  379. .apply_on_var_node(apply_on_var_node)
  380. .fallback();
  381. }} // concat
  382. namespace { namespace copy {
  383. auto apply_on_var_node(
  384. const OpDef& def,
  385. const VarNodeArray& inputs) {
  386. auto&& op = static_cast<const Copy&>(def);
  387. mgb_assert(inputs.size() == 1);
  388. cg::OperatorNodeConfig config{op.comp_node};
  389. config.name(op.make_name());
  390. return opr::Copy::make(inputs[0], config);
  391. }
  392. OP_TRAIT_REG(Copy, Copy)
  393. .apply_on_var_node(apply_on_var_node)
  394. .fallback();
  395. }} // copy
  396. namespace { namespace identity {
  397. auto apply_on_var_node(
  398. const OpDef& def,
  399. const VarNodeArray& inputs) {
  400. auto&& op = def.cast_final_safe<Identity>();
  401. mgb_assert(inputs.size() == 1);
  402. OperatorNodeConfig config{op.make_name()};
  403. return opr::Identity::make(inputs[0], config);
  404. }
  405. OP_TRAIT_REG(Identity, Identity)
  406. .apply_on_var_node(apply_on_var_node)
  407. .fallback();
  408. }} // identity
  409. namespace { namespace assert_equal {
  410. auto apply_on_var_node(
  411. const OpDef& def,
  412. const VarNodeArray& inputs) {
  413. auto&& op = def.cast_final<AssertEqual>();
  414. if (inputs.size() == 2) {
  415. return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
  416. } else {
  417. // workaround for MiniGraph, which only allow one opr in the graph
  418. mgb_assert(inputs.size() == 3);
  419. return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {});
  420. }
  421. }
  422. OP_TRAIT_REG(AssertEqual, AssertEqual)
  423. .apply_on_var_node(apply_on_var_node)
  424. .fallback();
  425. }} // assert_equal
  426. namespace { namespace uniform_rng {
  427. auto apply_on_var_node(
  428. const OpDef& def,
  429. const VarNodeArray& inputs) {
  430. auto&& op = static_cast<const UniformRNG&>(def);
  431. mgb_assert(inputs.size() == 1);
  432. OperatorNodeConfig config{op.make_name()};
  433. return opr::UniformRNG::make(inputs[0], op.param(), config);
  434. }
  435. OP_TRAIT_REG(UniformRNG, UniformRNG)
  436. .apply_on_var_node(apply_on_var_node)
  437. .fallback();
  438. }} // uniform_rng
  439. namespace { namespace gaussian_rng {
  440. auto apply_on_var_node(
  441. const OpDef& def,
  442. const VarNodeArray& inputs) {
  443. auto&& op = static_cast<const GaussianRNG&>(def);
  444. mgb_assert(inputs.size() == 1);
  445. OperatorNodeConfig config{op.make_name()};
  446. return opr::GaussianRNG::make(inputs[0], op.param(), config);
  447. }
  448. OP_TRAIT_REG(GaussianRNG, GaussianRNG)
  449. .apply_on_var_node(apply_on_var_node)
  450. .fallback();
  451. }} // gaussian_rng
  452. namespace { namespace roi_align {
  453. VarNodeArray apply_on_var_node(
  454. const OpDef& def,
  455. const VarNodeArray& inputs) {
  456. auto&& op = static_cast<const ROIAlign&>(def);
  457. mgb_assert(inputs.size() == 2);
  458. OperatorNodeConfig config{op.make_name()};
  459. auto* opr = opr::ROIAlign::make(
  460. inputs[0], inputs[1], op.param(), config).node()->owner_opr();
  461. return {opr->output(0), opr->output(1)};
  462. }
  463. OP_TRAIT_REG(ROIAlign, ROIAlign)
  464. .apply_on_var_node(apply_on_var_node)
  465. .fallback();
  466. }} // roi_align
  467. #if MGB_CUDA
  468. namespace { namespace nvof {
  469. auto apply_on_var_node(
  470. const OpDef& def,
  471. const VarNodeArray& inputs) {
  472. auto&& op = static_cast<const NvOf&>(def);
  473. mgb_assert(inputs.size() == 1);
  474. OperatorNodeConfig config{op.make_name()};
  475. return opr::NvOf::make(inputs[0], op.param(), config);
  476. }
  477. OP_TRAIT_REG(NvOf, NvOf)
  478. .apply_on_var_node(apply_on_var_node)
  479. .fallback();
  480. }} // nvof
  481. #endif
  482. namespace { namespace linspace {
  483. auto apply_on_var_node(
  484. const OpDef& def,
  485. const VarNodeArray& inputs) {
  486. auto&& op = static_cast<const Linspace&>(def);
  487. mgb_assert(inputs.size() == 3);
  488. cg::OperatorNodeConfig config{op.comp_node};
  489. config.name(op.make_name());
  490. return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
  491. }
  492. OP_TRAIT_REG(Linspace, Linspace)
  493. .apply_on_var_node(apply_on_var_node)
  494. .fallback();
  495. }} // linspace
  496. namespace { namespace eye {
  497. auto apply_on_var_node(
  498. const OpDef& def,
  499. const VarNodeArray& inputs) {
  500. auto&& op = static_cast<const Eye&>(def);
  501. mgb_assert(inputs.size() == 1);
  502. cg::OperatorNodeConfig config{op.comp_node};
  503. config.name(op.make_name());
  504. opr::Eye::Param param{op.k, op.dtype.enumv()};
  505. return opr::Eye::make(inputs[0], param, config);
  506. }
  507. OP_TRAIT_REG(Eye, Eye)
  508. .apply_on_var_node(apply_on_var_node)
  509. .fallback();
  510. }} // eye
  511. namespace { namespace roi_pooling {
  512. VarNodeArray apply_on_var_node(
  513. const OpDef& def,
  514. const VarNodeArray& inputs) {
  515. auto&& op = static_cast<const ROIPooling&>(def);
  516. mgb_assert(inputs.size() == 3);
  517. OperatorNodeConfig config{op.make_name()};
  518. auto* opr = opr::ROIPooling::make(
  519. inputs[0], inputs[1], inputs[2], op.param(), config
  520. ).node()->owner_opr();
  521. return {opr->output(0), opr->output(1)};
  522. }
  523. OP_TRAIT_REG(ROIPooling, ROIPooling)
  524. .apply_on_var_node(apply_on_var_node)
  525. .fallback();
  526. }} // roi_pooling
  527. namespace { namespace remap {
  528. auto apply_on_var_node(
  529. const OpDef& def,
  530. const VarNodeArray& inputs) {
  531. auto&& op = static_cast<const Remap&>(def);
  532. mgb_assert(inputs.size() == 2);
  533. OperatorNodeConfig config{op.make_name()};
  534. return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
  535. }
  536. OP_TRAIT_REG(Remap, Remap)
  537. .apply_on_var_node(apply_on_var_node)
  538. .fallback();
  539. }} // remap
  540. namespace {
  541. auto get_index(
  542. const VarNodeArray& inputs, size_t vidx,
  543. const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
  544. size_t length = mask.size();
  545. opr::Subtensor::IndexDesc ret(length);
  546. for (size_t i = 0; i < length; ++ i) {
  547. auto&& [axis, begin, end, step, idx] = mask[i];
  548. ret[i].axis = axis;
  549. if (idx) {
  550. ret[i].idx = inputs[vidx++];
  551. } else {
  552. mgb_assert(begin || end || step);
  553. if (begin) ret[i].begin = inputs[vidx++];
  554. if (end) ret[i].end = inputs[vidx++];
  555. if (step) ret[i].step = inputs[vidx++];
  556. }
  557. }
  558. mgb_assert(vidx == inputs.size());
  559. return ret;
  560. }
  561. #define IN1 inputs[0]
  562. #define IN2 inputs[0], inputs[1]
  563. #define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \
  564. namespace NAME##_impl { \
  565. auto apply_on_var_node( \
  566. const OpDef& def, \
  567. const VarNodeArray& inputs) { \
  568. auto&& op = static_cast<const NAME&>(def); \
  569. OperatorNodeConfig config{op.make_name()}; \
  570. return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \
  571. } \
  572. OP_TRAIT_REG(NAME, NAME) \
  573. .apply_on_var_node(apply_on_var_node) \
  574. .fallback(); \
  575. }
  576. FANCY_INDEXING_IMPL(Subtensor, 1)
  577. FANCY_INDEXING_IMPL(SetSubtensor, 2)
  578. FANCY_INDEXING_IMPL(IncrSubtensor, 2)
  579. FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
  580. FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
  581. FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
  582. FANCY_INDEXING_IMPL(MeshIndexing, 1)
  583. FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
  584. FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
  585. FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
  586. FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
  587. FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)
  588. #undef FANCY_INDEXING_IMPL
  589. #undef IN1
  590. #undef IN2
  591. } // anonymous namespace
  592. namespace { namespace fake_quant {
  593. auto apply_on_var_node(
  594. const OpDef& def,
  595. const VarNodeArray& inputs) {
  596. auto&& op = static_cast<const FakeQuant&>(def);
  597. mgb_assert(inputs.size() == 3);
  598. OperatorNodeConfig config{op.make_name()};
  599. return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config);
  600. }
  601. OP_TRAIT_REG(FakeQuant, FakeQuant)
  602. .apply_on_var_node(apply_on_var_node)
  603. .fallback();
  604. }} // fake_quant
  605. namespace { namespace tqt {
  606. auto apply_on_var_node(
  607. const OpDef& def,
  608. const VarNodeArray& inputs) {
  609. auto&& op = static_cast<const TQT&>(def);
  610. mgb_assert(inputs.size() == 2);
  611. OperatorNodeConfig config{op.make_name()};
  612. return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
  613. }
  614. OP_TRAIT_REG(TQT, TQT)
  615. .apply_on_var_node(apply_on_var_node)
  616. .fallback();
  617. }} // tqt
  618. namespace { namespace elemwise_multi_type {
  619. auto apply_on_var_node(
  620. const OpDef& def,
  621. const VarNodeArray& inputs) {
  622. auto&& op = static_cast<const ElemwiseMultiType&>(def);
  623. OperatorNodeConfig config{op.dtype};
  624. config.name(op.make_name());
  625. return opr::ElemwiseMultiType::make(inputs, op.param(), config);
  626. }
  627. OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
  628. .apply_on_var_node(apply_on_var_node)
  629. .fallback();
  630. }} // elemwise_multi_type
  631. namespace { namespace svd {
  632. auto apply_on_var_node(
  633. const OpDef& def,
  634. const VarNodeArray& inputs) {
  635. auto&& op = static_cast<const SVD&>(def);
  636. mgb_assert(inputs.size() == 1);
  637. OperatorNodeConfig config{op.make_name()};
  638. return opr::SVD::make(inputs[0], op.param(), config)[0]
  639. .node()->owner_opr()->usable_output();
  640. }
  641. OP_TRAIT_REG(SVD, SVD)
  642. .apply_on_var_node(apply_on_var_node)
  643. .fallback();
  644. }} // svd
  645. } // namespace mgb::imperative

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台