You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dnn.sereg.h 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. /**
  2. * \file src/opr/impl/dnn/dnn.sereg.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/opr/dnn/adaptive_pooling.h"
  13. #include "megbrain/opr/dnn/batch_norm.h"
  14. #include "megbrain/opr/dnn/convolution.h"
  15. #include "megbrain/opr/dnn/correlation.h"
  16. #include "megbrain/opr/dnn/fake_quant.h"
  17. #include "megbrain/opr/dnn/images2neibs.h"
  18. #include "megbrain/opr/dnn/sliding_window_transpose.h"
  19. #include "megbrain/opr/dnn/adaptive_pooling.h"
  20. #include "megbrain/opr/dnn/local.h"
  21. #include "megbrain/opr/dnn/lrn.h"
  22. #include "megbrain/opr/dnn/lsq.h"
  23. #include "megbrain/opr/dnn/pooling.h"
  24. #include "megbrain/opr/dnn/roi_align.h"
  25. #include "megbrain/opr/dnn/roi_pooling.h"
  26. #include "megbrain/opr/dnn/tqt.h"
  27. #include "megbrain/serialization/sereg.h"
  28. #include "megdnn/opr_param_defs.h"
  29. #include "megdnn/oprs/nn.h"
  30. namespace mgb {
  31. namespace serialization {
  32. template <class MegDNNPooling = megdnn::Pooling>
  33. struct MakePoolingCaller1 {
  34. template <typename Opr>
  35. static VarNode* make(const cg::VarNodeArray& inputs,
  36. const typename MegDNNPooling::Param& param,
  37. const OperatorNodeConfig& config) {
  38. if (inputs.size() == 1) {
  39. return Opr::make(inputs[0], param, config).node();
  40. }
  41. return nullptr;
  42. }
  43. };
  44. template <class MegDNNROIALIGN = megdnn::ROIAlign>
  45. struct MakeROIAlignCaller1 {
  46. template <typename Opr>
  47. static VarNode* make(const cg::VarNodeArray& inputs,
  48. const typename MegDNNROIALIGN::Param& param,
  49. const OperatorNodeConfig& config) {
  50. if (inputs.size() == 2) {
  51. return Opr::make(inputs[0], inputs[1], param, config).node();
  52. } else {
  53. return nullptr;
  54. }
  55. }
  56. };
  57. template <class MegDNNROIALIGN = megdnn::ROIAlignBackward>
  58. struct MakeROIAlignCaller4 {
  59. template <typename Opr>
  60. static VarNode* make(const cg::VarNodeArray& inputs,
  61. const typename MegDNNROIALIGN::Param& param,
  62. const OperatorNodeConfig& config) {
  63. if (inputs.size() == 4) {
  64. return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param,
  65. config)
  66. .node();
  67. } else {
  68. return nullptr;
  69. }
  70. }
  71. };
  72. template <class MegDNNPooling = megdnn::PoolingBackward>
  73. struct MakePoolingBackwardCaller3 {
  74. template <typename Opr>
  75. static VarNode* make(const cg::VarNodeArray& inputs,
  76. const typename MegDNNPooling::Param& param,
  77. const OperatorNodeConfig& config) {
  78. if (inputs.size() == 3) {
  79. return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
  80. .node();
  81. }
  82. return nullptr;
  83. }
  84. };
  85. template <class MegDNNPooling = megdnn::AdaptivePoolingBackward>
  86. struct MakeAdaptivePoolingBackwardCaller3 {
  87. template <typename Opr>
  88. static VarNode* make(const cg::VarNodeArray& inputs,
  89. const typename MegDNNPooling::Param& param,
  90. const OperatorNodeConfig& config) {
  91. if (inputs.size() == 4) {
  92. return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param,
  93. config)
  94. .node();
  95. }
  96. return nullptr;
  97. }
  98. };
  99. template <class MegDNNConv = megdnn::Convolution>
  100. struct MakeConvCaller2 {
  101. template <typename Opr>
  102. static VarNode* make(const cg::VarNodeArray& inputs,
  103. const typename MegDNNConv::Param& param,
  104. const megdnn::param::ExecutionPolicy& execution_policy,
  105. const OperatorNodeConfig& config) {
  106. if (inputs.size() == 2) {
  107. return Opr::make(inputs[0], inputs[1], param, execution_policy,
  108. config)
  109. .node();
  110. }
  111. return nullptr;
  112. }
  113. };
  114. template <class MegDNNConv = megdnn::Convolution>
  115. struct MakeConvCaller3 {
  116. template <typename Opr>
  117. static VarNode* make(const cg::VarNodeArray& inputs,
  118. const typename MegDNNConv::Param& param,
  119. const megdnn::param::ExecutionPolicy& execution_policy,
  120. const OperatorNodeConfig& config) {
  121. if (inputs.size() == 3) {
  122. return Opr::make(inputs[0], inputs[1], inputs[2], param,
  123. execution_policy, config)
  124. .node();
  125. }
  126. return nullptr;
  127. }
  128. };
  129. template <class MegDNNConv = megdnn::Convolution>
  130. struct MakeConvCaller4 {
  131. template <typename Opr>
  132. static VarNode* make(const cg::VarNodeArray& inputs,
  133. const typename MegDNNConv::Param& param,
  134. const megdnn::param::ExecutionPolicy& execution_policy,
  135. const OperatorNodeConfig& config) {
  136. if (inputs.size() == 4) {
  137. return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param,
  138. execution_policy, config)
  139. .node();
  140. }
  141. return nullptr;
  142. }
  143. };
  144. template <class MegDNNConv = megdnn::Convolution>
  145. struct MakeConvCaller5 {
  146. template <typename Opr>
  147. static VarNode* make(const cg::VarNodeArray& inputs,
  148. const typename MegDNNConv::Param& param,
  149. const megdnn::param::ExecutionPolicy& execution_policy,
  150. const OperatorNodeConfig& config) {
  151. if (inputs.size() == 5) {
  152. return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3],
  153. inputs[4], param, execution_policy, config)
  154. .node();
  155. }
  156. return nullptr;
  157. }
  158. };
  159. template <class MegDNNConv = megdnn::Convolution>
  160. struct MakeConvCallerEmpty {
  161. template <typename Opr>
  162. static VarNode* make(const cg::VarNodeArray&,
  163. const typename MegDNNConv::Param&,
  164. const megdnn::param::ExecutionPolicy&,
  165. const OperatorNodeConfig&) {
  166. return nullptr;
  167. }
  168. };
  169. template <class Opr, class Maker0, class MegDNNConv,
  170. class Maker1 = MakeConvCallerEmpty<MegDNNConv>,
  171. class Maker2 = MakeConvCallerEmpty<MegDNNConv>,
  172. typename ConvParam = megdnn::param::Convolution>
  173. struct ConvLoadDumpImpl {
  174. static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
  175. auto&& opr = opr_.cast_final_safe<Opr>();
  176. ctx.write_param<ConvParam>(opr.param());
  177. ctx.write_param<megdnn::param::ExecutionPolicy>(
  178. opr.execution_policy_transient());
  179. }
  180. static VarNode* make(const cg::VarNodeArray& inputs, const ConvParam& param,
  181. const megdnn::param::ExecutionPolicy& execution_policy,
  182. const OperatorNodeConfig& config) {
  183. VarNode* ret = Maker0::template make<Opr>(inputs, param,
  184. execution_policy, config);
  185. if (!ret) {
  186. ret = Maker1::template make<Opr>(inputs, param, execution_policy,
  187. config);
  188. }
  189. if (!ret) {
  190. ret = Maker2::template make<Opr>(inputs, param, execution_policy,
  191. config);
  192. }
  193. mgb_assert(ret);
  194. return ret;
  195. }
  196. static cg::OperatorNodeBase* load(OprLoadContext& ctx,
  197. const cg::VarNodeArray& inputs,
  198. const OperatorNodeConfig& config) {
  199. auto param = ctx.read_param<ConvParam>();
  200. auto execution_policy =
  201. ctx.read_param<megdnn::param::ExecutionPolicy>();
  202. return make(inputs, param, execution_policy, config)->owner_opr();
  203. }
  204. };
  205. template <class Opr, class Maker0,
  206. typename PoolingParam = megdnn::param::Pooling>
  207. struct PoolingLoadDumpImpl {
  208. static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
  209. auto&& opr = opr_.cast_final_safe<Opr>();
  210. ctx.write_param<PoolingParam>(opr.param());
  211. }
  212. static VarNode* make(const cg::VarNodeArray& inputs,
  213. const PoolingParam& param,
  214. const OperatorNodeConfig& config) {
  215. VarNode* ret = Maker0::template make<Opr>(inputs, param, config);
  216. mgb_assert(ret);
  217. return ret;
  218. }
  219. static cg::OperatorNodeBase* load(OprLoadContext& ctx,
  220. const cg::VarNodeArray& inputs,
  221. const OperatorNodeConfig& config) {
  222. auto param = ctx.read_param<PoolingParam>();
  223. return make(inputs, param, config)->owner_opr();
  224. }
  225. };
  226. template <>
  227. struct OprMaker<opr::TQTBackward, 3> {
  228. using Param = opr::TQTBackward::Param;
  229. static cg::OperatorNodeBase* make(const Param& param,
  230. const cg::VarNodeArray& i,
  231. ComputingGraph& graph,
  232. const OperatorNodeConfig& config) {
  233. MGB_MARK_USED_VAR(graph);
  234. return opr::TQTBackward::make(i[0], i[1], i[2], param, config)[0]
  235. .node()
  236. ->owner_opr();
  237. }
  238. };
  239. template <>
  240. struct OprMaker<opr::LSQBackward, 5> {
  241. using Param = opr::LSQBackward::Param;
  242. static cg::OperatorNodeBase* make(const Param& param,
  243. const cg::VarNodeArray& i,
  244. ComputingGraph& graph,
  245. const OperatorNodeConfig& config) {
  246. MGB_MARK_USED_VAR(graph);
  247. return opr::LSQBackward::make(i[0], i[1], i[2], i[3], i[4], param,
  248. config)[0]
  249. .node()
  250. ->owner_opr();
  251. }
  252. };
  253. template <>
  254. struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0>
  255. : public PoolingLoadDumpImpl<opr::AdaptivePoolingBackward,
  256. MakeAdaptivePoolingBackwardCaller3<
  257. megdnn::AdaptivePoolingBackward>,
  258. megdnn::param::AdaptivePooling> {};
  259. template <>
  260. struct OprLoadDumpImpl<opr::AdaptivePooling, 0>
  261. : public PoolingLoadDumpImpl<
  262. opr::AdaptivePooling,
  263. MakeROIAlignCaller1<megdnn::AdaptivePooling>,
  264. megdnn::param::AdaptivePooling> {};
  265. template <>
  266. struct OprLoadDumpImpl<opr::ROIAlign, 0>
  267. : public PoolingLoadDumpImpl<opr::ROIAlign,
  268. MakeROIAlignCaller1<megdnn::ROIAlign>,
  269. megdnn::param::ROIAlign> {};
  270. template <>
  271. struct OprLoadDumpImpl<opr::ROIAlignBackward, 0>
  272. : public PoolingLoadDumpImpl<
  273. opr::ROIAlignBackward,
  274. MakeROIAlignCaller4<megdnn::ROIAlignBackward>,
  275. megdnn::param::ROIAlign> {};
  276. template <>
  277. struct OprLoadDumpImpl<opr::Pooling, 0>
  278. : public PoolingLoadDumpImpl<opr::Pooling,
  279. MakePoolingCaller1<megdnn::Pooling>,
  280. megdnn::param::Pooling> {};
  281. template <>
  282. struct OprLoadDumpImpl<opr::PoolingBackward, 0>
  283. : public PoolingLoadDumpImpl<
  284. opr::PoolingBackward,
  285. MakePoolingBackwardCaller3<megdnn::PoolingBackward>,
  286. megdnn::param::Pooling> {};
  287. template <>
  288. struct OprLoadDumpImpl<opr::Convolution, 0>
  289. : public ConvLoadDumpImpl<opr::Convolution,
  290. MakeConvCaller2<megdnn::Convolution>,
  291. megdnn::Convolution> {};
  292. template <>
  293. struct OprLoadDumpImpl<opr::ConvolutionBackwardData, 0>
  294. : public ConvLoadDumpImpl<opr::ConvolutionBackwardData,
  295. MakeConvCaller2<megdnn::Convolution>,
  296. megdnn::Convolution,
  297. MakeConvCaller3<megdnn::Convolution> > {};
  298. template <>
  299. struct OprLoadDumpImpl<opr::ConvolutionBackwardFilter, 0>
  300. : public ConvLoadDumpImpl<opr::ConvolutionBackwardFilter,
  301. MakeConvCaller3<megdnn::Convolution>,
  302. megdnn::Convolution> {};
  303. template <>
  304. struct OprLoadDumpImpl<opr::Convolution3D, 0>
  305. : public ConvLoadDumpImpl<opr::Convolution3D,
  306. MakeConvCaller2<megdnn::Convolution3D>,
  307. megdnn::Convolution3D,
  308. MakeConvCallerEmpty<megdnn::Convolution3D>,
  309. MakeConvCallerEmpty<megdnn::Convolution3D>,
  310. megdnn::param::Convolution3D> {};
  311. template <>
  312. struct OprLoadDumpImpl<opr::Convolution3DBackwardData, 0>
  313. : public ConvLoadDumpImpl<opr::Convolution3DBackwardData,
  314. MakeConvCaller2<megdnn::Convolution3D>,
  315. megdnn::Convolution3D,
  316. MakeConvCaller3<megdnn::Convolution3D>,
  317. MakeConvCallerEmpty<megdnn::Convolution3D>,
  318. megdnn::param::Convolution3D> {};
  319. template <>
  320. struct OprLoadDumpImpl<opr::Convolution3DBackwardFilter, 0>
  321. : public ConvLoadDumpImpl<opr::Convolution3DBackwardFilter,
  322. MakeConvCaller3<megdnn::Convolution3D>,
  323. megdnn::Convolution3D,
  324. MakeConvCallerEmpty<megdnn::Convolution3D>,
  325. MakeConvCallerEmpty<megdnn::Convolution3D>,
  326. megdnn::param::Convolution3D> {};
  327. template <>
  328. struct OprLoadDumpImpl<opr::ConvBiasForward, 0>
  329. : public ConvLoadDumpImpl<opr::ConvBiasForward,
  330. MakeConvCaller2<megdnn::ConvBiasForward>,
  331. megdnn::ConvBiasForward,
  332. MakeConvCaller3<megdnn::ConvBiasForward>,
  333. MakeConvCaller4<megdnn::ConvBiasForward>,
  334. megdnn::param::ConvBias> {};
  335. template <>
  336. struct OprLoadDumpImpl<opr::BatchConvBiasForward, 0>
  337. : public ConvLoadDumpImpl<opr::BatchConvBiasForward,
  338. MakeConvCaller2<megdnn::BatchConvBiasForward>,
  339. megdnn::BatchConvBiasForward,
  340. MakeConvCaller3<megdnn::BatchConvBiasForward>,
  341. MakeConvCaller4<megdnn::BatchConvBiasForward>,
  342. megdnn::param::BatchConvBias> {};
  343. template <>
  344. struct OprMaker<opr::BatchNorm, 0> {
  345. using Param = opr::BatchNorm::Param;
  346. static cg::OperatorNodeBase* make(const Param& param,
  347. const cg::VarNodeArray& i,
  348. ComputingGraph& graph,
  349. const OperatorNodeConfig& config) {
  350. MGB_MARK_USED_VAR(graph);
  351. if (i.size() == 3) {
  352. return opr::BatchNorm::make(i[0], i[1], i[2], param, config)[0]
  353. .node()
  354. ->owner_opr();
  355. } else {
  356. mgb_assert(i.size() == 5);
  357. return opr::BatchNorm::make(i[0], i[1], i[2], i[3], i[4], param,
  358. config)[0]
  359. .node()
  360. ->owner_opr();
  361. }
  362. }
  363. };
  364. template <>
  365. struct OprMaker<opr::BatchNormBackward, 6> {
  366. using Param = opr::BatchNormBackward::Param;
  367. static cg::OperatorNodeBase* make(const Param& param,
  368. const cg::VarNodeArray& i,
  369. ComputingGraph& graph,
  370. const OperatorNodeConfig& config) {
  371. MGB_MARK_USED_VAR(graph);
  372. return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param,
  373. config)[0]
  374. .node()
  375. ->owner_opr();
  376. }
  377. };
  378. template <class MegDNNConv = megdnn::LocalShare>
  379. struct MakeLocalShareCaller2 {
  380. template <typename Opr>
  381. static VarNode* make(const cg::VarNodeArray& inputs,
  382. const typename MegDNNConv::Param& param,
  383. const megdnn::param::ExecutionPolicy& execution_policy,
  384. const OperatorNodeConfig& config) {
  385. if (inputs.size() == 2) {
  386. return Opr::make(inputs[0], inputs[1], param, execution_policy,
  387. config)
  388. .node();
  389. }
  390. return nullptr;
  391. }
  392. };
  393. template <class MegDNNConv = megdnn::LocalShare>
  394. struct MakeLocalShareCaller3 {
  395. template <typename Opr>
  396. static VarNode* make(const cg::VarNodeArray& inputs,
  397. const typename MegDNNConv::Param& param,
  398. const megdnn::param::ExecutionPolicy& execution_policy,
  399. const OperatorNodeConfig& config) {
  400. if (inputs.size() == 3) {
  401. return Opr::make(inputs[0], inputs[1], inputs[2], param,
  402. execution_policy, config)
  403. .node();
  404. }
  405. return nullptr;
  406. }
  407. };
  408. template <class MegDNNConv = megdnn::LocalShare>
  409. struct MakeLocalShareCallerEmpty {
  410. template <typename Opr>
  411. static VarNode* make(const cg::VarNodeArray&,
  412. const typename MegDNNConv::Param&,
  413. const megdnn::param::ExecutionPolicy&,
  414. const OperatorNodeConfig&) {
  415. return nullptr;
  416. }
  417. };
  418. template <class Opr, class Maker0, class MegDNNConv,
  419. class Maker1 = MakeLocalShareCallerEmpty<MegDNNConv>,
  420. class Maker2 = MakeLocalShareCallerEmpty<MegDNNConv>,
  421. typename LocalShareParam = megdnn::param::LocalShare>
  422. struct LocalShareLoadDumpImpl {
  423. static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
  424. auto&& opr = opr_.cast_final_safe<Opr>();
  425. ctx.write_param<LocalShareParam>(opr.param());
  426. ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
  427. }
  428. static VarNode* make(const cg::VarNodeArray& inputs,
  429. const LocalShareParam& param,
  430. const megdnn::param::ExecutionPolicy& execution_policy,
  431. const OperatorNodeConfig& config) {
  432. VarNode* ret = Maker0::template make<Opr>(inputs, param,
  433. execution_policy, config);
  434. if (!ret) {
  435. ret = Maker1::template make<Opr>(inputs, param, execution_policy,
  436. config);
  437. }
  438. if (!ret) {
  439. ret = Maker2::template make<Opr>(inputs, param, execution_policy,
  440. config);
  441. }
  442. mgb_assert(ret);
  443. return ret;
  444. }
  445. static cg::OperatorNodeBase* load(OprLoadContext& ctx,
  446. const cg::VarNodeArray& inputs,
  447. const OperatorNodeConfig& config) {
  448. auto param = ctx.read_param<LocalShareParam>();
  449. auto execution_policy =
  450. ctx.read_param<megdnn::param::ExecutionPolicy>();
  451. return make(inputs, param, execution_policy, config)->owner_opr();
  452. }
  453. };
  454. template <>
  455. struct OprLoadDumpImpl<opr::LocalShare, 0>
  456. : public LocalShareLoadDumpImpl<
  457. opr::LocalShare, MakeLocalShareCaller2<megdnn::LocalShare>,
  458. megdnn::LocalShare> {};
  459. template <>
  460. struct OprLoadDumpImpl<opr::LocalShareBackwardData, 0>
  461. : public LocalShareLoadDumpImpl<
  462. opr::LocalShareBackwardData,
  463. MakeLocalShareCaller3<megdnn::LocalShare>,
  464. megdnn::LocalShare> {};
  465. template <>
  466. struct OprLoadDumpImpl<opr::LocalShareBackwardFilter, 0>
  467. : public LocalShareLoadDumpImpl<
  468. opr::LocalShareBackwardFilter,
  469. MakeLocalShareCaller3<megdnn::LocalShare>,
  470. megdnn::LocalShare> {};
  471. template <>
  472. struct OprLoadDumpImpl<opr::DeformableConvForward, 0>
  473. : public ConvLoadDumpImpl<
  474. opr::DeformableConvForward,
  475. MakeConvCaller4<megdnn::DeformableConvForward>,
  476. megdnn::Convolution> {};
  477. template <>
  478. struct OprLoadDumpImpl<opr::DeformableConvBackwardData, 0>
  479. : public ConvLoadDumpImpl<
  480. opr::DeformableConvBackwardData,
  481. MakeConvCaller5<megdnn::DeformableConvBackwardData>,
  482. megdnn::Convolution> {};
  483. template <>
  484. struct OprLoadDumpImpl<opr::DeformableConvBackwardFilter, 0>
  485. : public ConvLoadDumpImpl<
  486. opr::DeformableConvBackwardFilter,
  487. MakeConvCaller5<megdnn::DeformableConvBackwardFilter>,
  488. megdnn::Convolution> {};
  489. } // namespace serialization
  490. namespace opr {
  491. using ConvolutionV2 = Convolution;
  492. using ConvolutionBackwardDataV2 = ConvolutionBackwardData;
  493. using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter;
  494. MGB_SEREG_OPR(ConvolutionV2, 0);
  495. MGB_SEREG_OPR(ConvolutionBackwardDataV2, 0);
  496. MGB_SEREG_OPR(ConvolutionBackwardFilterV2, 0);
  497. MGB_SEREG_OPR(Images2Neibs, 1);
  498. MGB_SEREG_OPR(Images2NeibsBackward, 2);
  499. MGB_SEREG_OPR(SlidingWindowTranspose, 1);
  500. MGB_SEREG_OPR(SlidingWindowTransposeBackward, 2);
  501. using LocalV2 = Local;
  502. using LocalBackwardDataV2 = LocalBackwardData;
  503. using LocalBackwardFilterV2 = LocalBackwardFilter;
  504. MGB_SEREG_OPR(LocalV2, 2);
  505. MGB_SEREG_OPR(LocalBackwardDataV2, 3);
  506. MGB_SEREG_OPR(LocalBackwardFilterV2, 3);
  507. using GroupLocalV2 = GroupLocal;
  508. using GroupLocalBackwardDataV2 = GroupLocalBackwardData;
  509. using GroupLocalBackwardFilterV2 = GroupLocalBackwardFilter;
  510. MGB_SEREG_OPR(GroupLocalV2, 2);
  511. MGB_SEREG_OPR(GroupLocalBackwardDataV2, 3);
  512. MGB_SEREG_OPR(GroupLocalBackwardFilterV2, 3);
  513. MGB_SEREG_OPR(LRN, 1);
  514. MGB_SEREG_OPR(LRNBackward, 3);
  515. using PoolingV1 = Pooling;
  516. using PoolingBackwardV1 = PoolingBackward;
  517. MGB_SEREG_OPR(PoolingV1, 1);
  518. MGB_SEREG_OPR(PoolingBackwardV1, 3);
  519. using AdaptivePoolingV1 = AdaptivePooling;
  520. using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward;
  521. MGB_SEREG_OPR(AdaptivePoolingV1, 2);
  522. MGB_SEREG_OPR(AdaptivePoolingBackwardV1, 4);
  523. MGB_SEREG_OPR(ROIPooling, 3);
  524. MGB_SEREG_OPR(ROIPoolingBackward, 4);
  525. using MaskConvolutionV2 = MaskConvolution;
  526. MGB_SEREG_OPR(MaskConvolutionV2, 3);
  527. MGB_SEREG_OPR(MaskPropagate, 1);
  528. MGB_SEREG_OPR(Convolution3D, 0);
  529. MGB_SEREG_OPR(Convolution3DBackwardData, 0);
  530. MGB_SEREG_OPR(Convolution3DBackwardFilter, 0);
  531. using ConvBiasForwardV4 = ConvBiasForward;
  532. MGB_SEREG_OPR(ConvBiasForwardV4, 0);
  533. MGB_SEREG_OPR(BatchNorm, 0);
  534. MGB_SEREG_OPR(BatchNormBackward, 6);
  535. using LocalShareForwardV1 = LocalShareForward;
  536. using LocalShareBackwardDataV1 = LocalShareBackwardData;
  537. using LocalShareBackwardFilterV1 = LocalShareBackwardFilter;
  538. MGB_SEREG_OPR(LocalShareForwardV1, 0);
  539. MGB_SEREG_OPR(LocalShareBackwardDataV1, 0);
  540. MGB_SEREG_OPR(LocalShareBackwardFilterV1, 0);
  541. using ROIAlignV1 = ROIAlign;
  542. using ROIAlignBackwardV1 = ROIAlignBackward;
  543. MGB_SEREG_OPR(ROIAlignV1, 2);
  544. MGB_SEREG_OPR(ROIAlignBackwardV1, 4);
  545. using DeformableConvForwardV1 = DeformableConvForward;
  546. using DeformableConvBackwardDataV1 = DeformableConvBackwardData;
  547. using DeformableConvBackwardFilterV1 = DeformableConvBackwardFilter;
  548. MGB_SEREG_OPR(DeformableConvForwardV1, 0);
  549. MGB_SEREG_OPR(DeformableConvBackwardDataV1, 0);
  550. MGB_SEREG_OPR(DeformableConvBackwardFilterV1, 0);
  551. MGB_SEREG_OPR(CorrelationForward, 2);
  552. MGB_SEREG_OPR(CorrelationBackwardData1, 3);
  553. MGB_SEREG_OPR(CorrelationBackwardData2, 3);
  554. MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3);
  555. MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);
  556. using BatchConvBiasForwardV1 = BatchConvBiasForward;
  557. MGB_SEREG_OPR(BatchConvBiasForwardV1, 0);
  558. MGB_SEREG_OPR(FakeQuant, 3);
  559. MGB_SEREG_OPR(FakeQuantBackward, 4);
  560. MGB_SEREG_OPR(TQT, 2);
  561. MGB_SEREG_OPR(TQTBackward, 3);
  562. MGB_SEREG_OPR(LSQ, 4);
  563. MGB_SEREG_OPR(LSQBackward, 5);
  564. } // namespace opr
  565. } // namespace mgb
  566. // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台