You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cg_impl.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941
  1. /**
  2. * \file src/core/impl/graph/cg_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./cg_impl.h"
  12. #include "./cg_impl_partial.h"
  13. #include "./cg_impl_seq.h"
  14. #include "megbrain/gopt/basic_arith.h"
  15. #include "megbrain/gopt/framework.h"
  16. #include "megbrain/gopt/inference.h"
  17. #include "megbrain/gopt/misc.h"
  18. #include "megbrain/graph/cg.h"
  19. #include "megbrain/graph/event.h"
  20. #include "megbrain/graph/exc_extra_info.h"
  21. #include "megbrain/graph/helper.h"
  22. #include "megbrain/opr/utility.h"
  23. #if MGB_ENABLE_TENSOR_RT
  24. #include "megbrain/tensorrt/opr_replace.h"
  25. #endif
  26. #if MGB_JIT
  27. #include "megbrain/jit/fusion_pass.h"
  28. #endif
  29. using namespace mgb;
  30. using namespace cg;
  31. namespace {
  32. void check_opr_not_cross_mem(OperatorNodeBase* opr) {
  33. if (opr->node_prop().contain(
  34. OperatorNodeBase::NodeProp::Flag::CROSS_COMP_NODE_MEMORY))
  35. return;
  36. MemNode mem_node_id;
  37. bool first = true;
  38. auto check = [&](VarNode* var) {
  39. auto cur = var->comp_node().mem_node();
  40. mgb_assert(cur);
  41. if (first) {
  42. first = false;
  43. mem_node_id = cur;
  44. } else
  45. mgb_assert(
  46. mem_node_id == cur,
  47. "for non cross-memory oprs, "
  48. "all vars should reside on the same memory node");
  49. };
  50. for (auto i : opr->input()) {
  51. check(i);
  52. }
  53. for (auto i : opr->output()) {
  54. check(i);
  55. }
  56. }
  57. void update_output_shapes(
  58. static_infer::StaticInferManagerImpl& infer_mgr, OperatorNodeBase* opr,
  59. bool add_freeze_flag) {
  60. for (auto i : opr->output()) {
  61. if (add_freeze_flag) {
  62. i->add_flag(VarNode::Flag::FLAG_FREEZED);
  63. }
  64. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  65. using namespace static_infer;
  66. if (infer_mgr.get_infer_type(i).shape &
  67. (InferType::CONST | InferType::RT_STATIC)) {
  68. auto shp = infer_mgr.infer_shape_fallible(i);
  69. if (shp) {
  70. i->shape(*shp);
  71. } else {
  72. i->shape({});
  73. }
  74. } else {
  75. i->shape({});
  76. }
  77. }
  78. }
  79. }
  80. } // anonymous namespace
  81. /* ========================== global helpers ========================== */
  82. void cg::update_output_var_shapes(OperatorNodeBase* opr) {
  83. update_output_shapes(
  84. static_cast<static_infer::StaticInferManagerImpl&>(
  85. opr->owner_graph()->static_infer_manager()),
  86. opr, false);
  87. }
  88. /* ========================= DeviceMemoryAllocator ========================= */
  89. void DeviceMemoryAllocator::alloc_static(
  90. ComputingGraph*, DeviceTensorStorage& dest, size_t size) {
  91. dest.ensure_size(size);
  92. }
  93. void DeviceMemoryAllocator::alloc_dynamic(
  94. VarNode*, DeviceTensorStorage& dest, size_t size) {
  95. dest.ensure_size(size);
  96. }
  97. void DeviceMemoryAllocator::defrag_prealloc_contig(
  98. ComputingGraph* graph, CompNode comp_node,
  99. size_t size){MGB_TRY{comp_node.free_device(comp_node.alloc_device(size));
  100. }
  101. MGB_CATCH(MemAllocError&, {})
  102. }
  103. size_t DeviceMemoryAllocator::static_alloc_version(ComputingGraph*) const {
  104. return 0;
  105. }
  106. /* ========================== ComputingGraph ========================== */
  107. ComputingGraph::ComputingGraph() {
  108. static std::atomic_size_t tot_id{0};
  109. m_id = (tot_id++);
  110. }
  111. void ComputingGraph::assert_destroy(std::shared_ptr<ComputingGraph>& ptr) {
  112. mgb_assert(
  113. ptr.use_count() <= 2, "unexpected use_count: %zu", size_t(ptr.use_count()));
  114. ptr.reset();
  115. }
  116. #if !MGB_THREAD_SAFE
  117. size_t ComputingGraph::prealloc_static_storage(size_t size) {
  118. // note that in single-threaded mode, all cpus map to the same comp node
  119. static int version = 0;
  120. auto cn = CompNode::load("cpu0");
  121. mgb_assert(cn == CompNode::load("cpu1"));
  122. auto inst = StaticDeviceMemoryManager::make_default_impl();
  123. auto ret = inst->get_size(cn);
  124. inst->alloc(nullptr, cn, size, version).ptr();
  125. version = inst->version(nullptr);
  126. return ret;
  127. }
  128. #endif
  129. /* ========================== JITConfig ========================== */
  130. bool ComputingGraph::Options::GraphOpt::JITConfig::enabled() const {
  131. if (fuse_dimshuffle != UNSET)
  132. return true;
  133. if (fuse_reduce != UNSET)
  134. return true;
  135. return false;
  136. }
  137. void ComputingGraph::Options::GraphOpt::JITConfig::update(const JITConfig& modifier) {
  138. if (modifier.fuse_dimshuffle != UNSET) {
  139. this->fuse_dimshuffle = modifier.fuse_dimshuffle;
  140. }
  141. if (modifier.fuse_reduce != UNSET) {
  142. this->fuse_reduce = modifier.fuse_reduce;
  143. }
  144. }
  145. /* ========================== CallbackCaller ========================== */
  146. MGB_DEFINE_OPR_CLASS(
  147. ComputingGraphImpl::CallbackCaller, SingleCNOperatorNodeBase) // {
  148. std::vector<std::vector<ComputingGraph::Callback>> m_cb;
  149. void scn_do_execute() override {
  150. for (size_t i = 0; i < input().size(); ++i) {
  151. auto&& in = input(i)->dev_tensor();
  152. for (auto&& callback : m_cb[i]) {
  153. // const cast for backward API compatibility
  154. callback(const_cast<DeviceTensorND&>(in));
  155. }
  156. }
  157. }
  158. void init_output_static_infer_desc() override {
  159. using namespace cg::static_infer;
  160. owner_graph()->static_infer_manager().register_shape_infer(
  161. output(0), ShapeInferDesc::make_const({}));
  162. }
  163. void add_input_layout_constraint() override {
  164. if (owner_graph()->options().comp_node_seq_record_level) {
  165. // the user callback usually copies from device to host, which
  166. // involves tmp alloc if input is not contiguous
  167. for (auto&& inp : input()) {
  168. inp->add_layout_constraint_contiguous();
  169. }
  170. }
  171. }
  172. void init_output_dtype() override {
  173. if (output(0)->dtype().valid()) {
  174. return;
  175. }
  176. mgb_assert(!input().empty());
  177. DType dtype = input(0)->dtype();
  178. mgb_assert(dtype.valid() && dtype != dtype::Byte());
  179. output(0)->dtype(dtype);
  180. }
  181. NodeProp* do_make_node_prop() const override {
  182. auto ret = Super::do_make_node_prop();
  183. for (auto&& inp : input()) {
  184. ret->add_dep_type_existing_var(inp, NodeProp::DepType::VALUE_ALLOW_EMPTY);
  185. }
  186. return ret;
  187. }
  188. bool update_priority() const override {
  189. node_prop().attribute().priority = std::numeric_limits<int>::min();
  190. return true;
  191. }
  192. public:
  193. CallbackCaller(const VarNodeArrayView& inp)
  194. : Super{inp[0]->owner_graph(), {}, "callback", inp} {
  195. mgb_assert(!inp.empty());
  196. m_cb.resize(inp.size());
  197. for (auto&& i : inp) {
  198. add_input({i});
  199. }
  200. using F = VarNode::Flag;
  201. add_output(None)->add_flag(F::ALLOW_EMPTY_SHAPE).add_flag(F::VOLATILE_CONTENT);
  202. }
  203. static SymbolVar make(const VarNodeArrayView& inp) {
  204. mgb_assert(!inp.empty());
  205. return SymbolVar{inp[0]}
  206. .node()
  207. ->owner_graph()
  208. ->insert_opr(std::make_unique<CallbackCaller>(inp))
  209. ->output(0);
  210. }
  211. void add_callback(const ComputingGraph::Callback& cb, size_t i = 0) {
  212. mgb_assert(cb && i < m_cb.size());
  213. m_cb[i].push_back(cb);
  214. }
  215. void clear_callback() {
  216. for (size_t i = 0; i < m_cb.size(); ++i) {
  217. m_cb[i].clear();
  218. }
  219. }
  220. };
  221. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ComputingGraphImpl::CallbackCaller);
  222. /* ========================== ComputingGraphImpl ========================== */
  223. ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner)
  224. : topo_sorter{owner},
  225. var_node_mem_manager{owner},
  226. seq_comp_node_opt{owner},
  227. static_infer_manager{owner},
  228. static_infer_comp_seq_manager{owner},
  229. grad_manager{owner},
  230. #if MGB_ENABLE_SUBLINEAR
  231. seq_modifier_for_sublinear_memory{
  232. owner, &(owner->options().sublinear_mem_config)},
  233. #endif
  234. #if MGB_ENABLE_DTR
  235. seq_modifier_for_dtr{owner, &(owner->options().dtr_config)},
  236. #endif
  237. #if MGB_ENABLE_MEMORY_SWAP
  238. memory_swap_support{owner},
  239. #endif
  240. eager_eval_manager{owner}
  241. {
  242. }
  243. ComputingGraphImpl::ComputingGraphImpl() {
  244. auto ptr = new (&m_components_storage) Components{this};
  245. mgb_assert(ptr == &components());
  246. }
  247. ComputingGraphImpl::~ComputingGraphImpl() {
  248. if (!is_finalized()) {
  249. cleanup();
  250. }
  251. }
  252. std::shared_ptr<void> ComputingGraphImpl::on_comp_node_finalize() {
  253. // hold a reference because the object itself may be deleted by user data or
  254. // oprs
  255. std::shared_ptr<void> ref = shared_from_this();
  256. cleanup();
  257. return ref;
  258. }
  259. void ComputingGraphImpl::cleanup() {
  260. if (m_recorded_seq_level2_dtor_chk) {
  261. m_recorded_seq_level2_dtor_chk->enable();
  262. }
  263. // clear device memory storage and return them to comp node
  264. clear_device_memory();
  265. // so opr dtors would incur no overhead when deleting vars
  266. m_var_node_pool.disable_freelist();
  267. // TODO: call this after each graph exec when we have faster impl
  268. CompNode::try_coalesce_all_free_memory();
  269. options().user_data.clear_all_user_data();
  270. components().~Components();
  271. m_var_receiver.clear();
  272. m_opr_refkeeper.clear();
  273. }
  274. void* ComputingGraphImpl::alloc_varnode_storage() {
  275. return m_var_node_pool.alloc_raw();
  276. };
  277. void ComputingGraphImpl::free_varnode_storage(void* ptr) {
  278. m_var_node_pool.free_raw(ptr);
  279. };
  280. OperatorNodeBase* ComputingGraphImpl::insert_opr(
  281. std::unique_ptr<OperatorNodeBase> opr_uniqp) {
  282. auto opr = opr_uniqp.get();
  283. if (options().imperative_proxy_graph) {
  284. if (!opr->inserted_in_graph()) {
  285. m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
  286. opr->set_inserted_in_graph();
  287. opr->init_output_comp_node();
  288. opr->init_output_dtype();
  289. opr->init_output_format();
  290. // register static infer
  291. {
  292. auto&& mgr = static_infer_manager_impl();
  293. auto old = mgr.set_register_allowed_opr(opr);
  294. opr->init_output_static_infer_desc();
  295. mgr.set_register_allowed_opr(old);
  296. }
  297. }
  298. return opr;
  299. }
  300. if (opr->inserted_in_graph()) {
  301. // FIXME: it's just a trick used for re-evaluation in eager evaluation
  302. // mode. Since comp_graph has already taken an ownership of the opr,
  303. // we can release it directly.
  304. mgb_throw_if(
  305. #if MGB_BUILD_SLIM_SERVING
  306. true,
  307. #else
  308. !options().eager_evaluation,
  309. #endif
  310. GraphError,
  311. "an inserted opr %s re-insert into graph"
  312. "with eager evaluation mode OFF.",
  313. opr->cname());
  314. opr_uniqp.release();
  315. // No need to do the insert_post under eager mode
  316. eager_eval_manager().on_opr_insert(opr);
  317. return opr;
  318. }
  319. auto&& infer_mgr = static_infer_manager_impl();
  320. auto cleanup = [&]() {
  321. infer_mgr.set_register_allowed_opr(nullptr);
  322. for (auto i : opr->output()) {
  323. infer_mgr.clear_tag_handler(i);
  324. var_node_mem_manager().remove_var_node_mem_trait(i);
  325. }
  326. };
  327. if (auto ret = graph_optimizer().insert_pre(opr)) {
  328. bool should_update_shape = true;
  329. #if !MGB_BUILD_SLIM_SERVING
  330. // in normal mode, we update the shape in deduplication in case shape
  331. // changes; in eager evaluation mode, shape is set by EagerEvalManager
  332. // and should not be modified
  333. should_update_shape = !options().eager_evaluation;
  334. #endif
  335. if (should_update_shape) {
  336. update_output_shapes(infer_mgr, ret, false);
  337. }
  338. cleanup();
  339. event().signal_inplace<cg::event::OprInserted>(true, ret, nullptr);
  340. ret = graph_optimizer().insert_post(ret);
  341. eager_eval_manager().on_opr_insert(ret);
  342. return ret;
  343. }
  344. // record opr early, since exceptions may refer to the opr
  345. m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
  346. MGB_TRY {
  347. mgb_assert(!opr->inserted_in_graph());
  348. mgb_assert(!opr->output().empty(), "operator must have at least one output");
  349. opr->set_inserted_in_graph();
  350. // basic init
  351. opr->init_output_comp_node();
  352. opr->init_output_dtype();
  353. opr->init_output_format();
  354. // check output initialized
  355. for (auto i : opr->output()) {
  356. mgb_assert(i->comp_node().valid() && i->dtype().valid());
  357. }
  358. // register static infer
  359. {
  360. auto old = infer_mgr.set_register_allowed_opr(opr);
  361. opr->init_output_static_infer_desc();
  362. infer_mgr.set_register_allowed_opr(old);
  363. }
  364. // more init
  365. opr->init_rt_force_dynamic_mem_alloc_imply_chain();
  366. // freeze output flag and static infer shape eagerly
  367. update_output_shapes(infer_mgr, opr, true);
  368. check_opr_not_cross_mem(opr);
  369. }
  370. MGB_CATCH(MegBrainError & exc, {
  371. cleanup();
  372. if (!exc.extra_info())
  373. OperatorNodeExcExtraInfo::record(opr, exc);
  374. event().signal_inplace<cg::event::OprInserted>(false, opr, &exc);
  375. throw;
  376. })
  377. // add to receiver list if above succeeds
  378. for (auto&& i : opr->input()) {
  379. auto iter = m_var_receiver.find(i);
  380. mgb_assert(iter != m_var_receiver.end());
  381. auto&& arr = iter->second;
  382. if (arr.empty() || arr.back() != opr) {
  383. // check if added, because opr may have identical inputs
  384. arr.push_back(opr);
  385. }
  386. }
  387. // alloc var receiver for the outputs
  388. for (auto&& i : opr->output()) {
  389. bool em = m_var_receiver[i].empty();
  390. mgb_assert(em);
  391. }
  392. event().signal_inplace<cg::event::OprInserted>(false, opr, nullptr);
  393. opr = graph_optimizer().insert_post(opr);
  394. eager_eval_manager().on_opr_insert(opr);
  395. return opr;
  396. }
  397. std::shared_ptr<ComputingGraph> ComputingGraph::make() {
  398. return std::make_shared<ComputingGraphImpl>();
  399. }
  400. std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile(
  401. const OutputSpec& out_spec) {
  402. return compile_commit(compile_prepare(out_spec));
  403. }
  404. SmallVector<std::unique_ptr<AsyncExecutable>> ComputingGraphImpl::compile_multi_part(
  405. const SmallVector<OutputSpec>& out_specs) {
  406. #if MGB_ENABLE_PARTIAL_EXECUTION
  407. return MultiPartCompiler{this}.compile(out_specs);
  408. #else
  409. mgb_throw(MegBrainError, "partial execution disabled at compile time");
  410. #endif
  411. }
  412. void ComputingGraphImpl::dest_var_optimize(VarNodeArray& dest_vars) {
  413. using F = VarNode::Flag;
  414. if (dest_vars[0]->owner_graph()->options().force_output_dynamic_alloc) {
  415. for (auto&& i : dest_vars) {
  416. if (!i->contain_flag(F::NO_SYS_MEM_ALLOC | F::NO_SYS_STATIC_MEM_ALLOC)) {
  417. mgb_assert(
  418. !i->contain_flag(F::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC),
  419. "Can not force graph output dynamic alloc with "
  420. "DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC flag, var: %s",
  421. i->cname());
  422. i->add_flag(F::NO_SYS_STATIC_MEM_ALLOC);
  423. }
  424. i->add_flag(F::NO_MEM_RECLAIM);
  425. }
  426. }
  427. if (dest_vars[0]->owner_graph()->options().force_output_use_user_specified_memory) {
  428. for (auto&& i : dest_vars) {
  429. mgb_assert(
  430. !i->contain_flag(F::RT_FORCE_DYNAMIC_MEM_ALLOC),
  431. "var %s with RT_FORCE_DYNAMIC_MEM_ALLOC flag should not set "
  432. "force write output to user memory",
  433. i->cname());
  434. i->add_flag(
  435. F::NO_SYS_MEM_ALLOC | F::NO_SYS_STATIC_MEM_ALLOC |
  436. F::NO_MEM_RECLAIM);
  437. }
  438. }
  439. }
  440. ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
  441. const OutputSpec& out_spec) {
  442. auto&& cmpnt = components();
  443. mgb_throw_if(
  444. m_recorded_seq_level2_dtor_chk, GraphError,
  445. "graphs with comp_node_seq_record_level==2 can only be "
  446. "compiled once");
  447. mgb_throw_if(
  448. out_spec.empty(), GraphError,
  449. "empty output spec given to ComputingGraph::compile");
  450. // topo sorter may have modified opr properties; restore them before this
  451. // new compiling
  452. topo_sorter().restore_opr_prop();
  453. cmpnt.seq_comp_node_opt.restore_comp_nodes();
  454. SpecialOprStat sopr_stat;
  455. auto dest_vars = get_dest_vars_from_out_spec(out_spec, sopr_stat);
  456. #if MGB_ENABLE_SUBLINEAR
  457. if (options().enable_sublinear_memory_opt) {
  458. mgb_assert(!options().enable_dtr_memory_opt);
  459. if (!sopr_stat.has_virtual_grad) {
  460. mgb_log_debug(
  461. "no virtual grad var; sublinear memory may produce "
  462. "unsatisfying result");
  463. }
  464. seq_modifier_for_sublinear_memory().set_priority_before_opt(dest_vars);
  465. }
  466. #else
  467. mgb_assert(!options().enable_sublinear_memory_opt);
  468. #endif // MGB_ENABLE_SUBLINEAR
  469. #if MGB_ENABLE_DTR
  470. if (options().enable_dtr_memory_opt) {
  471. mgb_assert(!options().enable_sublinear_memory_opt);
  472. seq_modifier_for_dtr().set_priority_before_opt(dest_vars);
  473. }
  474. #else
  475. mgb_assert(!options().enable_dtr_memory_opt);
  476. #endif // MGB_ENABLE_DTR
  477. #if !MGB_BUILD_SLIM_SERVING
  478. mgb_assert(
  479. !options().eager_evaluation, "attempt to compile eager_evaluation graph");
  480. {
  481. bool need_opt = std::abs(options().graph_opt_level) >= 2;
  482. gopt::GraphOptimizer optimizer;
  483. optimizer.verbosity(options().log_level);
  484. optimizer.enable_check_result(options().graph_opt_level < 0);
  485. if (sopr_stat.has_virtual_grad) {
  486. if (need_opt) {
  487. #if MGB_ENABLE_OPR_MM
  488. optimizer.add_pass<gopt::PackAllReduceScanPass>();
  489. #endif
  490. optimizer.add_preset_passes(false, nullptr, &options());
  491. }
  492. optimizer.add_pass<gopt::ExpandVirtualGradPass>();
  493. }
  494. if (need_opt) {
  495. optimizer.add_preset_passes(true, nullptr, &options());
  496. #if MGB_ENABLE_OPR_MM
  497. if (sopr_stat.has_virtual_grad) {
  498. optimizer.add_pass<gopt::PackAllReduceReplacePass>();
  499. }
  500. #endif
  501. }
  502. optimizer.apply_inplace(dest_vars);
  503. }
  504. #endif
  505. #if MGB_ENABLE_TENSOR_RT
  506. if (options().graph_opt.tensorrt) {
  507. options().graph_opt.tensorrt = false;
  508. tensorrt::transform_dest_vars_inplace(dest_vars, options().graph_opt);
  509. }
  510. #endif
  511. #if MGB_JIT
  512. if (std::abs(options().graph_opt_level) == 0 &&
  513. (options().graph_opt.jit || options().graph_opt.jit_config.enabled())) {
  514. // Deprecated usage added previously. It allows NVRTC JIT optimization
  515. // when graph_opt_level is 0. This usage is not recommanded any more.
  516. mgb_log_warn(
  517. "It is not recommanded to enable JIT optimization when "
  518. "graph_opt_level is 0.");
  519. setenv("MGB_JIT_BACKEND", "NVRTC", 1);
  520. gopt::GraphOptimizer optimizer;
  521. optimizer.add_pass<gopt::JITFusionPass>(
  522. sopr_stat.has_virtual_grad, options().graph_opt.jit,
  523. options().graph_opt.jit_config);
  524. optimizer.apply_inplace(dest_vars);
  525. }
  526. #endif
  527. gopt::GraphOptimizer optimizer;
  528. /**
  529. * \note We should reset options when we add passes indicated by optimize
  530. * options, As there exists `ParamFuse pass` will compile subgraph which may
  531. * cause ring invoking, \see
  532. * https://git-core.megvii-inc.com/brain-sdk/MegBrain/merge_requests/1717
  533. * for detail
  534. */
  535. optimizer.add_passes_for_optimize_options(options().graph_opt, true);
  536. optimizer.apply_inplace(dest_vars);
  537. if (sopr_stat.has_shape_hint) {
  538. // FIXME(zhangxuanrun): strictly speaking, it could and has to remove
  539. // ShapeHints even they were occured in subgraph
  540. mgb_assert(!m_parent_graph, "can not use ShapeHint in subgraph");
  541. // always need remove shape hint
  542. gopt::GraphOptimizer opt;
  543. opt.add_pass<gopt::RemoveShapeHintPass>();
  544. opt.apply_inplace(dest_vars);
  545. }
  546. const OprNodeArray* opr_seq = nullptr;
  547. CompSeqExtraInfo extra_info;
  548. cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars);
  549. bool init_flag = false;
  550. auto init_opr_seq = [&]() {
  551. mgb_assert(!init_flag);
  552. init_flag = true;
  553. ThinHashMap<VarNode*, size_t> var2idx;
  554. std::unordered_map<
  555. CallbackCallerKey, CallbackCallerVal, CallbackCallerKey::Hash>
  556. opr2vars;
  557. dest_var_optimize(dest_vars);
  558. for (size_t i = 0; i < out_spec.size(); ++i) {
  559. auto&& cb = out_spec[i].second;
  560. if (cb) {
  561. auto var = dest_vars[i];
  562. CallbackCallerKey key{var->owner_opr(), var->comp_node()};
  563. auto&& vals = opr2vars[key];
  564. auto&& var2idx_iter = var2idx.find(var);
  565. if (var2idx_iter == var2idx.end()) {
  566. vals.vars.push_back(var);
  567. vals.indexs.push_back({i});
  568. var2idx[var] = vals.vars.size() - 1;
  569. } else {
  570. vals.indexs[var2idx_iter->second].push_back(i);
  571. }
  572. }
  573. }
  574. for (auto& item : opr2vars) {
  575. auto&& val = item.second;
  576. auto dvar = CallbackCaller::make(val.vars);
  577. CallbackCaller* cb_caller =
  578. &dvar.node()->owner_opr()->cast_final_safe<CallbackCaller>();
  579. ++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req;
  580. cb_caller->clear_callback();
  581. for (size_t i = 0; i < val.vars.size(); ++i) {
  582. for (auto&& idx : val.indexs[i]) {
  583. cb_caller->add_callback(out_spec[idx].second, i);
  584. dest_vars[idx] = cb_caller->output(0);
  585. }
  586. }
  587. }
  588. opr_seq = topo_sorter().get_comp_seq(extra_info, dest_vars);
  589. };
  590. #if MGB_ENABLE_MEMORY_SWAP
  591. bool enable_swap_memory_after_sublinear =
  592. options().enable_sublinear_memory_opt && options().enable_memory_swap;
  593. bool enable_swap_memory_without_sublinear =
  594. !(options().enable_sublinear_memory_opt) && options().enable_memory_swap;
  595. if (enable_swap_memory_without_sublinear) {
  596. components().memory_swap_support.modify_dest_var_inplace(dest_vars);
  597. }
  598. #else
  599. mgb_assert(!options().enable_memory_swap);
  600. #endif
  601. #if MGB_ENABLE_DTR
  602. if (options().enable_dtr_memory_opt) {
  603. MGB_TRY {
  604. seq_modifier_for_dtr().modify_endpoint_vars(dest_vars);
  605. init_opr_seq();
  606. }
  607. MGB_FINALLY(seq_modifier_for_dtr().restore_graph_option());
  608. }
  609. #endif
  610. #if MGB_ENABLE_SUBLINEAR
  611. if (options().enable_sublinear_memory_opt) {
  612. MGB_TRY {
  613. seq_modifier_for_sublinear_memory().modify_endpoint_vars(dest_vars);
  614. #if MGB_ENABLE_MEMORY_SWAP
  615. if (enable_swap_memory_after_sublinear) {
  616. cmpnt.memory_swap_support.modify_dest_var_inplace(dest_vars);
  617. }
  618. #endif
  619. init_opr_seq();
  620. }
  621. MGB_FINALLY(
  622. /*
  623. * restore graph option immediately because it may be
  624. * read/modified by user
  625. */
  626. seq_modifier_for_sublinear_memory().restore_graph_option());
  627. seq_modifier_for_sublinear_memory().sanity_check(*opr_seq);
  628. }
  629. #endif // MGB_ENABLE_SUBLINEAR
  630. if (!init_flag) {
  631. init_opr_seq();
  632. }
  633. return {std::move(extra_info), opr_seq, std::move(dest_vars)};
  634. }
  635. std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile_commit(
  636. CompileState state) {
  637. auto comp_seq = std::make_unique<ComputingSequence>(shared_from_this());
  638. comp_seq->extra_info = std::move(state.extra_info);
  639. comp_seq->set_output_vars(state.dest_vars);
  640. auto opr_seq = state.opr_seq;
  641. auto&& cmpnt = components();
  642. comp_seq->setup_opr_seq(opr_seq);
  643. for (auto&& i : *opr_seq) {
  644. for (auto&& j : i->node_prop().dep_map()) {
  645. if (OperatorNodeBase::NodeProp::is_device_value_dep(j.second)) {
  646. comp_seq->extra_info.var2recvinfo.at(j.first).last_dev_value_reader = i;
  647. }
  648. }
  649. }
  650. comp_seq->attach_to_graph();
  651. MGB_TRY {
  652. var_node_mem_manager().reset_opr_seq(comp_seq->extra_info, opr_seq);
  653. static_infer_comp_seq_manager().reset_dest(comp_seq->extra_info);
  654. cmpnt.seq_comp_node_opt.init_ready_event(comp_seq->extra_info, *opr_seq);
  655. if (options().allocate_static_mem_after_graph_compile)
  656. var_node_mem_manager().alloc_var_node_mem_static();
  657. }
  658. MGB_FINALLY({ var_node_mem_manager().on_graph_compile_finished(); });
  659. event().signal_inplace<event::CompSeqOrderDetermined>(this, comp_seq.get());
  660. if (options().comp_node_seq_record_level > 1) {
  661. mgb_assert(
  662. options().comp_node_seq_record_level <= 2,
  663. "invalid comp_node_seq_record_level: %u",
  664. options().comp_node_seq_record_level);
  665. mgb_assert(
  666. !options().fake_next_exec && !options().var_sanity_check_first_run,
  667. "both fake_next_exec and var_sanity_check_first_run "
  668. "must be false when comp_node_seq_record_level is 2");
  669. return comp_seq->as_recorded_seq();
  670. }
  671. return comp_seq;
  672. }
  673. VarNodeArray ComputingGraphImpl::get_dest_vars_from_out_spec(
  674. const OutputSpec& spec, SpecialOprStat& sopr_stat) {
  675. SymbolVarArray sym_vars;
  676. for (auto&& i : spec) {
  677. sym_vars.push_back(i.first);
  678. }
  679. return to_var_node_array(get_dest_vars_with_extra_deps(sym_vars, &sopr_stat));
  680. }
  681. const ComputingGraph::VarReceiverInfo& ComputingGraphImpl::
  682. var_receiver_in_current_comp_seq(const VarNode* var) const {
  683. static VarReceiverInfo empty;
  684. if (auto ret = components().eager_eval_manager.var_receiver_info(var)) {
  685. return *ret;
  686. }
  687. if (!m_current_comp_seq)
  688. return empty;
  689. auto cseq = static_cast<ComputingSequence*>(m_current_comp_seq);
  690. auto iter = cseq->extra_info.var2recvinfo.find(var);
  691. if (iter == cseq->extra_info.var2recvinfo.end())
  692. return empty;
  693. return iter->second;
  694. }
  695. VarNode* ComputingGraphImpl::find_var_by_id(size_t id) const {
  696. for (auto&& i : m_opr_refkeeper) {
  697. for (auto j : i->output()) {
  698. if (j->id() == id)
  699. return j;
  700. }
  701. }
  702. for (auto&& i : m_subgraphs) {
  703. auto sub = i->find_var_by_id(id);
  704. if (sub)
  705. return sub;
  706. }
  707. return nullptr;
  708. }
  709. #if MGB_ENABLE_SUBLINEAR
  710. SeqModifierForSublinearMemory& ComputingGraphImpl::seq_modifier_for_sublinear_memory() {
  711. return components().seq_modifier_for_sublinear_memory;
  712. }
  713. #endif
  714. #if MGB_ENABLE_DTR
  715. SeqModifierForDTR& ComputingGraphImpl::seq_modifier_for_dtr() {
  716. return components().seq_modifier_for_dtr;
  717. }
  718. #endif
  719. void ComputingGraphImpl::share_device_memory_with(ComputingGraph& other) {
  720. mgb_assert(
  721. !m_current_comp_seq,
  722. "share_device_memory_with must be called before compiling graph");
  723. auto&& oimpl = *ComputingGraphImpl::downcast(&other);
  724. var_node_mem_manager().static_device_memory_manager(
  725. oimpl.var_node_mem_manager().static_device_memory_manager());
  726. }
  727. void ComputingGraphImpl::set_device_memory_allocator(
  728. std::shared_ptr<DeviceMemoryAllocator> allocator) {
  729. var_node_mem_manager().static_device_memory_manager()->set_allocator(
  730. std::move(allocator));
  731. }
  732. size_t ComputingGraphImpl::get_device_memory_size(CompNode cn) {
  733. return var_node_mem_manager().static_device_memory_manager()->get_size(cn);
  734. }
  735. size_t ComputingGraphImpl::clear_device_memory() {
  736. #if !MGB_BUILD_SLIM_SERVING
  737. if (options().eager_evaluation) {
  738. for (auto& opr : m_opr_refkeeper) {
  739. if (!opr->same_type<mgb::opr::SharedDeviceTensor>() &&
  740. !opr->same_type<mgb::opr::ImmutableTensor>()) {
  741. for (auto& var : opr->output()) {
  742. if (var->mem_plan().valid())
  743. var->mem_plan().release_chunk();
  744. }
  745. }
  746. }
  747. }
  748. #endif
  749. return var_node_mem_manager().clear_static_device_memory();
  750. }
  751. void ComputingGraphImpl::set_as_subgraph(ComputingGraph& par_graph) {
  752. m_parent_graph = ComputingGraphImpl::downcast(&par_graph);
  753. m_parent_graph->m_subgraphs.emplace_back(this);
  754. m_node_id_counter = m_parent_graph->m_node_id_counter;
  755. options().var_sanity_check_first_run =
  756. par_graph.options().var_sanity_check_first_run;
  757. par_graph.event().signal_inplace<event::SubgraphAssociated>(&par_graph, this);
  758. }
  759. void ComputingGraphImpl::record_async_error(std::unique_ptr<MegBrainError> async_exc) {
  760. mgb_assert(m_current_comp_seq);
  761. static_cast<ComputingSequence*>(m_current_comp_seq)
  762. ->set_async_error(std::move(async_exc));
  763. }
  764. const CompSeqExtraInfo& ComputingGraphImpl::current_comp_seq_extra_info() {
  765. if (auto ret = eager_eval_manager().comp_seq_extra_info()) {
  766. return *ret;
  767. }
  768. mgb_assert(m_current_comp_seq);
  769. return static_cast<ComputingSequence*>(m_current_comp_seq)->extra_info;
  770. }
  771. GraphExecutable::ExecEnv* ComputingGraphImpl::current_exec_env() {
  772. if (auto ret = eager_eval_manager().exec_env()) {
  773. return ret;
  774. }
  775. if (m_current_comp_seq) {
  776. return &static_cast<ComputingSequence*>(m_current_comp_seq)->exec_env();
  777. }
  778. return nullptr;
  779. }
  780. Maybe<size_t> ComputingGraphImpl::opr_step_num_in_cur_comp_seq(OperatorNodeBase* opr) {
  781. mgb_assert(m_current_comp_seq && opr->owner_graph() == this);
  782. return static_cast<ComputingSequence*>(m_current_comp_seq)->opr2stepnum(opr);
  783. }
  784. std::string ComputingGraphImpl::VarReceiverInfo::to_string() const {
  785. return mgb_ssprintf_log(
  786. "VarReceiverInfo("
  787. "nr_direct_comp_req=%zu dev_value=%zu, host_value=%zu, shape=%zu, "
  788. "allow_empty_value=%zu)",
  789. nr_direct_comp_req, dev_value, host_value, shape, allow_empty_value);
  790. }
  791. std::string ComputingGraphImpl::get_mem_allocation_info() const {
  792. #if MGB_ENABLE_JSON
  793. auto make_var_json = [](VarNode* single_var) {
  794. auto&& cur_mem_plan = single_var->mem_plan();
  795. if (cur_mem_plan.valid())
  796. return json::Object::make(
  797. {{"name", json::String::make(single_var->name())},
  798. {"memory", json::Number::make(cur_mem_plan.chunk().size())},
  799. {"dev_ptr", json::NumberInt::make(reinterpret_cast<size_t>(
  800. single_var->dev_tensor().raw_ptr()))}});
  801. else
  802. return json::Object::make(
  803. {{"name", json::String::make(single_var->name())},
  804. {"memory", json::Null::make()},
  805. {"dev_ptr", json::Null::make()}});
  806. };
  807. auto objlist = json::Array::make();
  808. for (auto& opri : m_opr_refkeeper) {
  809. auto cur_opr = opri.get();
  810. auto objptr = json::Object::make();
  811. auto&& objbody = *objptr;
  812. objbody["name"] = json::String::make(cur_opr->name());
  813. auto jvars = json::Array::make();
  814. for (auto& outputi : cur_opr->output()) {
  815. jvars->add(make_var_json(outputi));
  816. }
  817. objbody["output"] = jvars;
  818. auto obj = json::Object::make({{std::to_string(cur_opr->id()), objptr}});
  819. objlist->add(obj);
  820. }
  821. return objlist->to_string();
  822. #endif // MGB_ENABLE_JSON
  823. mgb_log_warn(
  824. "target is not configured with JSON BUILD on,"
  825. "get_mem_allocation_info returns null string");
  826. return std::string();
  827. }
  828. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}