You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dnn.sereg.h 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. /**
  2. * \file src/opr/impl/dnn/dnn.sereg.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/opr/dnn/adaptive_pooling.h"
  13. #include "megbrain/opr/dnn/batch_norm.h"
  14. #include "megbrain/opr/dnn/convolution.h"
  15. #include "megbrain/opr/dnn/correlation.h"
  16. #include "megbrain/opr/dnn/fake_quant.h"
  17. #include "megbrain/opr/dnn/images2neibs.h"
  18. #include "megbrain/opr/dnn/local.h"
  19. #include "megbrain/opr/dnn/lrn.h"
  20. #include "megbrain/opr/dnn/lsq.h"
  21. #include "megbrain/opr/dnn/pooling.h"
  22. #include "megbrain/opr/dnn/roi_align.h"
  23. #include "megbrain/opr/dnn/roi_pooling.h"
  24. #include "megbrain/opr/dnn/sliding_window_transpose.h"
  25. #include "megbrain/opr/dnn/tqt.h"
  26. #include "megbrain/serialization/sereg.h"
  27. #include "megdnn/opr_param_defs.h"
  28. #include "megdnn/oprs/nn.h"
  29. namespace mgb {
  30. namespace serialization {
  31. template <class MegDNNPooling = megdnn::Pooling>
  32. struct MakePoolingCaller1 {
  33. template <typename Opr>
  34. static VarNode* make(
  35. const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param,
  36. const megdnn::param::ExecutionPolicy& execution_policy,
  37. const OperatorNodeConfig& config) {
  38. if (inputs.size() == 1) {
  39. return Opr::make(inputs[0], param, execution_policy, config).node();
  40. }
  41. return nullptr;
  42. }
  43. };
  44. template <class MegDNNROIALIGN = megdnn::ROIAlign>
  45. struct MakeROIAlignCaller1 {
  46. template <typename Opr>
  47. static VarNode* make(
  48. const cg::VarNodeArray& inputs, const typename MegDNNROIALIGN::Param& param,
  49. const OperatorNodeConfig& config) {
  50. if (inputs.size() == 2) {
  51. return Opr::make(inputs[0], inputs[1], param, config).node();
  52. } else {
  53. return nullptr;
  54. }
  55. }
  56. };
  57. template <class MegDNNROIALIGN = megdnn::ROIAlignBackward>
  58. struct MakeROIAlignCaller4 {
  59. template <typename Opr>
  60. static VarNode* make(
  61. const cg::VarNodeArray& inputs, const typename MegDNNROIALIGN::Param& param,
  62. const OperatorNodeConfig& config) {
  63. if (inputs.size() == 4) {
  64. return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, config)
  65. .node();
  66. } else {
  67. return nullptr;
  68. }
  69. }
  70. };
  71. template <class MegDNNPooling = megdnn::PoolingBackward>
  72. struct MakePoolingBackwardCaller3 {
  73. template <typename Opr>
  74. static VarNode* make(
  75. const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param,
  76. const megdnn::param::ExecutionPolicy& execution_policy,
  77. const OperatorNodeConfig& config) {
  78. if (inputs.size() == 3) {
  79. return Opr::make(
  80. inputs[0], inputs[1], inputs[2], param, execution_policy,
  81. config)
  82. .node();
  83. }
  84. return nullptr;
  85. }
  86. };
  87. template <class MegDNNPooling = megdnn::AdaptivePoolingBackward>
  88. struct MakeAdaptivePoolingBackwardCaller3 {
  89. template <typename Opr>
  90. static VarNode* make(
  91. const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param,
  92. const OperatorNodeConfig& config) {
  93. if (inputs.size() == 4) {
  94. return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, config)
  95. .node();
  96. }
  97. return nullptr;
  98. }
  99. };
  100. template <class MegDNNConv = megdnn::Convolution>
  101. struct MakeConvCaller2 {
  102. template <typename Opr>
  103. static VarNode* make(
  104. const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
  105. const megdnn::param::ExecutionPolicy& execution_policy,
  106. const OperatorNodeConfig& config) {
  107. if (inputs.size() == 2) {
  108. return Opr::make(inputs[0], inputs[1], param, execution_policy, config)
  109. .node();
  110. }
  111. return nullptr;
  112. }
  113. };
  114. template <class MegDNNConv = megdnn::Convolution>
  115. struct MakeConvCaller3 {
  116. template <typename Opr>
  117. static VarNode* make(
  118. const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
  119. const megdnn::param::ExecutionPolicy& execution_policy,
  120. const OperatorNodeConfig& config) {
  121. if (inputs.size() == 3) {
  122. return Opr::make(
  123. inputs[0], inputs[1], inputs[2], param, execution_policy,
  124. config)
  125. .node();
  126. }
  127. return nullptr;
  128. }
  129. };
  130. template <class MegDNNConv = megdnn::Convolution>
  131. struct MakeConvCaller4 {
  132. template <typename Opr>
  133. static VarNode* make(
  134. const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
  135. const megdnn::param::ExecutionPolicy& execution_policy,
  136. const OperatorNodeConfig& config) {
  137. if (inputs.size() == 4) {
  138. return Opr::make(
  139. inputs[0], inputs[1], inputs[2], inputs[3], param,
  140. execution_policy, config)
  141. .node();
  142. }
  143. return nullptr;
  144. }
  145. };
  146. template <class MegDNNConv = megdnn::Convolution>
  147. struct MakeConvCaller5 {
  148. template <typename Opr>
  149. static VarNode* make(
  150. const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
  151. const megdnn::param::ExecutionPolicy& execution_policy,
  152. const OperatorNodeConfig& config) {
  153. if (inputs.size() == 5) {
  154. return Opr::make(
  155. inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], param,
  156. execution_policy, config)
  157. .node();
  158. }
  159. return nullptr;
  160. }
  161. };
  162. template <class MegDNNConv = megdnn::Convolution>
  163. struct MakeConvCallerEmpty {
  164. template <typename Opr>
  165. static VarNode* make(
  166. const cg::VarNodeArray&, const typename MegDNNConv::Param&,
  167. const megdnn::param::ExecutionPolicy&, const OperatorNodeConfig&) {
  168. return nullptr;
  169. }
  170. };
  171. template <
  172. class Opr, class Maker0, class MegDNNConv,
  173. class Maker1 = MakeConvCallerEmpty<MegDNNConv>,
  174. class Maker2 = MakeConvCallerEmpty<MegDNNConv>,
  175. typename ConvParam = megdnn::param::Convolution>
  176. struct ConvLoadDumpImpl {
  177. static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
  178. auto&& opr = opr_.cast_final_safe<Opr>();
  179. ctx.write_param<ConvParam>(opr.param());
  180. ctx.write_param<megdnn::param::ExecutionPolicy>(
  181. opr.execution_policy_transient());
  182. }
  183. static VarNode* make(
  184. const cg::VarNodeArray& inputs, const ConvParam& param,
  185. const megdnn::param::ExecutionPolicy& execution_policy,
  186. const OperatorNodeConfig& config) {
  187. VarNode* ret =
  188. Maker0::template make<Opr>(inputs, param, execution_policy, config);
  189. if (!ret) {
  190. ret = Maker1::template make<Opr>(inputs, param, execution_policy, config);
  191. }
  192. if (!ret) {
  193. ret = Maker2::template make<Opr>(inputs, param, execution_policy, config);
  194. }
  195. mgb_assert(ret);
  196. return ret;
  197. }
  198. static cg::OperatorNodeBase* load(
  199. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  200. const OperatorNodeConfig& config) {
  201. auto param = ctx.read_param<ConvParam>();
  202. auto execution_policy = ctx.read_param<megdnn::param::ExecutionPolicy>();
  203. return make(inputs, param, execution_policy, config)->owner_opr();
  204. }
  205. };
  206. template <class Opr, class Maker0, typename PoolingParam = megdnn::param::Pooling>
  207. struct PoolingLoadDumpImpl {
  208. static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
  209. auto&& opr = opr_.cast_final_safe<Opr>();
  210. ctx.write_param<PoolingParam>(opr.param());
  211. }
  212. static VarNode* make(
  213. const cg::VarNodeArray& inputs, const PoolingParam& param,
  214. const megdnn::param::ExecutionPolicy& execution_policy,
  215. const OperatorNodeConfig& config) {
  216. VarNode* ret =
  217. Maker0::template make<Opr>(inputs, param, execution_policy, config);
  218. mgb_assert(ret);
  219. return ret;
  220. }
  221. static cg::OperatorNodeBase* load(
  222. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  223. const OperatorNodeConfig& config) {
  224. auto param = ctx.read_param<PoolingParam>();
  225. return make(inputs, param, {}, config)->owner_opr();
  226. }
  227. };
  228. template <class Opr, class Maker0, typename GeneralOprParam = megdnn::param::ROIAlign>
  229. struct GeneralOprLoadDumpImpl {
  230. static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
  231. auto&& opr = opr_.cast_final_safe<Opr>();
  232. ctx.write_param<GeneralOprParam>(opr.param());
  233. }
  234. static VarNode* make(
  235. const cg::VarNodeArray& inputs, const GeneralOprParam& param,
  236. const OperatorNodeConfig& config) {
  237. VarNode* ret = Maker0::template make<Opr>(inputs, param, config);
  238. mgb_assert(ret);
  239. return ret;
  240. }
  241. static cg::OperatorNodeBase* load(
  242. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  243. const OperatorNodeConfig& config) {
  244. auto param = ctx.read_param<GeneralOprParam>();
  245. return make(inputs, param, config)->owner_opr();
  246. }
  247. };
  248. template <>
  249. struct OprMaker<opr::TQTBackward, 3> {
  250. using Param = opr::TQTBackward::Param;
  251. static cg::OperatorNodeBase* make(
  252. const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
  253. const OperatorNodeConfig& config) {
  254. MGB_MARK_USED_VAR(graph);
  255. return opr::TQTBackward::make(i[0], i[1], i[2], param, config)[0]
  256. .node()
  257. ->owner_opr();
  258. }
  259. };
  260. template <>
  261. struct OprMaker<opr::LSQBackward, 5> {
  262. using Param = opr::LSQBackward::Param;
  263. static cg::OperatorNodeBase* make(
  264. const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
  265. const OperatorNodeConfig& config) {
  266. MGB_MARK_USED_VAR(graph);
  267. return opr::LSQBackward::make(i[0], i[1], i[2], i[3], i[4], param, config)[0]
  268. .node()
  269. ->owner_opr();
  270. }
  271. };
  272. template <>
  273. struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0>
  274. : public GeneralOprLoadDumpImpl<
  275. opr::AdaptivePoolingBackward,
  276. MakeAdaptivePoolingBackwardCaller3<megdnn::AdaptivePoolingBackward>,
  277. megdnn::param::AdaptivePooling> {};
  278. template <>
  279. struct OprLoadDumpImpl<opr::AdaptivePooling, 0>
  280. : public GeneralOprLoadDumpImpl<
  281. opr::AdaptivePooling, MakeROIAlignCaller1<megdnn::AdaptivePooling>,
  282. megdnn::param::AdaptivePooling> {};
  283. template <>
  284. struct OprLoadDumpImpl<opr::ROIAlign, 0>
  285. : public GeneralOprLoadDumpImpl<
  286. opr::ROIAlign, MakeROIAlignCaller1<megdnn::ROIAlign>,
  287. megdnn::param::ROIAlign> {};
  288. template <>
  289. struct OprLoadDumpImpl<opr::ROIAlignBackward, 0>
  290. : public GeneralOprLoadDumpImpl<
  291. opr::ROIAlignBackward, MakeROIAlignCaller4<megdnn::ROIAlignBackward>,
  292. megdnn::param::ROIAlign> {};
  293. template <>
  294. struct OprLoadDumpImpl<opr::Pooling, 0>
  295. : public PoolingLoadDumpImpl<
  296. opr::Pooling, MakePoolingCaller1<megdnn::Pooling>,
  297. megdnn::param::Pooling> {};
  298. template <>
  299. struct OprLoadDumpImpl<opr::PoolingBackward, 0>
  300. : public PoolingLoadDumpImpl<
  301. opr::PoolingBackward,
  302. MakePoolingBackwardCaller3<megdnn::PoolingBackward>,
  303. megdnn::param::Pooling> {};
  304. template <>
  305. struct OprLoadDumpImpl<opr::Convolution, 0>
  306. : public ConvLoadDumpImpl<
  307. opr::Convolution, MakeConvCaller2<megdnn::Convolution>,
  308. megdnn::Convolution> {};
  309. template <>
  310. struct OprLoadDumpImpl<opr::ConvolutionBackwardData, 0>
  311. : public ConvLoadDumpImpl<
  312. opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>,
  313. megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {};
  314. template <>
  315. struct OprLoadDumpImpl<opr::ConvolutionBackwardFilter, 0>
  316. : public ConvLoadDumpImpl<
  317. opr::ConvolutionBackwardFilter, MakeConvCaller3<megdnn::Convolution>,
  318. megdnn::Convolution> {};
  319. template <>
  320. struct OprLoadDumpImpl<opr::Convolution3D, 0>
  321. : public ConvLoadDumpImpl<
  322. opr::Convolution3D, MakeConvCaller2<megdnn::Convolution3D>,
  323. megdnn::Convolution3D, MakeConvCallerEmpty<megdnn::Convolution3D>,
  324. MakeConvCallerEmpty<megdnn::Convolution3D>,
  325. megdnn::param::Convolution3D> {};
  326. template <>
  327. struct OprLoadDumpImpl<opr::Convolution3DBackwardData, 0>
  328. : public ConvLoadDumpImpl<
  329. opr::Convolution3DBackwardData,
  330. MakeConvCaller2<megdnn::Convolution3D>, megdnn::Convolution3D,
  331. MakeConvCaller3<megdnn::Convolution3D>,
  332. MakeConvCallerEmpty<megdnn::Convolution3D>,
  333. megdnn::param::Convolution3D> {};
  334. template <>
  335. struct OprLoadDumpImpl<opr::Convolution3DBackwardFilter, 0>
  336. : public ConvLoadDumpImpl<
  337. opr::Convolution3DBackwardFilter,
  338. MakeConvCaller3<megdnn::Convolution3D>, megdnn::Convolution3D,
  339. MakeConvCallerEmpty<megdnn::Convolution3D>,
  340. MakeConvCallerEmpty<megdnn::Convolution3D>,
  341. megdnn::param::Convolution3D> {};
  342. template <>
  343. struct OprLoadDumpImpl<opr::ConvBiasForward, 0>
  344. : public ConvLoadDumpImpl<
  345. opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>,
  346. megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>,
  347. MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {};
  348. template <>
  349. struct OprLoadDumpImpl<opr::BatchConvBiasForward, 0>
  350. : public ConvLoadDumpImpl<
  351. opr::BatchConvBiasForward,
  352. MakeConvCaller2<megdnn::BatchConvBiasForward>,
  353. megdnn::BatchConvBiasForward,
  354. MakeConvCaller3<megdnn::BatchConvBiasForward>,
  355. MakeConvCaller4<megdnn::BatchConvBiasForward>,
  356. megdnn::param::BatchConvBias> {};
  357. template <>
  358. struct OprMaker<opr::BatchNorm, 0> {
  359. using Param = opr::BatchNorm::Param;
  360. static cg::OperatorNodeBase* make(
  361. const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
  362. const OperatorNodeConfig& config) {
  363. MGB_MARK_USED_VAR(graph);
  364. if (i.size() == 3) {
  365. return opr::BatchNorm::make(i[0], i[1], i[2], param, config)[0]
  366. .node()
  367. ->owner_opr();
  368. } else {
  369. mgb_assert(i.size() == 5);
  370. return opr::BatchNorm::make(i[0], i[1], i[2], i[3], i[4], param, config)[0]
  371. .node()
  372. ->owner_opr();
  373. }
  374. }
  375. };
  376. // OprMaker in MGB_SEREG_OPR only support unique output opr
  377. template <>
  378. struct OprMaker<opr::BatchNormBackward, 6> {
  379. using Param = opr::BatchNormBackward::Param;
  380. static cg::OperatorNodeBase* make(
  381. const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
  382. const OperatorNodeConfig& config) {
  383. MGB_MARK_USED_VAR(graph);
  384. return opr::BatchNormBackward::make(
  385. i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0]
  386. .node()
  387. ->owner_opr();
  388. }
  389. };
  390. template <class MegDNNConv = megdnn::LocalShare>
  391. struct MakeLocalShareCaller2 {
  392. template <typename Opr>
  393. static VarNode* make(
  394. const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
  395. const megdnn::param::ExecutionPolicy& execution_policy,
  396. const OperatorNodeConfig& config) {
  397. if (inputs.size() == 2) {
  398. return Opr::make(inputs[0], inputs[1], param, execution_policy, config)
  399. .node();
  400. }
  401. return nullptr;
  402. }
  403. };
  404. template <class MegDNNConv = megdnn::LocalShare>
  405. struct MakeLocalShareCaller3 {
  406. template <typename Opr>
  407. static VarNode* make(
  408. const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
  409. const megdnn::param::ExecutionPolicy& execution_policy,
  410. const OperatorNodeConfig& config) {
  411. if (inputs.size() == 3) {
  412. return Opr::make(
  413. inputs[0], inputs[1], inputs[2], param, execution_policy,
  414. config)
  415. .node();
  416. }
  417. return nullptr;
  418. }
  419. };
  420. template <class MegDNNConv = megdnn::LocalShare>
  421. struct MakeLocalShareCallerEmpty {
  422. template <typename Opr>
  423. static VarNode* make(
  424. const cg::VarNodeArray&, const typename MegDNNConv::Param&,
  425. const megdnn::param::ExecutionPolicy&, const OperatorNodeConfig&) {
  426. return nullptr;
  427. }
  428. };
  429. template <
  430. class Opr, class Maker0, class MegDNNConv,
  431. class Maker1 = MakeLocalShareCallerEmpty<MegDNNConv>,
  432. class Maker2 = MakeLocalShareCallerEmpty<MegDNNConv>,
  433. typename LocalShareParam = megdnn::param::LocalShare>
  434. struct LocalShareLoadDumpImpl {
  435. static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
  436. auto&& opr = opr_.cast_final_safe<Opr>();
  437. ctx.write_param<LocalShareParam>(opr.param());
  438. ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
  439. }
  440. static VarNode* make(
  441. const cg::VarNodeArray& inputs, const LocalShareParam& param,
  442. const megdnn::param::ExecutionPolicy& execution_policy,
  443. const OperatorNodeConfig& config) {
  444. VarNode* ret =
  445. Maker0::template make<Opr>(inputs, param, execution_policy, config);
  446. if (!ret) {
  447. ret = Maker1::template make<Opr>(inputs, param, execution_policy, config);
  448. }
  449. if (!ret) {
  450. ret = Maker2::template make<Opr>(inputs, param, execution_policy, config);
  451. }
  452. mgb_assert(ret);
  453. return ret;
  454. }
  455. static cg::OperatorNodeBase* load(
  456. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  457. const OperatorNodeConfig& config) {
  458. auto param = ctx.read_param<LocalShareParam>();
  459. auto execution_policy = ctx.read_param<megdnn::param::ExecutionPolicy>();
  460. return make(inputs, param, execution_policy, config)->owner_opr();
  461. }
  462. };
  463. template <>
  464. struct OprLoadDumpImpl<opr::LocalShare, 0>
  465. : public LocalShareLoadDumpImpl<
  466. opr::LocalShare, MakeLocalShareCaller2<megdnn::LocalShare>,
  467. megdnn::LocalShare> {};
  468. template <>
  469. struct OprLoadDumpImpl<opr::LocalShareBackwardData, 0>
  470. : public LocalShareLoadDumpImpl<
  471. opr::LocalShareBackwardData,
  472. MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare> {};
  473. template <>
  474. struct OprLoadDumpImpl<opr::LocalShareBackwardFilter, 0>
  475. : public LocalShareLoadDumpImpl<
  476. opr::LocalShareBackwardFilter,
  477. MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare> {};
  478. template <>
  479. struct OprLoadDumpImpl<opr::DeformableConvForward, 0>
  480. : public ConvLoadDumpImpl<
  481. opr::DeformableConvForward,
  482. MakeConvCaller4<megdnn::DeformableConvForward>, megdnn::Convolution> {
  483. };
  484. template <>
  485. struct OprLoadDumpImpl<opr::DeformableConvBackwardData, 0>
  486. : public ConvLoadDumpImpl<
  487. opr::DeformableConvBackwardData,
  488. MakeConvCaller5<megdnn::DeformableConvBackwardData>,
  489. megdnn::Convolution> {};
  490. template <>
  491. struct OprLoadDumpImpl<opr::DeformableConvBackwardFilter, 0>
  492. : public ConvLoadDumpImpl<
  493. opr::DeformableConvBackwardFilter,
  494. MakeConvCaller5<megdnn::DeformableConvBackwardFilter>,
  495. megdnn::Convolution> {};
  496. template <typename Opr>
  497. cg::OperatorNodeBase* opr_shallow_copy_conv(
  498. const serialization::OprShallowCopyContext& ctx,
  499. const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
  500. const OperatorNodeConfig& config) {
  501. MGB_MARK_USED_VAR(ctx);
  502. auto&& opr = opr_.cast_final_safe<Opr>();
  503. return OprLoadDumpImpl<Opr, 0>::make(
  504. inputs, opr.param(), opr.execution_policy_transient(), config)
  505. ->owner_opr();
  506. }
  507. } // namespace serialization
  508. namespace opr {
  509. using ConvolutionV2 = Convolution;
  510. using ConvolutionBackwardDataV2 = ConvolutionBackwardData;
  511. using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter;
  512. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionV2, 0, opr_shallow_copy_conv);
  513. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionBackwardDataV2, 0, opr_shallow_copy_conv);
  514. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
  515. ConvolutionBackwardFilterV2, 0, opr_shallow_copy_conv);
  516. MGB_SEREG_OPR(Images2Neibs, 1);
  517. MGB_SEREG_OPR(Images2NeibsBackward, 2);
  518. MGB_SEREG_OPR(SlidingWindowTranspose, 1);
  519. MGB_SEREG_OPR(SlidingWindowTransposeBackward, 2);
  520. using LocalV2 = Local;
  521. using LocalBackwardDataV2 = LocalBackwardData;
  522. using LocalBackwardFilterV2 = LocalBackwardFilter;
  523. MGB_SEREG_OPR(LocalV2, 2);
  524. MGB_SEREG_OPR(LocalBackwardDataV2, 3);
  525. MGB_SEREG_OPR(LocalBackwardFilterV2, 3);
  526. using GroupLocalV2 = GroupLocal;
  527. using GroupLocalBackwardDataV2 = GroupLocalBackwardData;
  528. using GroupLocalBackwardFilterV2 = GroupLocalBackwardFilter;
  529. MGB_SEREG_OPR(GroupLocalV2, 2);
  530. MGB_SEREG_OPR(GroupLocalBackwardDataV2, 3);
  531. MGB_SEREG_OPR(GroupLocalBackwardFilterV2, 3);
  532. MGB_SEREG_OPR(LRN, 1);
  533. MGB_SEREG_OPR(LRNBackward, 3);
  534. using PoolingV1 = Pooling;
  535. using PoolingBackwardV1 = PoolingBackward;
  536. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingV1, 0, opr_shallow_copy_conv);
  537. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingBackwardV1, 0, opr_shallow_copy_conv);
  538. using AdaptivePoolingV1 = AdaptivePooling;
  539. using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward;
  540. MGB_SEREG_OPR(AdaptivePoolingV1, 2);
  541. MGB_SEREG_OPR(AdaptivePoolingBackwardV1, 4);
  542. MGB_SEREG_OPR(ROIPooling, 3);
  543. MGB_SEREG_OPR(ROIPoolingBackward, 4);
  544. using MaskConvolutionV2 = MaskConvolution;
  545. MGB_SEREG_OPR(MaskConvolutionV2, 3);
  546. MGB_SEREG_OPR(MaskPropagate, 1);
  547. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3D, 0, opr_shallow_copy_conv);
  548. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3DBackwardData, 0, opr_shallow_copy_conv);
  549. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
  550. Convolution3DBackwardFilter, 0, opr_shallow_copy_conv);
  551. using ConvBiasForwardV4 = ConvBiasForward;
  552. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvBiasForwardV4, 0, opr_shallow_copy_conv);
  553. using BatchNormV1 = BatchNorm;
  554. using BatchNormBackwardV1 = BatchNormBackward;
  555. MGB_SEREG_OPR(BatchNormV1, 0);
  556. MGB_SEREG_OPR(BatchNormBackwardV1, 6);
  557. using LocalShareForwardV1 = LocalShareForward;
  558. using LocalShareBackwardDataV1 = LocalShareBackwardData;
  559. using LocalShareBackwardFilterV1 = LocalShareBackwardFilter;
  560. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareForwardV1, 0, opr_shallow_copy_conv);
  561. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareBackwardDataV1, 0, opr_shallow_copy_conv);
  562. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
  563. LocalShareBackwardFilterV1, 0, opr_shallow_copy_conv);
  564. using ROIAlignV1 = ROIAlign;
  565. using ROIAlignBackwardV1 = ROIAlignBackward;
  566. MGB_SEREG_OPR(ROIAlignV1, 2);
  567. MGB_SEREG_OPR(ROIAlignBackwardV1, 4);
  568. using DeformableConvForwardV1 = DeformableConvForward;
  569. using DeformableConvBackwardDataV1 = DeformableConvBackwardData;
  570. using DeformableConvBackwardFilterV1 = DeformableConvBackwardFilter;
  571. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(DeformableConvForwardV1, 0, opr_shallow_copy_conv);
  572. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
  573. DeformableConvBackwardDataV1, 0, opr_shallow_copy_conv);
  574. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
  575. DeformableConvBackwardFilterV1, 0, opr_shallow_copy_conv);
  576. MGB_SEREG_OPR(CorrelationForward, 2);
  577. MGB_SEREG_OPR(CorrelationBackwardData1, 3);
  578. MGB_SEREG_OPR(CorrelationBackwardData2, 3);
  579. MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3);
  580. MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);
  581. using BatchConvBiasForwardV1 = BatchConvBiasForward;
  582. MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchConvBiasForwardV1, 0, opr_shallow_copy_conv);
  583. MGB_SEREG_OPR(FakeQuant, 3);
  584. MGB_SEREG_OPR(FakeQuantBackward, 4);
  585. MGB_SEREG_OPR(TQT, 2);
  586. MGB_SEREG_OPR(TQTBackward, 3);
  587. MGB_SEREG_OPR(LSQ, 4);
  588. MGB_SEREG_OPR(LSQBackward, 5);
  589. } // namespace opr
  590. } // namespace mgb
  591. // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}