You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.cpp 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815
  1. /**
  2. * \file src/gopt/impl/framework.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/gopt/framework.h"
  12. #include "megbrain/gopt/inference.h"
  13. #include "megbrain/gopt/basic_arith.h"
  14. #include "megbrain/gopt/misc.h"
  15. #include "megbrain/gopt/gtrans.h"
  16. #include "megbrain/graph/event.h"
  17. #include "megbrain/graph/exc_extra_info.h"
  18. #include "megbrain/serialization/serializer.h"
  19. #include "megbrain/serialization/opr_shallow_copy.h"
  20. #include "megbrain/utils/timer.h"
  21. #if MGB_JIT
  22. #include "megbrain/jit/fusion_pass.h"
  23. #endif
  24. #if MGB_ENABLE_TENSOR_RT
  25. #include "megbrain/tensorrt/opr_replace.h"
  26. #endif
  27. using namespace mgb;
  28. using namespace gopt;
  29. /* ================ SubGraph ================ */
  30. OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(
  31. OperatorNodeBase *opr) {
  32. auto &&new_inp = m_opr_new_inp_cache;
  33. new_inp.clear();
  34. new_inp.reserve(opr->input().size());
  35. bool has_replaced_inp = false;
  36. for (auto i: opr->input()) {
  37. auto new_var = get_var(i);
  38. if (new_var != i) {
  39. has_replaced_inp = true;
  40. new_inp.push_back(new_var);
  41. } else {
  42. new_inp.push_back(i);
  43. }
  44. }
  45. if (has_replaced_inp) {
  46. auto new_opr = serialization::copy_opr_shallow(
  47. *opr, new_inp, opr->config());
  48. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  49. size_t i = 0;
  50. auto err_msg = [opr, new_opr] {
  51. return ssprintf("bad opr copy: src=%s{%s} dst=%s{%s}",
  52. opr->cname(), opr->dyn_typeinfo()->name,
  53. new_opr->cname(), new_opr->dyn_typeinfo()->name);
  54. };
  55. MGB_MARK_USED_VAR(err_msg);
  56. // opr output size mismatch may be caused by:
  57. // 0) inplace arith optimization (e.g. PowC need an extra workspace)
  58. // 1) other post-insert optimization (e.g. const folding)
  59. // we can't handle only usable_output here, since some output var with
  60. // volatile flag could be the graph's endpoint (e.g. RemoteSend)
  61. for (; i < std::min(out0.size(), out1.size()); ++ i) {
  62. bool v0 = out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  63. v1 = out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT);
  64. mgb_assert(v0 == v1, "%s", err_msg().c_str());
  65. auto &&ins = m_varmap.insert({out0[i], {true, nullptr}});
  66. mgb_assert(ins.second || ins.first->second.first,
  67. "opr output already replaced");
  68. // handle repeated call on the same opr
  69. ins.first->second.second = out1[i];
  70. on_var_replaced(out0[i], out1[i], nullptr);
  71. }
  72. for (; i < out0.size(); ++ i) {
  73. mgb_assert(out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  74. "%s", err_msg().c_str());
  75. }
  76. for (; i < out1.size(); ++ i) {
  77. mgb_assert(out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  78. "%s", err_msg().c_str());
  79. }
  80. return new_opr;
  81. }
  82. return opr;
  83. }
  84. void SubGraph::Rewriter::replace_var(
  85. VarNode *src, VarNode *dst, const char *msg) {
  86. if (src == dst)
  87. return;
  88. // Optimizers should not create a loop in varaible replace map.
  89. mgb_throw_if(
  90. get_var_internal(dst).second == src, InternalError,
  91. "dst %s maps back to src %s in SubGraph::Rewriter::replace_var",
  92. dst->cname(), src->cname());
  93. auto &&ins = m_varmap.insert({src, {false, dst}});
  94. if (!ins.second) {
  95. auto &&old_rep = ins.first->second;
  96. mgb_assert(old_rep.first || old_rep.second == dst,
  97. "can not replace a var twice");
  98. old_rep.first = false;
  99. old_rep.second = dst;
  100. }
  101. on_var_replaced(src, dst, msg);
  102. }
  103. void SubGraph::Rewriter::on_var_replaced(
  104. VarNode* src, VarNode* dst, const char* msg) {
  105. if (auto state = m_owner_graph->owner_opt_state()) {
  106. state->on_var_replaced(src, dst, msg);
  107. }
  108. }
  109. void SubGraph::Rewriter::apply_inplace() const {
  110. m_owner_graph->m_endpoint_oprs.clear();
  111. m_owner_graph->m_endpoint_vars_set.clear();
  112. for (auto &&var: m_owner_graph->m_endpoint_vars) {
  113. var = get_var(var.node());
  114. m_owner_graph->m_endpoint_oprs.insert(var.node()->owner_opr());
  115. m_owner_graph->m_endpoint_vars_set.insert(var.node());
  116. }
  117. }
  118. std::pair<bool, VarNode*> SubGraph::Rewriter::get_var_internal(VarNode* var) {
  119. // The implementation is (manually) unrolled once, background:
  120. // git-core/brain-sdk/MegBrain/merge_requests/486#note_76971
  121. auto it = m_varmap.find(var);
  122. if (it == m_varmap.end()) {
  123. return {true, var};
  124. }
  125. mgb_assert(it->second.second != var, "loop detected in m_varmap");
  126. auto it_next = m_varmap.find(it->second.second);
  127. if (it_next == m_varmap.end()) {
  128. return it->second;
  129. }
  130. mgb_assert(it_next->second.second != it->second.second,
  131. "loop detected in m_varmap");
  132. auto next = get_var_internal(it_next->second.second);
  133. it_next->second = {next.first & it_next->second.first, next.second};
  134. return it->second = {it_next->second.first & it->second.first, next.second};
  135. }
  136. SubGraph::SubGraph(const SymbolVarArray &endpoint_vars):
  137. m_endpoint_vars(endpoint_vars)
  138. {
  139. mgb_assert(!endpoint_vars.empty(), "endpoints can not be empty");
  140. m_comp_graph = endpoint_vars[0].node()->owner_graph();
  141. for (auto i: endpoint_vars) {
  142. m_endpoint_oprs.insert(i.node()->owner_opr());
  143. m_endpoint_vars_set.insert(i.node());
  144. mgb_assert(m_comp_graph == i.node()->owner_graph(),
  145. "endpoints belong to different computing graphs");
  146. }
  147. }
  148. void SubGraph::iter(
  149. const Callback& cb,
  150. std::shared_ptr<ExtraDep> extra_dep) const {
  151. Callback on_opr;
  152. if (m_owner_opt_state) {
  153. on_opr = [state=m_owner_opt_state, &cb](OperatorNodeBase *opr) {
  154. state->m_opr_property_flag = OprPropertyFlag::ALL;
  155. state->m_cur_iter_src_opr = cg::get_opr_root_source_opr(opr);
  156. state->m_cur_iter_opr_priority =
  157. opr->node_prop().attribute().priority;
  158. state->m_cur_iter_opr_stream_prop_type =
  159. state->m_comp_node_opt.stream_prop_type(
  160. opr->output(0));
  161. mgb_assert(state->m_oprs_inserted.empty());
  162. cb(opr);
  163. state->m_opr_property_flag = OprPropertyFlag::NONE;
  164. state->m_cur_iter_src_opr = nullptr;
  165. state->m_oprs_inserted.clear();
  166. };
  167. } else {
  168. on_opr = cb;
  169. }
  170. cg::DepOprIter dep_iter{on_opr, std::move(extra_dep)};
  171. for (auto i: m_endpoint_oprs)
  172. dep_iter.add(i);
  173. }
  174. ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const {
  175. ThinHashMap<VarNode*, size_t> ret;
  176. auto cb = [&](OperatorNodeBase *opr) {
  177. for (auto &&i: opr->node_prop().dep_map()) {
  178. if (OperatorNodeBase::NodeProp::is_device_value_dep(i.second)) {
  179. ++ ret.at(i.first);
  180. }
  181. }
  182. for (auto i: opr->output()) {
  183. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  184. auto ins = ret.insert({i, 0});
  185. mgb_assert(ins.second);
  186. }
  187. }
  188. };
  189. iter(cb);
  190. for (auto i: m_endpoint_vars_set) {
  191. auto iter = ret.find(i);
  192. if (iter == ret.end()) {
  193. mgb_assert(i->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
  194. ret[i] = 1;
  195. } else {
  196. ++ ret.at(i);
  197. }
  198. }
  199. return ret;
  200. }
  201. /* ================ UniqReaderCheck ================ */
  202. UniqReaderCheck::UniqReaderCheck(const SubGraph &graph):
  203. m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()}
  204. {
  205. }
  206. void UniqReaderCheck::update_on_opr_auto_replace(OperatorNodeBase* opr,
  207. OperatorNodeBase* repl_opr) {
  208. auto non_volatile_size = [](const VarNodeArray& vars) -> size_t {
  209. size_t size = 0;
  210. for (size_t i = 0; i < vars.size(); ++i) {
  211. if (!vars[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  212. size++;
  213. }
  214. }
  215. return size;
  216. };
  217. if (opr != repl_opr) {
  218. auto &&o0 = opr->output(), &&o1 = repl_opr->output();
  219. mgb_assert(non_volatile_size(o0) == non_volatile_size(o1));
  220. for (size_t i = 0; i < o0.size(); ++i) {
  221. auto iter = m_var2nr_val_dep.find(o0[i]);
  222. if (iter != m_var2nr_val_dep.end()) {
  223. auto n = iter->second;
  224. m_var2nr_val_dep[o1[i]] = n;
  225. }
  226. }
  227. }
  228. }
  229. /* ================ OptState ================ */
  230. OptState::OptState(
  231. const GraphOptimizer *owner_optimizer, const SubGraph& graph):
  232. m_owner_optimizer{owner_optimizer},
  233. m_var_replace_map{
  234. const_cast<ThinHashMap<VarNode*, VarNode*>*>(
  235. &GraphOptimizer::var_replace_map(*graph.comp_graph()))},
  236. m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()},
  237. m_graph{graph}
  238. {
  239. mgb_assert(!m_graph.m_owner_opt_state);
  240. m_var_replace_map->clear();
  241. m_graph.m_owner_opt_state = this;
  242. m_oprs_inserted.clear();
  243. auto on_opr_insert = [this](const cg::event::OprInserted &ev) {
  244. auto need_src_opr = m_opr_property_flag & OprPropertyFlag::SOURCE_OPR,
  245. need_priority = m_opr_property_flag & OprPropertyFlag::PRIORITY;
  246. if (need_src_opr)
  247. mgb_assert(m_cur_iter_src_opr, "opr %s{%s} created outside from "
  248. "SubGraph::iter",
  249. ev.opr->cname(), ev.opr->dyn_typeinfo()->name);
  250. if (ev.exc || ev.is_dedup)
  251. return;
  252. auto &&new_attr = ev.opr->node_prop().attribute();
  253. auto &&ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE});
  254. mgb_assert(ins.second);
  255. if (need_src_opr && !new_attr.src_opr) {
  256. auto src_opr = m_cur_iter_src_opr;
  257. if (ev.opr != src_opr)
  258. new_attr.src_opr = src_opr;
  259. ins.first->second |= OprPropertyFlag::SOURCE_OPR;
  260. }
  261. if (need_priority) {
  262. new_attr.priority = m_cur_iter_opr_priority;
  263. if (!ev.opr->update_priority()) {
  264. ins.first->second |= OprPropertyFlag::PRIORITY;
  265. }
  266. }
  267. auto csp = m_cur_iter_opr_stream_prop_type;
  268. if (csp.prop_type != cg::SeqCompNodeOptimizer::StreamPropType::NONE) {
  269. for (auto i: ev.opr->output())
  270. m_comp_node_opt.register_stream_var(i, csp);
  271. }
  272. };
  273. m_on_opr_insert_handler = graph.comp_graph()->event().register_receiver<
  274. cg::event::OprInserted>(on_opr_insert);
  275. }
  276. void OptState::on_var_replaced(VarNode *src, VarNode *dst, const char *msg) {
  277. if (src->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  278. // this can only happen in auto_replace_outputs()
  279. mgb_assert(dst->contain_flag(VarNode::Flag::VOLATILE_CONTENT) &&
  280. src->owner_opr()->dyn_typeinfo() ==
  281. dst->owner_opr()->dyn_typeinfo());
  282. mgb_assert(!msg);
  283. return;
  284. }
  285. //! check_property
  286. {
  287. auto iter = m_oprs_inserted.find(dst->owner_opr());
  288. if (iter != m_oprs_inserted.end()) {
  289. auto &&src_attr = src->owner_opr()->node_prop().attribute(),
  290. &&dst_attr = dst->owner_opr()->node_prop().attribute();
  291. auto opr_info = [&](OperatorNodeBase* opr) {
  292. return opr ? opr->name() + "(" + std::to_string(opr->id()) + ")"
  293. : "NULL";
  294. };
  295. auto err_msg = [&] {
  296. std::string ret = "Please contact Engine group:\n";
  297. ret += "src opr: ";
  298. ret += opr_info(src->owner_opr());
  299. ret += ", dst opr: ";
  300. ret += opr_info(dst->owner_opr());
  301. return ret;
  302. };
  303. MGB_MARK_USED_VAR(err_msg);
  304. if (iter->second & OprPropertyFlag::SOURCE_OPR) {
  305. auto &&src_rt = get_opr_root_source_opr(src->owner_opr()),
  306. &&dst_rt = get_opr_root_source_opr(dst->owner_opr());
  307. mgb_assert(dst_rt == src_rt,
  308. "%s\nsrc source_opr: %s, dst source_opr: %s\n",
  309. err_msg().c_str(), opr_info(src_rt).c_str(),
  310. opr_info(dst_rt).c_str());
  311. }
  312. if (iter->second & OprPropertyFlag::PRIORITY) {
  313. mgb_assert(src_attr.priority == dst_attr.priority,
  314. "%s\nsrc priority: %d, dst priority %d\n",
  315. err_msg().c_str(), src_attr.priority,
  316. dst_attr.priority);
  317. }
  318. }
  319. }
  320. {
  321. bool suc = true;
  322. SmallVector<std::string> fail_chks;
  323. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_INFER_TYPE) {
  324. auto&& mgr = src->owner_graph()->static_infer_manager();
  325. auto it0 = mgr.get_infer_type(src), it1 = mgr.get_infer_type(dst);
  326. using cg::static_infer::InferType;
  327. // only check wheter inferable
  328. auto norm = [](InferType::Flag f) -> bool {
  329. return f & (InferType::RT_STATIC | InferType::CONST);
  330. };
  331. if (!(norm(it0.shape) == norm(it1.shape) &&
  332. norm(it0.value) <= norm(it1.value))) {
  333. suc = false;
  334. fail_chks.push_back("infer-type");
  335. }
  336. }
  337. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_DTYPE) {
  338. if (src->dtype() != dst->dtype()) {
  339. suc = false;
  340. fail_chks.push_back("dtype");
  341. }
  342. }
  343. if (m_var_replace_check_flag & VarReplaceCheckFlag::CHECK_SHAPE) {
  344. if (!(src->shape().eq_shape(dst->shape()))) {
  345. suc = false;
  346. fail_chks.push_back("shape");
  347. }
  348. }
  349. if (!suc) {
  350. std::string fail_msg = "{";
  351. for (size_t i = 0; i < fail_chks.size(); i++) {
  352. fail_msg += fail_chks[i];
  353. if (i < fail_chks.size() - 1) {
  354. fail_msg += ",";
  355. }
  356. }
  357. fail_msg += "}";
  358. mgb_throw_raw(
  359. cg::OperatorNodeExcExtraInfo::ExcMaker{src->owner_opr()}
  360. .make<InternalError>(ssprintf(
  361. "%s mismatch for replace_var: %s",
  362. fail_msg.c_str(),
  363. cg::dump_var_info({src, dst}).c_str())));
  364. }
  365. }
  366. if (src->has_name_set() && !dst->has_name_set()) {
  367. dst->name(src->name());
  368. }
  369. (*m_var_replace_map)[src] = dst;
  370. // dst should be considered as newly inserted, and previous replace
  371. // record should be ignored
  372. m_var_replace_map->erase(dst);
  373. #if MGB_ENABLE_LOGGING
  374. if (msg && m_owner_optimizer->verbosity()) {
  375. m_log_msg.
  376. append("\n ").
  377. append(std::to_string(m_log_nr_item)).
  378. append(": ").
  379. append(src->owner_opr()->cname()).
  380. append(" => ").
  381. append(dst->owner_opr()->cname()).
  382. append(" (").
  383. append(msg).
  384. append(")");
  385. }
  386. ++ m_log_nr_item;
  387. #endif
  388. }
  389. size_t OptState::flush_log(const char *title) {
  390. if (m_owner_optimizer->verbosity() >= 2) {
  391. if (m_log_msg.empty()) {
  392. m_log_msg = mgb_cstr_log(" no var replacement logged");
  393. }
  394. mgb_log("%s%s", title, m_log_msg.c_str());
  395. m_log_msg.clear();
  396. }
  397. auto ret = m_log_nr_item;
  398. m_log_nr_item = 0;
  399. return ret;
  400. }
  401. void OptState::call_with_opr(OperatorNodeBase *opr, thin_function<void(void)> func,
  402. OprPropertyFlag opr_property_flag) {
  403. auto src_opr = cg::get_opr_root_source_opr(opr);
  404. auto opr_priority = opr->node_prop().attribute().priority;
  405. auto stream_prop_type = m_comp_node_opt.stream_prop_type(opr->output(0));
  406. ThinHashMap<OperatorNodeBase*, OprPropertyFlag> oprs_inserted;
  407. auto swap_properties = [&,
  408. need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR,
  409. need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] {
  410. if (need_src_opr) {
  411. std::swap(m_cur_iter_src_opr, src_opr);
  412. }
  413. if (need_priority) {
  414. std::swap(m_cur_iter_opr_priority, opr_priority);
  415. }
  416. std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type);
  417. std::swap(m_opr_property_flag, opr_property_flag);
  418. std::swap(m_oprs_inserted, oprs_inserted);
  419. };
  420. MGB_TRY {
  421. swap_properties();
  422. func();
  423. } MGB_FINALLY({
  424. swap_properties();
  425. });
  426. }
  427. /* ================ RecursiveSubGraphRewriteHelper ================ */
  428. RecursiveSubGraphRewriteHelper::
  429. ~RecursiveSubGraphRewriteHelper() noexcept = default;
  430. RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState &state):
  431. m_opt_state{state}, m_rewriter{state.graph().make_rewriter()}
  432. {
  433. }
  434. void RecursiveSubGraphRewriteHelper::apply() {
  435. using namespace std::placeholders;
  436. m_opt_state.graph().iter(
  437. std::bind(&RecursiveSubGraphRewriteHelper::on_opr, this, _1));
  438. m_rewriter.apply_inplace();
  439. }
  440. void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) {
  441. auto on_new_opr = [this](OperatorNodeBase *opr) {
  442. auto repl_opr = m_rewriter.auto_replace_outputs(opr);
  443. return on_new_opr_check_should_process(opr, repl_opr);
  444. };
  445. if (!on_new_opr(opr))
  446. return;
  447. auto orig_out = get_opr_single_output_var(opr);
  448. if (!orig_out)
  449. return;
  450. mgb_assert(m_opr_stack.empty());
  451. m_opr_stack.push_back({
  452. orig_out, m_rewriter.get_var(orig_out)->owner_opr()});
  453. bool first = true;
  454. while (!m_opr_stack.empty()) {
  455. auto cur_frame = m_opr_stack.back();
  456. m_opr_stack.pop_back();
  457. auto cur_opr = cur_frame.opr;
  458. bool should_process;
  459. if (first) {
  460. should_process = true;
  461. first = false;
  462. } else {
  463. should_process = on_new_opr(cur_opr);
  464. }
  465. auto cur_out = get_opr_single_output_var(cur_opr);
  466. mgb_assert(cur_out);
  467. cur_out = m_rewriter.get_var(cur_out);
  468. if (should_process) {
  469. auto trans = process_opr(cur_out);
  470. if (trans.valid()) {
  471. m_opr_stack.push_back({
  472. cur_frame.orig_var, trans->result->owner_opr()});
  473. for (auto i: reverse_adaptor(trans->internal)) {
  474. if (i)
  475. m_opr_stack.push_back({i, i->owner_opr()});
  476. }
  477. if (trans->msg) {
  478. if (!m_log_msg.empty())
  479. m_log_msg.push_back(';');
  480. m_log_msg.append(trans->msg);
  481. }
  482. continue;
  483. }
  484. }
  485. auto src = cur_frame.orig_var;
  486. if (m_rewriter.get_var(src) != cur_out) {
  487. const char *msg = nullptr;
  488. if (m_opr_stack.empty()) {
  489. msg = m_log_msg.c_str();
  490. }
  491. m_rewriter.replace_var(src, cur_out, msg);
  492. after_replace_var(src, cur_out);
  493. if (m_opr_stack.empty()) {
  494. m_log_msg.clear();
  495. break;
  496. }
  497. }
  498. }
  499. }
  500. /* ================ GraphOptimizer ================ */
  501. GraphOptimizer::~GraphOptimizer() noexcept = default;
  502. class GraphOptimizer::VarReplaceMapStorage :public UserDataContainer::UserData {
  503. MGB_TYPEINFO_OBJ_DECL;
  504. public:
  505. ThinHashMap<VarNode*, VarNode*> map;
  506. };
  507. MGB_TYPEINFO_OBJ_IMPL(GraphOptimizer::VarReplaceMapStorage);
  508. GraphOptimizer& GraphOptimizer::add_pass(std::unique_ptr<Pass> pass) {
  509. mgb_assert(!pass->m_owner_optimizer);
  510. pass->m_owner_optimizer = this;
  511. m_passes.emplace_back(std::move(pass));
  512. return *this;
  513. }
  514. SubGraph GraphOptimizer::apply(const SubGraph &graph) const {
  515. RealTimer timer;
  516. OptState state{this, graph};
  517. size_t tot_nr_replace = 0;
  518. // first update output var shapes of all oprs
  519. state.graph().iter(cg::update_output_var_shapes);
  520. auto &&opt = graph.comp_graph()->options();
  521. auto orig_setting = opt.graph_opt_level;
  522. Pass *cur_pass = nullptr;
  523. MGB_MARK_USED_VAR(cur_pass);
  524. MGB_TRY {
  525. for (auto &&i: m_passes) {
  526. state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL);
  527. cur_pass = i.get();
  528. opt.graph_opt_level = 1;
  529. i->apply(state);
  530. tot_nr_replace += state.flush_log(
  531. mgb_ssprintf_log(
  532. "apply optimization pass %s:", i->name()).c_str());
  533. }
  534. } MGB_CATCH(std::exception &exc, {
  535. mgb_log_error("error while applying optimization pass %s: %s",
  536. cur_pass->name(), exc.what());
  537. opt.graph_opt_level = orig_setting;
  538. throw;
  539. })
  540. MGB_FINALLY(
  541. opt.graph_opt_level = orig_setting
  542. );
  543. if (verbosity() >= 1) {
  544. mgb_log_debug("graph optimization: applied %zu passes, "
  545. "total %zu var(s) replaced; time=%.2fms",
  546. m_passes.size(), tot_nr_replace, timer.get_msecs());
  547. }
  548. return state.graph();
  549. }
  550. const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray &vars) const {
  551. if (m_passes.empty()) {
  552. // this check is necessary, since OptState would clear
  553. // var_replace_map()
  554. return *this;
  555. }
  556. auto g = apply({{vars.begin(), vars.end()}});
  557. for (size_t i = 0; i < vars.size(); ++ i) {
  558. vars[i] = g.endpoint_vars()[i].node();
  559. }
  560. return *this;
  561. }
  562. GraphOptimizer& GraphOptimizer::add_preset_passes(
  563. bool after_grad, const OptimizeForInferenceOptions* inference_opt,
  564. const ComputingGraph::Options* comp_graph_opt) {
  565. auto cv_type = inference_opt ? ConstVarType::IMMUTABLE_AND_PARAM
  566. : ConstVarType::IMMUTABLE;
  567. if (inference_opt) {
  568. add_pass<ConvertBatchNormToElemwisePass>();
  569. }
  570. if (!after_grad || inference_opt) {
  571. add_pass<CondExecConstPredicateFolding>();
  572. }
  573. if (after_grad || inference_opt) {
  574. add_pass<RemoveNonComputingOprPass>();
  575. }
  576. add_pass<DelayBroadcastPass>();
  577. add_pass<ExpandFusedArithPass>();
  578. add_pass<NormalizeArithChainPass>();
  579. if (inference_opt) {
  580. add_pass<ParamRedistributePass>();
  581. add_pass<ParamFusePass>();
  582. }
  583. add_pass<ArithMulDistributePass>();
  584. add_pass<ReorderArithChainPass>(cv_type);
  585. add_pass<ArithFusePass>();
  586. // reorder again because shapes of fused oprs might change
  587. add_pass<ReorderArithChainPass>(cv_type);
  588. add_pass<FinalArithTransformPass>();
  589. add_pass<RemoveRedundantTypeCvtPass>();
  590. #if MGB_JIT
  591. bool need_jit = false;
  592. if (comp_graph_opt && (std::abs(comp_graph_opt->graph_opt_level) >= 3 ||
  593. comp_graph_opt->graph_opt.jit)) {
  594. need_jit = true;
  595. }
  596. if (need_jit && after_grad) {
  597. add_pass<gopt::RecompTypeCvtPass>();
  598. }
  599. #endif
  600. // combine astype and reduce.
  601. // Note: apply this pass before JITFusion, so the TypeCvt which
  602. // read by both Reduce and Elemwise could be fused correctly.
  603. add_pass<CombineAstypeAndReducePass>();
  604. #if MGB_JIT
  605. if (need_jit) {
  606. add_pass<gopt::JITFusionPass>(
  607. after_grad,
  608. std::max<uint8_t>(comp_graph_opt->graph_opt.jit, 1));
  609. }
  610. #endif
  611. apply_optimize_options(inference_opt);
  612. if (inference_opt) {
  613. // merge params to reduce loading time and graph overhead
  614. add_pass<ParamMergePass>();
  615. add_pass<FuseDeconvCvtPass>();
  616. }
  617. return *this;
  618. }
  619. const ThinHashMap<VarNode*, VarNode*>& GraphOptimizer::var_replace_map(
  620. ComputingGraph &graph) {
  621. auto storage = graph.options().user_data.get_user_data_or_create<
  622. VarReplaceMapStorage>();
  623. return storage->map;
  624. }
  625. VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) {
  626. auto &&map = var_replace_map(*(var->owner_graph()));
  627. for (; ; ) {
  628. auto iter = map.find(var);
  629. if (iter == map.end())
  630. return var;
  631. var = iter->second;
  632. }
  633. }
  634. void GraphOptimizer::apply_optimize_options(
  635. const OptimizeOptions* options) {
  636. if (!options) return;
  637. if (options->f16_io_comp) {
  638. add_pass(ConvertF32ToF16Pass::make(false));
  639. }
  640. if (options->f16_io_f32_comp) {
  641. add_pass(ConvertF32ToF16Pass::make(true));
  642. }
  643. if (options->transform_nchw2nhwcd4()) {
  644. add_pass(ConvertFormatPass::make_nhwcd4_converter());
  645. add_pass<FuseConvBiasNonlinPass>();
  646. }
  647. if (options->transform_nchw2nchw88()) {
  648. add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
  649. }
  650. if (options->transform_nchw2nchw44()) {
  651. add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
  652. }
  653. if (options->transform_nchw2nchw32()) {
  654. add_pass<FuseConvBiasNonlinPass>();
  655. add_pass(EnableTensorCorePass::make_tensorcore_converter());
  656. add_pass<ShuffleShuffleRemovePass>();
  657. add_pass<RemoveRedundantTypeCvtPass>();
  658. }
  659. if (options->transform_nchw42chwn4()) {
  660. add_pass<FuseConvBiasNonlinPass>();
  661. add_pass<FuseConvBiasZPass>();
  662. add_pass(EnableCHWN4Pass::make_chwn4_converter());
  663. add_pass<ShuffleShuffleRemovePass>();
  664. add_pass<RemoveRedundantTypeCvtPass>();
  665. }
  666. if (options->fuse_conv_bias_nonlinearity) {
  667. add_pass<FuseConvBiasNonlinPass>();
  668. }
  669. if (options->fuse_conv_bias_with_z) {
  670. add_pass<FuseConvBiasNonlinPass>();
  671. add_pass<FuseConvBiasZPass>();
  672. }
  673. add_pass<ParamFusePass>();
  674. }
  675. /* ================ ConstVarPropogateBase ================ */
  676. ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr(
  677. OperatorNodeBase *opr) {
  678. using ProfFlag = OperatorNodeBase::NodeProp::Flag;
  679. auto &&info = m_oprinfo[opr];
  680. if (info.processed)
  681. return info.result;
  682. info.processed = true;
  683. #if MGB_ENABLE_JSON
  684. (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(false);
  685. #endif
  686. AddOprResult ret{false, false, false};
  687. auto make_ret = [&ret, &info]() {
  688. info.result = ret;
  689. return ret;
  690. };
  691. if (is_const_var(m_const_var_type, opr)) {
  692. auto sz = var_mem_size(opr->output(0));
  693. mgb_assert(sz || opr->output(0)->contain_flag(
  694. VarNode::Flag::ALLOW_EMPTY_SHAPE));
  695. info.is_const = true;
  696. info.max_size = sz;
  697. return make_ret();
  698. }
  699. if (opr->input().empty())
  700. return make_ret();
  701. if (opr->node_prop().contain(
  702. ProfFlag::FORCE_UPDATE_INPUT_VAR |
  703. ProfFlag::IMPURE_FUNC)) {
  704. return make_ret();
  705. }
  706. size_t max_input_size = 0;
  707. ret.all_const_inp = true;
  708. for (auto i: opr->input()) {
  709. auto io = i->owner_opr();
  710. auto iter = m_oprinfo.find(io);
  711. if (iter == m_oprinfo.end()) {
  712. add_opr(io);
  713. iter = m_oprinfo.find(io);
  714. mgb_assert(iter != m_oprinfo.end());
  715. }
  716. auto &&src = iter->second;
  717. if (src.is_const) {
  718. update_max(max_input_size, src.max_size);
  719. ret.has_const_inp = true;
  720. if (!is_const_var(m_const_var_type, i->owner_opr())) {
  721. ret.has_midconst_inp = true;
  722. }
  723. } else {
  724. ret.all_const_inp = false;
  725. }
  726. }
  727. if (ret.all_const_inp) {
  728. #if MGB_ENABLE_JSON
  729. (*opr->to_json_extra_json)["gopt::cvprop"] = json::Bool::make(true);
  730. #endif
  731. info.max_size = max_input_size;
  732. info.is_const = true;
  733. on_midconst_opr(opr, max_input_size);
  734. }
  735. return make_ret();
  736. }
  737. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台