You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cpp 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. /**
  2. * \file src/core/impl/graph/helper.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/graph/helper.h"
  12. #include "megbrain/gopt/framework.h"
  13. #include "megbrain/opr/utility.h"
  14. #include "megbrain/serialization/opr_shallow_copy.h"
  15. #include "./cg_impl.h"
  16. using namespace mgb;
  17. using namespace cg;
  18. /* =================== global functions =================== */
  19. CompNode::UnorderedSet cg::get_opr_comp_node_set(OperatorNodeBase *opr) {
  20. CompNode::UnorderedSet rst;
  21. for (auto i: opr->output())
  22. rst.insert(i->comp_node());
  23. if (opr->node_prop().contain(
  24. OperatorNodeBase::NodeProp::Flag::SINGLE_COMP_NODE))
  25. mgb_assert(rst.size() == 1);
  26. return rst;
  27. }
  28. bool cg::is_all_input_static_storage(OperatorNodeBase* opr) {
  29. for (auto&& i : opr->node_prop().dep_map())
  30. if (i.second != OperatorNodeBase::NodeProp::DepType::DEV_COMP_ORDER &&
  31. !is_static_var_storage(i.first))
  32. return false;
  33. return true;
  34. }
  35. VarNodeArray cg::to_var_node_array(const SymbolVarArray& symbol_var_array) {
  36. VarNodeArray var_node_array(symbol_var_array.size());
  37. for (size_t i = 0; i < symbol_var_array.size(); ++i) {
  38. var_node_array[i] = symbol_var_array[i].node();
  39. }
  40. return var_node_array;
  41. }
  42. SymbolVarArray cg::to_symbol_var_array(const VarNodeArray& var_node_array) {
  43. SymbolVarArray symbol_var_array(var_node_array.size());
  44. for (size_t i = 0; i < var_node_array.size(); ++i) {
  45. symbol_var_array[i] = var_node_array[i];
  46. }
  47. return symbol_var_array;
  48. }
  49. std::string cg::dump_var_info(const VarNodeArrayView &vars) {
  50. std::string rst;
  51. int idx = 0;
  52. for (auto i: vars) {
  53. if (!rst.empty())
  54. rst.append(" ");
  55. auto opr = i->owner_opr();
  56. if (vars.size() > 1)
  57. rst.append(ssprintf("%d=", idx ++));
  58. bool valid = i->dev_tensor_valid();
  59. auto slot = find(opr->output(), i) - opr->output().begin();
  60. auto &&it = i->owner_graph()->static_infer_manager().get_infer_type(i);
  61. rst.append(ssprintf(
  62. "{id:%zu, %s:%s, %s, "
  63. "owner:%s{%s}, name:%s, slot:%td, %s, %c, %d, %d}",
  64. i->id(),
  65. valid ? "layout": "shape",
  66. valid ? i->layout().to_string().c_str() :
  67. i->shape().to_string().c_str(),
  68. i->dtype().name(),
  69. opr->cname(), opr->dyn_typeinfo()->name,
  70. i->cname(),
  71. slot,
  72. i->comp_node().to_string().c_str(),
  73. cg::is_static_var_storage(i) ? 's' : 'd',
  74. static_cast<int>(it.shape), static_cast<int>(it.value)
  75. ));
  76. }
  77. return rst;
  78. }
  79. SymbolVar cg::grad(SymbolVar target, SymbolVar wrt, bool warn_mid_wrt,
  80. bool return_zero_for_nodep) {
  81. return grad(target, SymbolVarArray{wrt},
  82. warn_mid_wrt, return_zero_for_nodep)[0];
  83. }
  84. SymbolVarArray cg::grad(SymbolVar target_, SymbolVarArray wrts_, bool warn_mid_wrt,
  85. bool return_zero_for_nodep) {
  86. #if MGB_ENABLE_GRAD
  87. auto target = target_.node();
  88. SymbolVarArray grads;
  89. grads.reserve(wrts_.size());
  90. VarNodeArray dest_vars;
  91. auto&& graph = target->owner_graph();
  92. auto&& eager_mgr = ComputingGraphImpl::downcast(graph)->eager_eval_manager();
  93. auto&& grad_mgr = ComputingGraphImpl::downcast(graph)->grad_manager();
  94. bool already_recorded = eager_mgr.enter_record_mode();
  95. for (auto&& wrt_ : wrts_) {
  96. auto wrt = wrt_.node();
  97. if (warn_mid_wrt && wrt->owner_opr()->input().size()) {
  98. mgb_log_warn("taking gradient with respect to an intermediate node may "
  99. "produce incorrect results (for example, when it is produced "
  100. "by subtensor); node: %s",
  101. cg::dump_var_info({wrt}).c_str());
  102. }
  103. mgb_throw_if(graph != wrt->owner_graph(), GraphError,
  104. "target and wrt must belong to the same graph");
  105. auto rst = grad_mgr.grad(target, wrt);
  106. if (!rst && return_zero_for_nodep) {
  107. mgb_log_warn("target node (%s) does not depend on wrt node (%s), "
  108. "return zeros as grad", cg::dump_var_info({target}).c_str(),
  109. cg::dump_var_info({wrt}).c_str());
  110. rst = (wrt_ * 0).node();
  111. }
  112. if (rst)
  113. dest_vars.push_back(rst);
  114. grads.emplace_back(rst);
  115. }
  116. if (!already_recorded && eager_mgr.enabled()) {
  117. eager_mgr.flush_record_oprs(dest_vars);
  118. grad_mgr.clean_cache();
  119. }
  120. return grads;
  121. #else
  122. MGB_MARK_USED_VAR(target_);
  123. MGB_MARK_USED_VAR(wrts_);
  124. MGB_MARK_USED_VAR(warn_mid_wrt);
  125. MGB_MARK_USED_VAR(return_zero_for_nodep);
  126. mgb_throw(MegBrainError, "grad disabled at compile time");
  127. #endif
  128. }
  129. SymbolVar cg::current_grad_target(ComputingGraph &graph) {
  130. #if MGB_ENABLE_GRAD
  131. auto var = ComputingGraphImpl::downcast(&graph)->grad_manager(
  132. ).current_grad_target();
  133. mgb_throw_if(!var, GraphError, "current_grad_target() called outside "
  134. "grad computing environment");
  135. return var;
  136. #else
  137. MGB_MARK_USED_VAR(graph);
  138. mgb_throw(MegBrainError, "grad disabled at compile time");
  139. #endif
  140. }
  141. SymbolVarArray cg::get_dest_vars_with_extra_deps(
  142. const SymbolVarArray& dest_vars, SpecialOprStat* sopr_stat) {
  143. return ExtraDependencyMerger{sopr_stat}.add(dest_vars);
  144. }
  145. namespace {
  146. SymbolVarArray replace_vars_internal(
  147. const SymbolVarArray& dest,
  148. thin_function<void(OperatorNodeBase*,
  149. gopt::SubGraph::Rewriter&)> on_opr) {
  150. if (dest.empty()) {
  151. return dest;
  152. }
  153. // check that they belong to the same graph
  154. mgb_assert(dest[0].node());
  155. auto og = dest[0].node()->owner_graph();
  156. for (auto i : dest) {
  157. mgb_assert(i.node() && i.node()->owner_graph() == og);
  158. }
  159. auto dest_with_extra_deps = get_dest_vars_with_extra_deps(dest);
  160. // do the replace
  161. gopt::SubGraph graph{dest_with_extra_deps};
  162. auto rewriter = graph.make_rewriter();
  163. graph.iter([&](OperatorNodeBase* opr){ on_opr(opr, rewriter); });
  164. auto new_og = rewriter.get_var(dest[0].node())->owner_graph();
  165. auto &&old_extra_vardeps = og->options().extra_vardeps,
  166. &&new_extra_vardeps = new_og->options().extra_vardeps;
  167. auto on_opr_replace_dep = [&](OperatorNodeBase* opr) {
  168. for (auto i : opr->output()) {
  169. auto new_node = rewriter.get_var(i);
  170. auto iter = old_extra_vardeps.find(i);
  171. if (iter == old_extra_vardeps.end())
  172. continue;
  173. if (new_node == i) {
  174. for (const auto& dep : iter->second) {
  175. auto new_dep = rewriter.get_var(dep);
  176. mgb_assert(dep == new_dep,
  177. "var %s is not replaced, but its extra "
  178. "dependency %s is replaced by %s ",
  179. cg::dump_var_info({i}).c_str(),
  180. cg::dump_var_info({dep}).c_str(),
  181. cg::dump_var_info({new_dep}).c_str());
  182. }
  183. } else {
  184. auto& new_deps = new_extra_vardeps[new_node];
  185. for (const auto& dep : iter->second) {
  186. new_deps.push_back(rewriter.get_var(dep));
  187. }
  188. }
  189. }
  190. };
  191. if (dest_with_extra_deps.size() != dest.size())
  192. graph.iter(on_opr_replace_dep);
  193. rewriter.apply_inplace();
  194. auto ret = graph.endpoint_vars();
  195. ret.resize(dest.size());
  196. return ret;
  197. }
  198. } //namespace
  199. SymbolVarArray cg::replace_oprs(
  200. const SymbolVarArray& dest,
  201. const ThinHashMap<OperatorNodeBase*, OperatorNodeBase*>& oprmap) {
  202. if (oprmap.empty() || dest.empty()) {
  203. return dest;
  204. }
  205. mgb_assert(dest[0].node());
  206. auto graph = dest[0].node()->owner_graph();
  207. for (auto i : dest) {
  208. mgb_assert(i.node() && i.node()->owner_graph() == graph,
  209. "Dest should all be in same graph");
  210. }
  211. for (auto&& i : oprmap) {
  212. mgb_assert(i.first->owner_graph() == graph &&
  213. i.second->owner_graph() == graph,
  214. "Original and dest operators in oprmap should all be in "
  215. "same graph");
  216. }
  217. ThinHashMap<SymbolVar, SymbolVar> varmap;
  218. for (auto&& p : oprmap) {
  219. const auto& outputs0 = p.first->usable_output();
  220. const auto& outputs1 = p.second->usable_output();
  221. mgb_assert(outputs0.size() == outputs1.size(),
  222. "Number of outputs differ: old operator %s has %zu outputs, "
  223. "while new operator %s has %zu outputs.",
  224. p.first->name().c_str(), outputs0.size(),
  225. p.second->name().c_str(), outputs1.size());
  226. for (size_t i = 0; i < outputs0.size(); i++) {
  227. varmap[outputs0[i]] = outputs1[i];
  228. }
  229. }
  230. return replace_vars(dest, varmap);
  231. }
  232. SymbolVarArray cg::replace_vars(
  233. const SymbolVarArray& dest,
  234. const ThinHashMap<SymbolVar, SymbolVar>& varmap) {
  235. if (varmap.empty())
  236. return dest;
  237. auto og = dest[0].node()->owner_graph();
  238. for (auto&& i : varmap) {
  239. mgb_assert(i.first.node() && i.second.node() &&
  240. i.first.node()->owner_graph() == og &&
  241. i.second.node()->owner_graph() == og);
  242. }
  243. auto on_opr = [&](OperatorNodeBase* opr,
  244. gopt::SubGraph::Rewriter& rewriter) {
  245. for (auto i : opr->output()) {
  246. auto viter = varmap.find(i);
  247. if (viter != varmap.end()) {
  248. rewriter.replace_var(i, viter->second.node(), nullptr);
  249. }
  250. }
  251. rewriter.auto_replace_outputs(opr);
  252. };
  253. return replace_vars_internal(dest, on_opr);
  254. }
  255. SymbolVarArray cg::replace_vars_comp_graph(
  256. const SymbolVarArray &dest, ComputingGraph* new_graph) {
  257. ComputingGraph *orig_graph = dest[0].node()->owner_graph();
  258. mgb_assert(new_graph != orig_graph);
  259. auto on_opr = [&](OperatorNodeBase* opr,
  260. gopt::SubGraph::Rewriter& rewriter) {
  261. OperatorNodeBase* new_opr;
  262. if (opr->input().size()) {
  263. rewriter.auto_replace_outputs(opr);
  264. } else {
  265. mgb_assert(opr->owner_graph() != new_graph);
  266. new_opr = serialization::copy_opr_shallow(
  267. *opr, {}, opr->config(), {new_graph});
  268. auto &&out0 = opr->output(), &&out1 = new_opr->output();
  269. mgb_assert(out0.size() == out1.size());
  270. for (size_t i = 0; i < out0.size(); ++ i) {
  271. rewriter.replace_var(out0[i], out1[i], "replace comp graph.");
  272. }
  273. }
  274. };
  275. return replace_vars_internal(dest, on_opr);
  276. }
  277. SymbolVarArray cg::find_h2d(const SymbolVarArray& dest) {
  278. mgb_assert(!dest.empty());
  279. SymbolVarArray h2d;
  280. auto on_opr = [&](OperatorNodeBase* opr) {
  281. if (opr->same_type<opr::Host2DeviceCopy>()) {
  282. h2d.emplace_back(opr->output(0));
  283. }
  284. };
  285. // check that they belong to the same graph
  286. mgb_assert(dest[0].node());
  287. auto og = dest[0].node()->owner_graph();
  288. for (auto i : dest) {
  289. mgb_assert(i.node() && i.node()->owner_graph() == og);
  290. }
  291. auto dest_with_extra_deps = get_dest_vars_with_extra_deps(dest);
  292. gopt::SubGraph graph{dest_with_extra_deps};
  293. graph.iter([&](OperatorNodeBase* opr){ on_opr(opr); });
  294. return h2d;
  295. }
  296. OperatorNodeBase* cg::get_opr_root_source_opr(OperatorNodeBase *opr) {
  297. auto &&attr = opr->node_prop().attribute();
  298. if (!attr.src_opr)
  299. return opr;
  300. auto orig = attr.src_opr;
  301. mgb_assert(orig != opr);
  302. return attr.src_opr = get_opr_root_source_opr(orig);
  303. }
  304. cg::MemPlanIntersectionType cg::get_mem_plan_intersection_type(
  305. VarNode* a, VarNode *b) {
  306. auto &&m0 = a->mem_plan(), &&m1 = b->mem_plan();
  307. if (&m0.chunk() != &m1.chunk())
  308. return MemPlanIntersectionType::DISJOINT;
  309. auto get_real_span = [](const MemAllocPlan &p) {
  310. auto span = p.layout().span();
  311. return std::make_pair(span.low_byte + p.offset_in_chunk_byte(),
  312. span.high_byte + p.offset_in_chunk_byte());
  313. };
  314. auto s0 = get_real_span(m0), s1 = get_real_span(m1);
  315. if (s0.first == s1.first && s0.second == s1.second)
  316. return MemPlanIntersectionType::IDENTICAL;
  317. if (s0.second <= s1.first || s1.second <= s0.first)
  318. return MemPlanIntersectionType::DISJOINT;
  319. return MemPlanIntersectionType::OVERLAP;
  320. }
  321. void cg::request_fwd_in2out_writable_if_no_mem_ovelap(
  322. OperatorNodeBase *opr, size_t inp, size_t out) {
  323. auto ivar = opr->input(inp), ovar = opr->output(out);
  324. if (is_static_var_storage(ivar) != is_static_var_storage(ovar)) {
  325. // If ovar is dynamic but there are other outputs of opr with static
  326. // storage, this function would be called during the static allocation
  327. // phase, and get_mem_plan_intersection_type() would fail.
  328. // So we just return here
  329. return;
  330. }
  331. auto &&dep_map = opr->node_prop().dep_map();
  332. using NP = OperatorNodeBase::NodeProp;
  333. mgb_assert(NP::is_device_value_dep(dep_map.at(ivar)));
  334. if (!ivar->layout().is_contiguous())
  335. return;
  336. using IT = MemPlanIntersectionType;
  337. for (size_t i = 0; i < opr->input().size(); ++ i) {
  338. auto iv = opr->input()[i];
  339. if (i != inp && NP::is_device_value_dep(dep_map.at(iv)) &&
  340. get_mem_plan_intersection_type(iv, ivar) != IT::DISJOINT) {
  341. return;
  342. }
  343. }
  344. ovar->set_fwd_in2out_writable(ivar);
  345. }
  346. void cg::add_workspace_output(OperatorNodeBase *opr) {
  347. opr->add_output("workspace")
  348. ->add_flag(VarNode::Flag::VOLATILE_CONTENT)
  349. .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
  350. .dtype(dtype::Byte());
  351. }
  352. void cg::copy_shape_to_tensor_value(
  353. DeviceTensorND &dest, const TensorShape &shp) {
  354. dest.comp_node(CompNode::default_cpu()).
  355. dtype(dtype::Int32()).
  356. resize({std::max<size_t>(1, shp.ndim)});
  357. auto ptr = dest.ptr<dt_int32>();
  358. if (!shp.ndim)
  359. ptr[0] = 0;
  360. else {
  361. for (size_t i = 0; i < shp.ndim; i ++)
  362. ptr[i] = shp.shape[i];
  363. }
  364. }
  365. void cg::copy_tensor_value_to_shape(
  366. TensorShape &dest, const DeviceTensorND &val) {
  367. constexpr size_t MAX_DT_SIZE = 4;
  368. mgb_assert(val.dtype().size() <= MAX_DT_SIZE);
  369. mgb_assert(val.shape().ndim == 1, "shape tensor must be 1-dim, got %s",
  370. val.shape().to_string().c_str());
  371. mgb_assert(val.comp_node().device_type() == CompNode::DeviceType::CPU);
  372. dest.ndim = val.shape(0);
  373. mgb_assert(dest.ndim <= TensorShape::MAX_NDIM);
  374. auto vptr = val.raw_ptr();
  375. dt_byte contig[MAX_DT_SIZE * TensorShape::MAX_NDIM];
  376. if (val.layout().stride[0] != 1) {
  377. auto dst = contig;
  378. auto dst_strd = val.dtype().size();
  379. auto src = val.raw_ptr();
  380. auto src_strd = val.layout().stride[0] * dst_strd;
  381. for (size_t i = 0; i < dest.ndim; ++ i) {
  382. memcpy(dst, src, dst_strd);
  383. dst += dst_strd;
  384. src += src_strd;
  385. }
  386. vptr = contig;
  387. }
  388. static_cast_dtype_safe(dest.shape, val.dtype(), vptr, dest.ndim);
  389. }
  390. SymbolVar cg::var_from_tensor_shape(
  391. ComputingGraph &graph, const OperatorNodeConfig &config,
  392. const char *opr_name, const TensorShape &shape) {
  393. auto cn = config.get_single_comp_node();
  394. mgb_throw_if(!cn.valid(), GraphError,
  395. "must specify comp node in %s config", opr_name);
  396. DeviceTensorND dv;
  397. copy_shape_to_tensor_value(dv, shape);
  398. HostTensorND hv{cn};
  399. hv.copy_from(dv);
  400. return opr::ImmutableTensor::make(graph, hv);
  401. }
  402. /* =================== DepOprIter =================== */
  403. void cg::DepOprIter::push_stack(OperatorNodeBase* opr) {
  404. if (m_visited.insert(opr).second) {
  405. if (m_extra_dep) {
  406. auto it = m_extra_dep->find(opr);
  407. if (it != m_extra_dep->end()) {
  408. m_stack.push_back({opr, opr->input().data(), it->second.data(),
  409. 0, opr->input().size(), it->second.size()});
  410. return;
  411. }
  412. }
  413. m_stack.push_back(
  414. {opr, opr->input().data(), nullptr, 0, opr->input().size(), 0});
  415. }
  416. }
  417. void cg::DepOprIter::add(OperatorNodeBase *dest) {
  418. if (!m_owner_graph) {
  419. m_owner_graph = dest->owner_graph();
  420. } else {
  421. mgb_assert(m_owner_graph == dest->owner_graph(),
  422. "dest oprs belong to different graphs");
  423. }
  424. push_stack(dest);
  425. while (!m_stack.empty()) {
  426. auto &&frame = m_stack.back();
  427. if (frame.inp_idx == frame.nr_input + frame.nr_extra_dep) {
  428. m_cb(frame.opr);
  429. m_stack.pop_back();
  430. } else {
  431. VarNode* inp = nullptr;
  432. if (frame.inp_idx < frame.nr_input) {
  433. inp = frame.inputs[frame.inp_idx ++];
  434. } else {
  435. inp = frame.extra_deps[frame.inp_idx - frame.nr_input];
  436. frame.inp_idx++;
  437. }
  438. push_stack(inp->owner_opr());
  439. }
  440. }
  441. }
  442. /* =================== InterGraphVarTransformer =================== */
  443. MGB_TYPEINFO_OBJ_IMPL(InterGraphVarTransformer);
  444. void InterGraphVarTransformer::register_to(ComputingGraph *dest,
  445. const ComputingGraph *src, const TransFunc &trans) {
  446. mgb_assert(dest && src && trans);
  447. mgb_assert(dest->id() > src->id(),
  448. "inter-graph trans only allowed from old graph to new graph");
  449. auto mk = []() {
  450. return std::shared_ptr<InterGraphVarTransformer>(
  451. new InterGraphVarTransformer);
  452. };
  453. auto ptr = dest->options().user_data.
  454. get_user_data_or_create<InterGraphVarTransformer>(mk);
  455. mgb_assert(!ptr->m_trans_func, "InterGraphVarTransformer on graph #%zu{%p} "
  456. "already registered", dest->id(), dest);
  457. ptr->m_graph_dest = dest;
  458. ptr->m_graph_src = src;
  459. ptr->m_trans_func = trans;
  460. }
  461. const InterGraphVarTransformer*
  462. InterGraphVarTransformer::get(const ComputingGraph &graph) {
  463. auto ret = graph.options().user_data.get_user_data<
  464. InterGraphVarTransformer>();
  465. if (!ret.second)
  466. return nullptr;
  467. mgb_assert(ret.second == 1);
  468. return ret.first[0];
  469. }
  470. VarNode* InterGraphVarTransformer::trans(VarNode *src) const {
  471. if (src->owner_graph() != m_graph_src) {
  472. auto strans = get(*m_graph_src);
  473. mgb_throw_if(!strans, GraphError,
  474. "no InterGraphVarTransformer registered for var %s, "
  475. "which belongs to graph #%zu{%p}",
  476. dump_var_info({src}).c_str(),
  477. src->owner_graph()->id(), src->owner_graph());
  478. src = strans->trans(src);
  479. }
  480. auto ret = m_trans_func(src);
  481. mgb_assert(ret && ret->owner_graph() == m_graph_dest);
  482. return ret;
  483. }
  484. /* =================== ExtraDependencyMerger =================== */
  485. ExtraDependencyMerger::ExtraDependencyMerger(SpecialOprStat* sopr_stat)
  486. : m_sopr_stat{sopr_stat}, m_opr_iter{[this](OperatorNodeBase* opr) {
  487. on_opr(opr);
  488. }} {}
  489. ExtraDependencyMerger::~ExtraDependencyMerger() = default;
  490. void ExtraDependencyMerger::on_opr(OperatorNodeBase* opr) {
  491. if (!m_owner_graph) {
  492. m_owner_graph = opr->owner_graph();
  493. }
  494. mgb_assert(m_owner_graph == opr->owner_graph(),
  495. "owner graph changes in ExtraDependencyMerger; opr: %s{%s}",
  496. opr->cname(), opr->dyn_typeinfo()->name);
  497. auto&& extra_deps = m_owner_graph->options().extra_vardeps;
  498. auto sopr_stat = m_sopr_stat;
  499. MGB_MARK_USED_VAR(sopr_stat);
  500. auto&& new_deps = m_new_deps;
  501. for (auto i : opr->output()) {
  502. auto&& iter = extra_deps.find(i);
  503. if (iter != extra_deps.end()) {
  504. new_deps.insert(new_deps.end(), iter->second.begin(),
  505. iter->second.end());
  506. }
  507. #if !MGB_BUILD_SLIM_SERVING && MGB_ENABLE_GRAD
  508. if (sopr_stat && opr->same_type<opr::VirtualGrad>()) {
  509. sopr_stat->has_virtual_grad = true;
  510. }
  511. #endif
  512. }
  513. }
  514. SymbolVarArray& ExtraDependencyMerger::add(const SymbolVarArray& vars) {
  515. m_result.reserve(m_result.size() + vars.size());
  516. for (auto&& i : vars) {
  517. m_result.push_back(i);
  518. m_opr_iter.add(i);
  519. }
  520. while (!m_new_deps.empty()) {
  521. auto opr = m_new_deps.back()->owner_opr();
  522. m_new_deps.pop_back();
  523. if (!m_opr_iter.visited(opr)) {
  524. m_opr_iter.add(opr);
  525. m_result.push_back(opr->output(0));
  526. }
  527. }
  528. return m_result;
  529. }
  530. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台