You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_impl.cpp 43 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139
  1. /**
  2. * \file imperative/src/impl/interpreter/interpreter_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./interpreter_impl.h"
  12. #include "megbrain/common.h"
  13. #include "megbrain/imperative/opr_utility.h"
  14. #include "megbrain/imperative/ops/autogen.h"
  15. #include "megbrain/imperative/ops/backward_graph.h"
  16. #include "megbrain/imperative/ops/opr_attr.h"
  17. #include "megbrain/imperative/utils/to_string.h"
  18. using namespace mgb;
  19. using namespace imperative;
  20. using namespace interpreter;
  21. using namespace interpreter::intl;
  22. std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
  23. return std::make_unique<ChannelImpl>();
  24. }
  25. Interpreter& Interpreter::inst() {
  26. static InterpreterImpl inst_;
  27. return inst_;
  28. }
  29. Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
  30. mgb_assert(check_available(), "Channel already closed");
  31. auto info = alloc();
  32. info->desc.layout = value.layout();
  33. info->desc.comp_node = value.comp_node();
  34. info->desc.value = value.proxy_to_default_cpu();
  35. info->h_value = value;
  36. m_buffer.enqueue(Put{info, value, no_cache});
  37. if (m_async_level == 0) {
  38. sync();
  39. info->desc.comp_node.sync();
  40. }
  41. return info;
  42. }
  43. Handle ChannelImpl::put(const DeviceTensorND& data) {
  44. mgb_assert(check_available(), "Channel already closed");
  45. auto info = alloc();
  46. info->desc.layout = data.layout();
  47. info->desc.comp_node = data.comp_node();
  48. info->ptr = Tensor::make(data);
  49. if (m_channel_state.profiler->is_profiling()) {
  50. m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node);
  51. }
  52. return info;
  53. }
  54. void ChannelImpl::del(Handle handle) {
  55. if (!check_available()){
  56. return;
  57. }
  58. mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
  59. auto* info = reinterpret_cast<TensorInfo*>(handle);
  60. m_valid_handle.erase(handle);
  61. m_buffer.enqueue(Del{info});
  62. }
  63. void ChannelImpl::swap_in(Handle handle) {
  64. mgb_assert(check_available(), "Channel already closed");
  65. if (m_worker_state.options.enable_swap) {
  66. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  67. "invalid handle: %p", handle);
  68. auto* info = reinterpret_cast<TensorInfo*>(handle);
  69. m_buffer.enqueue(SwapIn{info});
  70. }
  71. }
  72. void ChannelImpl::swap_out(Handle handle) {
  73. mgb_assert(check_available(), "Channel already closed");
  74. if (m_worker_state.options.enable_swap) {
  75. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  76. "invalid handle: %p", handle);
  77. auto* info = reinterpret_cast<TensorInfo*>(handle);
  78. m_buffer.enqueue(SwapOut{info});
  79. }
  80. }
  81. void ChannelImpl::drop(Handle handle) {
  82. mgb_assert(check_available(), "Channel already closed");
  83. if (m_worker_state.options.enable_drop) {
  84. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  85. "invalid handle: %p", handle);
  86. auto* info = reinterpret_cast<TensorInfo*>(handle);
  87. m_buffer.enqueue(Drop{info});
  88. }
  89. }
  90. void ChannelImpl::dispatch_default_cpu(
  91. std::shared_ptr<OpDef> op,
  92. const SmallVector<TensorInfo*>& input_infos,
  93. const SmallVector<LogicalTensorDesc>& input_descs,
  94. SmallVector<Handle>* outputs) {
  95. auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
  96. MGB_MARK_USED_VAR(validated);
  97. SmallVector<DeviceTensorND> input_tensornds;
  98. input_tensornds.reserve(input_descs.size());
  99. CompNode output_cn;
  100. {
  101. MGB_LOCK_GUARD(m_mutex);
  102. for (auto&& info : input_infos) {
  103. auto input_cn = info->desc.comp_node;
  104. if (!output_cn.valid()) {
  105. output_cn = input_cn;
  106. } else {
  107. mgb_assert(output_cn == input_cn, "cannot decide output comp node");
  108. }
  109. if (info->ptr && info->ptr->try_get_value()) {
  110. input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
  111. } else {
  112. mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
  113. input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
  114. }
  115. }
  116. }
  117. outputs->reserve(output_descs.size());
  118. SmallVector<DeviceTensorND> output_tensornds;
  119. output_tensornds.reserve(output_descs.size());
  120. for (auto&& desc : output_descs) {
  121. // TODO: may conflict with condtake, which need alloc inside
  122. mgb_assert(!desc.layout.is_empty());
  123. // use HostTensorND alloc_host for cuda pinned memory
  124. output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
  125. }
  126. auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
  127. SmallVector<uint64_t> tid;
  128. for (auto* ptinfo: tinfo) {
  129. tid.push_back(ptinfo->id);
  130. }
  131. return tid;
  132. };
  133. OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}};
  134. if (m_channel_state.profiler->is_profiling()) {
  135. m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data);
  136. }
  137. OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
  138. SmallVector<TensorInfo*> output_infos;
  139. output_infos.reserve(output_descs.size());
  140. for (auto&& tensornd : output_tensornds) {
  141. HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd)
  142. .proxy_to_comp_node(output_cn);
  143. // use `put` for consistency
  144. auto info = reinterpret_cast<TensorInfo*>(put(host_tensornd, false));
  145. mgb_assert(info->desc.layout.ndim != 0);
  146. output_infos.push_back(info);
  147. outputs->push_back(info);
  148. }
  149. event_data.outputs = tinfo_to_tid(output_infos);
  150. if (m_channel_state.profiler->is_profiling()) {
  151. m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data);
  152. }
  153. }
  154. void ChannelImpl::dispatch_kernel(
  155. std::shared_ptr<OpDef> op,
  156. const SmallVector<TensorInfo*>& input_infos,
  157. const SmallVector<LogicalTensorDesc>& input_descs,
  158. SmallVector<Handle>* outputs) {
  159. auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
  160. ApplyOp cmd{std::move(op)};
  161. cmd.inputs = std::move(input_infos);
  162. cmd.outputs.reserve(output_descs.size());
  163. outputs->reserve(output_descs.size());
  164. for (auto&& desc : output_descs) {
  165. auto info = alloc();
  166. info->desc = desc;
  167. // make sure desc's value is consistent with h_value
  168. if (!info->desc.value.empty()) {
  169. info->h_value = HostTensorND::make_proxy(desc.value)
  170. .proxy_to_comp_node(desc.comp_node);
  171. }
  172. cmd.outputs.push_back(info);
  173. outputs->push_back(info);
  174. }
  175. m_buffer.enqueue(std::move(cmd));
  176. if (!validated && m_channel_state.options.async_level == 1) {
  177. sync();
  178. } else if (m_channel_state.options.async_level == 0) {
  179. sync();
  180. // check device error
  181. for (auto&& oup : *outputs) {
  182. auto info = reinterpret_cast<TensorInfo*>(oup);
  183. info->ptr->comp_node().sync();
  184. }
  185. }
  186. }
  187. SmallVector<Handle> ChannelImpl::apply_op(
  188. std::shared_ptr<OpDef> op,
  189. const SmallVector<Handle>& inputs) {
  190. mgb_assert(check_available(), "Channel already closed");
  191. for (auto i : inputs) {
  192. mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
  193. "invalid handle: %p", i);
  194. }
  195. SmallVector<TensorInfo*> input_infos;
  196. input_infos.reserve(inputs.size());
  197. SmallVector<LogicalTensorDesc> input_descs;
  198. input_descs.reserve(inputs.size());
  199. {
  200. MGB_LOCK_GUARD(m_mutex);
  201. for (auto i : inputs) {
  202. auto info = reinterpret_cast<TensorInfo*>(i);
  203. mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
  204. input_infos.push_back(info);
  205. input_descs.push_back(info->desc);
  206. }
  207. }
  208. SmallVector<Handle> outputs;
  209. DispatchMode dispatch_mode = m_channel_state.options.enable_host_compute
  210. ? OpDef::decide_dispatch_mode(*op, input_descs)
  211. : DispatchMode::KERNEL;
  212. switch (dispatch_mode) {
  213. case DEFAULT_CPU: {
  214. dispatch_default_cpu(op, input_infos, input_descs, &outputs);
  215. break;
  216. }
  217. case KERNEL: {
  218. dispatch_kernel(op, input_infos, input_descs, &outputs);
  219. break;
  220. }
  221. }
  222. return outputs;
  223. }
  224. HostTensorND ChannelImpl::get_value(Handle handle) {
  225. mgb_assert(check_available(), "Channel already closed");
  226. // TODO: maybe get_value should be done on host. i.e. delete GetValue
  227. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  228. "invalid handle: %p", handle);
  229. auto info = reinterpret_cast<TensorInfo*>(handle);
  230. mgb_assert(!m_waitee);
  231. // donnot use info->value_fetched, it's unsafe
  232. mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
  233. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  234. TensorPtr tensor_ptr = info->ptr;
  235. auto value_fetched = [&]() {
  236. return tensor_ptr && tensor_ptr->value_fetched();
  237. };
  238. if (!value_fetched()) {
  239. m_waitee = info;
  240. m_buffer.enqueue(GetValue{info});
  241. if (m_channel_state.profiler->is_profiling()) {
  242. m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue);
  243. }
  244. m_cv.wait(lock, [&]() {
  245. check_worker_exc_unsafe();
  246. tensor_ptr = info->ptr;
  247. return value_fetched();
  248. });
  249. if (m_channel_state.profiler->is_profiling()) {
  250. m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue);
  251. }
  252. m_waitee = nullptr;
  253. }
  254. return tensor_ptr->get_value();
  255. }
  256. TensorShape ChannelImpl::get_shape(Handle handle) {
  257. mgb_assert(check_available(), "Channel already closed");
  258. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  259. "invalid handle: %p", handle);
  260. auto info = reinterpret_cast<TensorInfo*>(handle);
  261. if (info->desc.layout.ndim != 0) {
  262. return info->desc.layout;
  263. }
  264. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  265. mgb_assert(!m_waitee);
  266. m_waitee = info;
  267. m_buffer.flush();
  268. if (m_channel_state.profiler->is_profiling()) {
  269. m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape);
  270. }
  271. m_cv.wait(lock, [&]() {
  272. check_worker_exc_unsafe();
  273. return static_cast<bool>(info->ptr);
  274. });
  275. if (m_channel_state.profiler->is_profiling()) {
  276. m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape);
  277. }
  278. m_waitee = nullptr;
  279. TensorShape ret = info->ptr->layout();
  280. mgb_assert(ret.ndim != 0);
  281. return ret;
  282. }
  283. DType ChannelImpl::get_dtype(Handle handle) {
  284. mgb_assert(check_available(), "Channel already closed");
  285. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  286. "invalid handle: %p", handle);
  287. auto info = reinterpret_cast<TensorInfo*>(handle);
  288. if (m_channel_state.profiler->is_profiling()) {
  289. m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType);
  290. }
  291. auto ret = info->desc.layout.dtype;
  292. mgb_assert(ret.valid());
  293. return ret;
  294. }
  295. CompNode ChannelImpl::get_device(Handle handle) {
  296. mgb_assert(check_available(), "Channel already closed");
  297. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  298. "invalid handle: %p", handle);
  299. auto info = reinterpret_cast<TensorInfo*>(handle);
  300. if (m_channel_state.profiler->is_profiling()) {
  301. m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device);
  302. }
  303. auto ret = info->desc.comp_node;
  304. mgb_assert(ret.valid());
  305. return ret;
  306. }
  307. DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
  308. mgb_assert(check_available(), "Channel already closed");
  309. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  310. "invalid handle: %p", handle);
  311. auto info = reinterpret_cast<TensorInfo*>(handle);
  312. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  313. mgb_assert(!m_waitee);
  314. m_waitee = info;
  315. m_buffer.flush();
  316. if (m_channel_state.profiler->is_profiling()) {
  317. m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue);
  318. }
  319. m_cv.wait(lock, [&]() {
  320. check_worker_exc_unsafe();
  321. return static_cast<bool>(info->ptr);
  322. });
  323. if (m_channel_state.profiler->is_profiling()) {
  324. m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue);
  325. }
  326. m_waitee = nullptr;
  327. return info->ptr->dev_tensor();
  328. }
  329. void ChannelImpl::sync() {
  330. mgb_assert(check_available(), "Channel already closed");
  331. m_buffer.flush();
  332. if (m_channel_state.profiler->is_profiling()) {
  333. m_channel_state.profiler->record_host<SyncStartEvent>();
  334. }
  335. m_worker.wait_all_task_finish();
  336. CompNode::sync_all();
  337. if (m_channel_state.profiler->is_profiling()) {
  338. m_channel_state.profiler->record_host<SyncFinishEvent>();
  339. }
  340. MGB_LOCK_GUARD(m_mutex);
  341. check_worker_exc_unsafe();
  342. }
  343. void ChannelImpl::close() {
  344. if (!check_available()) {
  345. return;
  346. }
  347. std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
  348. for (auto* handle: valid_handles) {
  349. del(handle);
  350. }
  351. mgb_assert(m_valid_handle.empty());
  352. mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
  353. sync();
  354. m_closed = true;
  355. }
  356. size_t ChannelImpl::get_option(std::string name) {
  357. mgb_assert(check_available(), "Channel already closed");
  358. return m_channel_state.options.get_option(name);
  359. }
  360. void ChannelImpl::set_option(std::string name, size_t value) {
  361. mgb_assert(check_available(), "Channel already closed");
  362. m_channel_state.options.set_option(name, value);
  363. m_buffer.enqueue(SetOption{name, value});
  364. }
  365. TensorInfo* ChannelImpl::alloc() {
  366. MGB_LOCK_GUARD(m_mutex);
  367. auto info = m_pool.alloc();
  368. m_valid_handle.insert(info);
  369. info->id = m_last_id++;
  370. if (m_channel_state.profiler->is_profiling()) {
  371. m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id);
  372. }
  373. return info;
  374. }
  375. void ChannelImpl::do_drop(TensorInfo* ptr, bool user=false) {
  376. if (!ptr->producer) {
  377. if (user) {
  378. mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", ptr);
  379. }
  380. return;
  381. }
  382. if (ptr->evict_type != EvictType::NONE) {
  383. return;
  384. }
  385. ptr->evict_type = EvictType::DROP;
  386. release_tensor(ptr);
  387. }
  388. void ChannelImpl::free(TensorInfo* ptr) {
  389. if (m_worker_state.options.enable_dtr_auto_drop) {
  390. // Evicting a tensor, rather than freeing it, can avoid pinning
  391. // potentially exploding amounts of memory and allow us to save
  392. // more memory.
  393. ptr->allow_delete = true;
  394. if (!ptr->ref_cnt) {
  395. recursive_free(ptr);
  396. } else {
  397. do_drop(ptr);
  398. }
  399. } else {
  400. real_free(ptr);
  401. }
  402. }
  403. void ChannelImpl::recursive_free(TensorInfo* ptr) {
  404. SmallVector<TensorInfo*> inps(0);
  405. if (ptr->producer) {
  406. for (auto i : ptr->producer->inputs) {
  407. if (i && --i->ref_cnt == 0) {
  408. inps.push_back(i);
  409. }
  410. }
  411. }
  412. real_free(ptr);
  413. for (auto i : inps) {
  414. if (i->allow_delete) {
  415. recursive_free(i);
  416. }
  417. }
  418. }
  419. void ChannelImpl::real_free(TensorInfo* ptr) {
  420. MGB_LOCK_GUARD(m_mutex);
  421. if (m_channel_state.profiler->is_profiling()) {
  422. m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id);
  423. }
  424. if (ptr->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) {
  425. m_dtr.erase_candidate(ptr);
  426. }
  427. detach_users(ptr);
  428. ptr->detach_producer();
  429. m_pool.free(ptr);
  430. }
  431. ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){}
  432. ChannelImpl::~ChannelImpl() {
  433. close();
  434. }
  435. void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) {
  436. auto lock = notice ? std::unique_lock<std::mutex>(m_mutex)
  437. : std::unique_lock<std::mutex>();
  438. m_dtr.update_used_time(dest);
  439. if (notice && m_worker_state.profiler->is_profiling()) {
  440. m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node());
  441. }
  442. dest->value_fetched = ptr->value_fetched();
  443. // update tensor desc for static infer
  444. dest->desc.layout = ptr->layout();
  445. dest->desc.comp_node = ptr->comp_node();
  446. dest->memory = ptr->blob()->size();
  447. dest->ptr = std::move(ptr);
  448. dest->evict_type = EvictType::NONE;
  449. if (notice && dest->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) {
  450. m_dtr.insert_candidate(dest);
  451. }
  452. if (notice && m_waitee == dest) {
  453. m_cv.notify_all();
  454. }
  455. }
  456. void ChannelImpl::release_tensor(TensorInfo* dest) {
  457. MGB_LOCK_GUARD(m_mutex);
  458. dest->ptr.reset();
  459. }
  460. void ChannelImpl::regenerate(TensorInfo* dest) {
  461. if (dest->evict_type == EvictType::DROP) {
  462. recompute(dest->producer);
  463. } else if (dest->evict_type == EvictType::SWAP) {
  464. produce_tensor(dest, Tensor::make(dest->h_value));
  465. }
  466. }
  467. void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
  468. SmallVector<TensorPtr> inputs;
  469. inputs.reserve(path->inputs.size());
  470. m_dtr.pin(path->inputs);
  471. for (auto i : path->inputs) {
  472. if (!i->ptr) {
  473. regenerate(i);
  474. }
  475. inputs.push_back(i->ptr);
  476. m_dtr.update_used_time(i);
  477. }
  478. if (m_worker_state.options.enable_dtr_auto_drop && m_worker_state.options.dtr_eviction_threshold > 0) {
  479. auto_evict();
  480. }
  481. auto outputs = OpDef::apply_on_physical_tensor(*path->op, inputs);
  482. m_dtr.estimate_timestamp += path->compute_time / 1e8;
  483. m_dtr.unpin(path->inputs);
  484. for (size_t i = 0;i < outputs.size();i ++) {
  485. auto&& o = path->outputs[i];
  486. if (o) {
  487. o->recompute_times ++;
  488. if (!o->ptr) {
  489. produce_tensor(o, std::move(outputs[i]), false);
  490. if (m_worker_state.options.enable_dtr_auto_drop) {
  491. m_dtr.update_dsu_after_recompute(o);
  492. }
  493. }
  494. }
  495. }
  496. }
  497. void ChannelImpl::auto_evict() {
  498. if (!m_dtr.comp_node.valid()) {
  499. return;
  500. }
  501. size_t current_memory = m_dtr.comp_node.get_used_memory();
  502. while (current_memory > m_worker_state.options.dtr_eviction_threshold) {
  503. auto best = m_dtr.find_best_tensor();
  504. if (!best) {
  505. if (!m_dtr.warn_printed) {
  506. m_dtr.warn_printed = true;
  507. mgb_log_warn("No tensors on %s can be evicted automatically "
  508. "when memory usage is %.0lfMB. Maybe memory "
  509. "budget is too small.",
  510. m_dtr.comp_node.to_string().c_str(),
  511. current_memory / 1024.0 / 1024.0);
  512. }
  513. break;
  514. }
  515. if (best->ptr.unique() && best->ptr->blob().unique()) {
  516. current_memory -= best->memory;
  517. }
  518. do_drop(best);
  519. if (best->evict_type == EvictType::DROP) {
  520. m_dtr.update_dsu_after_evict(best);
  521. }
  522. }
  523. }
  524. void ChannelImpl::detach_users(TensorInfo* dest) {
  525. SmallVector<TensorInfo::ComputePath*> users = dest->users;
  526. for (auto* user: users) {
  527. SmallVector<TensorInfo*> outputs = user->outputs;
  528. SmallVector<TensorInfo*> inputs = user->inputs;
  529. for (auto* output: outputs) {
  530. if (output == nullptr) {
  531. continue;
  532. }
  533. regenerate(output);
  534. output->detach_producer();
  535. for (auto* input: inputs) {
  536. input->ref_cnt --;
  537. }
  538. }
  539. }
  540. mgb_assert(dest->users.size() == 0);
  541. //dest->users.clear();
  542. }
  543. bool ChannelImpl::check_available() {
  544. return !m_closed;
  545. }
  546. void ChannelImpl::sync_device_scope(CompNode device) {
  547. auto& prev = m_worker_state.device_scope_map[device];
  548. auto& current = m_worker_state.scopes;
  549. auto push_scope = [&](std::string name) {
  550. m_worker_state.profiler->record_device<DeviceBeginScope>(device, name);
  551. };
  552. auto pop_scope = [&](std::string name) {
  553. m_worker_state.profiler->record_device<DeviceEndScope>(device, name);
  554. };
  555. size_t similarity = 0;
  556. for (size_t i = 0; i < prev.size() && i < current.size(); i++) {
  557. if (prev[i] == current[i]) {
  558. similarity++;
  559. } else {
  560. break;
  561. }
  562. }
  563. while (prev.size() > similarity) {
  564. pop_scope(prev.back());
  565. prev.pop_back();
  566. }
  567. while (prev.size() < current.size()) {
  568. prev.push_back(current[prev.size()]);
  569. push_scope(prev.back());
  570. }
  571. }
  572. void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
  573. if (m_worker_state.profiler->is_profiling()) {
  574. m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd);
  575. }
  576. bool finished = false;
  577. auto do_finish_command = [&]{
  578. if (finished) {
  579. return;
  580. }
  581. if (m_worker_state.profiler->is_profiling()) {
  582. m_worker_state.profiler->record_host<CommandFinishEvent>(icmd);
  583. }
  584. finished = true;
  585. };
  586. //TODO: remove std::visit for support osx 10.12
  587. auto cmd_visitor = [&](const auto& cmd) {
  588. using T = std::decay_t<decltype(cmd)>;
  589. if constexpr (std::is_same_v<T, Put>) {
  590. auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value);
  591. produce_tensor(cmd.dest, std::move(value));
  592. } else if constexpr (std::is_same_v<T, ApplyOp>) {
  593. uint64_t apply_id = ++m_last_id;
  594. SmallVector<TensorPtr> tensor_inputs;
  595. SmallVector<CompNode> devices;
  596. if (m_worker_state.options.enable_dtr_auto_drop) {
  597. m_dtr.pin(cmd.inputs);
  598. }
  599. for (auto i : cmd.inputs) {
  600. if (!i->ptr && i->evict_type != EvictType::NONE) {
  601. regenerate(i);
  602. }
  603. m_dtr.update_used_time(i);
  604. }
  605. tensor_inputs.reserve(cmd.inputs.size());
  606. // refcnt == 1, owners: [TensorInfo::ptr]
  607. for (auto i : cmd.inputs) {
  608. mgb_assert(i->ptr, "Invalid input tensor ptr!");
  609. // refcnt ++, owners: [i->ptr, tensor_inputs]
  610. tensor_inputs.push_back(i->ptr);
  611. }
  612. // Begin profiling operator
  613. OpEvent event_data;
  614. if (m_worker_state.profiler->is_profiling()) {
  615. auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
  616. SmallVector<uint64_t> tid;
  617. for (auto* ptinfo: tinfo) {
  618. tid.push_back(ptinfo->id);
  619. }
  620. return tid;
  621. };
  622. event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)};
  623. // Collecting devices
  624. for (auto i : cmd.inputs) {
  625. devices.push_back(i->desc.comp_node);
  626. }
  627. for (auto i : cmd.outputs) {
  628. devices.push_back(i->desc.comp_node);
  629. }
  630. devices.erase(std::unique(devices.begin(), devices.end()), devices.end());
  631. }
  632. // Fused by command buffer. @see: CommandBuffer::fuse_del
  633. // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
  634. // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
  635. for (auto* del : cmd.dels) {
  636. // refcnt --, owners: [tensor_inputs]
  637. // if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor
  638. free(del);
  639. }
  640. // Before wait
  641. //TODO: split operator wait and execute so that OpWait could be corrected recorded.
  642. // Before execute
  643. if (m_worker_state.profiler->is_profiling()) {
  644. m_worker_state.profiler->record_host<HostOpExecuteEvent>(event_data);
  645. for (auto&& device: devices) {
  646. sync_device_scope(device);
  647. m_worker_state.profiler->record_device<DeviceOpExecuteEvent>(device, event_data);
  648. }
  649. }
  650. if (m_worker_state.options.enable_dtr_auto_drop && m_worker_state.options.dtr_eviction_threshold > 0) {
  651. auto_evict();
  652. }
  653. // Apply op
  654. // Here std::move is REQUIRED for removing duplicated references.
  655. auto tensor_outputs = OpDef::apply_on_physical_tensor(
  656. *cmd.op, std::move(tensor_inputs));
  657. // After execute
  658. if (m_worker_state.profiler->is_profiling()) {
  659. m_worker_state.profiler->record_host<HostOpFinishEvent>(event_data);
  660. for (auto&& device: devices) {
  661. m_worker_state.profiler->record_device<DeviceOpFinishEvent>(device, event_data);
  662. }
  663. }
  664. // End profiling operator
  665. double estimate_compute_time = 0;
  666. if (m_worker_state.options.enable_dtr_auto_drop) {
  667. for (auto i : cmd.inputs) {
  668. estimate_compute_time += i->memory;
  669. }
  670. for (auto i : tensor_outputs) {
  671. estimate_compute_time += i->blob()->size();
  672. }
  673. m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
  674. for (auto i : cmd.outputs) {
  675. i->compute_time = estimate_compute_time;
  676. m_dtr.update_used_time(i);
  677. }
  678. if (cmd.outputs[0]->producer) {
  679. cmd.outputs[0]->producer->compute_time = estimate_compute_time;
  680. }
  681. m_dtr.unpin(cmd.inputs);
  682. }
  683. mgb_assert(tensor_outputs.size() == cmd.outputs.size());
  684. for (size_t i = 0; i < tensor_outputs.size(); ++i) {
  685. if (cmd.outputs[i] == nullptr) {
  686. continue;
  687. }
  688. produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i]));
  689. if (m_worker_state.options.enable_dtr_auto_drop) {
  690. cmd.outputs[i]->dsu_ptr = std::make_shared<DsuNode>(estimate_compute_time);
  691. }
  692. }
  693. if (m_worker_state.options.enable_drop == 1
  694. && m_worker_state.options.record_computing_path == 1){
  695. bool is_inplace = false;
  696. bool cross_cn = false;
  697. for (auto input : cmd.inputs) {
  698. for (auto output : cmd.outputs) {
  699. if (input->ptr->blob()->storage() == output->ptr->blob()->storage()) {
  700. is_inplace = true;
  701. break;
  702. }
  703. }
  704. }
  705. for (auto input : cmd.inputs) {
  706. if (input->ptr->comp_node() != m_dtr.comp_node) {
  707. cross_cn = true;
  708. break;
  709. }
  710. }
  711. for (auto output : cmd.outputs) {
  712. if (output->ptr->comp_node() != m_dtr.comp_node) {
  713. cross_cn = true;
  714. break;
  715. }
  716. }
  717. // FIXME: do not use opname as identifier
  718. auto get_name = [](const OpDef& opdef) {
  719. if (auto attr = opdef.try_cast_final<OprAttr>()) {
  720. return attr->type.c_str();
  721. }
  722. return opdef.dyn_typeinfo()->name;
  723. };
  724. if (!is_inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
  725. TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs);
  726. size_t detach_cnt = 0;
  727. for (auto output : cmd.outputs) {
  728. if (!output->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) {
  729. output->detach_producer();
  730. detach_cnt ++;
  731. }
  732. }
  733. for (auto input : cmd.inputs) {
  734. input->ref_cnt -= detach_cnt;
  735. }
  736. }
  737. }
  738. } else if constexpr (std::is_same_v<T, Del>) {
  739. free(cmd.dest);
  740. } else if constexpr (std::is_same_v<T, GetValue>) {
  741. if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
  742. regenerate(cmd.dest);
  743. }
  744. mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
  745. cmd.dest->ptr->fetch_value();
  746. MGB_LOCK_GUARD(m_mutex);
  747. cmd.dest->value_fetched = true;
  748. if (m_waitee == cmd.dest) {
  749. m_cv.notify_all();
  750. }
  751. } else if constexpr (std::is_same_v<T, SwapIn>) {
  752. produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
  753. } else if constexpr (std::is_same_v<T, SwapOut>) {
  754. cmd.dest->h_value = cmd.dest->ptr->get_value();
  755. if (cmd.dest->evict_type == EvictType::NONE) {
  756. release_tensor(cmd.dest);
  757. cmd.dest->evict_type = EvictType::SWAP;
  758. }
  759. } else if constexpr (std::is_same_v<T, Drop>) {
  760. do_drop(cmd.dest, true);
  761. } else if constexpr (std::is_same_v<T, SetOption>) {
  762. m_worker_state.options.set_option(cmd.key, cmd.value);
  763. } else if constexpr (std::is_same_v<T, StartProfile>) {
  764. CompNode::sync_all();
  765. m_worker_state.profiler.reset(cmd.profiler);
  766. } else if constexpr (std::is_same_v<T, StopProfile>) {
  767. for (auto&& [device, scopes]: m_worker_state.device_scope_map) {
  768. MGB_MARK_USED_VAR(scopes);
  769. sync_device_scope(device);
  770. }
  771. do_finish_command();
  772. auto profiler = std::make_unique<InterpreterProfiler>();
  773. std::swap(profiler, m_worker_state.profiler);
  774. auto records = profiler->stop();
  775. auto host_map = [this](std::thread::id tid) {
  776. if (tid == m_worker_state.tid) {
  777. return "worker";
  778. } else {
  779. return "unknown";
  780. }
  781. };
  782. InterpreterProfiler::dump_data(cmd.basename, cmd.format, records, profiler->get_option(), host_map);
  783. } else if constexpr (std::is_same_v<T, PushScope>) {
  784. m_worker_state.scopes.push_back(cmd.scope_name);
  785. do_finish_command();
  786. m_worker_state.profiler->record_host<WorkerBeginScope>(cmd.scope_name);
  787. } else if constexpr (std::is_same_v<T, PopScope>) {
  788. mgb_assert(m_worker_state.scopes.back() == cmd.scope_name, "scope name mismatch");
  789. m_worker_state.scopes.pop_back();
  790. do_finish_command();
  791. m_worker_state.profiler->record_host<WorkerEndScope>(cmd.scope_name);
  792. } else {
  793. static_assert(!std::is_same_v<T, T>);
  794. }
  795. };
  796. std::visit([&](const auto& cmd){
  797. using T = std::decay_t<decltype(cmd)>;
  798. if (!m_worker_state.options.catch_worker_execption) {
  799. cmd_visitor(cmd);
  800. return;
  801. }
  802. try {
  803. cmd_visitor(cmd);
  804. } catch (...) {
  805. MGB_LOCK_GUARD(m_mutex);
  806. if constexpr (std::is_same_v<T, ApplyOp>) {
  807. for (auto oup : cmd.outputs) {
  808. oup->invalid = true;
  809. }
  810. } else if constexpr (std::is_same_v<T, Put>) {
  811. cmd.dest->invalid = true;
  812. }
  813. m_worker_exc = std::current_exception();
  814. m_cv.notify_all();
  815. }
  816. }, icmd.second);
  817. do_finish_command();
  818. }
  819. void ChannelImpl::check_worker_exc_unsafe() {
  820. if (m_worker_exc) {
  821. // for reuse interpreter_for_py after some exception tests
  822. m_waitee = nullptr;
  823. std::exception_ptr exc;
  824. std::swap(exc, m_worker_exc);
  825. std::rethrow_exception(exc);
  826. }
  827. }
  828. void ChannelImpl::CommandBuffer::enqueue(Command cmd) {
  829. if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
  830. return;
  831. }
  832. // mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
  833. m_commands.push_back(std::move(cmd));
  834. auto flush_pos = flush_pos_for(m_commands.back());
  835. flush(flush_pos);
  836. }
  837. void ChannelImpl::CommandBuffer::flush() {
  838. flush(m_commands.end());
  839. }
  840. void ChannelImpl::CommandBuffer::flush(Handle pos) {
  841. for (auto iter = m_commands.begin(); iter != pos; ++iter) {
  842. // mgb_log_debug("%s Flushed", to_string(*iter).c_str());
  843. IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)};
  844. if (m_owner->m_channel_state.profiler->is_profiling()) {
  845. m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd);
  846. }
  847. m_owner->m_worker.add_task(std::move(icmd));
  848. }
  849. m_commands.erase(m_commands.begin(), pos);
  850. }
  851. auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
  852. return std::visit([this](const auto& cmd) {
  853. using T = std::decay_t<decltype(cmd)>;
  854. if constexpr (std::is_same_v<T, ApplyOp>) {
  855. auto* op_type = cmd.op->dyn_typeinfo();
  856. if (op_type == RemoteRecv::typeinfo() ||
  857. op_type == RemoteSend::typeinfo() ||
  858. op_type == CollectiveComm::typeinfo() ||
  859. op_type == opr::InputCallback::typeinfo() ||
  860. op_type == opr::OutputCallback::typeinfo() ||
  861. op_type == BackwardGraph::typeinfo()) {
  862. return m_commands.end();
  863. }
  864. } else if constexpr (std::is_same_v<T, GetValue>) {
  865. return m_commands.end();
  866. }
  867. size_t buffer_length = m_owner->m_channel_state.options.buffer_length;
  868. if (m_commands.size() > buffer_length) {
  869. return m_commands.begin() + (m_commands.size() - buffer_length);
  870. }
  871. return m_commands.begin();
  872. }, cmd);
  873. }
  874. /**
  875. * 1. Find ApplyOp(dest) in buffered commands
  876. * 2. Check if there are other usages between ApplyOp and Del, return false if not
  877. * 3. Fuse Del into ApplyOp, return true
  878. */
  879. bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
  880. auto* dest = cmd.dest;
  881. // TODO: eliminate Puts
  882. auto begin = m_commands.begin(), end = m_commands.end();
  883. auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd){
  884. if (auto* apply = std::get_if<ApplyOp>(&cmd)) {
  885. return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
  886. }
  887. return false;
  888. });
  889. if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
  890. return false;
  891. }
  892. // mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
  893. std::get<ApplyOp>(*apply_iter).dels.push_back(dest);
  894. return true;
  895. }
  896. auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
  897. -> Handle {
  898. auto found = range[1];
  899. for (auto iter = range[0]; iter != range[1]; ++iter) {
  900. std::visit([&](const auto& cmd) {
  901. using T = std::decay_t<decltype(cmd)>;
  902. if constexpr (std::is_same_v<T, ApplyOp>) {
  903. if (std::count(cmd.inputs.begin(), cmd.inputs.end(),
  904. dest) > 0) {
  905. found = iter;
  906. }
  907. } else if constexpr (std::is_same_v<T, GetValue>) {
  908. if (cmd.dest == dest) {
  909. found = iter;
  910. }
  911. } else if constexpr (std::is_same_v<T, SwapIn> ||
  912. std::is_same_v<T, SwapOut> ||
  913. std::is_same_v<T, Drop>) {
  914. //TODO: ignore swap-like commands, just remove them from buffer
  915. if (cmd.dest == dest) {
  916. found = iter;
  917. }
  918. }
  919. }, *iter);
  920. };
  921. return found;
  922. }
  923. auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
  924. -> Handle {
  925. return std::find_if(range[0], range[1], [dest](auto& cmd) {
  926. return std::visit([dest](const auto& cmd){
  927. using T = std::decay_t<decltype(cmd)>;
  928. if constexpr (std::is_same_v<T, ApplyOp>) {
  929. return std::count(cmd.outputs.begin(), cmd.outputs.end(), dest) > 0;
  930. } else if constexpr (std::is_same_v<T, Put>) {
  931. return cmd.dest == dest;
  932. }
  933. return false;
  934. }, cmd);
  935. });
  936. }
  937. void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) {
  938. mgb_assert(check_available(), "Channel already closed");
  939. auto profiler_option = InterpreterProfiler::Option::from_dict(option);
  940. auto profiler = std::make_unique<InterpreterProfiler>();
  941. profiler->set_option(profiler_option);
  942. profiler->start(InterpreterProfiler::topic_to_mask(profiler_option.topic));
  943. std::swap(profiler, m_channel_state.profiler);
  944. m_buffer.enqueue(StartProfile{m_channel_state.profiler.get()});
  945. }
  946. void ChannelImpl::stop_profile(std::string basename, std::string format) {
  947. mgb_assert(check_available(), "Channel already closed");
  948. m_buffer.flush();
  949. auto profiler = std::make_unique<InterpreterProfiler>();
  950. std::swap(profiler, m_channel_state.profiler);
  951. profiler.release();
  952. m_buffer.enqueue(StopProfile{basename, format});
  953. }
  954. void ChannelImpl::push_scope(std::string name) {
  955. mgb_assert(check_available(), "Channel already closed");
  956. if (m_channel_state.profiler->is_profiling()) {
  957. m_channel_state.profiler->record_host<ChannelBeginScope>(name);
  958. m_channel_state.scopes.push_back(name);
  959. m_buffer.enqueue(PushScope{name});
  960. }
  961. }
  962. void ChannelImpl::pop_scope(std::string name) {
  963. mgb_assert(check_available(), "Channel already closed");
  964. if (m_channel_state.profiler->is_profiling()) {
  965. mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch");
  966. m_channel_state.scopes.pop_back();
  967. m_channel_state.profiler->record_host<ChannelEndScope>(name);
  968. m_buffer.enqueue(PopScope{name});
  969. }
  970. }
  971. void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
  972. for (auto i : vec) {
  973. i->pin();
  974. }
  975. }
  976. void ChannelImpl::DynamicSublinear::unpin(const SmallVector<TensorInfo*>& vec) {
  977. for (auto i : vec) {
  978. i->unpin();
  979. }
  980. }
  981. void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
  982. auto&& dsu_fa = find_father(ptr->dsu_ptr);
  983. dsu_fa->t -= ptr->compute_time;
  984. ptr->dsu_ptr->parent.reset();
  985. ptr->dsu_ptr->t = ptr->compute_time;
  986. }
  987. void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
  988. for (auto i : ptr->producer->inputs) {
  989. if (i->evict_type == EvictType::DROP) {
  990. merge(i->dsu_ptr, ptr->dsu_ptr);
  991. }
  992. }
  993. for (auto i : ptr->producer->outputs) {
  994. if (i && i->evict_type == EvictType::DROP) {
  995. merge(ptr->dsu_ptr, i->dsu_ptr);
  996. }
  997. }
  998. }
  999. double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
  1000. double cost = 0;
  1001. for (auto i : ptr->producer->inputs) {
  1002. if (i->evict_type == EvictType::DROP) {
  1003. double t = find_father(i->dsu_ptr)->t;
  1004. if (t < i->compute_time) {
  1005. t = i->compute_time;
  1006. }
  1007. cost += t;
  1008. }
  1009. }
  1010. for (auto i : ptr->producer->outputs) {
  1011. if (i && i->evict_type == EvictType::DROP) {
  1012. double t = find_father(i->dsu_ptr)->t;
  1013. if (t < i->compute_time) {
  1014. t = i->compute_time;
  1015. }
  1016. cost += t;
  1017. }
  1018. }
  1019. return cost;
  1020. }
  1021. TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() {
  1022. double min_msps = -1;
  1023. TensorInfo* best = nullptr;
  1024. for (auto i : candidates) {
  1025. if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) {
  1026. double neighbor_cost = estimate_neighbor_cost(i);
  1027. size_t begin_ptr = reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
  1028. auto side_info = i->ptr->comp_node().get_free_left_and_right(begin_ptr, begin_ptr + i->ptr->blob()->size());
  1029. double free_mem = side_info.first + side_info.second;
  1030. double msps = i->eval_func(neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
  1031. if (min_msps < 0 || msps < min_msps) {
  1032. min_msps = msps;
  1033. best = i;
  1034. }
  1035. }
  1036. }
  1037. return best;
  1038. }
  1039. void ChannelImpl::DynamicSublinear::merge(std::shared_ptr<DsuNode> &x, std::shared_ptr<DsuNode> &y) {
  1040. auto&& f_x = find_father(x);
  1041. auto&& f_y = find_father(y);
  1042. if (f_x.get() == f_y.get()) {
  1043. return;
  1044. }
  1045. f_y->t += f_x->t;
  1046. f_x->parent = f_y;
  1047. }
  1048. std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(std::shared_ptr<DsuNode>& x) {
  1049. if (x->is_root()) {
  1050. return x;
  1051. } else {
  1052. auto&& fa = find_father(x->parent);
  1053. return x->parent = fa;
  1054. }
  1055. }
  1056. void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
  1057. candidates.insert(ptr);
  1058. if (!comp_node.valid()) {
  1059. comp_node = ptr->ptr->comp_node();
  1060. }
  1061. }
  1062. void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
  1063. candidates.erase(ptr);
  1064. }
  1065. void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
  1066. ptr->last_used_time = estimate_timestamp;
  1067. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台