You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.cpp 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853
  1. /**
  2. * \file src/gopt/impl/framework.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/gopt/framework.h"
  12. #include "megbrain/gopt/inference.h"
  13. #include "megbrain/gopt/basic_arith.h"
  14. #include "megbrain/gopt/misc.h"
  15. #include "megbrain/gopt/gtrans.h"
  16. #include "megbrain/graph/cg.h"
  17. #include "megbrain/graph/event.h"
  18. #include "megbrain/graph/exc_extra_info.h"
  19. #include "megbrain/serialization/serializer.h"
  20. #include "megbrain/serialization/opr_shallow_copy.h"
  21. #include "megbrain/utils/timer.h"
  22. #if MGB_JIT
  23. #include "megbrain/jit/fusion_pass.h"
  24. #endif
  25. #if MGB_ENABLE_TENSOR_RT
  26. #include "megbrain/tensorrt/opr_replace.h"
  27. #endif
  28. using namespace mgb;
  29. using namespace gopt;
  30. /* ================ SubGraph ================ */
  31. OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(
  32. OperatorNodeBase *opr) {
  33. auto &&new_inp = m_opr_new_inp_cache;
  34. new_inp.clear();
  35. new_inp.reserve(opr->input().size());
  36. bool has_replaced_inp = false;
  37. for (auto i: opr->input()) {
  38. auto new_var = get_var(i);
  39. if (new_var != i) {
  40. has_replaced_inp = true;
  41. new_inp.push_back(new_var);
  42. } else {
  43. new_inp.push_back(i);
  44. }
  45. }
  46. if (has_replaced_inp) {
  47. auto new_opr = serialization::copy_opr_shallow(
  48. *opr, new_inp, opr->config());
  49. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  50. size_t i = 0;
  51. auto err_msg = [opr, new_opr] {
  52. return ssprintf("bad opr copy: src=%s{%s} dst=%s{%s}",
  53. opr->cname(), opr->dyn_typeinfo()->name,
  54. new_opr->cname(), new_opr->dyn_typeinfo()->name);
  55. };
  56. MGB_MARK_USED_VAR(err_msg);
  57. // opr output size mismatch may be caused by:
  58. // 0) inplace arith optimization (e.g. PowC need an extra workspace)
  59. // 1) other post-insert optimization (e.g. const folding)
  60. // we can't handle only usable_output here, since some output var with
  61. // volatile flag could be the graph's endpoint (e.g. RemoteSend)
  62. for (; i < std::min(out0.size(), out1.size()); ++ i) {
  63. bool v0 = out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  64. v1 = out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT);
  65. mgb_assert(v0 == v1, "%s", err_msg().c_str());
  66. auto &&ins = m_varmap.insert({out0[i], {true, nullptr}});
  67. mgb_assert(ins.second || ins.first->second.first,
  68. "opr output already replaced");
  69. // handle repeated call on the same opr
  70. ins.first->second.second = out1[i];
  71. on_var_replaced(out0[i], out1[i], nullptr);
  72. }
  73. for (; i < out0.size(); ++ i) {
  74. mgb_assert(out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  75. "%s", err_msg().c_str());
  76. }
  77. for (; i < out1.size(); ++ i) {
  78. mgb_assert(out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  79. "%s", err_msg().c_str());
  80. }
  81. return new_opr;
  82. }
  83. return opr;
  84. }
  85. void SubGraph::Rewriter::replace_var(
  86. VarNode *src, VarNode *dst, const char *msg) {
  87. if (src == dst)
  88. return;
  89. // Optimizers should not create a loop in varaible replace map.
  90. mgb_throw_if(
  91. get_var_internal(dst).second == src, InternalError,
  92. "dst %s maps back to src %s in SubGraph::Rewriter::replace_var",
  93. dst->cname(), src->cname());
  94. auto &&ins = m_varmap.insert({src, {false, dst}});
  95. if (!ins.second) {
  96. auto &&old_rep = ins.first->second;
  97. mgb_assert(old_rep.first || old_rep.second == dst,
  98. "can not replace a var twice");
  99. old_rep.first = false;
  100. old_rep.second = dst;
  101. }
  102. on_var_replaced(src, dst, msg);
  103. }
  104. void SubGraph::Rewriter::on_var_replaced(
  105. VarNode* src, VarNode* dst, const char* msg) {
  106. if (auto state = m_owner_graph->owner_opt_state()) {
  107. state->on_var_replaced(src, dst, msg);
  108. }
  109. }
  110. void SubGraph::Rewriter::apply_inplace() const {
  111. m_owner_graph->m_endpoint_oprs.clear();
  112. m_owner_graph->m_endpoint_vars_set.clear();
  113. for (auto &&var: m_owner_graph->m_endpoint_vars) {
  114. var = get_var(var.node());
  115. m_owner_graph->m_endpoint_oprs.insert(var.node()->owner_opr());
  116. m_owner_graph->m_endpoint_vars_set.insert(var.node());
  117. }
  118. }
  119. std::pair<bool, VarNode*> SubGraph::Rewriter::get_var_internal(VarNode* var) {
  120. // The implementation is (manually) unrolled once, background:
  121. // git-core/brain-sdk/MegBrain/merge_requests/486#note_76971
  122. auto it = m_varmap.find(var);
  123. if (it == m_varmap.end()) {
  124. return {true, var};
  125. }
  126. mgb_assert(it->second.second != var, "loop detected in m_varmap");
  127. auto it_next = m_varmap.find(it->second.second);
  128. if (it_next == m_varmap.end()) {
  129. return it->second;
  130. }
  131. mgb_assert(it_next->second.second != it->second.second,
  132. "loop detected in m_varmap");
  133. auto next = get_var_internal(it_next->second.second);
  134. it_next->second = {next.first & it_next->second.first, next.second};
  135. return it->second = {it_next->second.first & it->second.first, next.second};
  136. }
  137. SubGraph::SubGraph(const SymbolVarArray &endpoint_vars):
  138. m_endpoint_vars(endpoint_vars)
  139. {
  140. mgb_assert(!endpoint_vars.empty(), "endpoints can not be empty");
  141. m_comp_graph = endpoint_vars[0].node()->owner_graph();
  142. for (auto i: endpoint_vars) {
  143. m_endpoint_oprs.insert(i.node()->owner_opr());
  144. m_endpoint_vars_set.insert(i.node());
  145. mgb_assert(m_comp_graph == i.node()->owner_graph(),
  146. "endpoints belong to different computing graphs");
  147. }
  148. }
  149. void SubGraph::iter(
  150. const Callback& cb,
  151. std::shared_ptr<ExtraDep> extra_dep) const {
  152. Callback on_opr;
  153. if (m_owner_opt_state) {
  154. on_opr = [state=m_owner_opt_state, &cb](OperatorNodeBase *opr) {
  155. state->m_opr_property_flag = OprPropertyFlag::ALL;
  156. state->m_cur_iter_src_opr = cg::get_opr_root_source_opr(opr);
  157. state->m_cur_iter_opr_priority =
  158. opr->node_prop().attribute().priority;
  159. state->m_cur_iter_opr_stream_prop_type =
  160. state->m_comp_node_opt.stream_prop_type(
  161. opr->output(0));
  162. mgb_assert(state->m_oprs_inserted.empty());
  163. cb(opr);
  164. state->m_opr_property_flag = OprPropertyFlag::NONE;
  165. state->m_cur_iter_src_opr = nullptr;
  166. state->m_oprs_inserted.clear();
  167. };
  168. } else {
  169. on_opr = cb;
  170. }
  171. cg::DepOprIter dep_iter{on_opr, std::move(extra_dep)};
  172. for (auto i: m_endpoint_oprs)
  173. dep_iter.add(i);
  174. }
  175. ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const {
  176. ThinHashMap<VarNode*, size_t> ret;
  177. auto cb = [&](OperatorNodeBase *opr) {
  178. for (auto &&i: opr->node_prop().dep_map()) {
  179. if (OperatorNodeBase::NodeProp::is_device_value_dep(i.second)) {
  180. ++ ret.at(i.first);
  181. }
  182. }
  183. for (auto i: opr->output()) {
  184. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  185. auto ins = ret.insert({i, 0});
  186. mgb_assert(ins.second);
  187. }
  188. }
  189. };
  190. iter(cb);
  191. for (auto i: m_endpoint_vars_set) {
  192. auto iter = ret.find(i);
  193. if (iter == ret.end()) {
  194. mgb_assert(i->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
  195. ret[i] = 1;
  196. } else {
  197. ++ ret.at(i);
  198. }
  199. }
  200. return ret;
  201. }
  202. /* ================ UniqReaderCheck ================ */
  203. UniqReaderCheck::UniqReaderCheck(const SubGraph &graph):
  204. m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()}
  205. {
  206. }
  207. void UniqReaderCheck::update_on_opr_auto_replace(OperatorNodeBase* opr,
  208. OperatorNodeBase* repl_opr) {
  209. auto non_volatile_size = [](const VarNodeArray& vars) -> size_t {
  210. size_t size = 0;
  211. for (size_t i = 0; i < vars.size(); ++i) {
  212. if (!vars[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  213. size++;
  214. }
  215. }
  216. return size;
  217. };
  218. if (opr != repl_opr) {
  219. auto &&o0 = opr->output(), &&o1 = repl_opr->output();
  220. mgb_assert(non_volatile_size(o0) == non_volatile_size(o1));
  221. for (size_t i = 0; i < o0.size(); ++i) {
  222. auto iter = m_var2nr_val_dep.find(o0[i]);
  223. if (iter != m_var2nr_val_dep.end()) {
  224. auto n = iter->second;
  225. m_var2nr_val_dep[o1[i]] = n;
  226. }
  227. }
  228. }
  229. }
  230. /* ================ OptState ================ */
  231. OptState::OptState(
  232. const GraphOptimizer *owner_optimizer, const SubGraph& graph):
  233. m_owner_optimizer{owner_optimizer},
  234. m_var_replace_map{
  235. const_cast<ThinHashMap<VarNode*, VarNode*>*>(
  236. &GraphOptimizer::var_replace_map(*graph.comp_graph()))},
  237. m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()},
  238. m_graph{graph}
  239. {
  240. mgb_assert(!m_graph.m_owner_opt_state);
  241. m_var_replace_map->clear();
  242. m_graph.m_owner_opt_state = this;
  243. m_oprs_inserted.clear();
  244. auto on_opr_insert = [this](const cg::event::OprInserted &ev) {
  245. auto need_src_opr = m_opr_property_flag & OprPropertyFlag::SOURCE_OPR,
  246. need_priority = m_opr_property_flag & OprPropertyFlag::PRIORITY;
  247. if (need_src_opr)
  248. mgb_assert(m_cur_iter_src_opr, "opr %s{%s} created outside from "
  249. "SubGraph::iter",
  250. ev.opr->cname(), ev.opr->dyn_typeinfo()->name);
  251. if (ev.exc || ev.is_dedup)
  252. return;
  253. auto &&new_attr = ev.opr->node_prop().attribute();
  254. auto &&ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE});
  255. mgb_assert(ins.second);
  256. if (need_src_opr && !new_attr.src_opr) {
  257. auto src_opr = m_cur_iter_src_opr;
  258. if (ev.opr != src_opr)
  259. new_attr.src_opr = src_opr;
  260. ins.first->second |= OprPropertyFlag::SOURCE_OPR;
  261. }
  262. if (need_priority) {
  263. new_attr.priority = m_cur_iter_opr_priority;
  264. if (!ev.opr->update_priority()) {
  265. ins.first->second |= OprPropertyFlag::PRIORITY;
  266. }
  267. }
  268. auto csp = m_cur_iter_opr_stream_prop_type;
  269. if (csp.prop_type != cg::SeqCompNodeOptimizer::StreamPropType::NONE) {
  270. for (auto i: ev.opr->output())
  271. m_comp_node_opt.register_stream_var(i, csp);
  272. }
  273. };
  274. m_on_opr_insert_handler = graph.comp_graph()->event().register_receiver<
  275. cg::event::OprInserted>(on_opr_insert);
  276. }
  277. void OptState::on_var_replaced(VarNode *src, VarNode *dst, const char *msg) {
  278. if (src->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  279. // this can only happen in auto_replace_outputs()
  280. mgb_assert(dst->contain_flag(VarNode::Flag::VOLATILE_CONTENT) &&
  281. src->owner_opr()->dyn_typeinfo() ==
  282. dst->owner_opr()->dyn_typeinfo());
  283. mgb_assert(!msg);
  284. return;
  285. }
  286. //! check_property
  287. {
  288. auto iter = m_oprs_inserted.find(dst->owner_opr());
  289. if (iter != m_oprs_inserted.end()) {
  290. auto &&src_attr = src->owner_opr()->node_prop().attribute(),
  291. &&dst_attr = dst->owner_opr()->node_prop().attribute();
  292. auto opr_info = [&](OperatorNodeBase* opr) {
  293. return opr ? opr->name() + "(" + std::to_string(opr->id()) + ")"
  294. : "NULL";
  295. };
  296. auto err_msg = [&] {
  297. std::string ret = "Please contact Engine group:\n";
  298. ret += "src opr: ";
  299. ret += opr_info(src->owner_opr());
  300. ret += ", dst opr: ";
  301. ret += opr_info(dst->owner_opr());
  302. return ret;
  303. };
  304. MGB_MARK_USED_VAR(err_msg);
  305. if (iter->second & OprPropertyFlag::SOURCE_OPR) {
  306. auto &&src_rt = get_opr_root_source_opr(src->owner_opr()),
  307. &&dst_rt = get_opr_root_source_opr(dst->owner_opr());
  308. mgb_assert(dst_rt == src_rt,
  309. "%s\nsrc source_opr: %s, dst source_opr: %s\n",
  310. err_msg().c_str(), opr_info(src_rt).c_str(),
  311. opr_info(dst_rt).c_str());
  312. }
  313. if (iter->second & OprPropertyFlag::PRIORITY) {
  314. mgb_assert(src_attr.priority == dst_attr.priority,
  315. "%s\nsrc priority: %d, dst priority %d\n",
  316. err_msg().c_str(), src_attr.priority,
  317. dst_attr.priority);
  318. }
  319. }
  320. }
  321. {
  322. bool suc = true;
  323. SmallVector<std::string> fail_chks;
  324. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_INFER_TYPE) {
  325. auto&& mgr = src->owner_graph()->static_infer_manager();
  326. auto it0 = mgr.get_infer_type(src), it1 = mgr.get_infer_type(dst);
  327. using cg::static_infer::InferType;
  328. // only check wheter inferable
  329. auto norm = [](InferType::Flag f) -> bool {
  330. return f & (InferType::RT_STATIC | InferType::CONST);
  331. };
  332. if (!(norm(it0.shape) == norm(it1.shape) &&
  333. norm(it0.value) <= norm(it1.value))) {
  334. suc = false;
  335. fail_chks.push_back("infer-type");
  336. }
  337. }
  338. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_DTYPE) {
  339. if (src->dtype() != dst->dtype()) {
  340. suc = false;
  341. fail_chks.push_back("dtype");
  342. }
  343. }
  344. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_SHAPE) {
  345. if (!(src->shape().eq_shape(dst->shape()))) {
  346. suc = false;
  347. fail_chks.push_back("shape");
  348. }
  349. }
  350. if (!suc) {
  351. std::string fail_msg = "{";
  352. for (size_t i = 0; i < fail_chks.size(); i++) {
  353. fail_msg += fail_chks[i];
  354. if (i < fail_chks.size() - 1) {
  355. fail_msg += ",";
  356. }
  357. }
  358. fail_msg += "}";
  359. mgb_throw_raw(
  360. cg::OperatorNodeExcExtraInfo::ExcMaker{src->owner_opr()}
  361. .make<InternalError>(ssprintf(
  362. "%s mismatch for replace_var: %s",
  363. fail_msg.c_str(),
  364. cg::dump_var_info({src, dst}).c_str())));
  365. }
  366. }
  367. if (src->has_name_set() && !dst->has_name_set()) {
  368. dst->name(src->name());
  369. }
  370. (*m_var_replace_map)[src] = dst;
  371. // dst should be considered as newly inserted, and previous replace
  372. // record should be ignored
  373. m_var_replace_map->erase(dst);
  374. #if MGB_ENABLE_LOGGING
  375. if (msg && m_owner_optimizer->verbosity()) {
  376. m_log_msg.
  377. append("\n ").
  378. append(std::to_string(m_log_nr_item)).
  379. append(": ").
  380. append(src->owner_opr()->cname()).
  381. append(" => ").
  382. append(dst->owner_opr()->cname()).
  383. append(" (").
  384. append(msg).
  385. append(")");
  386. }
  387. ++ m_log_nr_item;
  388. #endif
  389. }
  390. size_t OptState::flush_log(const char *title) {
  391. if (m_owner_optimizer->verbosity() >= 2) {
  392. if (m_log_msg.empty()) {
  393. m_log_msg = mgb_cstr_log(" no var replacement logged");
  394. }
  395. mgb_log("%s%s", title, m_log_msg.c_str());
  396. m_log_msg.clear();
  397. }
  398. auto ret = m_log_nr_item;
  399. m_log_nr_item = 0;
  400. return ret;
  401. }
  402. void OptState::call_with_opr(OperatorNodeBase *opr, thin_function<void(void)> func,
  403. OprPropertyFlag opr_property_flag) {
  404. auto src_opr = cg::get_opr_root_source_opr(opr);
  405. auto opr_priority = opr->node_prop().attribute().priority;
  406. auto stream_prop_type = m_comp_node_opt.stream_prop_type(opr->output(0));
  407. ThinHashMap<OperatorNodeBase*, OprPropertyFlag> oprs_inserted;
  408. auto swap_properties = [&,
  409. need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR,
  410. need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] {
  411. if (need_src_opr) {
  412. std::swap(m_cur_iter_src_opr, src_opr);
  413. }
  414. if (need_priority) {
  415. std::swap(m_cur_iter_opr_priority, opr_priority);
  416. }
  417. std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type);
  418. std::swap(m_opr_property_flag, opr_property_flag);
  419. std::swap(m_oprs_inserted, oprs_inserted);
  420. };
  421. MGB_TRY {
  422. swap_properties();
  423. func();
  424. } MGB_FINALLY({
  425. swap_properties();
  426. });
  427. }
  428. /* ================ RecursiveSubGraphRewriteHelper ================ */
  429. RecursiveSubGraphRewriteHelper::
  430. ~RecursiveSubGraphRewriteHelper() noexcept = default;
  431. RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState &state):
  432. m_opt_state{state}, m_rewriter{state.graph().make_rewriter()}
  433. {
  434. }
  435. void RecursiveSubGraphRewriteHelper::apply() {
  436. using namespace std::placeholders;
  437. m_opt_state.graph().iter(
  438. std::bind(&RecursiveSubGraphRewriteHelper::on_opr, this, _1));
  439. m_rewriter.apply_inplace();
  440. }
  441. void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) {
  442. auto on_new_opr = [this](OperatorNodeBase *opr) {
  443. auto repl_opr = m_rewriter.auto_replace_outputs(opr);
  444. return on_new_opr_check_should_process(opr, repl_opr);
  445. };
  446. if (!on_new_opr(opr))
  447. return;
  448. auto orig_out = get_opr_single_output_var(opr);
  449. if (!orig_out)
  450. return;
  451. mgb_assert(m_opr_stack.empty());
  452. m_opr_stack.push_back({
  453. orig_out, m_rewriter.get_var(orig_out)->owner_opr()});
  454. bool first = true;
  455. while (!m_opr_stack.empty()) {
  456. auto cur_frame = m_opr_stack.back();
  457. m_opr_stack.pop_back();
  458. auto cur_opr = cur_frame.opr;
  459. bool should_process;
  460. if (first) {
  461. should_process = true;
  462. first = false;
  463. } else {
  464. should_process = on_new_opr(cur_opr);
  465. }
  466. auto cur_out = get_opr_single_output_var(cur_opr);
  467. mgb_assert(cur_out);
  468. cur_out = m_rewriter.get_var(cur_out);
  469. if (should_process) {
  470. auto trans = process_opr(cur_out);
  471. if (trans.valid()) {
  472. m_opr_stack.push_back({
  473. cur_frame.orig_var, trans->result->owner_opr()});
  474. for (auto i: reverse_adaptor(trans->internal)) {
  475. if (i)
  476. m_opr_stack.push_back({i, i->owner_opr()});
  477. }
  478. if (trans->msg) {
  479. if (!m_log_msg.empty())
  480. m_log_msg.push_back(';');
  481. m_log_msg.append(trans->msg);
  482. }
  483. continue;
  484. }
  485. }
  486. auto src = cur_frame.orig_var;
  487. if (m_rewriter.get_var(src) != cur_out) {
  488. const char *msg = nullptr;
  489. if (m_opr_stack.empty()) {
  490. msg = m_log_msg.c_str();
  491. }
  492. m_rewriter.replace_var(src, cur_out, msg);
  493. after_replace_var(src, cur_out);
  494. if (m_opr_stack.empty()) {
  495. m_log_msg.clear();
  496. break;
  497. }
  498. }
  499. }
  500. }
  501. /* ================ GraphOptimizer ================ */
  502. GraphOptimizer::~GraphOptimizer() noexcept = default;
  503. class GraphOptimizer::VarReplaceMapStorage :public UserDataContainer::UserData {
  504. MGB_TYPEINFO_OBJ_DECL;
  505. public:
  506. ThinHashMap<VarNode*, VarNode*> map;
  507. };
  508. MGB_TYPEINFO_OBJ_IMPL(GraphOptimizer::VarReplaceMapStorage);
  509. GraphOptimizer& GraphOptimizer::add_pass(std::unique_ptr<Pass> pass) {
  510. mgb_assert(!pass->m_owner_optimizer);
  511. pass->m_owner_optimizer = this;
  512. m_passes.emplace_back(std::move(pass));
  513. return *this;
  514. }
  515. SubGraph GraphOptimizer::apply(const SubGraph &graph) const {
  516. RealTimer timer;
  517. OptState state{this, graph};
  518. size_t tot_nr_replace = 0;
  519. // first update output var shapes of all oprs
  520. state.graph().iter(cg::update_output_var_shapes);
  521. auto &&opt = graph.comp_graph()->options();
  522. auto orig_setting = opt.graph_opt_level;
  523. Pass *cur_pass = nullptr;
  524. MGB_MARK_USED_VAR(cur_pass);
  525. MGB_TRY {
  526. for (auto &&i: m_passes) {
  527. state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL);
  528. cur_pass = i.get();
  529. opt.graph_opt_level = 1;
  530. i->apply(state);
  531. tot_nr_replace += state.flush_log(
  532. mgb_ssprintf_log(
  533. "apply optimization pass %s:", i->name()).c_str());
  534. }
  535. } MGB_CATCH(std::exception &exc, {
  536. mgb_log_error("error while applying optimization pass %s: %s",
  537. cur_pass->name(), exc.what());
  538. opt.graph_opt_level = orig_setting;
  539. throw;
  540. })
  541. MGB_FINALLY(
  542. opt.graph_opt_level = orig_setting
  543. );
  544. if (verbosity() >= 1) {
  545. mgb_log_debug("graph optimization: applied %zu passes, "
  546. "total %zu var(s) replaced; time=%.2fms",
  547. m_passes.size(), tot_nr_replace, timer.get_msecs());
  548. }
  549. return state.graph();
  550. }
  551. const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray &vars) const {
  552. if (m_passes.empty()) {
  553. // this check is necessary, since OptState would clear
  554. // var_replace_map()
  555. return *this;
  556. }
  557. auto g = apply({{vars.begin(), vars.end()}});
  558. for (size_t i = 0; i < vars.size(); ++ i) {
  559. vars[i] = g.endpoint_vars()[i].node();
  560. }
  561. return *this;
  562. }
  563. GraphOptimizer& GraphOptimizer::add_preset_passes(
  564. bool after_grad, const OptimizeForInferenceOptions* inference_opt,
  565. const ComputingGraph::Options* comp_graph_opt) {
  566. auto cv_type = inference_opt ? ConstVarType::IMMUTABLE_AND_PARAM
  567. : ConstVarType::IMMUTABLE;
  568. if (inference_opt) {
  569. add_pass<ConvertBatchNormToElemwisePass>();
  570. }
  571. if (!after_grad || inference_opt) {
  572. add_pass<CondExecConstPredicateFolding>();
  573. }
  574. if (after_grad || inference_opt) {
  575. add_pass<RemoveNonComputingOprPass>();
  576. }
  577. add_pass<DelayBroadcastPass>();
  578. add_pass<ExpandFusedArithPass>();
  579. add_pass<NormalizeArithChainPass>();
  580. if (inference_opt) {
  581. add_pass<ParamRedistributePass>();
  582. add_pass<ParamFusePass>();
  583. }
  584. add_pass<ArithMulDistributePass>();
  585. add_pass<ReorderArithChainPass>(cv_type);
  586. add_pass<ArithFusePass>();
  587. // reorder again because shapes of fused oprs might change
  588. add_pass<ReorderArithChainPass>(cv_type);
  589. add_pass<FinalArithTransformPass>();
  590. add_pass<RemoveRedundantTypeCvtPass>();
  591. #if MGB_JIT
  592. bool need_jit = false;
  593. if (comp_graph_opt && (std::abs(comp_graph_opt->graph_opt_level) >= 3 ||
  594. comp_graph_opt->graph_opt.jit)) {
  595. need_jit = true;
  596. }
  597. if (need_jit && after_grad) {
  598. add_pass<gopt::RecompTypeCvtPass>();
  599. }
  600. #endif
  601. // combine astype and reduce.
  602. // Note: apply this pass before JITFusion, so the TypeCvt which
  603. // read by both Reduce and Elemwise could be fused correctly.
  604. add_pass<CombineAstypeAndReducePass>();
  605. #if MGB_JIT
  606. if (need_jit) {
  607. add_pass<gopt::JITFusionPass>(
  608. after_grad,
  609. std::max<uint8_t>(comp_graph_opt->graph_opt.jit, 1));
  610. }
  611. #endif
  612. if (inference_opt) {
  613. add_pass<ParamFusePass>();
  614. add_passes_for_optimize_options(*inference_opt);
  615. }
  616. if (inference_opt) {
  617. // merge params to reduce loading time and graph overhead
  618. add_pass<ParamMergePass>();
  619. add_pass<FuseDeconvCvtPass>();
  620. }
  621. return *this;
  622. }
  623. const ThinHashMap<VarNode*, VarNode*>& GraphOptimizer::var_replace_map(
  624. ComputingGraph &graph) {
  625. auto storage = graph.options().user_data.get_user_data_or_create<
  626. VarReplaceMapStorage>();
  627. return storage->map;
  628. }
  629. VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) {
  630. auto &&map = var_replace_map(*(var->owner_graph()));
  631. for (; ; ) {
  632. auto iter = map.find(var);
  633. if (iter == map.end())
  634. return var;
  635. var = iter->second;
  636. }
  637. }
  638. const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
  639. const cg::GraphCommonOptimizeOptions& options) {
  640. return add_passes_for_optimize_options(
  641. const_cast<cg::GraphCommonOptimizeOptions&>(options));
  642. }
  643. const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
  644. cg::GraphCommonOptimizeOptions& options, bool reset) {
  645. bool need_param_fuse = false;
  646. #define cb(_option, _passes) \
  647. if (options.has_set_##_option()) { \
  648. _passes need_param_fuse = true; \
  649. if (reset) { \
  650. options.disable_##_option(); \
  651. } \
  652. }
  653. cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
  654. cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });
  655. cb(nchw4, {
  656. add_pass<FuseConvBiasNonlinPass>();
  657. add_pass<FuseConvBiasZPass>();
  658. add_pass(EnableNCHW4Pass::make_nchw4_converter());
  659. add_pass<ShuffleShuffleRemovePass>();
  660. add_pass<RemoveRedundantTypeCvtPass>();
  661. });
  662. cb(nhwcd4, {
  663. add_pass<FuseConvBiasNonlinPass>();
  664. add_pass(ConvertFormatPass::make_nhwcd4_converter());
  665. });
  666. cb(nchw88, {
  667. add_pass<FuseConvBiasNonlinPass>();
  668. add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
  669. add_pass<ShuffleShuffleRemovePass>();
  670. });
  671. cb(nchw44, {
  672. add_pass<FuseConvBiasNonlinPass>();
  673. add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
  674. add_pass<ShuffleShuffleRemovePass>();
  675. });
  676. cb(nchw44_dot, {
  677. add_pass<FuseConvBiasNonlinPass>();
  678. add_pass(EnableNchw44DotPass::make_nchw44_dot_converter());
  679. add_pass<ShuffleShuffleRemovePass>();
  680. });
  681. cb(nchw32, {
  682. add_pass<FuseConvBiasNonlinPass>();
  683. add_pass<FuseConvBiasZPass>();
  684. add_pass(EnableTensorCorePass::make_tensorcore_converter());
  685. add_pass<ShuffleShuffleRemovePass>();
  686. add_pass<RemoveRedundantTypeCvtPass>();
  687. });
  688. cb(chwn4, {
  689. add_pass<FuseConvBiasNonlinPass>();
  690. add_pass<FuseConvBiasZPass>();
  691. add_pass(EnableCHWN4Pass::make_chwn4_converter());
  692. add_pass<ShuffleShuffleRemovePass>();
  693. add_pass<RemoveRedundantTypeCvtPass>();
  694. });
  695. cb(fuse_conv_bias_nonlinearity, { add_pass<FuseConvBiasNonlinPass>(); });
  696. cb(fuse_conv_bias_with_z, {
  697. add_pass<FuseConvBiasNonlinPass>();
  698. add_pass<FuseConvBiasZPass>();
  699. });
  700. #undef cb
  701. if (need_param_fuse) {
  702. add_pass<ParamFusePass>();
  703. }
  704. return *this;
  705. }
  706. /* ================ ConstVarPropogateBase ================ */
  707. ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr(
  708. OperatorNodeBase *opr) {
  709. using ProfFlag = OperatorNodeBase::NodeProp::Flag;
  710. auto &&info = m_oprinfo[opr];
  711. if (info.processed)
  712. return info.result;
  713. info.processed = true;
  714. #if MGB_ENABLE_JSON
  715. (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(false);
  716. #endif
  717. AddOprResult ret{false, false, false};
  718. auto make_ret = [&ret, &info]() {
  719. info.result = ret;
  720. return ret;
  721. };
  722. if (is_const_var(m_const_var_type, opr)) {
  723. auto sz = var_mem_size(opr->output(0));
  724. mgb_assert(sz || opr->output(0)->contain_flag(
  725. VarNode::Flag::ALLOW_EMPTY_SHAPE));
  726. info.is_const = true;
  727. info.max_size = sz;
  728. return make_ret();
  729. }
  730. if (opr->input().empty())
  731. return make_ret();
  732. if (opr->node_prop().contain(
  733. ProfFlag::FORCE_UPDATE_INPUT_VAR |
  734. ProfFlag::IMPURE_FUNC)) {
  735. return make_ret();
  736. }
  737. size_t max_input_size = 0;
  738. ret.all_const_inp = true;
  739. for (auto i: opr->input()) {
  740. auto io = i->owner_opr();
  741. auto iter = m_oprinfo.find(io);
  742. if (iter == m_oprinfo.end()) {
  743. add_opr(io);
  744. iter = m_oprinfo.find(io);
  745. mgb_assert(iter != m_oprinfo.end());
  746. }
  747. auto &&src = iter->second;
  748. if (src.is_const) {
  749. update_max(max_input_size, src.max_size);
  750. ret.has_const_inp = true;
  751. if (!is_const_var(m_const_var_type, i->owner_opr())) {
  752. ret.has_midconst_inp = true;
  753. }
  754. } else {
  755. ret.all_const_inp = false;
  756. }
  757. }
  758. if (ret.all_const_inp) {
  759. #if MGB_ENABLE_JSON
  760. (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(true);
  761. #endif
  762. info.max_size = max_input_size;
  763. info.is_const = true;
  764. }
  765. return make_ret();
  766. }
  767. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台