You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

var_node.cpp 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. /**
  2. * \file src/core/impl/graph/var_node.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/graph/var_node.h"
  12. #include "megbrain/graph/operator_node.h"
  13. #include "megbrain/graph/helper.h"
  14. #include "./cg_impl.h"
  15. using namespace mgb;
  16. using namespace cg;
  17. /* ===================== MemAllocPlan ===================== */
  18. std::mutex MemAllocPlan::ReadonlyFwdList::list_mutex;
  19. void MemAllocPlan::ReadonlyFwdList::reset() {
  20. MGB_LOCK_GUARD(list_mutex);
  21. m_prev = m_next = nullptr;
  22. }
  23. void MemAllocPlan::ReadonlyFwdList::insert_after(const MemAllocPlan& prev,
  24. MemAllocPlan* self) {
  25. MGB_LOCK_GUARD(list_mutex);
  26. mgb_assert(!m_prev && !m_next);
  27. auto next = prev.m_readonly_fwd_list.m_next;
  28. prev.m_readonly_fwd_list.m_next = self;
  29. m_prev = const_cast<MemAllocPlan*>(&prev);
  30. m_next = next;
  31. if (next) {
  32. next->m_readonly_fwd_list.m_prev = self;
  33. }
  34. }
  35. void MemAllocPlan::ReadonlyFwdList::remove_self() {
  36. MGB_LOCK_GUARD(list_mutex);
  37. if (m_prev) {
  38. if (m_next) {
  39. m_prev->m_readonly_fwd_list.m_next = m_next;
  40. m_next->m_readonly_fwd_list.m_prev = m_prev;
  41. } else {
  42. m_prev->m_readonly_fwd_list.m_next = nullptr;
  43. }
  44. m_prev = m_next = nullptr;
  45. }
  46. }
  47. MemAllocPlan::Chunk MemAllocPlan::sm_chunk_invalid_cond_exec_marker{nullptr};
  48. MemAllocPlan::MemAllocPlan(VarNode *owner_var):
  49. m_chunk_storage(owner_var)
  50. {
  51. }
  52. MemAllocPlan& MemAllocPlan::assign(const MemAllocPlan &src) {
  53. mgb_assert(src.valid());
  54. m_layout = src.m_layout;
  55. m_layout.dtype = dtype();
  56. m_offset_byte = src.m_offset_byte;
  57. m_chunk = src.m_chunk;
  58. ++ m_chunk->m_refcnt;
  59. return *this;
  60. }
  61. MemAllocPlan& MemAllocPlan::assign_for_forward(
  62. const MemAllocPlan &src, const SubTensorSpec &sub) {
  63. mgb_assert(valid() && src.valid() && m_layout.eq_shape(sub.layout()));
  64. ++ (m_chunk = src.m_chunk)->m_refcnt;
  65. m_layout = sub.layout();
  66. // make layout strong-contig
  67. for (int i = static_cast<int>(m_layout.ndim) - 1; i >= 0; -- i) {
  68. if (m_layout.shape[i] == 1) {
  69. m_layout.stride[i] = i + 1 < static_cast<int>(m_layout.ndim) ?
  70. m_layout.stride[i + 1] * m_layout.shape[i + 1] : 1;
  71. }
  72. }
  73. m_layout.dtype = dtype();
  74. m_offset_byte = src.m_offset_byte + sub.offset_byte();
  75. auto &&span = sub.layout().span();
  76. mgb_assert(m_offset_byte + span.high_byte <= m_chunk->size() &&
  77. static_cast<ptrdiff_t>(m_offset_byte) + span.low_byte >= 0);
  78. // Note: Multiple mem plans may be forwarded from the same mem plan. Here we
  79. // do not need to find the root mem plan. Instead, we just insert this node
  80. // to the linked list headed at the root node, obeying topological order,
  81. // but note that new nodes may be inserted into the middle of the list.
  82. m_readonly_fwd_list.insert_after(src, this);
  83. return *this;
  84. }
  85. MemAllocPlan& MemAllocPlan::reset_from_owner_var() {
  86. auto owner_var = m_chunk_storage.owner_var;
  87. m_layout.dtype = dtype();
  88. m_layout.format = format();
  89. m_layout.init_contiguous_stride(owner_var->shape());
  90. m_offset_byte = 0;
  91. m_chunk = &m_chunk_storage;
  92. auto chk = m_chunk;
  93. chk->m_refcnt = 1;
  94. chk->m_size = m_layout.span().dist_byte();
  95. chk->mem_alloc_status.set_invalid();
  96. mgb_assert(chk->m_refcnt.is_lock_free());
  97. // check size for not overflow
  98. mgb_assert(m_layout.total_nr_elems() <= m_layout.dtype.max_elements(),
  99. "var too large: %s", cg::dump_var_info({owner_var}).c_str());
  100. return *this;
  101. }
  102. MemAllocPlan& MemAllocPlan::release_chunk() {
  103. mgb_assert(valid());
  104. auto chk = m_chunk;
  105. bool need_consider = chk != &sm_chunk_invalid_cond_exec_marker;
  106. m_readonly_fwd_list.remove_self();
  107. if (need_consider && (!--chk->m_refcnt)) {
  108. auto&& dv = chk->owner_var->m_dev_tensor;
  109. mgb_assert(dv.storage().comp_node_valid());
  110. if (chk->size()) {
  111. mgb_assert(chk->mem_alloc_status.is_from_owner_var());
  112. chk->m_size = 0;
  113. }
  114. chk->mem_alloc_status.set_invalid();
  115. dv.storage({});
  116. }
  117. m_chunk = nullptr;
  118. return *this;
  119. }
  120. MemAllocPlan& MemAllocPlan::layout(const TensorLayout& dest,
  121. bool allow_shape_change) {
  122. mgb_assert(allow_shape_change || m_layout.eq_shape(dest),
  123. "disallowed shape change: %s vs %s",
  124. m_layout.TensorShape::to_string().c_str(),
  125. dest.TensorShape::to_string().c_str());
  126. m_layout = dest;
  127. m_layout.dtype = dtype();
  128. return *this;
  129. }
  130. #if MGB_ENABLE_JSON
  131. std::shared_ptr<json::Value> MemAllocPlan::to_json() const {
  132. auto cvt_layout = [](const TensorLayout &layout) {
  133. auto shape = json::Array::make(),
  134. stride = json::Array::make();
  135. for (size_t i = 0; i < layout.ndim; i ++) {
  136. shape->add(json::Number::make(layout.shape[i]));
  137. stride->add(json::Number::make(layout.stride[i]));
  138. }
  139. return json::Object::make({
  140. {"shape", shape},
  141. {"stride", stride},
  142. {"dtype", json::String::make(layout.dtype.name())}
  143. });
  144. };
  145. return json::Object::make({
  146. {"mem_chunk_id", json::String::make(m_chunk->id_str())},
  147. {"layout", cvt_layout(m_layout)},
  148. {"offset_byte", json::Number::make(m_offset_byte)}
  149. });
  150. }
  151. #endif
  152. std::string MemAllocPlan::Chunk::id_str() const {
  153. return "chk" + std::to_string(owner_var->id());
  154. }
  155. /* ===================== MemAllocPlan::Chunk ===================== */
  156. #if MGB_ENABLE_JSON
  157. std::shared_ptr<json::Value> MemAllocPlan::Chunk::to_json() const {
  158. std::shared_ptr<json::Value> dev_ptr;
  159. if (owner_var->dev_tensor_valid()) {
  160. dev_ptr = json::NumberInt::make(
  161. reinterpret_cast<size_t>(owner_var->dev_tensor().raw_ptr()));
  162. } else {
  163. dev_ptr = json::Null::make();
  164. }
  165. return json::Object::make({
  166. {"node_type", json::String::make("mem_chunk")},
  167. {"id", json::String::make(id_str())},
  168. {"size", json::Number::make(size())},
  169. {"owner_var", json::String::make(owner_var->id_str())},
  170. {"dev_ptr", dev_ptr}
  171. });
  172. }
  173. #endif
  174. /* ===================== VarNode ===================== */
  175. const std::string& VarNode::name() const {
  176. return m_name.valid() ? m_name.val() : owner_opr()->name();
  177. }
  178. VarNode& VarNode::name(std::string name) {
  179. m_name = std::move(name);
  180. m_has_name_set = true;
  181. return *this;
  182. }
  183. const DeviceTensorND& VarNode::dev_tensor() const {
  184. mgb_assert(dev_tensor_valid());
  185. return m_dev_tensor;
  186. }
  187. DeviceTensorND& VarNode::mutable_dev_tensor() {
  188. mgb_assert(dev_tensor_valid() && contain_flag(Flag::NO_SYS_MEM_ALLOC));
  189. return m_dev_tensor;
  190. }
  191. VarNode& VarNode::dtype(DType dtype) {
  192. mgb_assert(dtype.valid() && !m_dev_tensor.dtype().valid());
  193. m_dev_tensor.dtype(dtype);
  194. return *this;
  195. }
  196. VarNode& VarNode::format(TensorFormat format) {
  197. mgb_assert(format == m_dev_tensor.format() ||
  198. m_dev_tensor.format().is_default());
  199. m_dev_tensor.format(format);
  200. return *this;
  201. }
  202. bool VarNode::set_fwd_in2out_readonly(
  203. VarNode *input, const SubTensorSpec &sub) {
  204. if (owner_graph()->options().imperative_proxy_graph) {
  205. return false;
  206. }
  207. return ComputingGraphImpl::downcast(owner_graph())
  208. ->var_node_mem_manager().fwd_in2out_readonly(input, sub, this);
  209. }
  210. VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) {
  211. ComputingGraphImpl::downcast(owner_graph())
  212. ->var_node_mem_manager().fwd_in2out_writable(input, this);
  213. return *this;
  214. }
  215. VarNode& VarNode::set_fwd_in2out_writable_force(VarNode *input) {
  216. mgb_assert(!owner_graph()->options().imperative_proxy_graph);
  217. ComputingGraphImpl::downcast(owner_graph())
  218. ->var_node_mem_manager().fwd_in2out_writable_force(input, this);
  219. return *this;
  220. }
  221. VarNode& VarNode::add_layout_constraint(LayoutConstraintCallback callback) {
  222. ComputingGraphImpl::downcast(owner_graph())
  223. ->var_node_mem_manager().add_layout_constraint(
  224. this, std::move(callback));
  225. return *this;
  226. }
  227. VarNode& VarNode::add_layout_constraint_contiguous() {
  228. ComputingGraphImpl::downcast(owner_graph())
  229. ->var_node_mem_manager()
  230. .add_layout_constraint_level(
  231. this, VarNodeMemManager::LayoutConstraintLevel::CONTIG);
  232. return *this;
  233. }
  234. VarNode& VarNode::add_layout_constraint_monotone() {
  235. ComputingGraphImpl::downcast(owner_graph())
  236. ->var_node_mem_manager()
  237. .add_layout_constraint_level(
  238. this, VarNodeMemManager::LayoutConstraintLevel::MONOTONE);
  239. return *this;
  240. }
  241. VarNode& VarNode::shape(const TensorShape &shape) {
  242. if (!m_shape.eq_shape(shape)) {
  243. mgb_assert(m_allow_shape_change, "invalid var shape change: "
  244. "dest=%s var=%s", shape.to_string().c_str(),
  245. dump_var_info({this}).c_str());
  246. m_shape = shape;
  247. for (auto &&i: m_shape_update_callback)
  248. i.second(this);
  249. }
  250. #if MGB_ENABLE_DEBUG_UTIL
  251. static size_t log_limit = MGB_GETENV("MGB_LOG_VAR_SIZE_MB") ?
  252. std::stold(MGB_GETENV("MGB_LOG_VAR_SIZE_MB")) * (1024 * 1024) : 0;
  253. if (log_limit) {
  254. auto size = dtype().size(shape.total_nr_elems());
  255. static size_t max_size = 0;
  256. if (size >= log_limit) {
  257. bool updated = false;
  258. if (size > max_size) {
  259. max_size = size;
  260. updated = true;
  261. }
  262. mgb_log("var exceeds log limit: %s; size=%.3fMiB%s",
  263. cg::dump_var_info({this}).c_str(),
  264. size / (1024.0 * 1024),
  265. updated ? " (with maxsize updated)" : "");
  266. }
  267. }
  268. #endif
  269. return *this;
  270. }
  271. VarNode& VarNode::shape_alloc(const TensorShape &shape, size_t size_req) {
  272. mgb_assert(shape.ndim, "got empty shape in shape_alloc: "
  273. "var=%s owner_opr=%s{%s}", cname(), owner_opr()->cname(),
  274. owner_opr()->dyn_typeinfo()->name);
  275. mgb_assert(contain_flag(Flag::NO_SYS_MEM_ALLOC),
  276. "shape_alloc() could only be used for vars with"
  277. " NO_SYS_MEM_ALLOC flag; actual var: %s",
  278. cg::dump_var_info({this}).c_str());
  279. ComputingGraphImpl::downcast(owner_graph())
  280. ->var_node_mem_manager().var_alloc_with_shape(this, shape, size_req);
  281. return *this;
  282. }
  283. bool VarNode::reset_dev_tensor_from_other_var(VarNode* src_var) {
  284. mgb_assert(contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC));
  285. if (src_var->owner_graph() == owner_graph()) {
  286. // this is actually readonly forwarding in the same graph
  287. mgb_assert(
  288. src_var->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE) ||
  289. !is_static_var_storage(src_var),
  290. "dynamic storage on src is required for dynamic readonly "
  291. "forwarding: vars=%s",
  292. dump_var_info({src_var, this}).c_str());
  293. auto&& trait = ComputingGraphImpl::downcast(owner_graph())
  294. ->var_node_mem_manager()
  295. .get_var_node_mem_trait_at(src_var);
  296. if (trait.seq_force_update_dest ||
  297. !src_var->dev_tensor().layout().is_contiguous()) {
  298. shape_alloc(src_var->shape())
  299. .dev_tensor()
  300. .copy_from_fixlayout(src_var->dev_tensor());
  301. return false;
  302. }
  303. }
  304. shape(src_var->shape());
  305. m_mem_plan.assign(src_var->m_mem_plan);
  306. assign_dev_tensor_from_tensor(src_var->dev_tensor());
  307. return true;
  308. }
  309. VarNode& VarNode::reset_dev_tensor_from_tensor(const DeviceTensorND& value) {
  310. mgb_assert(contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC));
  311. mgb_assert(value.comp_node() == comp_node(),
  312. "attempt to reset var on %s from a value on %s",
  313. comp_node().to_string().c_str(),
  314. value.comp_node().to_string().c_str());
  315. shape(value.shape());
  316. auto&& chk = m_mem_plan.reset_from_owner_var().chunk();
  317. assign_dev_tensor_from_tensor(value);
  318. chk.mem_alloc_status.set_from_owner_var();
  319. return *this;
  320. }
  321. void VarNode::assign_dev_tensor_from_tensor(const DeviceTensorND& value) {
  322. mgb_assert(value.layout().is_contiguous() &&
  323. m_dev_tensor.dtype() == value.dtype() &&
  324. m_dev_tensor.format() == value.format());
  325. if (cg::is_static_var_shape(this)) {
  326. mgb_assert(shape().eq_shape(value.shape()),
  327. "shape mismatch for static inferrable var when setting dev "
  328. "tensor: var=%s new_shape=%s",
  329. cg::dump_var_info({this}).c_str(),
  330. value.shape().to_string().c_str());
  331. }
  332. m_dev_tensor.reset(value.storage(), value.layout());
  333. m_dev_tensor.comp_node(comp_node());
  334. m_prev_dev_ptr = value.raw_ptr();
  335. mgb_assert(dev_tensor_valid());
  336. }
  337. VarNode& VarNode::add_rt_force_dynamic_mem_alloc_imply_chain(VarNode *dest) {
  338. mgb_assert(dest && dest->owner_graph() == owner_graph() &&
  339. (!contain_flag(Flag::FLAG_FREEZED) ||
  340. !dest->contain_flag(Flag::FLAG_FREEZED)));
  341. m_rt_force_dynamic_mem_alloc_imply_chain.push_back(dest);
  342. return *this;
  343. }
  344. VarNode& VarNode::comp_node(const CompNode &cn) {
  345. mgb_assert(cn.valid() && (!m_comp_node.valid() ||
  346. m_comp_node.mem_node() == cn.mem_node()));
  347. m_comp_node = cn;
  348. if (m_cn_sync_manager) {
  349. m_cn_sync_manager->comp_node(cn);
  350. }
  351. return *this;
  352. }
  353. #if MGB_ENABLE_JSON
  354. std::shared_ptr<json::Value>
  355. VarNode::dump_static_infer_info_to_json() const {
  356. using namespace cg::static_infer;
  357. auto&& mgr = static_cast<cg::ComputingGraphImpl*>(
  358. owner_graph())->static_infer_manager_impl();
  359. auto get_dep_type = [](const DepType& type) -> std::string {
  360. switch (type) {
  361. #define cb(name) \
  362. case DepType::name: \
  363. return #name;
  364. cb(SHAPE)
  365. cb(VALUE)
  366. #undef cb
  367. default:
  368. mgb_throw(MegBrainError, "unknown dep type");
  369. }
  370. };
  371. auto get_infer_type = [](const InferType::Flag& type) {
  372. switch (type) {
  373. #define cb(name) \
  374. case InferType::Flag::name: \
  375. return json::String::make(#name);
  376. cb(NO_DESC)
  377. cb(CONST)
  378. cb(RT_STATIC)
  379. cb(MISSING_INP)
  380. #undef cb
  381. default:
  382. mgb_throw(MegBrainError, "unknown infer type");
  383. }
  384. };
  385. auto make_tag = [&](const DepType& type) {
  386. VarNode* self = const_cast<VarNode*>(this);
  387. auto c_deps = mgr.get_deps({self, type});
  388. auto deps = json::Array::make();
  389. for (auto&& i : c_deps) {
  390. mgb_assert(i.dest);
  391. deps->add(json::Object::make({
  392. {"var", json::String::make(i.dest->id_str())},
  393. {"dep_type", json::String::make(get_dep_type(i.type))}
  394. }));
  395. }
  396. auto infer_type_handle = mgr.get_infer_type(self);
  397. auto inferred_result = json::Null::make();
  398. auto infer_type = type == DepType::SHAPE ? infer_type_handle.shape
  399. : infer_type_handle.value;
  400. if (infer_type != InferType::Flag::NO_DESC) {
  401. if (type == DepType::SHAPE) {
  402. if (auto shape = mgr.infer_shape_fallible(self)) {
  403. auto inferred_shape = json::Array::make();
  404. for (size_t i = 0; i < shape->ndim; ++ i) {
  405. inferred_shape->add(json::Number::make((*shape)[i]));
  406. }
  407. inferred_result = inferred_shape;
  408. }
  409. } else {
  410. if (auto p = mgr.infer_value_fallible(self)) {
  411. auto&& dev = *p;
  412. if (dev.shape().ndim == 1 &&
  413. dev.shape(0) < TensorShape::MAX_NDIM &&
  414. mgb_likely(dev.comp_node() == CompNode::default_cpu())) {
  415. MGB_TRY {
  416. size_t nr_elems = dev.shape(0);
  417. auto&& dtype = dev.dtype();
  418. void* vptr = dev.raw_ptr();
  419. double data[nr_elems];
  420. HostTensorND contig;
  421. if (!dev.layout().is_contiguous()) {
  422. // both src and dst are placed on default cpu,
  423. // no need for sync
  424. contig.copy_from(dev);
  425. mgb_assert(contig.layout().is_contiguous());
  426. vptr = contig.raw_ptr();
  427. }
  428. static_cast_dtype(data, dtype, vptr, nr_elems);
  429. auto inferred_value = json::Array::make();
  430. for (size_t i = 0; i < nr_elems; ++ i) {
  431. inferred_value->add(json::Number::make(data[i]));
  432. }
  433. inferred_result = inferred_value;
  434. }
  435. MGB_CATCH(ConversionError&, {});
  436. } else {
  437. inferred_result = json::String::make("Large Array");
  438. }
  439. }
  440. }
  441. }
  442. return json::Object::make({
  443. {"node_type", json::String::make("static_infer_tag")},
  444. {"infer_type", get_infer_type(infer_type)},
  445. {"inferred_result", inferred_result},
  446. {"deps", deps}
  447. });
  448. };
  449. return json::Object::make({
  450. #define TAG(type) {get_dep_type(type), make_tag(type)}
  451. TAG(DepType::SHAPE), TAG(DepType::VALUE)
  452. #undef TAG
  453. });
  454. }
  455. std::shared_ptr<json::Value> VarNode::to_json() const {
  456. auto get_var = [](VarNode *p) -> std::shared_ptr<json::Value> {
  457. if(p)
  458. return json::String::make(p->id_str());
  459. return json::Null::make();
  460. };
  461. auto &&trait = ComputingGraphImpl::downcast(owner_graph()
  462. )->var_node_mem_manager().get_var_node_mem_trait(this);
  463. auto flag = json::Array::make();
  464. {
  465. // add flags
  466. size_t flag_checked = static_cast<size_t>(Flag::FLAG_FREEZED);
  467. #define CHK(v) \
  468. do { \
  469. if (contain_flag(Flag::v)) { \
  470. flag->add(json::String::make(#v)); \
  471. flag_checked |= static_cast<size_t>(Flag::v); \
  472. } \
  473. } while(0)
  474. CHK(NO_SYS_MEM_ALLOC);
  475. CHK(NO_ALLOC_IF_UNUSED);
  476. CHK(NO_SYS_STATIC_MEM_ALLOC);
  477. CHK(NO_MEM_RECLAIM);
  478. CHK(RT_FORCE_DYNAMIC_MEM_ALLOC);
  479. CHK(VOLATILE_CONTENT);
  480. CHK(ALLOW_EMPTY_SHAPE);
  481. CHK(PERSISTENT_DEVICE_VALUE);
  482. CHK(DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC);
  483. CHK(DISALLOW_VAR_SANITY_CHECK);
  484. #undef CHK
  485. mgb_assert(flag_checked == static_cast<size_t>(m_flag));
  486. }
  487. auto rst = json::Object::make({
  488. {"node_type", json::String::make("var")},
  489. {"id", json::String::make(id_str())},
  490. {"name", json::String::make(name())},
  491. {"mem_readonly_fwd_src", get_var(trait.readonly_src)},
  492. {"force_update_src", get_var(trait.force_update_src)},
  493. {"mem_plan", m_mem_plan.valid() ?
  494. m_mem_plan.to_json() : json::Null::make()},
  495. {"comp_node", json::String::make(comp_node().to_string())},
  496. {"dev_ptr", json::Null::make()},
  497. {"prev_dev_ptr", json::NumberInt::make(reinterpret_cast<size_t>(
  498. m_prev_dev_ptr))},
  499. {"flag", flag},
  500. {"static_infer_tags", dump_static_infer_info_to_json()}
  501. });
  502. if (m_prev_dev_ptr) {
  503. (*rst)["prev_dev_ptr_end"] = json::NumberInt::make(
  504. reinterpret_cast<size_t>(m_prev_dev_ptr) +
  505. m_mem_plan.layout().span().high_byte);
  506. }
  507. if (dev_tensor_valid()) {
  508. (*rst)["dev_ptr"] = json::NumberInt::make(reinterpret_cast<size_t>(
  509. m_dev_tensor.raw_ptr()));
  510. }
  511. return rst;
  512. }
  513. #endif
  514. MemAllocPlan& VarNode::init_mem_plan(const DeviceTensorND* fixed_alloc) {
  515. ComputingGraphImpl::downcast(owner_graph())
  516. ->var_node_mem_manager()
  517. .init_single_var_mem_plan(this, fixed_alloc);
  518. return m_mem_plan;
  519. }
  520. VarNode& VarNode::add_flag(Flag flag) {
  521. modify_flag(flag, m_flag | flag);
  522. return *this;
  523. }
  524. void VarNode::modify_flag(Flag delta, Flag new_flag) {
  525. if (contain_flag(Flag::FLAG_FREEZED)) {
  526. mgb_assert(
  527. (delta & (Flag::NO_MEM_RECLAIM | Flag::NO_SYS_STATIC_MEM_ALLOC |
  528. Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta ||
  529. (new_flag & Flag::MEMORY_NO_NEED));
  530. mgb_assert(!ComputingGraphImpl::downcast(owner_graph())->
  531. var_node_mem_manager().optimize_started(),
  532. "could not modify var flags after optimization started");
  533. }
  534. mgb_assert(!(new_flag & Flag::RT_FORCE_DYNAMIC_MEM_ALLOC) ||
  535. !(new_flag & Flag::NO_SYS_MEM_ALLOC),
  536. "RT_FORCE_DYNAMIC_MEM_ALLOC conflicts with NO_SYS_MEM_ALLOC");
  537. mgb_assert(!(new_flag & Flag::NO_ALLOC_IF_UNUSED) ||
  538. !(new_flag & Flag::NO_SYS_MEM_ALLOC),
  539. "NO_ALLOC_IF_UNUSED conflicts with NO_SYS_MEM_ALLOC");
  540. mgb_assert(!(new_flag & Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC) ||
  541. (new_flag & Flag::NO_MEM_RECLAIM),
  542. "DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC must be added after "
  543. "NO_MEM_RECLAIM");
  544. m_flag = new_flag;
  545. }
  546. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台