You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.cpp 93 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190
  1. /**
  2. * \file src/gopt/impl/inference.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/gopt/inference.h"
  12. #include "megbrain/gopt/basic_arith.h"
  13. #include "megbrain/gopt/gtrans.h"
  14. #include "megbrain/graph/event.h"
  15. #include "megbrain/opr/basic_arith.h"
  16. #include "megbrain/opr/blas.h"
  17. #include "megbrain/opr/dnn/batch_norm.h"
  18. #include "megbrain/opr/dnn/convolution.h"
  19. #include "megbrain/opr/dnn/images2neibs.h"
  20. #include "megbrain/opr/dnn/local.h"
  21. #include "megbrain/opr/dnn/pooling.h"
  22. #include "megbrain/opr/imgproc.h"
  23. #include "megbrain/opr/misc.h"
  24. #include "megbrain/opr/nn_int.h"
  25. #include "megbrain/opr/search_policy/algo_chooser_helper.h"
  26. #include "megbrain/opr/search_policy/profiler.h"
  27. #include "megbrain/opr/tensor_gen.h"
  28. #include "megbrain/opr/tensor_manip.h"
  29. #include "megbrain/opr/utility.h"
  30. #include "megbrain/serialization/opr_shallow_copy.h"
  31. #include "megbrain/utils/hash_ct.h"
  32. #include "megbrain/utils/shared_set.h"
  33. #include "megdnn/tensor_format.h"
  34. #if MGB_ENABLE_TENSOR_RT
  35. #include "megbrain/tensorrt/tensorrt_opr.h"
  36. #endif
  37. #if MGB_CUDA
  38. #include <cudnn.h>
  39. #endif
  40. #include "megbrain/gopt/misc.h"
  41. #include "megbrain/utils/hash_ct.h"
  42. #include "midout.h"
  43. MIDOUT_DECL(megbrain_inference)
  44. #define MIDOUT_B(tag) MIDOUT_BEGIN(megbrain_inference, midout_iv(MGB_HASH_STR(tag))) {
  45. #define MIDOUT_E \
  46. } \
  47. MIDOUT_END();
  48. using namespace mgb;
  49. using namespace gopt;
  50. namespace {
  51. template <typename SharedDeviceTensor, typename MultipleDeviceTensorHolder>
  52. void param_merge(OptState& opt_state) {
  53. auto rewriter = opt_state.graph().make_rewriter();
  54. ThinHashMap<OperatorNodeBase*, size_t> opr2idx;
  55. std::vector<OperatorNodeBase*> all_oprs;
  56. typename MultipleDeviceTensorHolder::ValueArray all_values;
  57. auto cb_find_opr = [&](cg::OperatorNodeBase* opr) {
  58. if (opr->same_type<SharedDeviceTensor>()) {
  59. auto p = &opr->cast_final<SharedDeviceTensor>();
  60. // ShredD may be manu
  61. opr2idx[p] = all_values.size();
  62. all_values.push_back(p->dev_data());
  63. all_oprs.push_back(p);
  64. }
  65. };
  66. opt_state.graph().iter(cb_find_opr);
  67. SymbolVarArray new_vars;
  68. auto cb_replace = [&](cg::OperatorNodeBase* opr) {
  69. auto iter = opr2idx.find(opr);
  70. if (iter == opr2idx.end()) {
  71. rewriter.auto_replace_outputs(opr);
  72. } else {
  73. if (new_vars.empty()) {
  74. // new oprs must be created in iter callback; so we populate
  75. // new_vars lazily
  76. new_vars = MultipleDeviceTensorHolder::make(
  77. *opt_state.graph().comp_graph(), std::move(all_values),
  78. {ssprintf("merged%zu", all_values.size())});
  79. for (size_t i = 0; i < new_vars.size(); ++i) {
  80. auto src = all_oprs[i]->output(0);
  81. if (src->has_name_set()) {
  82. new_vars[i].rename(src->name());
  83. }
  84. }
  85. }
  86. rewriter.replace_var(
  87. opr->output(0), new_vars.at(iter->second).node(),
  88. mgb_cstr_log("replace multi SharedDeviceTensor(Format) to "
  89. "MultipleDeviceTensorHolder(Format)"));
  90. }
  91. };
  92. opt_state.graph().iter(cb_replace);
  93. rewriter.apply_inplace();
  94. }
  95. } // namespace
  96. /* ================ global functions ================ */
  97. SymbolVarArray gopt::optimize_for_inference(
  98. const SymbolVarArray& dest_vars, const OptimizeForInferenceOptions& opt) {
  99. return gopt::GraphOptimizer()
  100. .add_preset_passes(
  101. false, &opt, &dest_vars[0].node()->owner_graph()->options())
  102. .apply({dest_vars})
  103. .endpoint_vars();
  104. }
  105. SymbolVarArray gopt::layout_transform(
  106. const SymbolVarArray& dest_vars, GraphTuningOptions::Target target) {
  107. GraphTuningOptions options;
  108. options.target = target;
  109. options.enable_layout_transform();
  110. return gopt::GraphOptimizer{}
  111. .add_passes_for_graph_tuning_options(options)
  112. .apply({dest_vars})
  113. .endpoint_vars();
  114. }
  115. namespace {
  116. void modify_conv_strategy(
  117. opr::mixin::AlgoChooserHelper& conv,
  118. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
  119. auto policy = conv.execution_policy_transient();
  120. policy.strategy = strategy;
  121. conv.set_execution_policy(policy);
  122. }
  123. template <typename Opr>
  124. void inplace_conv_opr_modifier(
  125. OperatorNodeBase& opr,
  126. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
  127. modify_conv_strategy(opr.cast_final_safe<Opr>(), strategy);
  128. }
  129. void modify_conv_policy_workspace_limit(
  130. opr::mixin::AlgoChooserHelper& conv, size_t workspace_limit) {
  131. auto policy = conv.execution_policy_transient();
  132. policy.workspace_limit = workspace_limit;
  133. conv.set_execution_policy(policy);
  134. }
  135. template <typename Opr>
  136. void inplace_conv_opr_workspace_limit_modifier(
  137. OperatorNodeBase& opr, size_t workspace_limit) {
  138. modify_conv_policy_workspace_limit(opr.cast_final_safe<Opr>(), workspace_limit);
  139. }
  140. } // anonymous namespace
  141. void gopt::modify_opr_algo_strategy_inplace(
  142. const VarNodeArrayView& dest_vars,
  143. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
  144. #if !MGB_ENABLE_FASTRUN
  145. using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
  146. if ((strategy & S::PROFILE) && !(strategy & S::HEURISTIC)) {
  147. mgb_throw(MegBrainError, "fastrun is disabled at compile time");
  148. }
  149. #endif
  150. const ThinHashMap<Typeinfo*, std::function<void(OperatorNodeBase&)>> modifiers = {
  151. #define CONV(t) \
  152. {opr::t::typeinfo(), \
  153. std::bind(inplace_conv_opr_modifier<opr::t>, std::placeholders::_1, strategy)},
  154. MGB_FOREACH_FASTRUN_OPR(CONV)
  155. #undef CONV
  156. };
  157. auto on_opr = [&](OperatorNodeBase* opr) {
  158. auto iter = modifiers.find(opr->dyn_typeinfo());
  159. if (iter != modifiers.end()) {
  160. iter->second(*opr);
  161. }
  162. };
  163. cg::DepOprIter dep_iter{on_opr};
  164. for (auto i : dest_vars) {
  165. dep_iter.add(i);
  166. }
  167. }
  168. void gopt::enable_opr_algo_profiling_inplace(const VarNodeArrayView& dest_vars) {
  169. modify_opr_algo_strategy_inplace(
  170. dest_vars,
  171. opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy::PROFILE);
  172. }
  173. void gopt::enable_opr_use_profiling_cache_inplace(const VarNodeArrayView& dest_vars) {
  174. using S = megdnn::param::ExecutionPolicy::Strategy;
  175. modify_opr_algo_strategy_inplace(dest_vars, S::PROFILE | S::HEURISTIC);
  176. }
  177. void gopt::set_opr_algo_workspace_limit_inplace(
  178. const VarNodeArrayView& dest_vars, size_t workspace_limit) {
  179. static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> modifiers =
  180. {
  181. #define CONV(t) \
  182. {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>},
  183. MGB_FOREACH_FASTRUN_OPR(CONV)
  184. #undef CONV
  185. };
  186. auto on_opr = [&](OperatorNodeBase* opr) {
  187. auto iter = modifiers.find(opr->dyn_typeinfo());
  188. if (iter != modifiers.end()) {
  189. iter->second(*opr, workspace_limit);
  190. }
  191. };
  192. cg::DepOprIter dep_iter{on_opr};
  193. for (auto i : dest_vars) {
  194. dep_iter.add(i);
  195. }
  196. }
  197. /* ================ ParamRedistributePass ================ */
  198. const char* ParamRedistributePass::name() const {
  199. return mgb_cstr_log("param_redistribute");
  200. }
  201. class ParamRedistributePass::Impl final : public RecursiveSubGraphRewriteHelper {
  202. ConstVarPropogate m_cvprop;
  203. UniqReaderCheck m_uniq_reader_check;
  204. //! oprs already processed in try_distribute_then_reassociate() should be
  205. //! skipped in on_new_opr_check_should_process()
  206. ThinHashSet<OperatorNodeBase*> m_opr_blacklist;
  207. std::string m_distribute_reasso_log_msg;
  208. //! try applying BinaryTrans20::associtive
  209. GTransResult try_reassociate(OperatorNodeBase* opr);
  210. //! try applying BinaryTrans20::distributive_add
  211. GTransResult try_distribute_add(OperatorNodeBase* opr);
  212. //! try distribute MUL/DIV over ADD/SUB and then apply
  213. GTransResult try_distribute_then_reassociate(OperatorNodeBase* opr);
  214. GTransResult process_opr(VarNode* out_var) override;
  215. bool on_new_opr_check_should_process(
  216. OperatorNodeBase* opr, OperatorNodeBase* repl_opr) override {
  217. m_uniq_reader_check.update_on_opr_auto_replace(opr, repl_opr);
  218. auto ins = m_cvprop.add_opr(opr);
  219. return ins.has_const_inp && !ins.all_const_inp && !m_opr_blacklist.count(opr);
  220. };
  221. void after_replace_var(VarNode* orig_var, VarNode* new_var) override {
  222. m_uniq_reader_check.update_on_opr_auto_replace(
  223. orig_var->owner_opr(), new_var->owner_opr());
  224. }
  225. /*!
  226. * \brief try to reorder opr inputs to a const one and a non-const one
  227. *
  228. * return true if it can be reformulated as f(nci, ci), where nci is
  229. * non-const and ci is const.
  230. */
  231. bool reorder_for_normconst(
  232. OperatorNodeBase* opr, bool& swap_inp, VarNode*& nci, VarNode*& ci);
  233. public:
  234. Impl(OptState& state);
  235. };
  236. GTransResult ParamRedistributePass::Impl::process_opr(VarNode* out_var) {
  237. auto opr = out_var->owner_opr();
  238. auto trans = try_reassociate(opr);
  239. if (!trans.valid()) {
  240. trans = try_distribute_add(opr);
  241. if (!trans.valid())
  242. trans = try_distribute_then_reassociate(opr);
  243. }
  244. return trans;
  245. }
  246. GTransResult ParamRedistributePass::Impl::try_reassociate(OperatorNodeBase* opr) {
  247. // apply BinaryAssociative0 if opr is the form f(g(a, b), c) and b and c are
  248. // const
  249. bool swap_fop_inp = false, swap_gop_inp = false;
  250. VarNode *a, *b, *c, *ab;
  251. if (!reorder_for_normconst(opr, swap_fop_inp, ab, c))
  252. return None;
  253. if (!m_uniq_reader_check(ab))
  254. return None;
  255. if (!reorder_for_normconst(ab->owner_opr(), swap_gop_inp, a, b))
  256. return None;
  257. return BinaryTrans20::associtive().apply(opr, swap_fop_inp, swap_gop_inp);
  258. }
  259. GTransResult ParamRedistributePass::Impl::try_distribute_add(OperatorNodeBase* opr) {
  260. if (opr->same_type<opr::Elemwise>() || opr->input().size() != 2)
  261. return None;
  262. if (!m_cvprop.is_const(opr->input(1)))
  263. return None;
  264. auto ab = as_elem_opr(opr->input(0)->owner_opr(), opr::Elemwise::Mode::ADD);
  265. if (ab) {
  266. bool swap;
  267. VarNode *a, *b;
  268. if (reorder_for_normconst(ab, swap, a, b)) {
  269. return BinaryTrans20::distributive_add().apply(opr, false, swap);
  270. }
  271. }
  272. return None;
  273. }
  274. GTransResult ParamRedistributePass::Impl::try_distribute_then_reassociate(
  275. OperatorNodeBase* opr) {
  276. if (!opr->same_type<opr::Elemwise>())
  277. return None;
  278. using Mode = opr::Elemwise::Mode;
  279. auto mode = opr->cast_final<opr::Elemwise>().param().mode;
  280. if (!(mode == Mode::MUL || mode == Mode::TRUE_DIV))
  281. return None;
  282. VarNode *a, *b;
  283. bool swap;
  284. if (!reorder_for_normconst(opr, swap, a, b))
  285. return None;
  286. auto chain_pred = [this](OperatorNodeBase* opr) {
  287. if (as_elem_opr(opr, Mode::ADD)) {
  288. auto var = opr->output(0);
  289. return m_uniq_reader_check(var) || m_cvprop.is_const(var);
  290. }
  291. return false;
  292. };
  293. auto chain = extract_opr_leaves(a, chain_pred);
  294. if (chain.size() <= 1)
  295. return None;
  296. std::unordered_map<VarNode*, VarNode*> repl_map;
  297. m_distribute_reasso_log_msg.clear();
  298. int nr_fail = 0, nr_succ = 0;
  299. for (auto&& var : chain) {
  300. {
  301. auto iter = repl_map.find(var);
  302. if (iter != repl_map.end()) {
  303. var = iter->second;
  304. continue;
  305. }
  306. }
  307. auto vnew = (SymbolVar{var} * b).node();
  308. m_opr_blacklist.insert(vnew->owner_opr());
  309. if (!m_cvprop.is_const(var)) {
  310. auto trans = try_reassociate(vnew->owner_opr());
  311. if (!trans.valid()) {
  312. // allow at most one failed redistribution
  313. if (nr_fail)
  314. return None;
  315. ++nr_fail;
  316. } else {
  317. ++nr_succ;
  318. vnew = trans->result;
  319. if (!m_distribute_reasso_log_msg.empty()) {
  320. m_distribute_reasso_log_msg.append(mgb_cstr_log(";"));
  321. }
  322. m_distribute_reasso_log_msg.append(trans->msg);
  323. }
  324. }
  325. repl_map[var] = vnew;
  326. var = vnew;
  327. }
  328. if (nr_succ) {
  329. m_distribute_reasso_log_msg.insert(0, mgb_cstr_log("distribute_mul("));
  330. m_distribute_reasso_log_msg.append(mgb_cstr_log(")"));
  331. return GTransResultItem{
  332. elemwise_reduce_var_list(chain, Mode::ADD),
  333. m_distribute_reasso_log_msg.c_str(),
  334. {}};
  335. }
  336. return None;
  337. }
  338. bool ParamRedistributePass::Impl::reorder_for_normconst(
  339. OperatorNodeBase* opr, bool& swap_inp, VarNode*& nci, VarNode*& ci) {
  340. if (opr->input().size() != 2)
  341. return false;
  342. nci = opr->input(0);
  343. ci = opr->input(1);
  344. if (!m_cvprop.is_const(ci)) {
  345. if (!is_commutable_binary(opr) || !m_cvprop.is_const(nci))
  346. return false;
  347. swap_inp = true;
  348. std::swap(nci, ci);
  349. } else {
  350. if (m_cvprop.is_const(nci))
  351. return false;
  352. swap_inp = false;
  353. }
  354. return true;
  355. }
  356. ParamRedistributePass::Impl::Impl(OptState& state)
  357. : RecursiveSubGraphRewriteHelper{state},
  358. m_cvprop{ConstVarType::IMMUTABLE_AND_PARAM},
  359. m_uniq_reader_check{state.graph()} {
  360. auto cg = state.graph().comp_graph();
  361. auto on_new_opr = [this](const cg::event::OprInserted& ev) {
  362. if (!ev.is_dedup && !ev.exc) {
  363. // call add_opr eagerly to avoid deep recursion
  364. m_cvprop.add_opr(ev.opr);
  365. }
  366. };
  367. auto hdl = cg->event().register_receiver<cg::event::OprInserted>(on_new_opr);
  368. apply();
  369. }
  370. void ParamRedistributePass::apply(OptState& state) const {
  371. MIDOUT_B("ParamRedistributePass::apply")
  372. Impl{state};
  373. MIDOUT_E
  374. }
  375. /* ================ ParamFusePass ================ */
  376. /*!
  377. * \brief get name for new param
  378. */
  379. class ParamFusePass::VarNamer {
  380. #if MGB_BUILD_SLIM_SERVING
  381. public:
  382. const std::string& name(VarNode*) {
  383. static std::string ret("fuse");
  384. return ret;
  385. }
  386. #else
  387. using SrcSet = SharedSet<OperatorNodeBase*>;
  388. //! map from var to source SharedDeviceTensor/MultiSharedDeviceHolder oprs
  389. //! that it depends on
  390. ThinHashMap<OperatorNodeBase*, SrcSet> m_opr2srcs;
  391. std::string m_name_cache;
  392. std::vector<const char*> m_cur_name;
  393. SrcSet& get_src_set(OperatorNodeBase* opr) {
  394. auto opr_typeinfo = opr->dyn_typeinfo();
  395. auto iter = m_opr2srcs.find(opr);
  396. if (iter != m_opr2srcs.end()) {
  397. return iter->second;
  398. }
  399. auto&& ret = m_opr2srcs[opr];
  400. if (opr->input().empty()) {
  401. if (opr_typeinfo == opr::SharedDeviceTensor::typeinfo() ||
  402. opr_typeinfo == opr::MultipleDeviceTensorHolder::typeinfo()) {
  403. ret.insert(opr);
  404. } else {
  405. mgb_assert(opr_typeinfo == opr::ImmutableTensor::typeinfo());
  406. }
  407. return ret;
  408. }
  409. for (auto i : opr->input()) {
  410. ret.merge_from(get_src_set(i->owner_opr()));
  411. }
  412. return ret;
  413. }
  414. public:
  415. const std::string& name(VarNode* var) {
  416. m_cur_name.clear();
  417. for (auto i : get_src_set(var->owner_opr())) {
  418. m_cur_name.push_back(i->cname());
  419. }
  420. auto cmp = [](const char* x, const char* y) { return strcmp(x, y) < 0; };
  421. std::sort(m_cur_name.begin(), m_cur_name.end(), cmp);
  422. m_name_cache.clear();
  423. m_name_cache.append(mgb_cstr_log("fuse("));
  424. bool first = true;
  425. for (auto i : m_cur_name) {
  426. if (first) {
  427. first = false;
  428. } else {
  429. m_name_cache.push_back(',');
  430. }
  431. m_name_cache.append(i);
  432. }
  433. m_name_cache.append(
  434. mgb_cstr_log(ssprintf("):%s@%zu", var->cname(), var->id())));
  435. return m_name_cache;
  436. }
  437. #endif
  438. };
  439. const char* ParamFusePass::name() const {
  440. return mgb_cstr_log("param_fuse");
  441. }
  442. void ParamFusePass::apply(OptState& state) const {
  443. MIDOUT_B("ParamFusePass::apply")
  444. auto rewriter = state.graph().make_rewriter();
  445. auto cg = state.graph().comp_graph();
  446. ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM};
  447. state.graph().iter([&cvprop](OperatorNodeBase* opr) { cvprop.add_opr(opr); });
  448. ThinHashSet<VarNode*> processed_var;
  449. VarNamer var_namer;
  450. // reader: null if used as endvar
  451. auto replace_single_var = [&](VarNode* var, OperatorNodeBase* reader) {
  452. if (!processed_var.insert(var).second)
  453. return;
  454. auto inferred_val =
  455. std::make_shared<DeviceTensorND>(var->comp_node(), var->dtype());
  456. auto cb = [&](DeviceTensorND& val) {
  457. // retain format of val
  458. mgb_assert(val.format() == var->format());
  459. inferred_val->format(val.format())
  460. .resize(val.shape())
  461. .copy_from_fixlayout(val);
  462. };
  463. {
  464. auto orig_level = cg->options().log_level;
  465. cg->options().log_level = 0;
  466. MGB_TRY { cg->compile({{var, cb}})->execute(); }
  467. MGB_FINALLY(cg->options().log_level = orig_level);
  468. }
  469. SymbolVar new_var;
  470. bool is_default_format = var->format().is_default();
  471. bool is_lowbit_aligned = var->format().is_lowbit_aligned();
  472. if (cg::is_static_var_value(var) && (is_default_format || is_lowbit_aligned)) {
  473. // use ImmutableTensor for inferable vars
  474. HostTensorND hv;
  475. hv.copy_from(*inferred_val).sync();
  476. new_var = opr::ImmutableTensor::make(
  477. *var->owner_graph(), hv, var_namer.name(var));
  478. } else {
  479. if (is_default_format || is_lowbit_aligned) {
  480. new_var = opr::SharedDeviceTensor::make_const(
  481. *var->owner_graph(), inferred_val, var_namer.name(var));
  482. } else {
  483. new_var = opr::SharedDeviceTensorWithFormat::make_const(
  484. *var->owner_graph(), inferred_val, var_namer.name(var));
  485. }
  486. }
  487. std::string log;
  488. if (reader) {
  489. log = mgb_ssprintf_log(
  490. "due to read by %s{%s}", reader->cname(),
  491. reader->dyn_typeinfo()->name);
  492. } else {
  493. log = mgb_cstr_log("as endpoint");
  494. }
  495. rewriter.replace_var(var, new_var.node(), log.c_str());
  496. };
  497. auto replace_opr = [&](OperatorNodeBase* opr) {
  498. auto add_ret = cvprop.opr_rst(opr);
  499. if (!add_ret.all_const_inp && add_ret.has_midconst_inp) {
  500. for (auto i : opr->input()) {
  501. if (cvprop.is_midconst(i)) {
  502. state.call_with_opr(
  503. i->owner_opr(), [&] { replace_single_var(i, opr); });
  504. }
  505. }
  506. }
  507. rewriter.auto_replace_outputs(opr);
  508. //! we should deal with midconst var after auto_replace_outputs, as
  509. //! on_midconst_opr will replace the endpoint output which may cause
  510. //! double replace.
  511. if (add_ret.all_const_inp) {
  512. for (auto var : opr->output()) {
  513. if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT))
  514. continue;
  515. auto osize = ConstVarPropogate::var_mem_size(var);
  516. if (osize >= cvprop.max_size(opr) &&
  517. osize - cvprop.max_size(opr) > m_param_grow_limit) {
  518. return;
  519. }
  520. // const oprs should be evaluated when output is used by another
  521. // non-const opr or output is needed by the user
  522. if (state.graph().endpoint_contain(var)) {
  523. replace_single_var(var, nullptr);
  524. }
  525. }
  526. }
  527. };
  528. state.graph().iter(replace_opr);
  529. rewriter.apply_inplace();
  530. MIDOUT_E
  531. }
  532. /* ================ One2OneOprReplacePass ================ */
  533. const char* ConvertF32ToF16Pass::name() const {
  534. return mgb_cstr_log("convert_f32_to_f16");
  535. }
  536. void ConvertF32ToF16Pass::apply(OptState& state) const {
  537. MIDOUT_B("ConvertF32ToF16Pass::apply")
  538. state.set_var_replace_check_flag(m_var_replace_check_flag);
  539. auto rewriter = state.graph().make_rewriter();
  540. VarNodeArray new_inp_cache;
  541. // record original output dtype
  542. const SymbolVarArray& vars = state.graph().endpoint_vars();
  543. std::vector<DType> dtypes;
  544. for (size_t i = 0; i < vars.size(); i++) {
  545. dtypes.push_back(vars[i].node()->dtype());
  546. }
  547. auto on_opr = [this, &rewriter, &new_inp_cache](OperatorNodeBase* opr) {
  548. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  549. if (it != m_opr_replace_func.end()) {
  550. auto&& new_inp = new_inp_cache;
  551. new_inp.clear();
  552. new_inp.reserve(opr->input().size());
  553. for (auto i : opr->input()) {
  554. new_inp.push_back(rewriter.get_var(i));
  555. }
  556. auto new_opr = (it->second)(opr, new_inp);
  557. auto &&origin_out = opr->output(), &&cur_out = new_opr->output();
  558. mgb_assert(
  559. origin_out.size() == cur_out.size(),
  560. "bad opr replace: src=%s{%s} dst=%s{%s}", opr->cname(),
  561. opr->dyn_typeinfo()->name, new_opr->cname(),
  562. new_opr->dyn_typeinfo()->name);
  563. for (size_t i = 0; i < origin_out.size(); i++) {
  564. rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
  565. }
  566. } else {
  567. rewriter.auto_replace_outputs(opr);
  568. }
  569. };
  570. state.graph().iter(on_opr);
  571. rewriter.apply_inplace();
  572. // recover output dtype
  573. rewriter = state.graph().make_rewriter();
  574. const SymbolVarArray& endpoints = state.graph().endpoint_vars();
  575. auto replace_output = [&]() {
  576. for (size_t i = 0; i < endpoints.size(); i++) {
  577. VarNode* var = endpoints[i].node();
  578. if (var->dtype().enumv() != dtypes[i].enumv()) {
  579. auto new_var = opr::TypeCvt::make(var, dtypes[i]).node();
  580. rewriter.replace_var(var, new_var, nullptr);
  581. }
  582. }
  583. };
  584. mgb_assert(endpoints.size() > 0);
  585. auto opr = endpoints[0].node()->owner_opr();
  586. state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE);
  587. rewriter.apply_inplace();
  588. MIDOUT_E
  589. }
  590. std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(bool use_f32_comp) {
  591. #if MEGDNN_DISABLE_FLOAT16
  592. mgb_throw(SystemError, "float16 disabled at compile time.");
  593. #else
  594. auto replace_h2d_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
  595. mgb_assert(opr->input().size() == new_inp.size());
  596. auto& h2d_opr = opr->cast_final_safe<opr::Host2DeviceCopy>();
  597. if (h2d_opr.output(0)->dtype() == dtype::Float32()) {
  598. auto cvt_var = opr::TypeCvt::make(h2d_opr.output(0), dtype::Float16(), {});
  599. return cvt_var.node()->owner_opr();
  600. }
  601. return opr;
  602. };
  603. auto replace_sdt_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
  604. mgb_assert(opr->input().size() == new_inp.size());
  605. auto& sdt_opr = opr->cast_final_safe<opr::SharedDeviceTensor>();
  606. if (sdt_opr.output(0)->dtype() == dtype::Float32()) {
  607. auto cvt_var = opr::TypeCvt::make(sdt_opr.output(0), dtype::Float16(), {});
  608. return cvt_var.node()->owner_opr();
  609. }
  610. return opr;
  611. };
  612. auto replace_imt_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
  613. mgb_assert(opr->same_type<opr::ImmutableTensor>());
  614. mgb_assert(opr->input().size() == new_inp.size());
  615. auto& imt_opr = opr->cast_final_safe<opr::ImmutableTensor>();
  616. if (imt_opr.output(0)->dtype() == dtype::Float32()) {
  617. auto cvt_var = opr::TypeCvt::make(imt_opr.output(0), dtype::Float16(), {});
  618. return cvt_var.node()->owner_opr();
  619. }
  620. return opr;
  621. };
  622. auto replace_lsp_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
  623. mgb_assert(opr->same_type<opr::Linspace>());
  624. mgb_assert(opr->input().size() == new_inp.size());
  625. auto& lsp_opr = opr->cast_final_safe<opr::Linspace>();
  626. if (lsp_opr.output(0)->dtype() != dtype::Float16()) {
  627. auto cvt_var = opr::TypeCvt::make(lsp_opr.output(0), dtype::Float16(), {});
  628. return cvt_var.node()->owner_opr();
  629. }
  630. return opr;
  631. };
  632. auto replace_conv_opr = [use_f32_comp](
  633. OperatorNodeBase* opr,
  634. const VarNodeArray& new_inp) {
  635. mgb_assert(opr->input().size() == new_inp.size());
  636. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  637. auto new_param = conv_opr.param();
  638. if (use_f32_comp) {
  639. new_param.compute_mode = megdnn::param::Convolution::ComputeMode::FLOAT32;
  640. }
  641. mgb_assert(
  642. new_inp[0]->dtype() == dtype::Float16(), "inp %s:%s, owner_opr:%s",
  643. new_inp[0]->dtype().name(), new_inp[0]->name().c_str(),
  644. new_inp[0]->owner_opr()->name().c_str());
  645. mgb_assert(
  646. new_inp[1]->dtype() == dtype::Float16(), "inp %s:%s, owner_opr:%s",
  647. new_inp[1]->dtype().name(), new_inp[1]->name().c_str(),
  648. new_inp[1]->owner_opr()->name().c_str());
  649. auto new_conv_opr = opr::Convolution::make(
  650. new_inp[0], new_inp[1], new_param, conv_opr.execution_policy(),
  651. conv_opr.config());
  652. return new_conv_opr.node()->owner_opr();
  653. };
  654. auto replace_deconv_opr = [use_f32_comp](
  655. OperatorNodeBase* opr,
  656. const VarNodeArray& new_inp) {
  657. mgb_assert(opr->input().size() == new_inp.size());
  658. auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>();
  659. auto new_param = deconv_opr.param();
  660. if (use_f32_comp) {
  661. new_param.compute_mode = megdnn::param::Convolution::ComputeMode::FLOAT32;
  662. }
  663. mgb_assert(
  664. new_inp[0]->dtype() == dtype::Float16(), "inp %s:%s, owner_opr:%s",
  665. new_inp[0]->dtype().name(), new_inp[0]->name().c_str(),
  666. new_inp[0]->owner_opr()->name().c_str());
  667. mgb_assert(
  668. new_inp[1]->dtype() == dtype::Float16(), "inp %s:%s, owner_opr:%s",
  669. new_inp[1]->dtype().name(), new_inp[1]->name().c_str(),
  670. new_inp[1]->owner_opr()->name().c_str());
  671. auto new_deconv_opr = opr::ConvolutionBackwardData::make(
  672. new_inp[0], new_inp[1], new_param, deconv_opr.execution_policy(),
  673. deconv_opr.config());
  674. return new_deconv_opr.node()->owner_opr();
  675. };
  676. auto replace_convbias_opr = [use_f32_comp](
  677. OperatorNodeBase* opr,
  678. const VarNodeArray& new_inp) {
  679. auto& convbias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  680. auto new_param = convbias_opr.param();
  681. if (use_f32_comp) {
  682. new_param.compute_mode = megdnn::param::ConvBias::ComputeMode::FLOAT32;
  683. }
  684. mgb_assert(
  685. new_inp[0]->dtype() == dtype::Float16(), "inp %s:%s, owner_opr:%s",
  686. new_inp[0]->dtype().name(), new_inp[0]->name().c_str(),
  687. new_inp[0]->owner_opr()->name().c_str());
  688. mgb_assert(
  689. new_inp[1]->dtype() == dtype::Float16(), "inp %s:%s, owner_opr:%s",
  690. new_inp[1]->dtype().name(), new_inp[1]->name().c_str(),
  691. new_inp[1]->owner_opr()->name().c_str());
  692. if (opr->input().size() == 2) {
  693. auto new_conv_opr = opr::ConvBias::make(
  694. new_inp[0], new_inp[1], new_param, convbias_opr.execution_policy(),
  695. convbias_opr.config());
  696. return new_conv_opr.node()->owner_opr();
  697. } else if (opr->input().size() == 3) {
  698. auto new_conv_opr = opr::ConvBias::make(
  699. new_inp[0], new_inp[1], new_inp[2], new_param,
  700. convbias_opr.execution_policy(), convbias_opr.config());
  701. return new_conv_opr.node()->owner_opr();
  702. } else {
  703. mgb_assert(
  704. opr->input().size() == 4, "invalid input size %zu",
  705. opr->input().size());
  706. auto new_conv_opr = opr::ConvBias::make(
  707. new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param,
  708. convbias_opr.execution_policy(), convbias_opr.config());
  709. return new_conv_opr.node()->owner_opr();
  710. }
  711. };
  712. auto replace_matmul_opr =
  713. [use_f32_comp](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
  714. mgb_assert(opr->input().size() == new_inp.size());
  715. auto& matmul_opr = opr->cast_final_safe<opr::MatrixMul>();
  716. auto new_param = matmul_opr.param();
  717. if (use_f32_comp) {
  718. new_param.compute_mode =
  719. megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  720. }
  721. auto new_matmul_opr = opr::MatrixMul::make(
  722. new_inp[0], new_inp[1], new_param,
  723. matmul_opr.execution_policy(), matmul_opr.config());
  724. return new_matmul_opr.node()->owner_opr();
  725. };
  726. auto replace_batched_matmul_opr = [use_f32_comp](
  727. OperatorNodeBase* opr,
  728. const VarNodeArray& new_inp) {
  729. mgb_assert(opr->input().size() == new_inp.size());
  730. auto& matmul_opr = opr->cast_final_safe<opr::BatchedMatrixMul>();
  731. auto new_param = matmul_opr.param();
  732. if (use_f32_comp) {
  733. new_param.compute_mode = megdnn::param::MatrixMul::ComputeMode::FLOAT32;
  734. }
  735. mgb_assert(
  736. new_inp[0]->dtype() == dtype::Float16(), "inp %s:%s, owner_opr:%s",
  737. new_inp[0]->dtype().name(), new_inp[0]->name().c_str(),
  738. new_inp[0]->owner_opr()->name().c_str());
  739. mgb_assert(
  740. new_inp[1]->dtype() == dtype::Float16(), "inp %s:%s, owner_opr:%s",
  741. new_inp[1]->dtype().name(), new_inp[1]->name().c_str(),
  742. new_inp[1]->owner_opr()->name().c_str());
  743. auto new_matmul_opr = opr::BatchedMatrixMul::make(
  744. new_inp[0], new_inp[1], new_param, matmul_opr.execution_policy(),
  745. matmul_opr.config());
  746. return new_matmul_opr.node()->owner_opr();
  747. };
  748. auto replace_reduce_opr = [use_f32_comp](
  749. OperatorNodeBase* opr,
  750. const VarNodeArray& new_inp) {
  751. auto& reduce_opr = opr->cast_final_safe<opr::Reduce>();
  752. auto new_param = reduce_opr.param();
  753. if (use_f32_comp) {
  754. new_param.data_type = megdnn::param::Reduce::DataType::FLOAT_O16xC32;
  755. }
  756. if (opr->input().size() == 1) {
  757. auto new_matmul_opr =
  758. opr::Reduce::make(new_inp[0], new_param, {}, reduce_opr.config());
  759. return new_matmul_opr.node()->owner_opr();
  760. } else {
  761. mgb_assert(
  762. opr->input().size() == 2, "invalid input size %zu",
  763. opr->input().size());
  764. auto new_matmul_opr = opr::Reduce::make(
  765. new_inp[0], new_param, new_inp[1], reduce_opr.config());
  766. return new_matmul_opr.node()->owner_opr();
  767. }
  768. };
  769. auto replace_cvt_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
  770. auto& cvt_opr = opr->cast_final_safe<opr::TypeCvt>();
  771. SymbolVar new_cvt;
  772. if (cvt_opr.output(0)->dtype() == dtype::Float32()) {
  773. new_cvt =
  774. opr::TypeCvt::make(new_inp[0], dtype::Float16(), cvt_opr.config());
  775. } else {
  776. new_cvt = opr::TypeCvt::make(
  777. new_inp[0], cvt_opr.output()[0]->dtype(), cvt_opr.config());
  778. }
  779. return new_cvt.node()->owner_opr();
  780. };
  781. auto replace_warp_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
  782. mgb_assert(
  783. opr->input().size() == new_inp.size() &&
  784. (new_inp.size() == 3 || new_inp.size() == 4));
  785. auto& warp_opr = opr->cast_final<opr::WarpPerspective>();
  786. // mat tensor must be float32
  787. auto new_mat = new_inp[1];
  788. if (new_inp[1]->dtype() != dtype::Float32()) {
  789. if (try_cast_as_op<opr::TypeCvt>(new_mat->owner_opr()) &&
  790. new_mat->owner_opr()->input(0)->dtype() == dtype::Float32())
  791. new_mat = new_mat->owner_opr()->input(0);
  792. else
  793. new_mat = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node();
  794. }
  795. SymbolVar new_warp;
  796. if (new_inp.size() == 3) {
  797. new_warp = opr::WarpPerspective::make(
  798. new_inp[0], new_mat, new_inp[2], warp_opr.param(),
  799. warp_opr.config());
  800. } else {
  801. mgb_assert(new_inp.size() == 4);
  802. new_warp = opr::WarpPerspective::make(
  803. new_inp[0], new_mat, new_inp[2], new_inp[3], warp_opr.param(),
  804. warp_opr.config());
  805. }
  806. return new_warp.node()->owner_opr();
  807. };
  808. auto replace_remap_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
  809. mgb_assert(opr->input().size() == new_inp.size() && (new_inp.size() == 2));
  810. auto& remap_opr = opr->cast_final<opr::Remap>();
  811. // map tensor must be float32
  812. auto new_map = new_inp[1];
  813. if (new_inp[1]->dtype() != dtype::Float32()) {
  814. if (try_cast_as_op<opr::TypeCvt>(new_map->owner_opr()) &&
  815. new_map->owner_opr()->input(0)->dtype() == dtype::Float32())
  816. new_map = new_map->owner_opr()->input(0);
  817. else
  818. new_map = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node();
  819. }
  820. SymbolVar new_remap;
  821. new_remap = opr::Remap::make(
  822. new_inp[0], new_map, remap_opr.param(), remap_opr.config());
  823. return new_remap.node()->owner_opr();
  824. };
  825. auto ret = std::make_unique<ConvertF32ToF16Pass>();
  826. // don't check dtype
  827. ret->set_var_replace_check_flag(
  828. VarReplaceCheckFlag::CHECK_ALL ^ VarReplaceCheckFlag::CHECK_DTYPE);
  829. auto&& replace_func = ret->m_opr_replace_func;
  830. replace_func[opr::Linspace::typeinfo()] = replace_lsp_opr;
  831. replace_func[opr::Host2DeviceCopy::typeinfo()] = replace_h2d_opr;
  832. replace_func[opr::SharedDeviceTensor::typeinfo()] = replace_sdt_opr;
  833. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  834. replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
  835. replace_func[opr::ConvBias::typeinfo()] = replace_convbias_opr;
  836. replace_func[opr::MatrixMul::typeinfo()] = replace_matmul_opr;
  837. replace_func[opr::Reduce::typeinfo()] = replace_reduce_opr;
  838. replace_func[opr::ImmutableTensor::typeinfo()] = replace_imt_opr;
  839. replace_func[opr::TypeCvt::typeinfo()] = replace_cvt_opr;
  840. replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr;
  841. replace_func[opr::Remap::typeinfo()] = replace_remap_opr;
  842. replace_func[opr::BatchedMatrixMul::typeinfo()] = replace_batched_matmul_opr;
  843. return ret;
  844. #endif
  845. }
  846. /* ================ ConvertFormatPass ================ */
  847. void ConvertFormatPass::apply(OptState& state) const {
  848. MIDOUT_B("ConvertFormatPass::apply")
  849. state.set_var_replace_check_flag(m_var_replace_check_flag);
  850. auto rewriter = state.graph().make_rewriter();
  851. VarNodeArray new_inp_cache;
  852. auto on_opr = [this, &state, &rewriter, &new_inp_cache](OperatorNodeBase* opr) {
  853. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  854. if (it != m_opr_replace_func.end()) {
  855. auto&& new_inp = new_inp_cache;
  856. new_inp.clear();
  857. new_inp.reserve(opr->input().size());
  858. for (auto i : opr->input()) {
  859. new_inp.push_back(rewriter.get_var(i));
  860. }
  861. auto new_opr = (it->second)(opr, new_inp);
  862. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  863. mgb_assert(
  864. out0.size() == out1.size(),
  865. "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu "
  866. "dst.size=%zu",
  867. opr->cname(), opr->dyn_typeinfo()->name, new_opr->cname(),
  868. new_opr->dyn_typeinfo()->name, out0.size(), out1.size());
  869. for (size_t i = 0; i < out0.size(); i++) {
  870. if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  871. mgb_assert(!out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
  872. auto src = out0[i];
  873. auto dst = out1[i];
  874. auto dst_is_image =
  875. dst->format().type() == TensorFormat::Type::IMAGE2D_PACK4;
  876. if (!dst_is_image &&
  877. !src->owner_opr()->same_type<opr::ImmutableTensor>()) {
  878. mgb_log_warn(
  879. "convert NHWCD4 replaced to non-img format: "
  880. "dst_opr=%s{%s} format=%s",
  881. dst->owner_opr()->cname(),
  882. dst->owner_opr()->dyn_typeinfo()->name,
  883. dst->format().to_string().c_str());
  884. }
  885. if (state.graph().endpoint_contain(src) && dst_is_image) {
  886. // relayout back to NCHW for output vars
  887. dst = opr::RelayoutFormat::make(
  888. dst,
  889. {opr::RelayoutFormat::Param::Mode::NHWCD4I_NCHW})
  890. .node();
  891. }
  892. rewriter.replace_var(src, dst, nullptr);
  893. }
  894. }
  895. } else {
  896. rewriter.auto_replace_outputs(opr);
  897. }
  898. };
  899. state.graph().iter(on_opr);
  900. rewriter.apply_inplace();
  901. MIDOUT_E
  902. }
  903. std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
  904. MIDOUT_B("ConvertFormatPass::make")
  905. auto filter_mode =
  906. [](const megdnn::param::Convolution::Sparse conv_mode,
  907. const VarNode* filter) -> megdnn::param::RelayoutFormat::Mode {
  908. bool use_dot = false;
  909. if (filter->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 ||
  910. filter->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm)
  911. use_dot = true;
  912. if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
  913. if (use_dot)
  914. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_DENSEI_DOT;
  915. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_DENSEI;
  916. } else {
  917. mgb_throw_if(
  918. conv_mode != megdnn::param::Convolution::Sparse::GROUP,
  919. MegBrainError, "mode error");
  920. if (filter->shape()[1] == 1 && filter->shape()[2] == 1) {
  921. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_CHANI;
  922. } else {
  923. if (use_dot)
  924. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_GROUPI_DOT;
  925. return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_GROUPI;
  926. }
  927. }
  928. };
  929. auto size_one_conv_to_dense_conv = [](VarNode* origin_filter_input,
  930. megdnn::param::Convolution::Sparse sparse) {
  931. VarNode* reshaped_filter = origin_filter_input;
  932. bool is_size_one_group_conv = false;
  933. if (sparse == megdnn::param::Convolution::Sparse::GROUP &&
  934. origin_filter_input->shape()[0] == 1) {
  935. is_size_one_group_conv = true;
  936. auto new_shape = origin_filter_input->shape();
  937. new_shape.ndim = 4;
  938. for (int i = 0; i < 4; i++) {
  939. new_shape[i] = origin_filter_input->shape()[i + 1];
  940. }
  941. SymbolVar new_var(origin_filter_input);
  942. reshaped_filter = new_var.reshape(new_shape).node();
  943. }
  944. return std::make_tuple(reshaped_filter, is_size_one_group_conv);
  945. };
  946. auto replace_conv_opr = [&filter_mode, &size_one_conv_to_dense_conv](
  947. OperatorNodeBase* opr,
  948. const VarNodeArray& new_inp) {
  949. mgb_assert(opr->input().size() == new_inp.size());
  950. auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
  951. mgb_throw_if(
  952. conv_opr.param().format != megdnn::param::Convolution::Format::NCHW,
  953. MegBrainError,
  954. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  955. VarNode *conv_src = nullptr, *conv_weights = nullptr;
  956. if (new_inp[0]->shape().ndim == 4) {
  957. // new input src is NCHW
  958. size_t group, icpg, ocpg;
  959. if (conv_opr.param().sparse == megdnn::param::Convolution::Sparse::DENSE) {
  960. group = 1;
  961. icpg = new_inp[1]->shape()[1];
  962. ocpg = new_inp[1]->shape()[0];
  963. } else {
  964. mgb_throw_if(
  965. conv_opr.param().sparse !=
  966. megdnn::param::Convolution::Sparse::GROUP,
  967. MegBrainError, "ERROR mode");
  968. group = new_inp[1]->shape()[0];
  969. icpg = new_inp[1]->shape()[2];
  970. ocpg = new_inp[1]->shape()[1];
  971. }
  972. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  973. auto param = megdnn::param::RelayoutFormat();
  974. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  975. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  976. conv_src = rf.node();
  977. } else {
  978. // can not convert to hwcd4
  979. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  980. }
  981. } else {
  982. size_t ocpg;
  983. bool is_channel_wise = false;
  984. if (conv_opr.param().sparse == megdnn::param::Convolution::Sparse::DENSE) {
  985. ocpg = new_inp[1]->shape()[0];
  986. } else {
  987. mgb_throw_if(
  988. conv_opr.param().sparse !=
  989. megdnn::param::Convolution::Sparse::GROUP,
  990. MegBrainError, "ERROR mode");
  991. size_t icpg = new_inp[1]->shape()[2];
  992. ocpg = new_inp[1]->shape()[1];
  993. if (icpg == 1 && ocpg == 1) {
  994. is_channel_wise = true;
  995. }
  996. }
  997. if (ocpg % 4 != 0 && !is_channel_wise) {
  998. VarNodeArray t_inp = new_inp;
  999. auto param = megdnn::param::RelayoutFormat();
  1000. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1001. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1002. t_inp[0] = rf.node();
  1003. auto new_opr =
  1004. serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1005. return new_opr;
  1006. }
  1007. // new input src is NHWCD4
  1008. auto&& fmt =
  1009. new_inp[0]->format().as_impl<megdnn::Image2DPack4TensorFormat>();
  1010. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1011. conv_src = new_inp[0];
  1012. }
  1013. VarNode* reshaped_filter;
  1014. bool is_size_one_group_conv;
  1015. std::tie(reshaped_filter, is_size_one_group_conv) =
  1016. size_one_conv_to_dense_conv(new_inp[1], conv_opr.param().sparse);
  1017. auto new_conv_param = conv_opr.param();
  1018. if (is_size_one_group_conv) {
  1019. new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE;
  1020. }
  1021. mgb_assert(new_inp[1]->format().type() != TensorFormat::Type::IMAGE2D_PACK4);
  1022. auto param = megdnn::param::RelayoutFormat();
  1023. param.mode = filter_mode(new_conv_param.sparse, reshaped_filter);
  1024. auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param);
  1025. conv_weights = relayout_weight.node();
  1026. new_conv_param.format = megdnn::param::Convolution::Format::NHWCD4;
  1027. mgb_assert(
  1028. conv_src->shape().ndim == 5 &&
  1029. conv_src->format().type() == TensorFormat::Type::IMAGE2D_PACK4);
  1030. auto new_conv_opr = opr::Convolution::make(
  1031. conv_src, conv_weights, new_conv_param, conv_opr.execution_policy(),
  1032. conv_opr.config());
  1033. OperatorNodeBase* ret = new_conv_opr.node()->owner_opr();
  1034. mgb_assert(
  1035. new_conv_opr.shape().ndim == 5 &&
  1036. new_conv_opr.format().type() == TensorFormat::Type::IMAGE2D_PACK4);
  1037. return ret;
  1038. };
  1039. auto replace_conv_bias_opr = [&filter_mode, &size_one_conv_to_dense_conv](
  1040. OperatorNodeBase* opr,
  1041. const VarNodeArray& new_inp) {
  1042. mgb_assert(opr->input().size() == new_inp.size());
  1043. auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
  1044. mgb_throw_if(
  1045. conv_bias_opr.param().format != megdnn::param::ConvBias::Format::NCHW,
  1046. MegBrainError,
  1047. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1048. VarNode *conv_bias_src = nullptr, *conv_bias_weights = nullptr,
  1049. *conv_bias_bias = nullptr;
  1050. if (new_inp[0]->shape().ndim == 4) {
  1051. // new input src is NCHW
  1052. size_t group, icpg, ocpg;
  1053. if (conv_bias_opr.param().sparse ==
  1054. megdnn::param::ConvBias::Sparse::DENSE) {
  1055. group = 1;
  1056. icpg = new_inp[1]->shape()[1];
  1057. ocpg = new_inp[1]->shape()[0];
  1058. } else {
  1059. mgb_throw_if(
  1060. conv_bias_opr.param().sparse !=
  1061. megdnn::param::ConvBias::Sparse::GROUP,
  1062. MegBrainError, "mode error");
  1063. group = new_inp[1]->shape()[0];
  1064. icpg = new_inp[1]->shape()[2];
  1065. ocpg = new_inp[1]->shape()[1];
  1066. }
  1067. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  1068. auto param = megdnn::param::RelayoutFormat();
  1069. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1070. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1071. conv_bias_src = rf.node();
  1072. } else {
  1073. // can not convert to hwcd4
  1074. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1075. }
  1076. } else {
  1077. size_t ocpg;
  1078. bool is_channel_wise = false;
  1079. if (conv_bias_opr.param().sparse ==
  1080. megdnn::param::ConvBias::Sparse::DENSE) {
  1081. ocpg = new_inp[1]->shape()[0];
  1082. } else {
  1083. mgb_throw_if(
  1084. conv_bias_opr.param().sparse !=
  1085. megdnn::param::ConvBias::Sparse::GROUP,
  1086. MegBrainError, "ERROR mode");
  1087. size_t icpg = new_inp[1]->shape()[2];
  1088. ocpg = new_inp[1]->shape()[1];
  1089. if (icpg == 1 && ocpg == 1) {
  1090. is_channel_wise = true;
  1091. }
  1092. }
  1093. if (ocpg % 4 != 0 && !is_channel_wise) {
  1094. VarNodeArray t_inp = new_inp;
  1095. auto param = megdnn::param::RelayoutFormat();
  1096. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1097. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1098. t_inp[0] = rf.node();
  1099. auto new_opr =
  1100. serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1101. return new_opr;
  1102. }
  1103. // new input src is NHWCD4
  1104. auto&& fmt =
  1105. new_inp[0]->format().as_impl<megdnn::Image2DPack4TensorFormat>();
  1106. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1107. conv_bias_src = new_inp[0];
  1108. }
  1109. mgb_assert(new_inp[1]->format().type() != TensorFormat::Type::IMAGE2D_PACK4);
  1110. VarNode* reshaped_filter;
  1111. bool is_size_one_group_conv;
  1112. std::tie(reshaped_filter, is_size_one_group_conv) =
  1113. size_one_conv_to_dense_conv(new_inp[1], conv_bias_opr.param().sparse);
  1114. auto new_conv_param = conv_bias_opr.param();
  1115. if (is_size_one_group_conv) {
  1116. new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE;
  1117. }
  1118. auto param = megdnn::param::RelayoutFormat();
  1119. param.mode = filter_mode(new_conv_param.sparse, reshaped_filter);
  1120. auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param);
  1121. conv_bias_weights = relayout_weight.node();
  1122. mgb_assert(new_inp.size() < 4, "ConvertFormat pass does not support fuse Z");
  1123. bool has_bias = new_inp.size() > 2;
  1124. if (has_bias && new_inp[2]->format().type() == TensorFormat::Type::DEFAULT) {
  1125. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1126. auto relayout_bias = opr::RelayoutFormat::make(new_inp[2], param);
  1127. conv_bias_bias = relayout_bias.node();
  1128. } else if (has_bias) {
  1129. conv_bias_bias = new_inp[2];
  1130. }
  1131. new_conv_param.format = megdnn::param::ConvBias::Format::NHWCD4;
  1132. mgb_assert(
  1133. conv_bias_src->shape().ndim == 5 &&
  1134. conv_bias_src->format().type() == TensorFormat::Type::IMAGE2D_PACK4);
  1135. SymbolVar new_conv_bias_opr;
  1136. if (has_bias) {
  1137. new_conv_bias_opr = opr::ConvBias::make(
  1138. conv_bias_src, conv_bias_weights, conv_bias_bias, new_conv_param,
  1139. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1140. } else {
  1141. new_conv_bias_opr = opr::ConvBias::make(
  1142. conv_bias_src, conv_bias_weights, new_conv_param,
  1143. conv_bias_opr.execution_policy(), conv_bias_opr.config());
  1144. }
  1145. OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr();
  1146. mgb_assert(
  1147. new_conv_bias_opr.shape().ndim == 5 &&
  1148. new_conv_bias_opr.format().type() == TensorFormat::Type::IMAGE2D_PACK4);
  1149. return ret;
  1150. };
  1151. auto replace_deconv_opr = [&filter_mode](
  1152. OperatorNodeBase* opr,
  1153. const VarNodeArray& new_inp) {
  1154. mgb_assert(opr->input().size() == new_inp.size());
  1155. auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>();
  1156. mgb_throw_if(
  1157. deconv_opr.param().format != megdnn::param::Convolution::Format::NCHW,
  1158. MegBrainError,
  1159. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1160. VarNode *deconv_src = nullptr, *deconv_weights = nullptr;
  1161. if (new_inp[1]->shape().ndim == 4) {
  1162. // new input src is NCHW
  1163. size_t group, icpg, ocpg;
  1164. if (deconv_opr.param().sparse ==
  1165. megdnn::param::Convolution::Sparse::DENSE) {
  1166. group = 1;
  1167. icpg = new_inp[0]->shape()[0];
  1168. ocpg = new_inp[0]->shape()[1];
  1169. } else {
  1170. mgb_throw_if(
  1171. deconv_opr.param().sparse !=
  1172. megdnn::param::Convolution::Sparse::GROUP,
  1173. MegBrainError, "mode error");
  1174. group = new_inp[0]->shape()[0];
  1175. icpg = new_inp[0]->shape()[1];
  1176. ocpg = new_inp[0]->shape()[2];
  1177. }
  1178. if (ocpg % 4 == 0 && (icpg % 4 == 0 || group == 1)) {
  1179. auto param = megdnn::param::RelayoutFormat();
  1180. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1181. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1182. deconv_src = rf.node();
  1183. } else {
  1184. // can not convert to hwcd4
  1185. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1186. }
  1187. } else {
  1188. //! XXXX, fix me, check filter size
  1189. size_t ocpg;
  1190. if (deconv_opr.param().sparse ==
  1191. megdnn::param::Convolution::Sparse::DENSE) {
  1192. ocpg = new_inp[0]->shape()[1];
  1193. } else {
  1194. mgb_throw_if(
  1195. deconv_opr.param().sparse !=
  1196. megdnn::param::Convolution::Sparse::GROUP,
  1197. MegBrainError, "mode error");
  1198. ocpg = new_inp[0]->shape()[2];
  1199. }
  1200. if (ocpg % 4 != 0) {
  1201. VarNodeArray t_inp = new_inp;
  1202. auto param = megdnn::param::RelayoutFormat();
  1203. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1204. auto rf = opr::RelayoutFormat::make(new_inp[1], param);
  1205. t_inp[1] = rf.node();
  1206. auto new_opr =
  1207. serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1208. return new_opr;
  1209. }
  1210. // new input src is NHWCD4
  1211. auto&& fmt =
  1212. new_inp[1]->format().as_impl<megdnn::Image2DPack4TensorFormat>();
  1213. mgb_assert(new_inp[1]->shape().ndim == 5 && fmt.align_axis() == 2);
  1214. deconv_src = new_inp[1];
  1215. }
  1216. mgb_assert(new_inp[0]->format().type() != TensorFormat::Type::IMAGE2D_PACK4);
  1217. auto param = megdnn::param::RelayoutFormat();
  1218. param.mode = filter_mode(deconv_opr.param().sparse, new_inp[0]);
  1219. auto relayout_weight = opr::RelayoutFormat::make(new_inp[0], param);
  1220. deconv_weights = relayout_weight.node();
  1221. auto new_param = deconv_opr.param();
  1222. new_param.format = megdnn::param::Convolution::Format::NHWCD4;
  1223. mgb_assert(
  1224. deconv_src->shape().ndim == 5 &&
  1225. deconv_src->format().type() == TensorFormat::Type::IMAGE2D_PACK4);
  1226. auto new_deconv_opr = opr::ConvolutionBackwardData::make(
  1227. deconv_weights, deconv_src, new_param, deconv_opr.execution_policy(),
  1228. deconv_opr.config());
  1229. OperatorNodeBase* ret = new_deconv_opr.node()->owner_opr();
  1230. mgb_assert(
  1231. new_deconv_opr.shape().ndim == 5 &&
  1232. new_deconv_opr.format().type() == TensorFormat::Type::IMAGE2D_PACK4);
  1233. return ret;
  1234. };
  1235. /* This helper function guarantees the format convert pass won't change
  1236. * output var's channel. Changing output's channel will cause channel
  1237. * mismatch problem for replacing conv/conv_bias operator.
  1238. */
  1239. auto replace_helper = [](OperatorNodeBase* opr,
  1240. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1241. auto&& new_shp = new_inp[0]->shape();
  1242. size_t inp_channel = new_shp[1];
  1243. if (new_shp.eq_shape(opr->input(0)->shape()) && inp_channel % 4 != 0) {
  1244. auto new_opr =
  1245. serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1246. return new_opr;
  1247. }
  1248. return nullptr;
  1249. };
  1250. auto replace_resize_opr = [replace_helper](
  1251. OperatorNodeBase* opr,
  1252. const VarNodeArray& new_inp) {
  1253. mgb_assert(opr->input().size() == new_inp.size());
  1254. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1255. return opr_shallow_copy;
  1256. }
  1257. auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>();
  1258. mgb_throw_if(
  1259. resize_opr.param().format != megdnn::param::Resize::Format::NCHW,
  1260. MegBrainError,
  1261. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1262. VarNode* inp = nullptr;
  1263. if (new_inp[0]->shape().ndim == 4) {
  1264. auto param = megdnn::param::RelayoutFormat();
  1265. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1266. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1267. inp = rf.node();
  1268. } else {
  1269. // new input src is NHWCD
  1270. auto&& fmt =
  1271. new_inp[0]->format().as_impl<megdnn::Image2DPack4TensorFormat>();
  1272. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1273. inp = new_inp[0];
  1274. }
  1275. auto new_param = resize_opr.param();
  1276. new_param.format = megdnn::param::Resize::Format::NHWCD4;
  1277. auto new_resize_opr =
  1278. opr::ResizeForward::make(inp, new_inp[1], new_param, opr->config());
  1279. return new_resize_opr.node()->owner_opr();
  1280. };
  1281. auto replace_warp_perspective_opr = [replace_helper](
  1282. OperatorNodeBase* opr,
  1283. const VarNodeArray& new_inp) {
  1284. mgb_assert(opr->input().size() == new_inp.size());
  1285. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1286. return opr_shallow_copy;
  1287. }
  1288. auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>();
  1289. mgb_throw_if(
  1290. warp_opr.param().format != megdnn::param::WarpPerspective::Format::NCHW,
  1291. MegBrainError,
  1292. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1293. VarNode* inp = nullptr;
  1294. if (new_inp[0]->shape().ndim == 4) {
  1295. // new input src is NCHW
  1296. auto param = megdnn::param::RelayoutFormat();
  1297. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1298. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1299. inp = rf.node();
  1300. } else {
  1301. // new input src is NHWCD
  1302. auto&& fmt =
  1303. new_inp[0]->format().as_impl<megdnn::Image2DPack4TensorFormat>();
  1304. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1305. inp = new_inp[0];
  1306. }
  1307. auto new_param = warp_opr.param();
  1308. new_param.format = megdnn::param::WarpPerspective::Format::NHWCD4;
  1309. SymbolVar new_warp_opr;
  1310. if (new_inp.size() == 3) {
  1311. new_warp_opr = opr::WarpPerspectiveForward::make(
  1312. inp, new_inp[1], nullptr, new_inp[2], new_param, opr->config());
  1313. } else {
  1314. mgb_assert(new_inp.size() == 4);
  1315. new_warp_opr = opr::WarpPerspectiveForward::make(
  1316. inp, new_inp[1], new_inp[2], new_inp[3], new_param, opr->config());
  1317. }
  1318. return new_warp_opr.node()->owner_opr();
  1319. };
  1320. auto replace_warp_affine_opr = [replace_helper](
  1321. OperatorNodeBase* opr,
  1322. const VarNodeArray& new_inp) {
  1323. mgb_assert(opr->input().size() == new_inp.size());
  1324. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1325. return opr_shallow_copy;
  1326. }
  1327. auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>();
  1328. mgb_throw_if(
  1329. warp_opr.param().format != megdnn::param::WarpAffine::Format::NCHW,
  1330. MegBrainError,
  1331. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1332. VarNode* inp = nullptr;
  1333. if (new_inp[0]->shape().ndim == 4) {
  1334. // new input src is NCHW
  1335. auto param = megdnn::param::RelayoutFormat();
  1336. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1337. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1338. inp = rf.node();
  1339. } else {
  1340. // new input src is NHWCD
  1341. auto&& fmt =
  1342. new_inp[0]->format().as_impl<megdnn::Image2DPack4TensorFormat>();
  1343. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1344. inp = new_inp[0];
  1345. }
  1346. auto new_param = warp_opr.param();
  1347. new_param.format = megdnn::param::WarpAffine::Format::NHWCD4;
  1348. SymbolVar new_warp_opr;
  1349. new_warp_opr = opr::WarpAffineForward::make(
  1350. inp, new_inp[1], new_inp[2], new_param, opr->config());
  1351. return new_warp_opr.node()->owner_opr();
  1352. };
  1353. auto replace_pooling_opr = [replace_helper](
  1354. OperatorNodeBase* opr,
  1355. const VarNodeArray& new_inp) {
  1356. mgb_assert(opr->input().size() == new_inp.size());
  1357. if (auto opr_shallow_copy = replace_helper(opr, new_inp)) {
  1358. return opr_shallow_copy;
  1359. }
  1360. auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
  1361. mgb_throw_if(
  1362. pooling_opr.param().format != megdnn::param::Pooling::Format::NCHW,
  1363. MegBrainError,
  1364. "ConvertFormat Pass only support converting NCHW to NHWCD4");
  1365. VarNode* inp = nullptr;
  1366. if (new_inp[0]->shape().ndim == 4) {
  1367. // new input src is NCHW
  1368. auto param = megdnn::param::RelayoutFormat();
  1369. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1370. auto rf = opr::RelayoutFormat::make(new_inp[0], param);
  1371. inp = rf.node();
  1372. } else {
  1373. // new input src is NHWCD
  1374. auto&& fmt =
  1375. new_inp[0]->format().as_impl<megdnn::Image2DPack4TensorFormat>();
  1376. mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
  1377. inp = new_inp[0];
  1378. }
  1379. auto new_param = pooling_opr.param();
  1380. new_param.format = megdnn::param::Pooling::Format::NHWCD4;
  1381. auto new_pooling_opr = opr::PoolingForward::make(
  1382. inp, new_param, pooling_opr.execution_policy(), opr->config());
  1383. return new_pooling_opr.node()->owner_opr();
  1384. };
  1385. auto var_to_chw = [](VarNode* inp, VarNode* new_inp) {
  1386. if (!inp->shape().eq_shape(new_inp->shape())) {
  1387. mgb_assert(
  1388. inp->shape().ndim == 4 &&
  1389. inp->format().type() != TensorFormat::Type::IMAGE2D_PACK4);
  1390. mgb_assert(
  1391. new_inp->shape().ndim == 5 &&
  1392. new_inp->format().type() == TensorFormat::Type::IMAGE2D_PACK4);
  1393. auto param = megdnn::param::RelayoutFormat();
  1394. param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW;
  1395. auto rf = opr::RelayoutFormat::make(new_inp, param);
  1396. return rf.node();
  1397. }
  1398. return new_inp;
  1399. };
  1400. auto relayout_inp_to_chw =
  1401. [var_to_chw](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
  1402. mgb_assert(opr->input().size() == new_inp.size());
  1403. VarNodeArray t_inp = new_inp;
  1404. for (size_t i = 0; i < opr->input().size(); i++) {
  1405. t_inp[i] = var_to_chw(opr->input(i), new_inp[i]);
  1406. }
  1407. auto new_opr =
  1408. serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1409. return new_opr;
  1410. };
  1411. auto replace_concat_opr = [&relayout_inp_to_chw](
  1412. OperatorNodeBase* opr,
  1413. const VarNodeArray& new_inp) {
  1414. //! map nchw axis to CD4 axis(n h c/4 w 4)
  1415. auto axis_nchw_to_cd4_map = [=](int32_t org_axis) -> int32_t {
  1416. mgb_assert(org_axis >= 0 && org_axis <= 3);
  1417. int32_t ret = 0;
  1418. if (0 == org_axis) {
  1419. ret = 0;
  1420. } else if (1 == org_axis) {
  1421. ret = 2;
  1422. } else if (2 == org_axis) {
  1423. ret = 1;
  1424. } else if (3 == org_axis) {
  1425. mgb_throw(
  1426. InternalError,
  1427. "Do not support axis=3 for concat bypass for CD4!");
  1428. } else {
  1429. mgb_throw(
  1430. InternalError,
  1431. "Do not support axis for concat pass, may input is "
  1432. "not NCHW format!");
  1433. }
  1434. return ret;
  1435. };
  1436. mgb_assert(opr->input().size() == new_inp.size());
  1437. auto nchw_axis = opr->cast_final_safe<opr::Concat>().param().axis;
  1438. if (nchw_axis < 0 || nchw_axis > 3) {
  1439. mgb_log_warn("concat pass fallback to relayout chw\n");
  1440. return relayout_inp_to_chw(opr, new_inp);
  1441. }
  1442. bool can_exec_cd4 = true;
  1443. //! only consider OpenCL CD4, if other backend has relayout performance
  1444. //! issue, may add other bypass format
  1445. for (size_t i = 0; i < opr->input().size(); i++) {
  1446. if (opr->input(i)->format().type() != TensorFormat::Type::DEFAULT ||
  1447. opr->input(i)->shape()[1] % 4 != 0 || new_inp[i]->shape().ndim != 5 ||
  1448. new_inp[i]->format().type() != TensorFormat::Type::IMAGE2D_PACK4 ||
  1449. nchw_axis == 3) {
  1450. can_exec_cd4 = false;
  1451. break;
  1452. }
  1453. }
  1454. if (!can_exec_cd4) {
  1455. mgb_log_warn("concat pass fallback to relayout chw");
  1456. return relayout_inp_to_chw(opr, new_inp);
  1457. }
  1458. megdnn::param::Axis param;
  1459. //! now only support nchw bypass to CD4
  1460. mgb_log_warn("concat pass bypass to CD4");
  1461. param.axis = axis_nchw_to_cd4_map(nchw_axis);
  1462. return opr::Concat::make(VarNodeArrayView(new_inp), param, opr->config())
  1463. .node()
  1464. ->owner_opr();
  1465. };
  1466. auto replace_elemwise_opr = [&relayout_inp_to_chw](
  1467. OperatorNodeBase* opr,
  1468. const VarNodeArray& new_inp) {
  1469. mgb_assert(opr->input().size() == new_inp.size());
  1470. bool has_inp_changed = false;
  1471. bool can_exec_cd4 = true;
  1472. for (size_t i = 0; i < opr->input().size(); i++) {
  1473. if (!new_inp[i]->format().is_default()) {
  1474. has_inp_changed = true;
  1475. } else if (new_inp[i]->shape().ndim == 4) {
  1476. if (new_inp[i]->shape()[1] % 4 != 0) {
  1477. can_exec_cd4 = false;
  1478. }
  1479. //! cd4 elemwise with scaler is unsupported
  1480. } else if (!new_inp[i]->shape().is_scalar()) {
  1481. can_exec_cd4 = false;
  1482. }
  1483. }
  1484. if (!can_exec_cd4) {
  1485. return relayout_inp_to_chw(opr, new_inp);
  1486. }
  1487. if (has_inp_changed) {
  1488. // assumption: all inputs are changed from nchw to nhwcd4
  1489. auto t_inp = new_inp;
  1490. for (size_t i = 0; i < opr->input().size(); i++) {
  1491. if (new_inp[i]->shape().ndim == 4) {
  1492. auto param = megdnn::param::RelayoutFormat();
  1493. param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I;
  1494. auto rf = opr::RelayoutFormat::make(new_inp[i], param);
  1495. t_inp[i] = rf.node();
  1496. } else {
  1497. mgb_assert(
  1498. (new_inp[i]->shape().ndim == 5 &&
  1499. new_inp[i]->format().type() ==
  1500. TensorFormat::Type::IMAGE2D_PACK4) ||
  1501. new_inp[i]->shape().is_scalar());
  1502. }
  1503. }
  1504. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1505. } else {
  1506. return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
  1507. }
  1508. };
  1509. /* This helper function converts the first input to the NCHW format to
  1510. * handle operations that do not support NHWCD4 format
  1511. */
  1512. auto relayout_first_inp_to_chw =
  1513. [var_to_chw](
  1514. OperatorNodeBase* opr,
  1515. const VarNodeArray& new_inp) -> OperatorNodeBase* {
  1516. mgb_assert(opr->input().size() == new_inp.size());
  1517. VarNodeArray t_inp = new_inp;
  1518. t_inp[0] = var_to_chw(opr->input(0), new_inp[0]);
  1519. return serialization::copy_opr_shallow(*opr, t_inp, opr->config());
  1520. };
  1521. auto ret = std::make_unique<ConvertFormatPass>();
  1522. ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
  1523. auto&& replace_func = ret->m_opr_replace_func;
  1524. replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
  1525. replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
  1526. replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
  1527. replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
  1528. replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
  1529. replace_func[opr::Concat::typeinfo()] = replace_concat_opr;
  1530. replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw;
  1531. replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw;
  1532. replace_func[opr::Images2NeibsBackward::typeinfo()] = relayout_inp_to_chw;
  1533. replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw;
  1534. replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_chw;
  1535. replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_chw;
  1536. replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_chw;
  1537. replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw;
  1538. replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw;
  1539. replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw;
  1540. replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
  1541. replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
  1542. replace_func[opr::WarpPerspectiveForward::typeinfo()] =
  1543. replace_warp_perspective_opr;
  1544. replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr;
  1545. replace_func[opr::LocalForward::typeinfo()] = relayout_first_inp_to_chw;
  1546. replace_func[opr::GroupLocalForward::typeinfo()] = relayout_first_inp_to_chw;
  1547. return ret;
  1548. MIDOUT_E
  1549. }
  1550. /* ================ ConvertBatchNormPass ================ */
  1551. const char* ConvertBatchNormToElemwisePass::name() const {
  1552. return "convert_batch_norm";
  1553. }
  1554. void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
  1555. MIDOUT_B("ConvertBatchNormToElemwisePass::apply")
  1556. auto rewriter = state.graph().make_rewriter();
  1557. auto on_opr = [&](OperatorNodeBase* opr) {
  1558. if (auto bn = try_cast_as_op<opr::BatchNorm>(opr)) {
  1559. if (bn->input().size() == 5) {
  1560. mgb_assert(
  1561. bn->param().fwd_mode ==
  1562. opr::BatchNorm::Param::FwdMode::INFERENCE);
  1563. SymbolVar x = {rewriter.get_var(bn->input(0))};
  1564. SymbolVar scale = {rewriter.get_var(bn->input(1))};
  1565. SymbolVar bias = {rewriter.get_var(bn->input(2))};
  1566. SymbolVar mean = {rewriter.get_var(bn->input(3))};
  1567. SymbolVar variance = {rewriter.get_var(bn->input(4))};
  1568. SymbolVar invsqrt_variance = opr::PowC::make(
  1569. variance + variance.make_scalar_dt(float(bn->param().epsilon)),
  1570. {-0.5});
  1571. auto res = scale * (x - mean) * invsqrt_variance + bias;
  1572. if (x.dtype() != res.dtype()) {
  1573. mgb_throw(
  1574. MegBrainError,
  1575. "BN's input dtype %s is not compatible with "
  1576. "param dtype %s when fusing BN. You may need to "
  1577. "dump FP32 model.",
  1578. x.dtype().name(), res.dtype().name());
  1579. }
  1580. rewriter.replace_var(
  1581. opr->output(5), res.node(),
  1582. mgb_cstr_log("replace batch_norm(x, scale, bias, mean, "
  1583. "varience) "
  1584. "-> (sclae * (x - mean) / sqrt(variance)) + b)"));
  1585. return;
  1586. }
  1587. }
  1588. rewriter.auto_replace_outputs(opr);
  1589. };
  1590. state.graph().iter(on_opr);
  1591. rewriter.apply_inplace();
  1592. MIDOUT_E
  1593. }
  1594. /* ================ FuseConvBiasNonlinPass ================ */
  1595. const char* FuseConvBiasNonlinPass::name() const {
  1596. return "combine_conv_bias_and_relu";
  1597. }
  1598. void FuseConvBiasNonlinPass::apply(OptState& state) const {
  1599. MIDOUT_B("FuseConvBiasNonlinPass::apply")
  1600. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1601. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1602. for (auto& inp : opr->input()) {
  1603. m_deps[inp].push_back(opr);
  1604. }
  1605. });
  1606. auto rewriter = state.graph().make_rewriter();
  1607. using Mode = opr::Elemwise::Param::Mode;
  1608. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1609. auto get_nonlinearity_mode = [&](opr::Elemwise* elem) -> NonlineMode {
  1610. if (elem->param().mode == Mode::FUSE_ADD_RELU ||
  1611. elem->param().mode == Mode::RELU) {
  1612. return NonlineMode::RELU;
  1613. } else if (
  1614. elem->param().mode == Mode::FUSE_ADD_SIGMOID ||
  1615. elem->param().mode == Mode::SIGMOID) {
  1616. return NonlineMode::SIGMOID;
  1617. } else {
  1618. return NonlineMode::IDENTITY;
  1619. }
  1620. };
  1621. auto try_fuse_bias_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1622. bool can_be_fused = true;
  1623. can_be_fused &= (elem->input().size() == 2);
  1624. can_be_fused &= (elem->param().mode == Mode::FUSE_ADD_RELU) ||
  1625. (elem->param().mode == Mode::FUSE_ADD_TANH) ||
  1626. (elem->param().mode == Mode::FUSE_ADD_SIGMOID);
  1627. return can_be_fused;
  1628. };
  1629. auto try_fuse_bias = [&](opr::Elemwise* elem) -> bool {
  1630. bool can_be_fused = true;
  1631. can_be_fused &= (elem->input().size() == 2);
  1632. can_be_fused &= (elem->param().mode == Mode::ADD);
  1633. return can_be_fused;
  1634. };
  1635. auto try_fuse_nonlinearity = [&](opr::Elemwise* elem) -> bool {
  1636. bool can_be_fused = true;
  1637. can_be_fused &= (elem->input().size() == 1);
  1638. can_be_fused &= (elem->param().mode == Mode::RELU) ||
  1639. (elem->param().mode == Mode::TANH) ||
  1640. (elem->param().mode == Mode::SIGMOID);
  1641. return can_be_fused;
  1642. };
  1643. auto convert_to_conv_bias_param =
  1644. [&](const opr::Convolution::Param& param) -> opr::ConvBiasForward::Param {
  1645. using Param = opr::ConvBiasForward::Param;
  1646. return opr::ConvBiasForward::Param{
  1647. Param::NonlineMode::IDENTITY,
  1648. param.mode,
  1649. param.sparse,
  1650. param.format,
  1651. param.pad_h,
  1652. param.pad_w,
  1653. param.stride_h,
  1654. param.stride_w,
  1655. param.dilate_h,
  1656. param.dilate_w,
  1657. param.compute_mode};
  1658. };
  1659. auto check_bias_shape = [&](opr::Convolution* conv, VarNode* bias) -> bool {
  1660. bool valid_bias_shape = true;
  1661. using Format = opr::Convolution::Param::Format;
  1662. using Sparse = opr::Convolution::Param::Sparse;
  1663. auto dst_shape = conv->output(0)->shape();
  1664. auto filter_shape = conv->input(1)->shape();
  1665. auto bias_shape = bias->shape();
  1666. //! pay attention: make sure bias node is not const provider when
  1667. //! batch > 1 cause shape assert problem in convbias
  1668. //! if you resize the input shape, can not update the bias shape too.
  1669. //! so do not fuse conv bias in this situation
  1670. if (dst_shape.eq_shape(bias_shape) && !cg::is_const_var_shape(bias)) {
  1671. return valid_bias_shape;
  1672. }
  1673. size_t OC = filter_shape[0];
  1674. if (conv->param().sparse == Sparse::GROUP) {
  1675. OC *= filter_shape[1];
  1676. }
  1677. if (conv->param().format == Format::NCHW) {
  1678. valid_bias_shape &=
  1679. ((bias_shape.ndim == 4) && (bias_shape[0] == 1) &&
  1680. (bias_shape[1] == OC) && (bias_shape[2] == 1) &&
  1681. (bias_shape[3] == 1));
  1682. } else if (conv->param().format == Format::NCHW4) {
  1683. valid_bias_shape &=
  1684. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1685. (bias_shape[1] == OC / 4) && (bias_shape[2] == 1) &&
  1686. (bias_shape[3] == 1) && bias_shape[4] == 4);
  1687. } else if (conv->param().format == Format::NHWC) {
  1688. valid_bias_shape &=
  1689. ((bias_shape.ndim == 4) && (bias_shape[0] == 1) &&
  1690. (bias_shape[1] == 1) && (bias_shape[2] == 1) &&
  1691. (bias_shape[3] == OC));
  1692. } else {
  1693. valid_bias_shape &=
  1694. ((bias_shape.ndim == 5) && (bias_shape[0] == 1) &&
  1695. (bias_shape[1] == 1) && (bias_shape[2] == OC) &&
  1696. (bias_shape[3] == 1) && (bias_shape[4] == 4));
  1697. mgb_assert(conv->param().format == Format::NHWCD4);
  1698. }
  1699. return valid_bias_shape;
  1700. };
  1701. auto try_fuse_typecvt = [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1702. mgb_assert(typecvt->input().size() == 1);
  1703. auto conv_bias = try_cast_as_op<opr::ConvBias>(
  1704. rewriter.get_var(typecvt->input(0))->owner_opr());
  1705. if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 ||
  1706. typecvt->output(0)->dtype().enumv() !=
  1707. DTypeTrait<dtype::QuantizedS8>::enumv ||
  1708. typecvt->input(0)->dtype().enumv() !=
  1709. DTypeTrait<dtype::QuantizedS32>::enumv)
  1710. return nullptr;
  1711. auto config = conv_bias->config();
  1712. config.output_dtype(typecvt->output(0)->dtype());
  1713. if (conv_bias->input().size() == 3) {
  1714. // conv + bias
  1715. return opr::ConvBias::make(
  1716. conv_bias->input(0), conv_bias->input(1),
  1717. conv_bias->input(2), conv_bias->param(),
  1718. conv_bias->execution_policy(), config)
  1719. .node()
  1720. ->owner_opr();
  1721. } else {
  1722. // conv without bias
  1723. return opr::ConvBias::make(
  1724. conv_bias->input(0), conv_bias->input(1), conv_bias->param(),
  1725. conv_bias->execution_policy(), config)
  1726. .node()
  1727. ->owner_opr();
  1728. }
  1729. };
  1730. auto on_opr = [&](OperatorNodeBase* opr) {
  1731. auto check_conv = [](opr::Convolution* conv) -> bool {
  1732. return conv->param().format == megdnn::param::Convolution::Format::NHWCD4 ||
  1733. conv->param().format == megdnn::param::Convolution::Format::NHWC ||
  1734. conv->param().format == megdnn::param::Convolution::Format::NCHW ||
  1735. conv->param().format == megdnn::param::Convolution::Format::NCHW4
  1736. ;
  1737. };
  1738. if (auto elem = try_cast_as_op<opr::Elemwise>(opr)) {
  1739. if (try_fuse_bias_nonlinearity(elem) || try_fuse_bias(elem)) {
  1740. auto inp1 = rewriter.get_var(elem->input(0));
  1741. auto inp2 = rewriter.get_var(elem->input(1));
  1742. opr::Convolution* conv = nullptr;
  1743. size_t bias_idx = 0;
  1744. if (inp1->owner_opr()->same_type<opr::Convolution>() &&
  1745. m_deps[elem->input(0)].size() == 1) {
  1746. conv = try_cast_as_op<opr::Convolution>(inp1->owner_opr());
  1747. bias_idx = 1;
  1748. } else if (
  1749. inp2->owner_opr()->same_type<opr::Convolution>() &&
  1750. m_deps[elem->input(1)].size() == 1) {
  1751. conv = try_cast_as_op<opr::Convolution>(inp2->owner_opr());
  1752. bias_idx = 0;
  1753. }
  1754. auto bias_inp = rewriter.get_var(elem->input(bias_idx));
  1755. if (conv && check_conv(conv) && check_bias_shape(conv, bias_inp)) {
  1756. opr::ConvBiasForward::Param param =
  1757. convert_to_conv_bias_param(conv->param());
  1758. param.nonlineMode = get_nonlinearity_mode(elem);
  1759. auto new_var =
  1760. opr::ConvBiasForward::make(
  1761. conv->input(0), conv->input(1), bias_inp, param,
  1762. conv->execution_policy(), conv->config())
  1763. .node();
  1764. rewriter.replace_var(
  1765. opr->output(0), new_var,
  1766. mgb_cstr_log("replace nonlinearity(conv(x, w) + b) "
  1767. "-> conv_bias(x, w, b)"));
  1768. return;
  1769. }
  1770. } else if (try_fuse_nonlinearity(elem)) {
  1771. auto inp = rewriter.get_var(elem->input(0));
  1772. {
  1773. auto conv = try_cast_as_op<opr::Convolution>(inp->owner_opr());
  1774. if (conv && check_conv(conv) &&
  1775. m_deps[elem->input(0)].size() == 1) {
  1776. opr::ConvBiasForward::Param param =
  1777. convert_to_conv_bias_param(conv->param());
  1778. param.nonlineMode = get_nonlinearity_mode(elem);
  1779. auto new_var = opr::ConvBiasForward::make(
  1780. conv->input(0), conv->input(1), param,
  1781. conv->execution_policy(), conv->config())
  1782. .node();
  1783. rewriter.replace_var(
  1784. opr->output(0), new_var,
  1785. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1786. "-> conv_bias(x, w)"));
  1787. return;
  1788. }
  1789. }
  1790. {
  1791. auto conv = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1792. auto check_conv_bias = [&](opr::ConvBias* opr) {
  1793. return opr->param().format ==
  1794. opr::ConvBias::Param::Format::NHWC ||
  1795. opr->param().format ==
  1796. opr::ConvBias::Param::Format::NCHW ||
  1797. opr->param().format ==
  1798. opr::ConvBias::Param::Format::NCHW4
  1799. ;
  1800. };
  1801. if (conv && check_conv_bias(conv) &&
  1802. m_deps[elem->input(0)].size() == 1) {
  1803. auto param = conv->param();
  1804. param.nonlineMode = get_nonlinearity_mode(elem);
  1805. auto new_var =
  1806. opr::ConvBiasForward::make(
  1807. conv->input(0), conv->input(1), conv->input(2),
  1808. param, conv->execution_policy(), conv->config())
  1809. .node();
  1810. rewriter.replace_var(
  1811. opr->output(0), new_var,
  1812. mgb_cstr_log("replace nonlinearity(conv(x, w)) "
  1813. "-> conv_bias(x, w)"));
  1814. return;
  1815. }
  1816. }
  1817. }
  1818. } else if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1819. auto new_opr = try_fuse_typecvt(typecvt);
  1820. if (new_opr) {
  1821. rewriter.replace_var(
  1822. opr->output(0), new_opr->output(0),
  1823. mgb_cstr_log("replace typecvt(conv_bias(x, w, b)) -> "
  1824. "conv_bias(x, w, b)"));
  1825. return;
  1826. }
  1827. }
  1828. rewriter.auto_replace_outputs(opr);
  1829. };
  1830. state.graph().iter(on_opr);
  1831. rewriter.apply_inplace();
  1832. MIDOUT_E
  1833. }
  1834. /* ================ FuseConvBiasZPass ================ */
  1835. const char* FuseConvBiasZPass::name() const {
  1836. return "combine_conv_bias_and_z";
  1837. }
  1838. void FuseConvBiasZPass::apply(OptState& state) const {
  1839. MIDOUT_B("FuseConvBiasZPass::apply")
  1840. UniqReaderCheck uniq_reader_check{state.graph()};
  1841. auto rewriter = state.graph().make_rewriter();
  1842. using Mode = opr::Elemwise::Param::Mode;
  1843. using MultiMode = opr::ElemwiseMultiType::Param::Mode;
  1844. using NonlineMode = opr::ConvBiasForward::Param::NonlineMode;
  1845. auto check_conv_bias = [](opr::ConvBias* conv_bias) -> bool {
  1846. return conv_bias->param().format == megdnn::param::ConvBias::Format::NHWC ||
  1847. conv_bias->param().format == megdnn::param::ConvBias::Format::NCHW ||
  1848. conv_bias->param().format == megdnn::param::ConvBias::Format::NCHW4
  1849. ;
  1850. };
  1851. auto check_fuse_shape = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1852. bool valid_fuse_shape = true;
  1853. auto z_shape = z->shape();
  1854. auto bias_shape = conv_bias->input(2)->shape();
  1855. auto conv_bias_shape = conv_bias->output(0)->shape();
  1856. valid_fuse_shape &= (!conv_bias_shape.eq_shape(bias_shape));
  1857. valid_fuse_shape &= conv_bias_shape.eq_shape(z_shape);
  1858. return valid_fuse_shape;
  1859. };
  1860. auto check_fuse_dtype = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1861. return conv_bias->output(0)->dtype().enumv() == z->dtype().enumv();
  1862. };
  1863. #if MGB_CUDA && (CUDNN_MAJOR == 8)
  1864. auto check_fuse_param = [&](opr::ConvBias* conv_bias, VarNode* z) -> bool {
  1865. return conv_bias->input(0) != z;
  1866. };
  1867. #endif
  1868. auto get_convbias_nonline_mode = [&](OperatorNodeBase* opr) -> NonlineMode {
  1869. if (opr->same_type<opr::Elemwise>()) {
  1870. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1871. if (elem->param().mode == Mode::FUSE_ADD_RELU)
  1872. return NonlineMode::RELU;
  1873. }
  1874. if (opr->same_type<opr::ElemwiseMultiType>()) {
  1875. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1876. if (elem->param().mode == MultiMode::QFUSE_ADD_RELU)
  1877. return NonlineMode::RELU;
  1878. else if (elem->param().mode == MultiMode::QFUSE_ADD_H_SWISH)
  1879. return NonlineMode::H_SWISH;
  1880. }
  1881. return NonlineMode::IDENTITY;
  1882. };
  1883. auto try_replace_var_node = [&](OperatorNodeBase* opr) {
  1884. opr::ConvBias* conv_bias = nullptr;
  1885. size_t z_idx = 0;
  1886. size_t nr_inps = opr->input().size();
  1887. for (size_t i = 0; i < nr_inps; i++) {
  1888. auto inp = rewriter.get_var(opr->input(i));
  1889. if (inp->owner_opr()->same_type<opr::ConvBias>()) {
  1890. auto cb = try_cast_as_op<opr::ConvBias>(inp->owner_opr());
  1891. if (cb->input().size() == 3 &&
  1892. cb->param().nonlineMode ==
  1893. opr::ConvBias::Param::NonlineMode::IDENTITY &&
  1894. uniq_reader_check(opr->input(i))) {
  1895. conv_bias = cb;
  1896. z_idx = nr_inps - i - 1;
  1897. break;
  1898. }
  1899. }
  1900. }
  1901. auto z_inp = rewriter.get_var(opr->input(z_idx));
  1902. if (conv_bias && check_conv_bias(conv_bias) &&
  1903. check_fuse_shape(conv_bias, z_inp) &&
  1904. #if MGB_CUDA && (CUDNN_MAJOR == 8)
  1905. check_fuse_param(conv_bias, z_inp) &&
  1906. #endif
  1907. check_fuse_dtype(conv_bias, z_inp)) {
  1908. auto param = conv_bias->param();
  1909. param.nonlineMode = get_convbias_nonline_mode(opr);
  1910. auto config = conv_bias->config();
  1911. auto new_var = opr::ConvBiasForward::make(
  1912. conv_bias->input(0), conv_bias->input(1),
  1913. conv_bias->input(2), z_inp, param,
  1914. conv_bias->execution_policy(),
  1915. config.output_dtype(opr->output(0)->dtype()))
  1916. .node();
  1917. rewriter.replace_var(
  1918. opr->output(0), new_var,
  1919. mgb_cstr_log("replace "
  1920. "nonlinearity(conv_bias(x,w,b) + z) "
  1921. "-> conv_bias(x, w, b, z)"));
  1922. uniq_reader_check.update_on_opr_auto_replace(opr, new_var->owner_opr());
  1923. return true;
  1924. }
  1925. return false;
  1926. };
  1927. auto try_fuse_elemwise = [&](OperatorNodeBase* opr) {
  1928. if (!opr->same_type<opr::Elemwise>())
  1929. return false;
  1930. auto elem = try_cast_as_op<opr::Elemwise>(opr);
  1931. if (elem->input().size() != 2)
  1932. return false;
  1933. if (elem->param().mode != Mode::ADD &&
  1934. elem->param().mode != Mode::FUSE_ADD_RELU)
  1935. return false;
  1936. return try_replace_var_node(opr);
  1937. };
  1938. auto try_fuse_elemwise_multi_type = [&](OperatorNodeBase* opr) {
  1939. if (!opr->same_type<opr::ElemwiseMultiType>())
  1940. return false;
  1941. auto elem = try_cast_as_op<opr::ElemwiseMultiType>(opr);
  1942. if (elem->input().size() != 2)
  1943. return false;
  1944. if (elem->param().mode != MultiMode::QADD &&
  1945. elem->param().mode != MultiMode::QFUSE_ADD_RELU &&
  1946. elem->param().mode != MultiMode::QFUSE_ADD_H_SWISH)
  1947. return false;
  1948. return try_replace_var_node(opr);
  1949. };
  1950. auto on_opr = [&](OperatorNodeBase* opr) {
  1951. if (try_fuse_elemwise(opr))
  1952. return;
  1953. if (try_fuse_elemwise_multi_type(opr))
  1954. return;
  1955. auto new_opr = rewriter.auto_replace_outputs(opr);
  1956. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  1957. };
  1958. state.graph().iter(on_opr);
  1959. rewriter.apply_inplace();
  1960. MIDOUT_E
  1961. }
  1962. /* ================ FuseDeconvCvtPass ================ */
  1963. const char* FuseDeconvCvtPass::name() const {
  1964. return "combine_deconv_and_typecvt";
  1965. }
  1966. void FuseDeconvCvtPass::apply(OptState& state) const {
  1967. MIDOUT_B("FuseDeconvCvtPass::apply")
  1968. std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
  1969. state.graph().iter([&m_deps](OperatorNodeBase* opr) {
  1970. for (auto& inp : opr->input()) {
  1971. m_deps[inp].push_back(opr);
  1972. }
  1973. });
  1974. UniqReaderCheck uniq_reader_check{state.graph()};
  1975. auto rewriter = state.graph().make_rewriter();
  1976. auto try_fuse_deconv_typecvt = [&](opr::TypeCvt* typecvt) -> OperatorNodeBase* {
  1977. mgb_assert(typecvt->input().size() == 1);
  1978. auto deconv = try_cast_as_op<opr::ConvolutionBackwardData>(
  1979. rewriter.get_var(typecvt->input(0))->owner_opr());
  1980. if (!deconv
  1981. || m_deps.count(typecvt->input(0)) != 1 ||
  1982. typecvt->output(0)->dtype().enumv() !=
  1983. DTypeTrait<dtype::QuantizedS8>::enumv) {
  1984. return nullptr;
  1985. }
  1986. if (!uniq_reader_check(deconv->output(0)))
  1987. return nullptr;
  1988. auto config = deconv->config();
  1989. config.output_dtype(typecvt->output(0)->dtype());
  1990. return opr::ConvolutionBackwardData::make(
  1991. deconv->input(0), deconv->input(1), deconv->param(),
  1992. deconv->execution_policy(), config)
  1993. .node()
  1994. ->owner_opr();
  1995. };
  1996. auto on_opr = [&](OperatorNodeBase* opr) {
  1997. if (auto typecvt = try_cast_as_op<opr::TypeCvt>(opr)) {
  1998. if (auto deconv_new = try_fuse_deconv_typecvt(typecvt)) {
  1999. rewriter.replace_var(
  2000. opr->output(0), deconv_new->output(0),
  2001. mgb_cstr_log("replace typecvt(deconv(x, w)) -> "
  2002. "deconv(x, w)"));
  2003. uniq_reader_check.update_on_opr_auto_replace(opr, deconv_new);
  2004. return;
  2005. }
  2006. }
  2007. auto new_opr = rewriter.auto_replace_outputs(opr);
  2008. uniq_reader_check.update_on_opr_auto_replace(opr, new_opr);
  2009. };
  2010. state.graph().iter(on_opr);
  2011. rewriter.apply_inplace();
  2012. MIDOUT_E
  2013. }
  2014. /* ================ ParamMergePass ================ */
  2015. const char* ParamMergePass::name() const {
  2016. return mgb_cstr_log("param_merge");
  2017. }
  2018. void ParamMergePass::apply(OptState& opt_state) const {
  2019. MIDOUT_B("ParamMergePass::apply")
  2020. param_merge<opr::SharedDeviceTensor, opr::MultipleDeviceTensorHolder>(opt_state);
  2021. param_merge<
  2022. opr::SharedDeviceTensorWithFormat,
  2023. opr::MultipleDeviceTensorWithFormatHolder>(opt_state);
  2024. MIDOUT_E
  2025. }
  2026. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}