You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_reformat.cpp 101 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230
  1. /**
  2. * \file src/gopt/impl/tensor_reformat.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/gopt/inference.h"
  13. #include "megbrain/gopt/gtrans.h"
  14. #include "megbrain/gopt/basic_arith.h"
  15. #include "megbrain/graph/event.h"
  16. #include "megbrain/opr/dnn/batch_norm.h"
  17. #include "megbrain/opr/dnn/local.h"
  18. #include "megbrain/utils/shared_set.h"
  19. #include "megbrain/serialization/opr_shallow_copy.h"
  20. #include "megbrain/opr/basic_arith.h"
  21. #include "megbrain/opr/dnn/convolution.h"
  22. #include "megbrain/opr/blas.h"
  23. #include "megbrain/opr/misc.h"
  24. #include "megbrain/opr/utility.h"
  25. #include "megbrain/opr/dnn/pooling.h"
  26. #include "megbrain/opr/tensor_manip.h"
  27. #include "megbrain/opr/imgproc.h"
  28. #include "megbrain/opr/nn_int.h"
  29. #include "megdnn/tensor_format.h"
  30. #if MGB_ENABLE_TENSOR_RT
  31. #include "megbrain/tensorrt/tensorrt_opr.h"
  32. #endif
  33. #include "megbrain/gopt/misc.h"
  34. using namespace mgb;
  35. using namespace gopt;
  36. /* ================ TensorReformatPass =============== */
  37. /*!
  38. * \brief relayout placeholder opr
  39. *
  40. * RelayoutPlaceholder oprs act as the placeholders of the ComputingGraph
  41. * during graph opt pass `TensorReformatPass`. These oprs are introduced
  42. * into a ComputingGraph for conveniently discovering further optimize
  43. * opportunities (such as fuse consecutive relayouts, translate into
  44. * optimized implementations). They are canonized to have a shape infer, so
  45. * the ouput's shape can be correctly deduced during the opt pass.
  46. *
  47. * Note that the oprs in the ComputingGraph are only used as intermediate
  48. * representations before being translated to MegBrain oprs, so the
  49. * oprs should not get involved in any actual computing.
  50. */
  51. MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder,
  52. cg::SingleCNOperatorNodeBase) // {
  53. public:
  54. //! relayout type of this opr
  55. enum class LayoutType {
  56. NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
  57. NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
  58. NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
  59. CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
  60. NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
  61. NCHW_TO_NCHW44, //!< from nchw layout to nchw44 layout
  62. NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
  63. NCHW44_TO_NCHW, //!< from nchw44 layout to nchw layout
  64. WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
  65. //!< layout
  66. WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to
  67. //!< nchw88 layout
  68. WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout
  69. //!< to nchw88 layout
  70. //!< the weight layout of input is nchw output is nchw88, special for
  71. //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8}
  72. WEIGHT_HYBIRD_NCHW_NCHW88,
  73. WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44
  74. //!< layout
  75. WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to
  76. //!< nchw44 layout
  77. WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout
  78. //!< to nchw44 layout
  79. //!< the weight layout of input is nchw output is nchw44, special for
  80. //!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4}
  81. WEIGHT_HYBIRD_NCHW_NCHW44,
  82. };
  83. RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);
  84. /*!
  85. * \param src_var the input var
  86. * \param layout_type tensor layout transform type of this relayout
  87. * placeholder as described in LayoutType
  88. */
  89. static SymbolVar make(VarNode* src_var, LayoutType layout_type);
  90. LayoutType layout_type() const { return m_layout_type; }
  91. private:
  92. void init_output_static_infer_desc() override;
  93. void scn_do_execute() override;
  94. void init_output_comp_node() override;
  95. const LayoutType m_layout_type;
  96. };
  97. MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder);
  98. TensorReformatPass::RelayoutPlaceholder::RelayoutPlaceholder(
  99. VarNode* src_var, LayoutType layout_type)
  100. : Super(src_var->owner_graph(), {}, "RelayoutPlaceholder", {src_var}),
  101. m_layout_type{layout_type} {
  102. add_input({src_var});
  103. add_equivalence_component<ScalarHash<LayoutType>>(m_layout_type);
  104. add_output(None)->dtype(src_var->dtype());
  105. }
  106. void TensorReformatPass::RelayoutPlaceholder::scn_do_execute() {
  107. mgb_throw(InternalError, "RelayoutPlaceholder opr can not be executed");
  108. }
  109. void TensorReformatPass::RelayoutPlaceholder::init_output_comp_node() {
  110. output(0)->comp_node(input(0)->comp_node());
  111. }
  112. void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
  113. using namespace cg::static_infer;
  114. auto&& mgr = owner_graph()->static_infer_manager();
  115. DepVal deps;
  116. for (auto i : input())
  117. deps.push_back({i, DepType::SHAPE});
  118. auto infer_shape = [this](TensorShape& dst, const InpVal& inp) {
  119. TensorShape inp_shape = inp.val[0].shape();
  120. dst = inp_shape;
  121. if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32) {
  122. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  123. dst[0] = inp_shape[0];
  124. dst[1] = inp_shape[1] / 8;
  125. dst[2] = inp_shape[2];
  126. dst[3] = inp_shape[3];
  127. dst[4] = inp_shape[4] * 8;
  128. } else if (layout_type() ==
  129. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4) {
  130. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32);
  131. dst[0] = inp_shape[0];
  132. dst[1] = inp_shape[1] * 8;
  133. dst[2] = inp_shape[2];
  134. dst[3] = inp_shape[3];
  135. dst[4] = inp_shape[4] / 8;
  136. } else if (layout_type() ==
  137. RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4) {
  138. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  139. dst[0] = inp_shape[1];
  140. dst[1] = inp_shape[2];
  141. dst[2] = inp_shape[3];
  142. dst[3] = inp_shape[0];
  143. dst[4] = inp_shape[4];
  144. } else if (layout_type() ==
  145. RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4) {
  146. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  147. dst[0] = inp_shape[3];
  148. dst[1] = inp_shape[0];
  149. dst[2] = inp_shape[1];
  150. dst[3] = inp_shape[2];
  151. dst[4] = inp_shape[4];
  152. } else if (layout_type() ==
  153. RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) {
  154. mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0);
  155. dst.ndim = 5;
  156. dst[0] = inp_shape[0];
  157. dst[1] = inp_shape[1] / 8;
  158. dst[2] = inp_shape[2];
  159. dst[3] = inp_shape[3];
  160. dst[4] = 8;
  161. } else if (layout_type() ==
  162. RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) {
  163. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 8);
  164. dst.ndim = 4;
  165. dst[0] = inp_shape[0];
  166. dst[1] = inp_shape[1] * 8;
  167. dst[2] = inp_shape[2];
  168. dst[3] = inp_shape[3];
  169. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  170. WEIGHT_NCHW_TO_NCHW88_DENSE) {
  171. mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0 &&
  172. inp_shape[1] % 8 == 0);
  173. dst.ndim = 6;
  174. dst[0] = inp_shape[0] / 8;
  175. dst[1] = inp_shape[1] / 8;
  176. dst[2] = inp_shape[2];
  177. dst[3] = inp_shape[3];
  178. dst[4] = 8;
  179. dst[5] = 8;
  180. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  181. WEIGHT_NCHW_TO_NCHW88_GROUP) {
  182. mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 8 == 0 &&
  183. inp_shape[2] % 8 == 0);
  184. dst.ndim = 7;
  185. dst[0] = inp_shape[0];
  186. dst[1] = inp_shape[1] / 8;
  187. dst[2] = inp_shape[2] / 8;
  188. dst[3] = inp_shape[3];
  189. dst[4] = inp_shape[4];
  190. dst[5] = 8;
  191. dst[6] = 8;
  192. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  193. WEIGHT_NCHW_TO_NCHW88_CHAN) {
  194. mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 &&
  195. inp_shape[2] == 1 && inp_shape[0] % 8 == 0);
  196. dst.ndim = 6;
  197. dst[0] = inp_shape[0] / 8;
  198. dst[1] = inp_shape[1];
  199. dst[2] = inp_shape[2];
  200. dst[3] = inp_shape[3];
  201. dst[4] = inp_shape[4];
  202. dst[5] = 8;
  203. } else if (layout_type() ==
  204. RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88) {
  205. mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0);
  206. dst.ndim = 5;
  207. dst[0] = inp_shape[0] / 8;
  208. dst[1] = inp_shape[2];
  209. dst[2] = inp_shape[3];
  210. dst[3] = inp_shape[1];
  211. dst[4] = 8;
  212. } else if (layout_type() ==
  213. RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW44) {
  214. mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
  215. dst.ndim = 5;
  216. dst[0] = inp_shape[0];
  217. dst[1] = inp_shape[1] / 4;
  218. dst[2] = inp_shape[2];
  219. dst[3] = inp_shape[3];
  220. dst[4] = 4;
  221. } else if (layout_type() ==
  222. RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW) {
  223. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  224. dst.ndim = 4;
  225. dst[0] = inp_shape[0];
  226. dst[1] = inp_shape[1] * 4;
  227. dst[2] = inp_shape[2];
  228. dst[3] = inp_shape[3];
  229. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  230. WEIGHT_NCHW_TO_NCHW44_DENSE) {
  231. mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 &&
  232. inp_shape[1] % 4 == 0);
  233. dst.ndim = 6;
  234. dst[0] = inp_shape[0] / 4;
  235. dst[1] = inp_shape[1] / 4;
  236. dst[2] = inp_shape[2];
  237. dst[3] = inp_shape[3];
  238. dst[4] = 4;
  239. dst[5] = 4;
  240. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  241. WEIGHT_NCHW_TO_NCHW44_GROUP) {
  242. mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 4 == 0 &&
  243. inp_shape[2] % 4 == 0);
  244. dst.ndim = 7;
  245. dst[0] = inp_shape[0];
  246. dst[1] = inp_shape[1] / 4;
  247. dst[2] = inp_shape[2] / 4;
  248. dst[3] = inp_shape[3];
  249. dst[4] = inp_shape[4];
  250. dst[5] = 4;
  251. dst[6] = 4;
  252. } else if (layout_type() == RelayoutPlaceholder::LayoutType::
  253. WEIGHT_NCHW_TO_NCHW44_CHAN) {
  254. mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 &&
  255. inp_shape[2] == 1 && inp_shape[0] % 4 == 0);
  256. dst.ndim = 6;
  257. dst[0] = inp_shape[0] / 4;
  258. dst[1] = inp_shape[1];
  259. dst[2] = inp_shape[2];
  260. dst[3] = inp_shape[3];
  261. dst[4] = inp_shape[4];
  262. dst[5] = 4;
  263. } else {
  264. mgb_assert(
  265. layout_type() ==
  266. RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44);
  267. mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0);
  268. dst.ndim = 5;
  269. dst[0] = inp_shape[0] / 4;
  270. dst[1] = inp_shape[2];
  271. dst[2] = inp_shape[3];
  272. dst[3] = inp_shape[1];
  273. dst[4] = 4;
  274. }
  275. return true;
  276. };
  277. mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape});
  278. }
  279. SymbolVar TensorReformatPass::RelayoutPlaceholder::make(
  280. VarNode* src_var, LayoutType layout_type) {
  281. return src_var->owner_graph()
  282. ->insert_opr(
  283. std::make_unique<RelayoutPlaceholder>(src_var, layout_type))
  284. ->output(0);
  285. }
  286. void TensorReformatPass::insert_pass(OptState& opt) const {
  287. opt.set_var_replace_check_flag(m_var_replace_check_flag);
  288. auto rewriter = opt.graph().make_rewriter();
  289. VarNodeArray new_inp_cache;
  290. auto on_opr = [this, &opt, &rewriter,
  291. &new_inp_cache](OperatorNodeBase* opr) {
  292. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  293. if (it != m_opr_replace_func.end()) {
  294. auto& new_inp = new_inp_cache;
  295. new_inp.clear();
  296. new_inp.reserve(opr->input().size());
  297. for (auto&& inp : opr->input()) {
  298. new_inp.push_back(rewriter.get_var(inp));
  299. }
  300. auto new_opr = (it->second)(opr, new_inp);
  301. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  302. mgb_assert(out0.size() == out1.size(),
  303. "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu "
  304. "dst.size=%zu",
  305. opr->cname(), opr->dyn_typeinfo()->name,
  306. new_opr->cname(), new_opr->dyn_typeinfo()->name,
  307. out0.size(), out1.size());
  308. for (size_t i = 0; i < out0.size(); ++i) {
  309. if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  310. mgb_assert(!out1[i]->contain_flag(
  311. VarNode::Flag::VOLATILE_CONTENT));
  312. auto src = out0[i];
  313. auto dst = out1[i];
  314. if (opt.graph().endpoint_contain(src)) {
  315. // additional process on endpoint var node
  316. dst = on_graph_endpoint_var(dst, src);
  317. }
  318. rewriter.replace_var(src, dst, nullptr);
  319. }
  320. }
  321. } else {
  322. rewriter.auto_replace_outputs(opr);
  323. }
  324. };
  325. opt.graph().iter(on_opr);
  326. rewriter.apply_inplace();
  327. }
  328. void TensorReformatPass::translate_pass(OptState& opt) const {
  329. ThinHashMap<RelayoutPlaceholder::LayoutType,
  330. thin_function<VarNode*(VarNode*)>>
  331. reformat;
  332. using LayoutType = RelayoutPlaceholder::LayoutType;
  333. reformat[LayoutType::NCHW4_TO_CHWN4] = [](VarNode* inp) -> VarNode* {
  334. megdnn::param::RelayoutFormat param;
  335. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4;
  336. auto reformat = opr::RelayoutFormat::make(inp, param);
  337. return reformat.node();
  338. };
  339. reformat[LayoutType::CHWN4_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
  340. megdnn::param::RelayoutFormat param;
  341. param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4;
  342. auto reformat = opr::RelayoutFormat::make(inp, param);
  343. return reformat.node();
  344. };
  345. reformat[LayoutType::NCHW4_TO_NCHW32] = [](VarNode* inp) -> VarNode* {
  346. auto x = SymbolVar(inp);
  347. auto xshp = opr::GetVarShape::make(x);
  348. auto cv = [&x](int v) { return x.make_scalar(v); };
  349. auto sub = [&xshp, &cv](int idx) {
  350. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  351. };
  352. auto tshp0 = opr::Concat::make(
  353. {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0),
  354. tshp1 = opr::Concat::make(
  355. {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0);
  356. auto y0 = opr::Reshape::make(x, tshp0);
  357. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5});
  358. auto y2 = opr::Reshape::make(y1, tshp1);
  359. return y2.node();
  360. };
  361. reformat[LayoutType::NCHW32_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
  362. auto x = SymbolVar(inp);
  363. auto xshp = opr::GetVarShape::make(x);
  364. auto cv = [&x](int v) { return x.make_scalar(v); };
  365. auto sub = [&xshp, &cv](int idx) {
  366. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  367. };
  368. auto tshp0 = opr::Concat::make(
  369. {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, 0),
  370. tshp1 = opr::Concat::make(
  371. {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0);
  372. auto y0 = opr::Reshape::make(x, tshp0);
  373. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5});
  374. auto y2 = opr::Reshape::make(y1, tshp1);
  375. return y2.node();
  376. };
  377. reformat[LayoutType::NCHW_TO_NCHW88] = [](VarNode* inp) -> VarNode* {
  378. auto x = SymbolVar(inp);
  379. auto xshp = opr::GetVarShape::make(x);
  380. auto cv = [&x](int v) { return x.make_scalar(v); };
  381. auto sub = [&xshp, &cv](int idx) {
  382. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  383. };
  384. auto tshp0 = opr::Concat::make(
  385. {sub(0), sub(1) / 8, cv(8), sub(2), sub(3)}, 0),
  386. tshp1 = opr::Concat::make(
  387. {sub(0), sub(1) / 8, sub(2), sub(3), cv(8)}, 0);
  388. auto y0 = opr::Reshape::make(x, tshp0);
  389. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  390. auto y2 = opr::Reshape::make(y1, tshp1);
  391. return y2.node();
  392. };
  393. reformat[LayoutType::NCHW88_TO_NCHW] = [](VarNode* inp) -> VarNode* {
  394. auto x = SymbolVar(inp);
  395. auto xshp = opr::GetVarShape::make(x);
  396. auto cv = [&x](int v) { return x.make_scalar(v); };
  397. auto sub = [&xshp, &cv](int idx) {
  398. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  399. };
  400. auto tshp0 = opr::Concat::make({sub(0), sub(1) * 8, sub(2), sub(3)}, 0);
  401. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  402. auto y1 = opr::Reshape::make(y0, tshp0);
  403. return y1.node();
  404. };
  405. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_DENSE] =
  406. [](VarNode* inp) -> VarNode* {
  407. auto x = SymbolVar(inp);
  408. auto xshp = opr::GetVarShape::make(x);
  409. auto cv = [&x](int v) { return x.make_scalar(v); };
  410. auto sub = [&xshp, &cv](int idx) {
  411. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  412. };
  413. auto tshp0 = opr::Concat::make(
  414. {sub(0) / 8, cv(8), sub(1) / 8, cv(8), sub(2), sub(3)}, 0),
  415. tshp1 = opr::Concat::make(
  416. {sub(0) / 8, sub(1) / 8, sub(2), sub(3), cv(8), cv(8)}, 0);
  417. auto y0 = opr::Reshape::make(x, tshp0);
  418. auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1});
  419. auto y2 = opr::Reshape::make(y1, tshp1);
  420. return y2.node();
  421. };
  422. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_GROUP] =
  423. [](VarNode* inp) -> VarNode* {
  424. auto x = SymbolVar(inp);
  425. auto xshp = opr::GetVarShape::make(x);
  426. auto cv = [&x](int v) { return x.make_scalar(v); };
  427. auto sub = [&xshp, &cv](int idx) {
  428. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  429. };
  430. auto tshp0 = opr::Concat::make({sub(0), sub(1) / 8, cv(8), sub(2) / 8,
  431. cv(8), sub(3), sub(4)},
  432. 0),
  433. tshp1 = opr::Concat::make({sub(0), sub(1) / 8, sub(2) / 8, sub(3),
  434. sub(4), cv(8), cv(8)},
  435. 0);
  436. auto y0 = opr::Reshape::make(x, tshp0);
  437. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2});
  438. auto y2 = opr::Reshape::make(y1, tshp1);
  439. return y2.node();
  440. };
  441. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_CHAN] =
  442. [](VarNode* inp) -> VarNode* {
  443. auto x = SymbolVar(inp);
  444. auto xshp = opr::GetVarShape::make(x);
  445. auto cv = [&x](int v) { return x.make_scalar(v); };
  446. auto sub = [&xshp, &cv](int idx) {
  447. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  448. };
  449. auto tshp0 = opr::Concat::make(
  450. {sub(0) / 8, cv(8), sub(1), sub(2), sub(3), sub(4)}, 0),
  451. tshp1 = opr::Concat::make(
  452. {sub(0) / 8, sub(1), sub(2), sub(3), sub(4), cv(8)}, 0);
  453. auto y0 = opr::Reshape::make(x, tshp0);
  454. auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1});
  455. auto y2 = opr::Reshape::make(y1, tshp1);
  456. return y2.node();
  457. };
  458. reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88] =
  459. [](VarNode* inp) -> VarNode* {
  460. auto x = SymbolVar(inp);
  461. auto xshp = opr::GetVarShape::make(x);
  462. auto cv = [&x](int v) { return x.make_scalar(v); };
  463. auto sub = [&xshp, &cv](int idx) {
  464. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  465. };
  466. auto tshp0 = opr::Concat::make(
  467. {sub(0) / 8, cv(8), sub(1), sub(2), sub(3)}, 0),
  468. tshp1 = opr::Concat::make(
  469. {sub(0) / 8, sub(2), sub(3), sub(1), cv(8)}, 0);
  470. auto y0 = opr::Reshape::make(x, tshp0);
  471. auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1});
  472. auto y2 = opr::Reshape::make(y1, tshp1);
  473. return y2.node();
  474. };
  475. reformat[LayoutType::NCHW_TO_NCHW44] = [](VarNode* inp) -> VarNode* {
  476. auto x = SymbolVar(inp);
  477. auto xshp = opr::GetVarShape::make(x);
  478. auto cv = [&x](int v) { return x.make_scalar(v); };
  479. auto sub = [&xshp, &cv](int idx) {
  480. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  481. };
  482. auto tshp0 = opr::Concat::make(
  483. {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
  484. tshp1 = opr::Concat::make(
  485. {sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0);
  486. auto y0 = opr::Reshape::make(x, tshp0);
  487. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  488. auto y2 = opr::Reshape::make(y1, tshp1);
  489. return y2.node();
  490. };
  491. reformat[LayoutType::NCHW44_TO_NCHW] = [](VarNode* inp) -> VarNode* {
  492. auto x = SymbolVar(inp);
  493. auto xshp = opr::GetVarShape::make(x);
  494. auto cv = [&x](int v) { return x.make_scalar(v); };
  495. auto sub = [&xshp, &cv](int idx) {
  496. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  497. };
  498. auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
  499. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  500. auto y1 = opr::Reshape::make(y0, tshp0);
  501. return y1.node();
  502. };
  503. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DENSE] =
  504. [](VarNode* inp) -> VarNode* {
  505. auto x = SymbolVar(inp);
  506. auto xshp = opr::GetVarShape::make(x);
  507. auto cv = [&x](int v) { return x.make_scalar(v); };
  508. auto sub = [&xshp, &cv](int idx) {
  509. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  510. };
  511. auto tshp0 = opr::Concat::make(
  512. {sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
  513. tshp1 = opr::Concat::make(
  514. {sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0);
  515. auto y0 = opr::Reshape::make(x, tshp0);
  516. auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1});
  517. auto y2 = opr::Reshape::make(y1, tshp1);
  518. return y2.node();
  519. };
  520. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_GROUP] =
  521. [](VarNode* inp) -> VarNode* {
  522. auto x = SymbolVar(inp);
  523. auto xshp = opr::GetVarShape::make(x);
  524. auto cv = [&x](int v) { return x.make_scalar(v); };
  525. auto sub = [&xshp, &cv](int idx) {
  526. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  527. };
  528. auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4,
  529. cv(4), sub(3), sub(4)},
  530. 0),
  531. tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3),
  532. sub(4), cv(4), cv(4)},
  533. 0);
  534. auto y0 = opr::Reshape::make(x, tshp0);
  535. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2});
  536. auto y2 = opr::Reshape::make(y1, tshp1);
  537. return y2.node();
  538. };
  539. reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_CHAN] =
  540. [](VarNode* inp) -> VarNode* {
  541. auto x = SymbolVar(inp);
  542. auto xshp = opr::GetVarShape::make(x);
  543. auto cv = [&x](int v) { return x.make_scalar(v); };
  544. auto sub = [&xshp, &cv](int idx) {
  545. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  546. };
  547. auto tshp0 = opr::Concat::make(
  548. {sub(0) / 4, cv(4), sub(1), sub(2), sub(3), sub(4)}, 0),
  549. tshp1 = opr::Concat::make(
  550. {sub(0) / 4, sub(1), sub(2), sub(3), sub(4), cv(4)}, 0);
  551. auto y0 = opr::Reshape::make(x, tshp0);
  552. auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1});
  553. auto y2 = opr::Reshape::make(y1, tshp1);
  554. return y2.node();
  555. };
  556. reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44] =
  557. [](VarNode* inp) -> VarNode* {
  558. auto x = SymbolVar(inp);
  559. auto xshp = opr::GetVarShape::make(x);
  560. auto cv = [&x](int v) { return x.make_scalar(v); };
  561. auto sub = [&xshp, &cv](int idx) {
  562. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  563. };
  564. auto tshp0 = opr::Concat::make(
  565. {sub(0) / 4, cv(4), sub(1), sub(2), sub(3)}, 0),
  566. tshp1 = opr::Concat::make(
  567. {sub(0) / 4, sub(2), sub(3), sub(1), cv(4)}, 0);
  568. auto y0 = opr::Reshape::make(x, tshp0);
  569. auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1});
  570. auto y2 = opr::Reshape::make(y1, tshp1);
  571. return y2.node();
  572. };
  573. auto rewriter = opt.graph().make_rewriter();
  574. auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) {
  575. if (opr->same_type<RelayoutPlaceholder>()) {
  576. auto ph = try_cast_as_op<RelayoutPlaceholder>(opr);
  577. auto new_inp = rewriter.get_var(opr->input(0));
  578. mgb_assert(reformat.count(ph->layout_type()),
  579. "no replace rule can be found for layout_type(%u)",
  580. static_cast<uint32_t>(ph->layout_type()));
  581. auto new_var = reformat[ph->layout_type()](new_inp);
  582. rewriter.replace_var(opr->output(0), new_var,
  583. mgb_cstr_log("replace relayout placeholder"));
  584. return;
  585. }
  586. rewriter.auto_replace_outputs(opr);
  587. };
  588. opt.graph().iter(on_opr);
  589. rewriter.apply_inplace();
  590. }
  591. void TensorReformatPass::apply(OptState& opt) const {
  592. insert_pass(opt);
  593. translate_pass(opt);
  594. }
  595. /* ================ EnableTensorCorePass =============== */
  596. VarNode* EnableTensorCorePass::on_graph_endpoint_var(VarNode* new_var,
  597. VarNode* orig_var) const {
  598. if (!orig_var->shape().eq_shape(new_var->shape())) {
  599. return RelayoutPlaceholder::make(
  600. new_var,
  601. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4)
  602. .node();
  603. }
  604. return new_var;
  605. }
  606. std::unique_ptr<EnableTensorCorePass>
  607. EnableTensorCorePass::make_tensorcore_converter() {
  608. // replace rule for conv bias opr
  609. auto replace_conv_bias_opr = [](OperatorNodeBase* opr,
  610. const VarNodeArray& new_inp) {
  611. using Param = megdnn::param::ConvBias;
  612. using Format = Param::Format;
  613. using Sparse = Param::Sparse;
  614. mgb_assert(opr->input().size() == new_inp.size());
  615. auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
  616. if (conv_bias.param().format != Format::NCHW4 ||
  617. conv_bias.output(0)->dtype().enumv() != DTypeEnum::QuantizedS8) {
  618. size_t nr_inps = opr->input().size();
  619. bool shape_has_changed = false;
  620. for (size_t i = 0; i < nr_inps; ++i) {
  621. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  622. shape_has_changed = true;
  623. }
  624. }
  625. MGB_MARK_USED_VAR(shape_has_changed);
  626. mgb_assert(
  627. !shape_has_changed,
  628. "EnableTensorCorePass assumes that the shape of inputs of"
  629. "ConvBias operators whose output dtype is not QuantizedS8 "
  630. "can not be changed in this opt pass");
  631. return serialization::copy_opr_shallow(*opr, new_inp,
  632. opr->config());
  633. }
  634. mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()),
  635. "EnableTensorCorePass assumes that filter tensor of "
  636. "conv_bias operator can not be changed by other operators");
  637. VarNode* orig_filter = opr->input(1);
  638. auto is_nchw4 = [](TensorShape shape) -> bool {
  639. return shape.ndim == 5 && shape[4] == 4;
  640. };
  641. auto is_nchw32 = [](TensorShape shape) -> bool {
  642. return shape.ndim == 5 && shape[4] == 32;
  643. };
  644. bool can_replace_nchw32 = false;
  645. VarNode *src = nullptr, *weight = nullptr, *bias = nullptr,
  646. *z_inp = nullptr;
  647. // process src tensor
  648. if (is_nchw4(new_inp[0]->shape())) { // new input is NCHW4 layout
  649. size_t group = 1, icpg, ocpg;
  650. if (conv_bias.param().sparse == Sparse::DENSE) {
  651. icpg = orig_filter->shape()[1] * 4;
  652. ocpg = orig_filter->shape()[0];
  653. } else {
  654. mgb_assert(conv_bias.param().sparse == Sparse::GROUP);
  655. group = orig_filter->shape()[0];
  656. icpg = orig_filter->shape()[2];
  657. ocpg = orig_filter->shape()[1];
  658. if (icpg == 1 && ocpg == 1) { // channel wise conv
  659. group *= 4;
  660. } else {
  661. icpg *= 4;
  662. }
  663. }
  664. // nchw32 layout need that input width and height are larger than 3
  665. size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3];
  666. if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 &&
  667. iw >= 3) {
  668. auto symvar = RelayoutPlaceholder::make(
  669. new_inp[0],
  670. RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
  671. src = symvar.node();
  672. can_replace_nchw32 = true;
  673. } else {
  674. src = new_inp[0];
  675. }
  676. } else { // new input is NCHW32 layout
  677. mgb_assert(is_nchw32(new_inp[0]->shape()));
  678. size_t group = 1, ocpg;
  679. if (conv_bias.param().sparse == Sparse::DENSE) {
  680. ocpg = orig_filter->shape()[0];
  681. } else {
  682. mgb_assert(conv_bias.param().sparse == Sparse::GROUP);
  683. size_t icpg = orig_filter->shape()[2];
  684. ocpg = orig_filter->shape()[1];
  685. if (icpg == 1 && ocpg == 1) {
  686. group *= 4;
  687. } else {
  688. icpg *= 4;
  689. }
  690. }
  691. size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3];
  692. if (group == 1 && ocpg % 32 == 0 && ih >= 3 && iw >= 3) {
  693. can_replace_nchw32 = true;
  694. src = new_inp[0];
  695. } else {
  696. auto symvar = RelayoutPlaceholder::make(
  697. new_inp[0],
  698. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
  699. src = symvar.node();
  700. }
  701. }
  702. // process filter tensor
  703. if (can_replace_nchw32) {
  704. auto symvar = RelayoutPlaceholder::make(
  705. new_inp[1],
  706. RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
  707. weight = symvar.node();
  708. } else {
  709. weight = new_inp[1];
  710. }
  711. if (new_inp.size() == 2) {
  712. if (can_replace_nchw32) {
  713. auto param = conv_bias.param();
  714. param.format = Format::NCHW32;
  715. auto new_opr = opr::ConvBiasForward::make(
  716. src, weight, param, conv_bias.execution_policy(),
  717. conv_bias.config());
  718. return new_opr.node()->owner_opr();
  719. } else {
  720. VarNodeArray inps{src, weight};
  721. auto new_opr = serialization::copy_opr_shallow(*opr, inps,
  722. opr->config());
  723. return new_opr;
  724. }
  725. }
  726. auto process_inp = [&](VarNode* inp) -> VarNode* {
  727. if (can_replace_nchw32) {
  728. if (is_nchw4(inp->shape())) {
  729. auto symvar = RelayoutPlaceholder::make(
  730. inp,
  731. RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
  732. return symvar.node();
  733. } else {
  734. mgb_assert(is_nchw32(inp->shape()));
  735. return inp;
  736. }
  737. } else {
  738. if (is_nchw4(inp->shape())) {
  739. return inp;
  740. } else {
  741. mgb_assert(is_nchw32(inp->shape()));
  742. auto symvar = RelayoutPlaceholder::make(
  743. inp,
  744. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
  745. return symvar.node();
  746. }
  747. }
  748. };
  749. // process bias tensor
  750. bias = process_inp(new_inp[2]);
  751. if (new_inp.size() == 3) {
  752. if (can_replace_nchw32) {
  753. auto param = conv_bias.param();
  754. param.format = Format::NCHW32;
  755. auto new_opr = opr::ConvBiasForward::make(
  756. src, weight, bias, param, conv_bias.execution_policy(),
  757. conv_bias.config());
  758. return new_opr.node()->owner_opr();
  759. } else {
  760. VarNodeArray inps{src, weight, bias};
  761. auto new_opr = serialization::copy_opr_shallow(*opr, inps,
  762. opr->config());
  763. return new_opr;
  764. }
  765. }
  766. // process z_inp tensor
  767. z_inp = process_inp(new_inp[3]);
  768. if (can_replace_nchw32) {
  769. auto param = conv_bias.param();
  770. param.format = Format::NCHW32;
  771. auto new_opr = opr::ConvBiasForward::make(
  772. src, weight, bias, z_inp, param,
  773. conv_bias.execution_policy(), conv_bias.config());
  774. return new_opr.node()->owner_opr();
  775. }
  776. VarNodeArray inps{src, weight, bias, z_inp};
  777. auto new_opr =
  778. serialization::copy_opr_shallow(*opr, inps, opr->config());
  779. return new_opr;
  780. };
  781. // replace rule for elemwise like opr
  782. // for oprs support NCHW4 and NCHW32 layout
  783. auto replace_elemwise_like_opr = [](OperatorNodeBase* opr,
  784. const VarNodeArray new_inp) {
  785. mgb_assert(opr->input().size() == new_inp.size());
  786. size_t nr_inps = new_inp.size();
  787. size_t nr_shape_changed = 0;
  788. for (size_t i = 0; i < nr_inps; ++i) {
  789. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  790. nr_shape_changed++;
  791. }
  792. }
  793. if (nr_shape_changed) {
  794. auto inps = new_inp;
  795. if (nr_shape_changed >=
  796. nr_inps / 2) { // NCHW32 > NCHW4 -> use NCHW32
  797. for (size_t i = 0; i < nr_inps; ++i) {
  798. if (opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  799. auto symvar = RelayoutPlaceholder::make(
  800. new_inp[i], RelayoutPlaceholder::LayoutType::
  801. NCHW4_TO_NCHW32);
  802. inps[i] = symvar.node();
  803. }
  804. }
  805. } else { // NCHW32 < NCHW4 -> use NCHW4
  806. for (size_t i = 0; i < nr_inps; ++i) {
  807. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  808. auto symvar = RelayoutPlaceholder::make(
  809. new_inp[i], RelayoutPlaceholder::LayoutType::
  810. NCHW32_TO_NCHW4);
  811. inps[i] = symvar.node();
  812. }
  813. }
  814. }
  815. return serialization::copy_opr_shallow(*opr, inps, opr->config());
  816. }
  817. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  818. };
  819. // for oprs only supports NCHW4 layout
  820. auto replace_inps_to_nchw4 = [](OperatorNodeBase* opr,
  821. const VarNodeArray new_inp) {
  822. mgb_assert(opr->input().size() == new_inp.size());
  823. VarNodeArray inps = new_inp;
  824. for (size_t i = 0; i < opr->input().size(); ++i) {
  825. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  826. mgb_assert(opr->input(i)->shape().ndim == 5 &&
  827. opr->input(i)->shape()[4] == 4);
  828. mgb_assert(new_inp[i]->shape().ndim == 5 &&
  829. new_inp[i]->shape()[4] == 32);
  830. auto symvar = RelayoutPlaceholder::make(
  831. new_inp[i],
  832. RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
  833. inps[i] = symvar.node();
  834. }
  835. }
  836. auto new_opr =
  837. serialization::copy_opr_shallow(*opr, inps, opr->config());
  838. return new_opr;
  839. };
  840. auto replace_non_nchw4_opr = [](OperatorNodeBase* opr,
  841. const VarNodeArray new_inp) {
  842. size_t nr_inps = opr->input().size();
  843. bool shape_has_changed = false;
  844. for (size_t i = 0; i < nr_inps; ++i) {
  845. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  846. shape_has_changed = true;
  847. }
  848. }
  849. mgb_assert(!shape_has_changed,
  850. "EnableTensorCorePass assumes that inputs' shape of "
  851. "non-nchw4 operators "
  852. "can not be changed in this opt "
  853. "pass");
  854. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  855. };
  856. auto replace_warp_affine_opr =
  857. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  858. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  859. using Param = opr::WarpAffineForward::Param;
  860. using Format = Param::Format;
  861. mgb_assert(opr->input().size() == new_inp.size());
  862. auto& warp = opr->cast_final_safe<opr::WarpAffineForward>();
  863. if (warp.param().format != Format::NCHW4) {
  864. return replace_non_nchw4_opr(opr, new_inp);
  865. }
  866. return replace_inps_to_nchw4(opr, new_inp);
  867. };
  868. auto replace_warp_perspective_opr =
  869. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  870. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  871. using Param = opr::WarpPerspectiveForward::Param;
  872. using Format = Param::Format;
  873. mgb_assert(opr->input().size() == new_inp.size());
  874. auto& warp =
  875. opr->cast_final_safe<opr::WarpPerspectiveForward>();
  876. if (warp.param().format != Format::NCHW4) {
  877. return replace_non_nchw4_opr(opr, new_inp);
  878. }
  879. return replace_inps_to_nchw4(opr, new_inp);
  880. };
  881. auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr](
  882. OperatorNodeBase* opr,
  883. const VarNodeArray new_inp) {
  884. using Param = opr::ResizeForward::Param;
  885. using Format = Param::Format;
  886. mgb_assert(opr->input().size() == new_inp.size());
  887. auto& resize = opr->cast_final_safe<opr::ResizeForward>();
  888. if (resize.param().format != Format::NCHW4) {
  889. return replace_non_nchw4_opr(opr, new_inp);
  890. }
  891. return replace_inps_to_nchw4(opr, new_inp);
  892. };
  893. auto replace_pooling_opr = [replace_non_nchw4_opr](
  894. OperatorNodeBase* opr,
  895. const VarNodeArray new_inp) {
  896. using Param = opr::PoolingForward::Param;
  897. using Format = Param::Format;
  898. mgb_assert(opr->input().size() == new_inp.size());
  899. auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
  900. if (pooling.param().format != Format::NCHW4) {
  901. return replace_non_nchw4_opr(opr, new_inp);
  902. }
  903. size_t nr_inps = opr->input().size();
  904. MGB_MARK_USED_VAR(nr_inps);
  905. mgb_assert(nr_inps == 1);
  906. if (!opr->input(0)->shape().eq_shape(new_inp[0]->shape())) {
  907. mgb_assert(opr->input(0)->shape().ndim == 5 &&
  908. opr->input(0)->shape()[4] == 4);
  909. mgb_assert(new_inp[0]->shape().ndim == 5 &&
  910. new_inp[0]->shape()[4] == 32);
  911. auto new_param = pooling.param();
  912. new_param.format = Format::NCHW32;
  913. auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param,
  914. opr->config());
  915. return new_pooling.node()->owner_opr();
  916. }
  917. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  918. };
  919. auto ret = std::make_unique<EnableTensorCorePass>();
  920. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  921. auto&& replace_func = ret->m_opr_replace_func;
  922. replace_func[opr::ConvBiasForward::typeinfo()] = replace_conv_bias_opr;
  923. // elemwise like
  924. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
  925. replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;
  926. replace_func[opr::ElemwiseMultiType::typeinfo()] =
  927. replace_elemwise_like_opr;
  928. replace_func[opr::PowC::typeinfo()] = replace_elemwise_like_opr;
  929. // format aware
  930. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  931. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  932. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  933. replace_warp_perspective_opr;
  934. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  935. // to nchw4
  936. replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw4;
  937. replace_func[opr::Concat::typeinfo()] = replace_inps_to_nchw4;
  938. replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw4;
  939. replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4;
  940. replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4;
  941. return ret;
  942. }
  943. /* ================ EnableCHWN4Pass =============== */
  944. VarNode* EnableCHWN4Pass::on_graph_endpoint_var(VarNode* new_var,
  945. VarNode* /* orig_var */) const {
  946. if (m_varshape_changed.count(new_var)) {
  947. return RelayoutPlaceholder::make(
  948. new_var, RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4)
  949. .node();
  950. }
  951. return new_var;
  952. }
  953. std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
  954. auto ret = std::make_unique<EnableCHWN4Pass>();
  955. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  956. auto&& replace_func = ret->m_opr_replace_func;
  957. auto&& varshape_changed = ret->m_varshape_changed;
  958. // replace rule for conv bias opr
  959. auto replace_conv_bias_opr = [&varshape_changed](
  960. OperatorNodeBase* opr,
  961. const VarNodeArray& new_inp) {
  962. using Param = megdnn::param::ConvBias;
  963. using Format = Param::Format;
  964. mgb_assert(opr->input().size() == new_inp.size());
  965. auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
  966. if (conv_bias.param().format != Format::NCHW4 ||
  967. conv_bias.output(0)->dtype().enumv() != DTypeEnum::QuantizedS8) {
  968. size_t nr_inps = new_inp.size();
  969. bool shape_has_changed = false;
  970. for (size_t i = 0; i < nr_inps; ++i) {
  971. if (varshape_changed.count(new_inp[i])) {
  972. shape_has_changed = true;
  973. break;
  974. }
  975. }
  976. mgb_assert(
  977. !shape_has_changed,
  978. "EnableCHWN4Pass assumes that the shape of inputs of"
  979. "ConvBias operators whose output dtype is not QuantizedS8 "
  980. "can not be changed in this opt pass");
  981. return serialization::copy_opr_shallow(*opr, new_inp,
  982. opr->config());
  983. }
  984. mgb_assert(varshape_changed.count(new_inp[1]) == 0,
  985. "EnableCHWN4Pass assumes that filter tensor of "
  986. "conv_bias operator can not be changed by other operators");
  987. VarNode *src = nullptr, *weight = nullptr, *bias = nullptr,
  988. *z_inp = nullptr;
  989. // process src tensor
  990. if (varshape_changed.count(new_inp[0]) ==
  991. 0) { // new input is NCHW4 layout
  992. // currently not support group conv
  993. auto symvar = RelayoutPlaceholder::make(
  994. new_inp[0],
  995. RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
  996. src = symvar.node();
  997. } else { // new input is NCHW32 layout
  998. src = new_inp[0];
  999. }
  1000. // process weight tensor
  1001. {
  1002. auto symvar = RelayoutPlaceholder::make(
  1003. new_inp[1],
  1004. RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
  1005. weight = symvar.node();
  1006. }
  1007. if (new_inp.size() == 2) {
  1008. auto param = conv_bias.param();
  1009. param.format = Format::CHWN4;
  1010. auto new_opr = opr::ConvBiasForward::make(
  1011. src, weight, param, conv_bias.execution_policy(),
  1012. conv_bias.config());
  1013. varshape_changed.insert(new_opr.node());
  1014. return new_opr.node()->owner_opr();
  1015. }
  1016. auto process_inp = [&](VarNode* inp) -> VarNode* {
  1017. if (varshape_changed.count(inp) == 0) {
  1018. auto symvar = RelayoutPlaceholder::make(
  1019. inp, RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
  1020. return symvar.node();
  1021. } else {
  1022. return inp;
  1023. }
  1024. };
  1025. // process bias tensor
  1026. bias = process_inp(new_inp[2]);
  1027. if (new_inp.size() == 3) {
  1028. auto param = conv_bias.param();
  1029. param.format = Format::CHWN4;
  1030. auto new_opr = opr::ConvBiasForward::make(
  1031. src, weight, bias, param, conv_bias.execution_policy(),
  1032. conv_bias.config());
  1033. varshape_changed.insert(new_opr.node());
  1034. return new_opr.node()->owner_opr();
  1035. }
  1036. // process z_inp tensor
  1037. z_inp = process_inp(new_inp[3]);
  1038. auto param = conv_bias.param();
  1039. param.format = Format::CHWN4;
  1040. auto new_opr = opr::ConvBiasForward::make(
  1041. src, weight, bias, z_inp, param, conv_bias.execution_policy(),
  1042. conv_bias.config());
  1043. varshape_changed.insert(new_opr.node());
  1044. return new_opr.node()->owner_opr();
  1045. };
  1046. // replace rule for elemwise like opr
  1047. // for oprs support NCHW4 and CHWN4 layout
  1048. auto replace_elemwise_like_opr = [&varshape_changed](
  1049. OperatorNodeBase* opr,
  1050. const VarNodeArray new_inp) {
  1051. mgb_assert(opr->input().size() == new_inp.size());
  1052. size_t nr_inps = new_inp.size();
  1053. size_t nr_shape_changed = 0;
  1054. for (size_t i = 0; i < nr_inps; ++i) {
  1055. if (varshape_changed.count(new_inp[i])) {
  1056. nr_shape_changed++;
  1057. }
  1058. }
  1059. if (nr_shape_changed) {
  1060. auto inps = new_inp;
  1061. if (nr_shape_changed >= nr_inps / 2) { // CHWN4 > NCHW4 -> use CHWN4
  1062. for (size_t i = 0; i < nr_inps; ++i) {
  1063. if (varshape_changed.count(new_inp[i]) == 0) {
  1064. auto symvar = RelayoutPlaceholder::make(
  1065. new_inp[i], RelayoutPlaceholder::LayoutType::
  1066. NCHW4_TO_CHWN4);
  1067. inps[i] = symvar.node();
  1068. }
  1069. }
  1070. auto new_opr = serialization::copy_opr_shallow(*opr, inps,
  1071. opr->config());
  1072. varshape_changed.insert(new_opr->output(0));
  1073. return new_opr;
  1074. } else { // CHWN4 < NCHW4 -> use NCHW4
  1075. for (size_t i = 0; i < nr_inps; ++i) {
  1076. if (varshape_changed.count(new_inp[i])) {
  1077. auto symvar = RelayoutPlaceholder::make(
  1078. new_inp[i], RelayoutPlaceholder::LayoutType::
  1079. CHWN4_TO_NCHW4);
  1080. inps[i] = symvar.node();
  1081. }
  1082. }
  1083. return serialization::copy_opr_shallow(*opr, inps,
  1084. opr->config());
  1085. }
  1086. }
  1087. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1088. };
  1089. // for oprs only supports NCHW4 layout
  1090. auto replace_inps_to_nchw4 = [&varshape_changed](
  1091. OperatorNodeBase* opr,
  1092. const VarNodeArray new_inp) {
  1093. mgb_assert(opr->input().size() == new_inp.size());
  1094. VarNodeArray inps = new_inp;
  1095. for (size_t i = 0; i < opr->input().size(); ++i) {
  1096. if (varshape_changed.count(new_inp[i])) {
  1097. auto symvar = RelayoutPlaceholder::make(
  1098. new_inp[i],
  1099. RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4);
  1100. inps[i] = symvar.node();
  1101. }
  1102. }
  1103. auto new_opr =
  1104. serialization::copy_opr_shallow(*opr, inps, opr->config());
  1105. return new_opr;
  1106. };
  1107. auto replace_non_nchw4_opr = [&varshape_changed](
  1108. OperatorNodeBase* opr,
  1109. const VarNodeArray new_inp) {
  1110. size_t nr_inps = opr->input().size();
  1111. bool shape_has_changed = false;
  1112. for (size_t i = 0; i < nr_inps; ++i) {
  1113. if (varshape_changed.count(new_inp[i])) {
  1114. shape_has_changed = true;
  1115. }
  1116. }
  1117. mgb_assert(!shape_has_changed,
  1118. "EnableCHWN4Pass assumes that inputs' shape of "
  1119. "non-nchw4 operators "
  1120. "can not be changed in this opt "
  1121. "pass");
  1122. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1123. };
  1124. // capture by copy to avoid use after return
  1125. auto replace_warp_affine_opr =
  1126. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  1127. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  1128. using Param = opr::WarpAffineForward::Param;
  1129. using Format = Param::Format;
  1130. mgb_assert(opr->input().size() == new_inp.size());
  1131. auto& warp = opr->cast_final_safe<opr::WarpAffineForward>();
  1132. if (warp.param().format != Format::NCHW4) {
  1133. return replace_non_nchw4_opr(opr, new_inp);
  1134. }
  1135. return replace_inps_to_nchw4(opr, new_inp);
  1136. };
  1137. auto replace_warp_perspective_opr =
  1138. [replace_inps_to_nchw4, replace_non_nchw4_opr](
  1139. OperatorNodeBase* opr, const VarNodeArray new_inp) {
  1140. using Param = opr::WarpPerspectiveForward::Param;
  1141. using Format = Param::Format;
  1142. mgb_assert(opr->input().size() == new_inp.size());
  1143. auto& warp =
  1144. opr->cast_final_safe<opr::WarpPerspectiveForward>();
  1145. if (warp.param().format != Format::NCHW4) {
  1146. return replace_non_nchw4_opr(opr, new_inp);
  1147. }
  1148. return replace_inps_to_nchw4(opr, new_inp);
  1149. };
  1150. auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr](
  1151. OperatorNodeBase* opr,
  1152. const VarNodeArray new_inp) {
  1153. using Param = opr::ResizeForward::Param;
  1154. using Format = Param::Format;
  1155. mgb_assert(opr->input().size() == new_inp.size());
  1156. auto& resize = opr->cast_final_safe<opr::ResizeForward>();
  1157. if (resize.param().format != Format::NCHW4) {
  1158. return replace_non_nchw4_opr(opr, new_inp);
  1159. }
  1160. return replace_inps_to_nchw4(opr, new_inp);
  1161. };
  1162. auto replace_pooling_opr = [&varshape_changed, replace_non_nchw4_opr](
  1163. OperatorNodeBase* opr,
  1164. const VarNodeArray new_inp) {
  1165. using Param = opr::PoolingForward::Param;
  1166. using Format = Param::Format;
  1167. mgb_assert(opr->input().size() == new_inp.size());
  1168. auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
  1169. if (pooling.param().format != Format::NCHW4) {
  1170. return replace_non_nchw4_opr(opr, new_inp);
  1171. }
  1172. size_t nr_inps = opr->input().size();
  1173. MGB_MARK_USED_VAR(nr_inps);
  1174. mgb_assert(nr_inps == 1);
  1175. if (varshape_changed.count(new_inp[0])) {
  1176. auto new_param = pooling.param();
  1177. new_param.format = Format::CHWN4;
  1178. auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param,
  1179. opr->config());
  1180. varshape_changed.insert(new_pooling.node());
  1181. return new_pooling.node()->owner_opr();
  1182. }
  1183. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1184. };
  1185. replace_func[opr::ConvBiasForward::typeinfo()] = replace_conv_bias_opr;
  1186. // elemwise like
  1187. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
  1188. replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;
  1189. replace_func[opr::ElemwiseMultiType::typeinfo()] =
  1190. replace_elemwise_like_opr;
  1191. replace_func[opr::PowC::typeinfo()] = replace_elemwise_like_opr;
  1192. // format aware
  1193. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  1194. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  1195. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  1196. replace_warp_perspective_opr;
  1197. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  1198. // to nchw4
  1199. replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw4;
  1200. replace_func[opr::Concat::typeinfo()] = replace_inps_to_nchw4;
  1201. replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw4;
  1202. replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4;
  1203. replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4;
  1204. replace_func[opr::BatchConvBias::typeinfo()] = replace_inps_to_nchw4;
  1205. return ret;
  1206. }
  1207. /* ================ EnableNchwxxPass =============== */
  1208. VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
  1209. VarNode* orig_var) const {
  1210. if (!orig_var->shape().eq_shape(new_var->shape())) {
  1211. if (m_pack_c_size == 8) {
  1212. return RelayoutPlaceholder::make(
  1213. new_var,
  1214. RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW)
  1215. .node();
  1216. } else if (m_pack_c_size == 4) {
  1217. return RelayoutPlaceholder::make(
  1218. new_var,
  1219. RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW)
  1220. .node();
  1221. }
  1222. }
  1223. return new_var;
  1224. }
  1225. std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
  1226. size_t pack_c_size) {
  1227. auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size);
  1228. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  1229. //! First is whether the conv can trans to nchwxx, second is the filter
  1230. //! trans mode
  1231. using RelayoutMode = RelayoutPlaceholder::LayoutType;
  1232. using TestFilterResult = std::pair<TransType, RelayoutMode>;
  1233. RelayoutMode weight_to_nchwxx_mode_dense =
  1234. RelayoutMode::WEIGHT_NCHW_TO_NCHW88_DENSE;
  1235. RelayoutMode weight_to_nchwxx_mode_group =
  1236. RelayoutMode::WEIGHT_NCHW_TO_NCHW88_GROUP;
  1237. RelayoutMode weight_to_nchwxx_mode_chan =
  1238. RelayoutMode::WEIGHT_NCHW_TO_NCHW88_CHAN;
  1239. RelayoutMode hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW88;
  1240. RelayoutMode src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW88;
  1241. RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW88_TO_NCHW;
  1242. megdnn::param::ConvBias::Format conv_bias_format =
  1243. megdnn::param::ConvBias::Format::NCHW88;
  1244. megdnn::param::Convolution::Format conv_format =
  1245. megdnn::param::ConvolutionV0::Format::NCHW88;
  1246. megdnn::param::Pooling::Format pooling_format =
  1247. megdnn::param::Pooling::Format::NCHW88;
  1248. std::string convter_pass_name = "conv_format_nchw88";
  1249. if (pack_c_size == 4) {
  1250. weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE;
  1251. weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP;
  1252. weight_to_nchwxx_mode_chan = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN;
  1253. hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
  1254. src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW44;
  1255. src_to_nchw_mode = RelayoutMode::NCHW44_TO_NCHW;
  1256. conv_bias_format = megdnn::param::ConvBias::Format::NCHW44;
  1257. conv_format = megdnn::param::ConvolutionV0::Format::NCHW44;
  1258. pooling_format = megdnn::param::Pooling::Format::NCHW44;
  1259. convter_pass_name = "conv_format_nchw44";
  1260. }
  1261. auto test_trans_nchwxx =
  1262. [pack_c_size, weight_to_nchwxx_mode_dense,
  1263. weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan,
  1264. hybrid_nchw_nchwxx](
  1265. const megdnn::param::Convolution::Sparse conv_mode,
  1266. const VarNode* filter) -> TestFilterResult {
  1267. TestFilterResult ret{TransType::TRANS_NONE, {}};
  1268. if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
  1269. size_t IC = filter->shape()[1];
  1270. size_t OC = filter->shape()[0];
  1271. if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
  1272. ret.first = TransType::TRANS_PURE_NCHWXX;
  1273. ret.second = weight_to_nchwxx_mode_dense;
  1274. } else if (IC < pack_c_size && OC % pack_c_size == 0) {
  1275. ret.first = TransType::TRANS_HYBIRD_NCHWXX;
  1276. ret.second = hybrid_nchw_nchwxx;
  1277. }
  1278. } else {
  1279. mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
  1280. size_t group = filter->shape()[0];
  1281. size_t ocpg = filter->shape()[1];
  1282. size_t icpg = filter->shape()[2];
  1283. if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) {
  1284. ret.first = TransType::TRANS_PURE_NCHWXX;
  1285. ret.second = weight_to_nchwxx_mode_chan;
  1286. } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) {
  1287. ret.first = TransType::TRANS_PURE_NCHWXX;
  1288. ret.second = weight_to_nchwxx_mode_group;
  1289. }
  1290. }
  1291. return ret;
  1292. };
  1293. auto replace_conv_opr = [test_trans_nchwxx, conv_format, src_to_nchwxx_mode,
  1294. src_to_nchw_mode](OperatorNodeBase* opr,
  1295. const VarNodeArray& new_inp) {
  1296. mgb_assert(opr->input().size() == new_inp.size());
  1297. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  1298. mgb_assert(conv_opr.param().format ==
  1299. megdnn::param::Convolution::Format::NCHW,
  1300. "ConvertFormat Pass only support converting NCHW to NCHWXX");
  1301. auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1]);
  1302. //! can not trans to nchwxx
  1303. if (is_trans.first == TransType::TRANS_NONE) {
  1304. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  1305. new_inp[1]->shape().ndim == 5,
  1306. "The origin filter is not NCHW mode");
  1307. VarNodeArray temp_inp = new_inp;
  1308. //! if src is nchwxx, should RelayoutPlaceholder to nchw
  1309. if (temp_inp[0]->shape().ndim == 5) {
  1310. auto new_src =
  1311. RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode);
  1312. temp_inp[0] = new_src.node();
  1313. }
  1314. auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
  1315. opr->config());
  1316. return new_opr;
  1317. } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
  1318. //! filter trans to nchwxx mode
  1319. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  1320. new_inp[1]->shape().ndim == 5,
  1321. "The origin filter is not NCHW mode");
  1322. VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
  1323. auto new_filter =
  1324. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  1325. conv_filter = new_filter.node();
  1326. //! src trans to nchwxx mode
  1327. if (new_inp[0]->shape().ndim != 5) {
  1328. mgb_assert(new_inp[0]->shape().ndim == 4);
  1329. auto new_src = RelayoutPlaceholder::make(new_inp[0],
  1330. src_to_nchwxx_mode);
  1331. conv_src = new_src.node();
  1332. }
  1333. auto new_param = conv_opr.param();
  1334. new_param.format = conv_format;
  1335. mgb_assert(conv_src->shape().ndim == 5 &&
  1336. conv_filter->shape().ndim >= 6,
  1337. "The conv src dim is not trans to nchwxx");
  1338. auto new_conv_opr = opr::Convolution::make(
  1339. conv_src, conv_filter, new_param,
  1340. conv_opr.execution_policy(), conv_opr.config());
  1341. OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
  1342. mgb_assert(new_conv_opr.shape().ndim == 5,
  1343. "The conv dst dim is not trans to nchwxx");
  1344. return new_opr;
  1345. } else {
  1346. mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
  1347. VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
  1348. auto new_filter =
  1349. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  1350. conv_filter = new_filter.node();
  1351. mgb_assert(conv_src->shape().ndim == 4 &&
  1352. conv_filter->shape().ndim == 5,
  1353. "The src and filter is OK");
  1354. auto new_param = conv_opr.param();
  1355. new_param.format = conv_format;
  1356. auto new_conv_opr = opr::Convolution::make(
  1357. conv_src, conv_filter, new_param,
  1358. conv_opr.execution_policy(), conv_opr.config());
  1359. OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
  1360. mgb_assert(new_conv_opr.shape().ndim == 5,
  1361. "The conv dst dim is not trans to nchwxx");
  1362. return new_opr;
  1363. }
  1364. };
  1365. auto replace_conv_bias_opr = [test_trans_nchwxx, conv_bias_format,
  1366. src_to_nchwxx_mode, src_to_nchw_mode](
  1367. OperatorNodeBase* opr,
  1368. const VarNodeArray& new_inp) {
  1369. mgb_assert(opr->input().size() == new_inp.size());
  1370. auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  1371. mgb_assert(conv_bias_opr.param().format ==
  1372. megdnn::param::ConvBias::Format::NCHW,
  1373. "ConvertFormat Pass only support converting NCHW to NCHWXX");
  1374. auto is_trans =
  1375. test_trans_nchwxx(conv_bias_opr.param().sparse, new_inp[1]);
  1376. //! can not trans to nchwxx
  1377. if (is_trans.first == TransType::TRANS_NONE) {
  1378. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  1379. new_inp[1]->shape().ndim == 5,
  1380. "The origin filter is not NCHW mode");
  1381. VarNodeArray temp_inp = new_inp;
  1382. //! if src is nchwxx, should RelayoutPlaceholder to nchw
  1383. if (temp_inp[0]->shape().ndim == 5) {
  1384. auto new_src =
  1385. RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode);
  1386. temp_inp[0] = new_src.node();
  1387. }
  1388. //! the bias is nchwxx
  1389. if (temp_inp[2]->shape().ndim == 5) {
  1390. auto new_bias =
  1391. RelayoutPlaceholder::make(new_inp[2], src_to_nchw_mode);
  1392. temp_inp[2] = new_bias.node();
  1393. }
  1394. auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
  1395. opr->config());
  1396. return new_opr;
  1397. } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
  1398. VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
  1399. *conv_bias_bias = new_inp[2];
  1400. //! filter trans to nchwxx mode
  1401. mgb_assert(new_inp[1]->shape().ndim == 4 ||
  1402. new_inp[1]->shape().ndim == 5,
  1403. "The origin filter is not NCHW mode");
  1404. auto new_filter =
  1405. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  1406. conv_bias_filter = new_filter.node();
  1407. //! src trans to nchwxx mode
  1408. if (new_inp[0]->shape().ndim != 5) {
  1409. mgb_assert(new_inp[0]->shape().ndim == 4);
  1410. auto new_src = RelayoutPlaceholder::make(new_inp[0],
  1411. src_to_nchwxx_mode);
  1412. conv_bias_src = new_src.node();
  1413. }
  1414. //! bias trans to nchwxx mode, bias may be scale
  1415. if (new_inp[2]->shape().ndim == 4) {
  1416. auto new_bias = RelayoutPlaceholder::make(new_inp[2],
  1417. src_to_nchwxx_mode);
  1418. conv_bias_bias = new_bias.node();
  1419. }
  1420. auto new_param = conv_bias_opr.param();
  1421. new_param.format = conv_bias_format;
  1422. mgb_assert(conv_bias_src->shape().ndim == 5 &&
  1423. conv_bias_filter->shape().ndim >= 6,
  1424. "The conv_bias src dim is not trans to nchwxx");
  1425. auto new_conv_bias_opr = opr::ConvBias::make(
  1426. conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
  1427. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1428. OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
  1429. mgb_assert(new_conv_bias_opr.shape().ndim == 5,
  1430. "The conv_bias dst dim is not trans to nchwxx");
  1431. return new_opr;
  1432. } else {
  1433. mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
  1434. VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
  1435. *conv_bias_bias = new_inp[2];
  1436. auto new_filter =
  1437. RelayoutPlaceholder::make(new_inp[1], is_trans.second);
  1438. conv_bias_filter = new_filter.node();
  1439. //! bias trans to nchwxx mode, bias may be scale
  1440. if (new_inp[2]->shape().ndim == 4) {
  1441. auto new_bias = RelayoutPlaceholder::make(new_inp[2],
  1442. src_to_nchwxx_mode);
  1443. conv_bias_bias = new_bias.node();
  1444. }
  1445. mgb_assert(conv_bias_src->shape().ndim == 4 &&
  1446. conv_bias_filter->shape().ndim == 5);
  1447. mgb_assert((conv_bias_bias->shape().ndim == 5) ||
  1448. conv_bias_bias->shape().is_scalar());
  1449. auto new_param = conv_bias_opr.param();
  1450. new_param.format = conv_bias_format;
  1451. auto new_conv_bias_opr = opr::ConvBias::make(
  1452. conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
  1453. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1454. OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
  1455. mgb_assert(new_conv_bias_opr.shape().ndim == 5,
  1456. "The conv dst dim is not trans to nchwxx");
  1457. return new_opr;
  1458. }
  1459. };
  1460. auto replace_pooling_opr = [=](OperatorNodeBase* opr,
  1461. const VarNodeArray& new_inp) {
  1462. mgb_assert(opr->input().size() == new_inp.size());
  1463. auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
  1464. mgb_assert(pooling_opr.param().format ==
  1465. megdnn::param::Pooling::Format::NCHW,
  1466. "ConvertFormat Pass only support converting NCHW to NCHWxx");
  1467. VarNode* inp = new_inp[0];
  1468. //! if input is nchwxx
  1469. if (inp->shape().ndim == 5) {
  1470. auto new_param = pooling_opr.param();
  1471. new_param.format = pooling_format;
  1472. auto new_pooling_opr =
  1473. opr::PoolingForward::make(inp, new_param, opr->config());
  1474. mgb_assert(new_pooling_opr.shape().ndim == 5,
  1475. "The pooling dst dim is not trans to nchwxx");
  1476. return new_pooling_opr.node()->owner_opr();
  1477. } else {
  1478. auto new_opr = serialization::copy_opr_shallow(*opr, new_inp,
  1479. opr->config());
  1480. return new_opr;
  1481. }
  1482. };
  1483. auto replace_concat_opr = [=](OperatorNodeBase* opr,
  1484. const VarNodeArray& new_inp) {
  1485. mgb_assert(opr->input().size() == new_inp.size());
  1486. bool has_inp_changed = false;
  1487. bool can_exec_ncwxx = true;
  1488. for (size_t i = 0; i < opr->input().size(); i++) {
  1489. if (new_inp[i]->shape().ndim == 5) {
  1490. has_inp_changed = true;
  1491. break;
  1492. } else if (new_inp[i]->shape().ndim == 4) {
  1493. if (new_inp[i]->shape()[1] % pack_c_size != 0) {
  1494. can_exec_ncwxx = false;
  1495. }
  1496. }
  1497. }
  1498. if (has_inp_changed) {
  1499. auto temp_inp = new_inp;
  1500. if (can_exec_ncwxx) {
  1501. for (size_t i = 0; i < opr->input().size(); i++) {
  1502. if (new_inp[i]->shape().ndim == 4) {
  1503. auto new_var = RelayoutPlaceholder::make(
  1504. new_inp[i], src_to_nchwxx_mode);
  1505. temp_inp[i] = new_var.node();
  1506. } else {
  1507. mgb_assert((new_inp[i]->shape().ndim == 5) ||
  1508. new_inp[i]->shape().is_scalar());
  1509. }
  1510. }
  1511. } else {
  1512. for (size_t i = 0; i < opr->input().size(); i++) {
  1513. if (new_inp[i]->shape().ndim == 5) {
  1514. auto new_var = RelayoutPlaceholder::make(
  1515. new_inp[i], src_to_nchw_mode);
  1516. temp_inp[i] = new_var.node();
  1517. }
  1518. }
  1519. }
  1520. return serialization::copy_opr_shallow(*opr, temp_inp,
  1521. opr->config());
  1522. } else {
  1523. return serialization::copy_opr_shallow(*opr, new_inp,
  1524. opr->config());
  1525. }
  1526. };
  1527. auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
  1528. const VarNodeArray& new_inp) {
  1529. mgb_assert(opr->input().size() == new_inp.size());
  1530. bool has_inp_changed = false;
  1531. for (size_t i = 0; i < opr->input().size(); i++) {
  1532. if (new_inp[i]->shape().ndim == 5) {
  1533. has_inp_changed = true;
  1534. break;
  1535. }
  1536. }
  1537. if (has_inp_changed) {
  1538. auto temp_inp = new_inp;
  1539. for (size_t i = 0; i < opr->input().size(); i++) {
  1540. if (new_inp[i]->shape().ndim == 4) {
  1541. auto new_var = RelayoutPlaceholder::make(
  1542. new_inp[i], src_to_nchwxx_mode);
  1543. temp_inp[i] = new_var.node();
  1544. } else {
  1545. mgb_assert((new_inp[i]->shape().ndim == 5) ||
  1546. new_inp[i]->shape().is_scalar());
  1547. }
  1548. }
  1549. return serialization::copy_opr_shallow(*opr, temp_inp,
  1550. opr->config());
  1551. } else {
  1552. return serialization::copy_opr_shallow(*opr, new_inp,
  1553. opr->config());
  1554. }
  1555. };
  1556. auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr,
  1557. const VarNodeArray& new_inp) {
  1558. mgb_assert(opr->input().size() == new_inp.size());
  1559. VarNodeArray temp_inp = new_inp;
  1560. for (size_t i = 0; i < opr->input().size(); i++) {
  1561. if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
  1562. mgb_assert(opr->input(i)->shape().ndim == 4);
  1563. mgb_assert(new_inp[i]->shape().ndim == 5);
  1564. auto new_var =
  1565. RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode);
  1566. temp_inp[i] = new_var.node();
  1567. }
  1568. }
  1569. return serialization::copy_opr_shallow(*opr, temp_inp, opr->config());
  1570. };
  1571. ret->set_name(convter_pass_name);
  1572. auto&& replace_func = ret->m_opr_replace_func;
  1573. //! supportted nchwxx
  1574. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  1575. replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
  1576. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  1577. replace_func[opr::Concat::typeinfo()] = replace_concat_opr;
  1578. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
  1579. replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
  1580. replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr;
  1581. replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr;
  1582. //! not support yet
  1583. replace_func[opr::ConvolutionBackwardData::typeinfo()] =
  1584. relayout_inp_to_nchw;
  1585. replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw;
  1586. replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw;
  1587. replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw;
  1588. replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw;
  1589. replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw;
  1590. replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw;
  1591. replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw;
  1592. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  1593. relayout_inp_to_nchw;
  1594. replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
  1595. return ret;
  1596. }
  1597. /* ==================== ShuffleShuffleRemovePass ================= */
  1598. class ShuffleShuffleRemovePass::Impl {
  1599. using TensorFormat = opr::ConvBias::Param::Format;
  1600. OptState& m_opt_state;
  1601. ThinHashMap<std::pair<TensorFormat, TensorFormat>,
  1602. thin_function<VarNode*(VarNode*)>>
  1603. m_reformat;
  1604. class AbstractShuffleOpr;
  1605. void detect_shuffle_operations();
  1606. void do_replace();
  1607. public:
  1608. Impl(OptState& opt_state) : m_opt_state{opt_state} {
  1609. m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW4)] =
  1610. [](VarNode* inp) -> VarNode* {
  1611. auto x = SymbolVar(inp);
  1612. auto xshp = opr::GetVarShape::make(x);
  1613. auto cv = [&x](int v) { return x.make_scalar(v); };
  1614. auto sub = [&xshp, &cv](int idx) {
  1615. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1616. };
  1617. auto tshp = opr::Concat::make(
  1618. {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  1619. auto y0 = opr::Reshape::make(x, tshp);
  1620. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  1621. return y1.node();
  1622. };
  1623. m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW32)] =
  1624. [](VarNode* inp) -> VarNode* {
  1625. auto x = SymbolVar(inp);
  1626. auto xshp = opr::GetVarShape::make(x);
  1627. auto cv = [&x](int v) { return x.make_scalar(v); };
  1628. auto sub = [&xshp, &cv](int idx) {
  1629. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1630. };
  1631. auto tshp = opr::Concat::make(
  1632. {sub(0), sub(1) / 32, cv(32), sub(2), sub(3)}, 0);
  1633. auto y0 = opr::Reshape::make(x, tshp);
  1634. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
  1635. return y1.node();
  1636. };
  1637. m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW)] =
  1638. [](VarNode* inp) -> VarNode* {
  1639. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
  1640. auto x = SymbolVar(inp);
  1641. auto xshp = opr::GetVarShape::make(x);
  1642. auto cv = [&x](int v) { return x.make_scalar(v); };
  1643. auto sub = [&xshp, &cv](int idx) {
  1644. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1645. };
  1646. auto tshp =
  1647. opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
  1648. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  1649. auto y1 = opr::Reshape::make(y0, tshp);
  1650. return y1.node();
  1651. };
  1652. m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW)] =
  1653. [](VarNode* inp) -> VarNode* {
  1654. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32);
  1655. auto x = SymbolVar(inp);
  1656. auto xshp = opr::GetVarShape::make(x);
  1657. auto cv = [&x](int v) { return x.make_scalar(v); };
  1658. auto sub = [&xshp, &cv](int idx) {
  1659. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1660. };
  1661. auto tshp =
  1662. opr::Concat::make({sub(0), sub(1) * 32, sub(2), sub(3)}, 0);
  1663. auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
  1664. auto y1 = opr::Reshape::make(y0, tshp);
  1665. return y1.node();
  1666. };
  1667. m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW32)] =
  1668. [](VarNode* inp) -> VarNode* {
  1669. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
  1670. auto x = SymbolVar(inp);
  1671. auto xshp = opr::GetVarShape::make(x);
  1672. auto cv = [&x](int v) { return x.make_scalar(v); };
  1673. auto sub = [&xshp, &cv](int idx) {
  1674. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1675. };
  1676. auto tshp0 = opr::Concat::make(
  1677. {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)},
  1678. 0),
  1679. tshp1 = opr::Concat::make(
  1680. {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0);
  1681. auto y0 = opr::Reshape::make(x, tshp0);
  1682. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5});
  1683. auto y2 = opr::Reshape::make(y1, tshp1);
  1684. return y2.node();
  1685. };
  1686. m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW4)] =
  1687. [](VarNode* inp) -> VarNode* {
  1688. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32);
  1689. auto x = SymbolVar(inp);
  1690. auto xshp = opr::GetVarShape::make(x);
  1691. auto cv = [&x](int v) { return x.make_scalar(v); };
  1692. auto sub = [&xshp, &cv](int idx) {
  1693. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1694. };
  1695. auto tshp0 = opr::Concat::make(
  1696. {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8},
  1697. 0),
  1698. tshp1 = opr::Concat::make(
  1699. {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0);
  1700. auto y0 = opr::Reshape::make(x, tshp0);
  1701. auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5});
  1702. auto y2 = opr::Reshape::make(y1, tshp1);
  1703. return y2.node();
  1704. };
  1705. m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::CHWN4)] =
  1706. [](VarNode* inp) -> VarNode* {
  1707. megdnn::param::RelayoutFormat param;
  1708. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4;
  1709. auto reformat = opr::RelayoutFormat::make(inp, param);
  1710. return reformat.node();
  1711. };
  1712. m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW4)] =
  1713. [](VarNode* inp) -> VarNode* {
  1714. megdnn::param::RelayoutFormat param;
  1715. param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4;
  1716. auto reformat = opr::RelayoutFormat::make(inp, param);
  1717. return reformat.node();
  1718. };
  1719. m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::CHWN4)] =
  1720. [](VarNode* inp) -> VarNode* {
  1721. auto x = SymbolVar(inp);
  1722. auto xshp = opr::GetVarShape::make(x);
  1723. auto cv = [&x](int v) { return x.make_scalar(v); };
  1724. auto sub = [&xshp, &cv](int idx) {
  1725. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1726. };
  1727. auto tshp = opr::Concat::make(
  1728. {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
  1729. auto y0 = opr::Reshape::make(x, tshp);
  1730. auto y1 = opr::Dimshuffle::make(y0, {1, 3, 4, 0, 2});
  1731. return y1.node();
  1732. };
  1733. m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW)] =
  1734. [](VarNode* inp) -> VarNode* {
  1735. mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
  1736. auto x = SymbolVar(inp);
  1737. auto xshp = opr::GetVarShape::make(x);
  1738. auto cv = [&x](int v) { return x.make_scalar(v); };
  1739. auto sub = [&xshp, &cv](int idx) {
  1740. return opr::IndexAt::make(xshp, {{0, cv(idx)}});
  1741. };
  1742. auto tshp =
  1743. opr::Concat::make({sub(3), sub(0) * 4, sub(1), sub(2)}, 0);
  1744. auto y0 = opr::Dimshuffle::make(x, {3, 0, 4, 1, 2});
  1745. auto y1 = opr::Reshape::make(y0, tshp);
  1746. return y1.node();
  1747. };
  1748. detect_shuffle_operations();
  1749. do_replace();
  1750. }
  1751. };
  1752. /*!
  1753. * \brief abstract operator representation of shuffle operation
  1754. */
  1755. MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr,
  1756. cg::SingleCNOperatorNodeBase) // {
  1757. public:
  1758. AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format,
  1759. TensorFormat out_format);
  1760. static SymbolVar make(VarNode* inpvar, TensorFormat inp_format,
  1761. TensorFormat out_format);
  1762. TensorFormat inp_format() const { return m_inp_format; }
  1763. TensorFormat out_format() const { return m_out_format; }
  1764. private:
  1765. void init_output_static_infer_desc() override;
  1766. void scn_do_execute() override;
  1767. const TensorFormat m_inp_format;
  1768. const TensorFormat m_out_format;
  1769. };
  1770. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr);
  1771. void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::scn_do_execute() {
  1772. mgb_throw(InternalError, "AbstractShuffleOpr cannot be executed");
  1773. }
  1774. void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::
  1775. init_output_static_infer_desc() {
  1776. using namespace cg::static_infer;
  1777. auto&& mgr = owner_graph()->static_infer_manager();
  1778. DepVal deps;
  1779. for (auto i : input())
  1780. deps.push_back({i, DepType::SHAPE});
  1781. auto infer_shape = [this](TensorShape& dst, const InpVal& inp) {
  1782. TensorShape inp_shape = inp.val[0].shape();
  1783. if (m_inp_format == TensorFormat::NCHW4 &&
  1784. m_out_format == TensorFormat::NCHW32) {
  1785. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  1786. dst = inp_shape;
  1787. dst[0] = inp_shape[0];
  1788. dst[1] = inp_shape[1] / 8;
  1789. dst[2] = inp_shape[2];
  1790. dst[3] = inp_shape[3];
  1791. dst[4] = inp_shape[4] * 8;
  1792. } else if (m_inp_format == TensorFormat::NCHW32 &&
  1793. m_out_format == TensorFormat::NCHW4) {
  1794. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32);
  1795. dst = inp_shape;
  1796. dst[0] = inp_shape[0];
  1797. dst[1] = inp_shape[1] * 8;
  1798. dst[2] = inp_shape[2];
  1799. dst[3] = inp_shape[3];
  1800. dst[4] = inp_shape[4] / 8;
  1801. } else if (m_inp_format == TensorFormat::NCHW &&
  1802. m_out_format == TensorFormat::NCHW4) {
  1803. mgb_assert(inp_shape.ndim == 4);
  1804. dst.ndim = 5;
  1805. dst[0] = inp_shape[0];
  1806. dst[1] = inp_shape[1] / 4;
  1807. dst[2] = inp_shape[2];
  1808. dst[3] = inp_shape[3];
  1809. dst[4] = 4;
  1810. } else if (m_inp_format == TensorFormat::NCHW4 &&
  1811. m_out_format == TensorFormat::NCHW) {
  1812. mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
  1813. dst.ndim = 4;
  1814. dst[0] = inp_shape[0];
  1815. dst[1] = inp_shape[1] * 4;
  1816. dst[2] = inp_shape[2];
  1817. dst[3] = inp_shape[3];
  1818. } else if (m_inp_format == TensorFormat::NCHW4 &&
  1819. m_out_format == TensorFormat::CHWN4) {
  1820. dst.ndim = 5;
  1821. dst[0] = inp_shape[1];
  1822. dst[1] = inp_shape[2];
  1823. dst[2] = inp_shape[3];
  1824. dst[3] = inp_shape[0];
  1825. dst[4] = inp_shape[4];
  1826. } else if (m_inp_format == TensorFormat::CHWN4 &&
  1827. m_out_format == TensorFormat::NCHW4) {
  1828. dst.ndim = 5;
  1829. dst[0] = inp_shape[3];
  1830. dst[1] = inp_shape[0];
  1831. dst[2] = inp_shape[1];
  1832. dst[3] = inp_shape[2];
  1833. dst[4] = inp_shape[4];
  1834. } else {
  1835. mgb_throw(InternalError,
  1836. "Unsupported input format and output format.");
  1837. }
  1838. return true;
  1839. };
  1840. mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape});
  1841. }
  1842. ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::AbstractShuffleOpr(
  1843. VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format)
  1844. : Super(inpvar->owner_graph(), {}, "AbstractShuffleOpr", {inpvar}),
  1845. m_inp_format{inp_format},
  1846. m_out_format{out_format} {
  1847. add_input({inpvar});
  1848. add_equivalence_component<ScalarHash<TensorFormat>>(m_inp_format);
  1849. add_equivalence_component<ScalarHash<TensorFormat>>(m_out_format);
  1850. add_output(None)->dtype(inpvar->dtype());
  1851. }
  1852. SymbolVar ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::make(
  1853. VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format) {
  1854. return inpvar->owner_graph()
  1855. ->insert_opr(std::make_unique<AbstractShuffleOpr>(
  1856. inpvar, inp_format, out_format))
  1857. ->output(0);
  1858. }
  1859. void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() {
  1860. auto rewriter = m_opt_state.graph().make_rewriter();
  1861. auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()};
  1862. auto try_reshape_shuffle = [&rewriter,
  1863. &uniq_reader_check](OperatorNodeBase* opr) {
  1864. // check shuffle
  1865. auto shuffle = try_cast_as_op<opr::Dimshuffle>(opr);
  1866. if (shuffle == nullptr)
  1867. return false;
  1868. auto&& param = shuffle->param();
  1869. if (param.pattern_len != 5)
  1870. return false;
  1871. bool is_nchw2nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  1872. param.pattern[2] == 3 && param.pattern[3] == 4 &&
  1873. param.pattern[4] == 2 &&
  1874. opr->output(0)->shape()[4] == 4;
  1875. if (!is_nchw2nchw4)
  1876. return false;
  1877. if (!uniq_reader_check(shuffle->input(0)))
  1878. return false;
  1879. // check reshape
  1880. auto reshape = try_cast_as_op<opr::Reshape>(opr->input(0)->owner_opr());
  1881. if (reshape == nullptr)
  1882. return false;
  1883. auto inp_var = rewriter.get_var(reshape->input(0));
  1884. auto abstract_shuffle = AbstractShuffleOpr::make(
  1885. inp_var, TensorFormat::NCHW, TensorFormat::NCHW4);
  1886. rewriter.replace_var(
  1887. opr->output(0), abstract_shuffle.node(),
  1888. mgb_cstr_log("replace reformat(nchw -> nchw4) to "
  1889. "AbstractShuffleOpr(nchw -> nchw4)."));
  1890. return true;
  1891. };
  1892. auto try_reshape_shuffle_reshape = [&rewriter, &uniq_reader_check](
  1893. OperatorNodeBase* opr) {
  1894. // check reshape
  1895. auto reshape1 = try_cast_as_op<opr::Reshape>(opr);
  1896. if (reshape1 == nullptr)
  1897. return false;
  1898. if (!uniq_reader_check(reshape1->input(0)))
  1899. return false;
  1900. // check shuffle
  1901. auto shuffle =
  1902. try_cast_as_op<opr::Dimshuffle>(opr->input(0)->owner_opr());
  1903. if (shuffle == nullptr)
  1904. return false;
  1905. auto&& param = shuffle->param();
  1906. if (param.pattern_len != 6)
  1907. return false;
  1908. bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  1909. param.pattern[2] == 3 && param.pattern[3] == 4 &&
  1910. param.pattern[4] == 2 && param.pattern[5] == 5 &&
  1911. shuffle->input(0)->shape()[5] == 4 &&
  1912. shuffle->input(0)->shape()[2] == 8;
  1913. bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  1914. param.pattern[2] == 4 && param.pattern[3] == 2 &&
  1915. param.pattern[4] == 3 && param.pattern[5] == 5 &&
  1916. shuffle->input(0)->shape()[4] == 8 &&
  1917. shuffle->input(0)->shape()[5] == 4;
  1918. if (!is_nchw42nchw32 && !is_nchw322nchw4)
  1919. return false;
  1920. if (!uniq_reader_check(shuffle->input(0)))
  1921. return false;
  1922. // check reshape
  1923. auto reshape2 =
  1924. try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr());
  1925. if (reshape2 == nullptr)
  1926. return false;
  1927. auto inp_var = rewriter.get_var(reshape2->input(0));
  1928. TensorFormat inp_format = is_nchw42nchw32 ? TensorFormat::NCHW4
  1929. : TensorFormat::NCHW32,
  1930. out_format = is_nchw42nchw32 ? TensorFormat::NCHW32
  1931. : TensorFormat::NCHW4;
  1932. auto abstract_shuffle =
  1933. AbstractShuffleOpr::make(inp_var, inp_format, out_format);
  1934. std::string reformat_type =
  1935. is_nchw42nchw32 ? "nchw4 -> nchw32" : "nchw32 -> nchw4";
  1936. rewriter.replace_var(opr->output(0), abstract_shuffle.node(),
  1937. mgb_cstr_log(ssprintf("replace reformat(%s) to "
  1938. "AbstractShuffleOpr(%s).",
  1939. reformat_type.c_str(),
  1940. reformat_type.c_str())
  1941. .c_str()));
  1942. return true;
  1943. };
  1944. auto try_shuffle_reshape = [&rewriter,
  1945. &uniq_reader_check](OperatorNodeBase* opr) {
  1946. // check reshape
  1947. auto reshape = try_cast_as_op<opr::Reshape>(opr);
  1948. if (reshape == nullptr)
  1949. return false;
  1950. if (!uniq_reader_check(reshape->input(0)))
  1951. return false;
  1952. // check shuffle
  1953. auto shuffle =
  1954. try_cast_as_op<opr::Dimshuffle>(opr->input(0)->owner_opr());
  1955. if (shuffle == nullptr)
  1956. return false;
  1957. auto&& param = shuffle->param();
  1958. if (param.pattern_len != 5)
  1959. return false;
  1960. bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 &&
  1961. param.pattern[2] == 4 && param.pattern[3] == 2 &&
  1962. param.pattern[4] == 3 &&
  1963. shuffle->input(0)->shape()[4] == 4;
  1964. if (!is_nchw42nchw)
  1965. return false;
  1966. auto inp_var = rewriter.get_var(shuffle->input(0));
  1967. auto abstract_shuffle = AbstractShuffleOpr::make(
  1968. inp_var, TensorFormat::NCHW4, TensorFormat::NCHW);
  1969. rewriter.replace_var(
  1970. opr->output(0), abstract_shuffle.node(),
  1971. mgb_cstr_log("replace reformat(nchw4 -> nchw) to "
  1972. "AbstractShuffleOpr(nchw4 -> nchw)."));
  1973. return true;
  1974. };
  1975. auto try_relayout_format = [&rewriter](OperatorNodeBase* opr) {
  1976. // check relayout format
  1977. auto reformat = try_cast_as_op<opr::RelayoutFormat>(opr);
  1978. if (reformat == nullptr)
  1979. return false;
  1980. auto&& param = reformat->param();
  1981. if (param.mode != opr::RelayoutFormat::Param::Mode::CHWN4_NCHW4 &&
  1982. param.mode != opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4)
  1983. return false;
  1984. auto inp_var = rewriter.get_var(reformat->input(0));
  1985. cg::SymbolVar abstract_shuffle;
  1986. if (param.mode == opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4) {
  1987. abstract_shuffle = AbstractShuffleOpr::make(
  1988. inp_var, TensorFormat::NCHW4, TensorFormat::CHWN4);
  1989. } else {
  1990. abstract_shuffle = AbstractShuffleOpr::make(
  1991. inp_var, TensorFormat::CHWN4, TensorFormat::NCHW4);
  1992. }
  1993. rewriter.replace_var(
  1994. opr->output(0), abstract_shuffle.node(),
  1995. mgb_cstr_log("replace reformat(nchw4 -> nchw) to "
  1996. "AbstractShuffleOpr(nchw4 -> nchw)."));
  1997. return true;
  1998. };
  1999. auto on_opr = [&try_reshape_shuffle, &try_shuffle_reshape,
  2000. &try_reshape_shuffle_reshape, &try_relayout_format,
  2001. &rewriter, &uniq_reader_check](OperatorNodeBase* opr) {
  2002. if (!try_reshape_shuffle_reshape(opr) && !try_reshape_shuffle(opr) &&
  2003. !try_shuffle_reshape(opr) && !try_relayout_format(opr)) {
  2004. auto new_opr = rewriter.auto_replace_outputs(opr);
  2005. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  2006. }
  2007. };
  2008. m_opt_state.graph().iter(on_opr);
  2009. rewriter.apply_inplace();
  2010. }
  2011. void ShuffleShuffleRemovePass::Impl::do_replace() {
  2012. auto rewriter = m_opt_state.graph().make_rewriter();
  2013. auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()};
  2014. ThinHashMap<VarNode*, VarNode*> var2endpoint;
  2015. ThinHashSet<VarNode*> trt_opr_inps;
  2016. SmallVector<OperatorNodeBase*> topo_order;
  2017. auto cb = [&topo_order, &trt_opr_inps](OperatorNodeBase* opr) {
  2018. topo_order.push_back(opr);
  2019. MGB_MARK_USED_VAR(trt_opr_inps);
  2020. #if MGB_ENABLE_TENSOR_RT
  2021. if (opr->same_type<opr::TensorRTOpr>()) {
  2022. for (auto&& inp : opr->input())
  2023. trt_opr_inps.insert(inp);
  2024. }
  2025. #endif
  2026. };
  2027. m_opt_state.graph().iter(cb);
  2028. for (auto&& opr : reverse_adaptor(topo_order)) {
  2029. if (opr->same_type<opr::TypeCvt>() ||
  2030. opr->same_type<AbstractShuffleOpr>()) {
  2031. auto find = var2endpoint.find(opr->output(0));
  2032. if (find != var2endpoint.end()) {
  2033. if (uniq_reader_check(opr->output(0))) {
  2034. var2endpoint[opr->input(0)] = find->second;
  2035. } else {
  2036. var2endpoint[opr->input(0)] = opr->output(0);
  2037. }
  2038. } else {
  2039. var2endpoint[opr->input(0)] = opr->output(0);
  2040. }
  2041. }
  2042. }
  2043. auto on_opr = [this, &rewriter, &uniq_reader_check, &trt_opr_inps,
  2044. &var2endpoint](OperatorNodeBase* opr) {
  2045. MGB_MARK_USED_VAR(trt_opr_inps);
  2046. bool cond_opr = opr->same_type<opr::TypeCvt>() ||
  2047. opr->same_type<AbstractShuffleOpr>();
  2048. if (cond_opr) {
  2049. bool cond_endpoint = var2endpoint[opr->input(0)] == opr->output(0);
  2050. if (!cond_endpoint)
  2051. return;
  2052. auto cur = opr;
  2053. auto var = opr->output(0), inp_var = opr->input(0);
  2054. bool force_folding_typecvt = false;
  2055. bool first_shuffle = false;
  2056. // initialize inp_format and out_format
  2057. TensorFormat out_format = TensorFormat::NCHW, inp_format = out_format;
  2058. megdnn::DType inp_dtype = cur->input(0)->dtype(),
  2059. out_dtype = cur->output(0)->dtype();
  2060. SmallVector<megdnn::DType> out_dtype_vec;
  2061. while (cond_opr) {
  2062. if (cur->same_type<AbstractShuffleOpr>()) {
  2063. auto shuffle = try_cast_as_op<AbstractShuffleOpr>(cur);
  2064. inp_format = shuffle->inp_format();
  2065. if (!first_shuffle) {
  2066. out_format = shuffle->out_format();
  2067. first_shuffle = true;
  2068. }
  2069. } else {
  2070. mgb_assert(cur->same_type<opr::TypeCvt>());
  2071. out_dtype_vec.push_back(cur->output(0)->dtype());
  2072. }
  2073. inp_var = cur->input(0);
  2074. bool cond_reader = uniq_reader_check(inp_var);
  2075. if (!cond_reader)
  2076. break;
  2077. cur = cur->input(0)->owner_opr();
  2078. cond_opr = cur->same_type<opr::TypeCvt>() ||
  2079. cur->same_type<AbstractShuffleOpr>();
  2080. }
  2081. std::reverse(out_dtype_vec.begin(), out_dtype_vec.end());
  2082. #if MGB_ENABLE_TENSOR_RT
  2083. force_folding_typecvt =
  2084. inp_var->owner_opr()->same_type<opr::TensorRTOpr>() ||
  2085. trt_opr_inps.count(var);
  2086. #endif
  2087. auto new_var = rewriter.get_var(inp_var);
  2088. if (inp_format != out_format) {
  2089. new_var = m_reformat[std::make_pair(inp_format, out_format)](
  2090. new_var);
  2091. }
  2092. if (force_folding_typecvt) {
  2093. inp_dtype = inp_var->dtype();
  2094. if (inp_dtype != out_dtype) {
  2095. auto type_cvt = opr::TypeCvt::make(new_var, out_dtype);
  2096. new_var = type_cvt.node();
  2097. }
  2098. } else {
  2099. if (out_dtype_vec.back() != var->dtype())
  2100. out_dtype_vec.push_back(var->dtype());
  2101. for (auto&& dtype : out_dtype_vec) {
  2102. auto type_cvt = opr::TypeCvt::make(new_var, dtype);
  2103. new_var = type_cvt.node();
  2104. }
  2105. }
  2106. rewriter.replace_var(
  2107. var, new_var,
  2108. mgb_cstr_log("replace Dimshuffle and TypeCvt chain"));
  2109. } else {
  2110. auto new_opr = rewriter.auto_replace_outputs(opr);
  2111. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  2112. }
  2113. };
  2114. m_opt_state.graph().iter(on_opr);
  2115. rewriter.apply_inplace();
  2116. }
  2117. const char* ShuffleShuffleRemovePass::name() const {
  2118. return mgb_cstr_log("shuffle shuffle remove pass");
  2119. }
  2120. void ShuffleShuffleRemovePass::apply(OptState& opt) const {
  2121. opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_SHAPE |
  2122. VarReplaceCheckFlag::CHECK_DTYPE);
  2123. Impl{opt};
  2124. }
  2125. void gopt::reformat_to_chwn4_transform_dest_vars_inplace(
  2126. mgb::cg::VarNodeArray& dest_vars) {
  2127. gopt::GraphOptimizer optimizer;
  2128. optimizer.add_pass<FuseConvBiasNonlinPass>();
  2129. optimizer.add_pass<FuseConvBiasZPass>();
  2130. optimizer.add_pass(EnableCHWN4Pass::make_chwn4_converter());
  2131. optimizer.add_pass<ShuffleShuffleRemovePass>();
  2132. optimizer.add_pass<RemoveRedundantTypeCvtPass>();
  2133. optimizer.add_pass<ParamFusePass>();
  2134. optimizer.apply_inplace(dest_vars);
  2135. }
  2136. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台