You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_impl.cpp 57 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503
  1. /**
  2. * \file imperative/src/impl/interpreter/interpreter_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./interpreter_impl.h"
  12. #include "range/v3/all.hpp"
  13. #include "megbrain/common.h"
  14. #include "megbrain/imperative/opr_utility.h"
  15. #include "megbrain/imperative/ops/autogen.h"
  16. #include "megbrain/imperative/ops/backward_graph.h"
  17. #include "megbrain/imperative/ops/opr_attr.h"
  18. #include "megbrain/imperative/ops/utility.h"
  19. #include "megbrain/imperative/utils/to_string.h"
  20. #include "../blob_manager_impl.h"
  21. #include "../event_pool.h"
  22. #include "../op_trait.h"
  23. using namespace mgb;
  24. using namespace imperative;
  25. using namespace interpreter;
  26. using namespace interpreter::intl;
  27. namespace {
  28. auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
  29. SmallVector<uint64_t> tid;
  30. for (auto* ptinfo: tinfo) {
  31. tid.push_back(ptinfo->id);
  32. }
  33. return tid;
  34. };
  35. }
  36. namespace mgb {
  37. using namespace profiler;
  38. }
  39. #if defined(_WIN32) || defined(_WIN64)
  40. #define SYMBOL_EXPORT __declspec(dllexport)
  41. #else
  42. #define SYMBOL_EXPORT __attribute__((visibility("default")))
  43. #endif
  44. namespace mgb {
  45. /**
  46. * USAGE
  47. *
  48. * header:
  49. * namespace mgb { void imperative_log_profile(const char* message); }
  50. *
  51. * code:
  52. * mgb::imperative_log_profile("MY MESSAGE");
  53. *
  54. **/
  55. SYMBOL_EXPORT
  56. void imperative_log_profile_begin(const char* message) {
  57. MGB_RECORD_EVENT(CustomEvent, std::string{message});
  58. }
  59. SYMBOL_EXPORT
  60. void imperative_log_profile_end(const char* message) {
  61. MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
  62. }
  63. SYMBOL_EXPORT
  64. void imperative_log_profile(const char* message){
  65. imperative_log_profile_begin(message);
  66. imperative_log_profile_end(message);
  67. }
  68. SYMBOL_EXPORT
  69. void imperative_log_profile_begin(const char* message, const char* device) {
  70. auto comp_node = CompNode::load(device);
  71. MGB_RECORD_EVENT(CustomEvent, std::string{message}, {}, comp_node);
  72. MGB_RECORD_EVENT(RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
  73. }
  74. SYMBOL_EXPORT
  75. void imperative_log_profile_end(const char* message, const char* device) {
  76. auto comp_node = CompNode::load(device);
  77. MGB_RECORD_EVENT(RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
  78. MGB_RECORD_EVENT(CustomFinishEvent, std::string{message}, {}, comp_node);
  79. }
  80. }
  81. std::thread::id ChannelImpl::get_worker_tid() {
  82. return m_worker_state.tid;
  83. }
  84. ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
  85. assert_in_channel();
  86. return m_channel_state;
  87. }
  88. ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
  89. assert_in_worker();
  90. return m_worker_state;
  91. }
  92. void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
  93. sys::set_thread_name("worker");
  94. m_owner->m_worker_state.tid = std::this_thread::get_id();
  95. OpDef::set_allocator([&](CompNode device, size_t size) {
  96. auto blob = Blob::make(device, size);
  97. m_owner->alloc_tensor_with_evict(blob.get());
  98. return blob->storage();
  99. });
  100. }
  101. // Do not use m_xxx_state directly
  102. #define m_channel_state
  103. #define m_worker_state
  104. std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
  105. return std::make_unique<ChannelImpl>();
  106. }
  107. Interpreter& Interpreter::inst() {
  108. static InterpreterImpl inst_;
  109. return inst_;
  110. }
  111. Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
  112. MGB_LOCK_GUARD(m_spin);
  113. mgb_assert(check_available(), "Channel already closed");
  114. auto& state = get_channel_state();
  115. auto _ = StackManager::Guard{"Put", &state.stack_manager};
  116. auto info = put_impl(value, no_cache);
  117. return reinterpret_cast<Handle>(info);
  118. }
  119. TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
  120. if (value.empty()) {
  121. auto layout = value.layout();
  122. layout.init_contiguous_stride();
  123. const_cast<HostTensorND&>(value).reset(value.storage(), layout);
  124. }
  125. auto info = alloc();
  126. init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()});
  127. info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
  128. info->h_value = value;
  129. m_buffer.enqueue(Put{info, value, no_cache});
  130. if (m_async_level == 0) {
  131. sync_impl();
  132. info->desc.comp_node.sync();
  133. }
  134. return info;
  135. }
  136. Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
  137. MGB_LOCK_GUARD(m_spin);
  138. mgb_assert(check_available(), "Channel already closed");
  139. return reinterpret_cast<Handle>(put_impl(data, hvalue));
  140. }
  141. TensorInfo* ChannelImpl::put_impl(const DeviceTensorND& data, const HostTensorND& hvalue) {
  142. auto& state = get_channel_state();
  143. auto _ = StackManager::Guard{"Put", &state.stack_manager};
  144. auto info = alloc();
  145. MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
  146. init(info, {data.layout(), data.comp_node()});
  147. info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
  148. info->ptr = Tensor::make(data, hvalue);
  149. MGB_RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr());
  150. info->status = TensorInfo::Produced;
  151. MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
  152. return info;
  153. }
  154. void ChannelImpl::del(Handle handle) {
  155. MGB_LOCK_GUARD(m_spin);
  156. if (!check_available()){
  157. return;
  158. }
  159. del_impl(handle);
  160. }
  161. void ChannelImpl::del_impl(Handle handle) {
  162. mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
  163. auto* info = reinterpret_cast<TensorInfo*>(handle);
  164. m_valid_handle.erase(handle);
  165. m_buffer.enqueue(Del{info});
  166. }
  167. void ChannelImpl::swap_in(Handle handle) {
  168. MGB_LOCK_GUARD(m_spin);
  169. mgb_assert(check_available(), "Channel already closed");
  170. auto& state = get_channel_state();
  171. if (state.options.enable_swap) {
  172. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  173. "invalid handle: %p", handle);
  174. auto* info = reinterpret_cast<TensorInfo*>(handle);
  175. m_buffer.enqueue(SwapIn{info});
  176. }
  177. }
  178. void ChannelImpl::swap_out(Handle handle) {
  179. MGB_LOCK_GUARD(m_spin);
  180. mgb_assert(check_available(), "Channel already closed");
  181. auto& state = get_channel_state();
  182. if (state.options.enable_swap) {
  183. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  184. "invalid handle: %p", handle);
  185. auto* info = reinterpret_cast<TensorInfo*>(handle);
  186. m_buffer.enqueue(SwapOut{info});
  187. }
  188. }
  189. void ChannelImpl::drop(Handle handle) {
  190. MGB_LOCK_GUARD(m_spin);
  191. mgb_assert(check_available(), "Channel already closed");
  192. auto& state = get_channel_state();
  193. if (state.options.enable_drop) {
  194. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  195. "invalid handle: %p", handle);
  196. auto* info = reinterpret_cast<TensorInfo*>(handle);
  197. m_buffer.enqueue(Drop{info});
  198. }
  199. }
  200. void ChannelImpl::dispatch_default_cpu(
  201. std::shared_ptr<OpDef> op,
  202. const SmallVector<TensorInfo*>& input_infos,
  203. const SmallVector<LogicalTensorDesc>& input_descs,
  204. SmallVector<Handle>* outputs) {
  205. auto& state = get_channel_state();
  206. auto name = op->trait()->make_name(*op);
  207. auto _ = StackManager::Guard(name, &state.stack_manager);
  208. auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
  209. MGB_RECORD_EVENT(ShapeInferEvent, validated);
  210. SmallVector<DeviceTensorND> input_tensornds;
  211. input_tensornds.reserve(input_descs.size());
  212. CompNode output_cn;
  213. {
  214. MGB_LOCK_GUARD(m_mutex);
  215. for (auto&& info : input_infos) {
  216. auto input_cn = info->desc.comp_node;
  217. if (!output_cn.valid()) {
  218. output_cn = input_cn;
  219. } else {
  220. mgb_assert(output_cn == input_cn, "cannot decide output comp node");
  221. }
  222. if (info->ptr && info->ptr->try_get_value()) {
  223. input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
  224. } else {
  225. // It's OK for SwapOut. We assign h_value before drop ptr
  226. mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
  227. input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
  228. }
  229. }
  230. }
  231. outputs->reserve(output_descs.size());
  232. SmallVector<DeviceTensorND> output_tensornds;
  233. output_tensornds.reserve(output_descs.size());
  234. for (auto&& desc : output_descs) {
  235. // TODO: may conflict with condtake, which need alloc inside
  236. mgb_assert(!desc.layout.is_empty());
  237. // use HostTensorND alloc_host for cuda pinned memory
  238. output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
  239. }
  240. uint64_t op_id = Profiler::next_id();
  241. OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
  242. SmallVector<TensorInfo*> output_infos;
  243. output_infos.reserve(output_descs.size());
  244. for (auto&& tensornd : output_tensornds) {
  245. HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd)
  246. .proxy_to_comp_node(output_cn);
  247. // use `put` for consistency
  248. auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
  249. mgb_assert(info->desc.layout.ndim != 0);
  250. output_infos.push_back(info);
  251. outputs->push_back(reinterpret_cast<Handle>(info));
  252. }
  253. auto op_info_getter = [op]{
  254. std::unordered_map<std::string, std::string> op_info;
  255. auto props = OpDef::props(*op);
  256. for (auto&& [key, value]: props) {
  257. op_info[key] = value;
  258. }
  259. return op_info;
  260. };
  261. MGB_RECORD_EVENT(OpDispatchEvent, op_id, name, op_info_getter,
  262. tinfo_to_tid(input_infos), tinfo_to_tid(output_infos),
  263. state.stack_manager.dump());
  264. }
  265. void ChannelImpl::dispatch_kernel(
  266. std::shared_ptr<OpDef> op,
  267. const SmallVector<TensorInfo*>& input_infos,
  268. const SmallVector<LogicalTensorDesc>& input_descs,
  269. SmallVector<Handle>* outputs) {
  270. auto& state = get_channel_state();
  271. auto& options = state.options;
  272. auto name = op->trait()->make_name(*op);
  273. auto _ = StackManager::Guard{name, &state.stack_manager};
  274. auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
  275. MGB_RECORD_EVENT(ShapeInferEvent, validated);
  276. ApplyOp cmd{Profiler::next_id(), std::move(op)};
  277. cmd.inputs = std::move(input_infos);
  278. cmd.outputs.reserve(output_descs.size());
  279. outputs->reserve(output_descs.size());
  280. for (int i = 0; i < output_descs.size(); ++i) {
  281. auto&& desc = output_descs[i];
  282. auto info = alloc();
  283. init(info, desc);
  284. // make sure desc's value is consistent with h_value
  285. if (!info->desc.value.empty()) {
  286. info->h_value = HostTensorND::make_proxy(desc.value)
  287. .proxy_to_comp_node(desc.comp_node);
  288. }
  289. cmd.outputs.push_back(info);
  290. outputs->push_back(reinterpret_cast<Handle>(info));
  291. }
  292. auto op_info_getter = [op=cmd.op]{
  293. std::unordered_map<std::string, std::string> op_info;
  294. auto props = OpDef::props(*op);
  295. for (auto&& [key, value]: props) {
  296. op_info[key] = value;
  297. }
  298. return op_info;
  299. };
  300. MGB_RECORD_EVENT(OpDispatchEvent, cmd.id, name, op_info_getter,
  301. tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs),
  302. state.stack_manager.dump());
  303. m_buffer.enqueue(std::move(cmd));
  304. if (!validated && options.async_level == 1) {
  305. sync_impl();
  306. } else if (options.async_level == 0) {
  307. sync_impl();
  308. // check device error
  309. for (auto&& oup : *outputs) {
  310. auto info = reinterpret_cast<TensorInfo*>(oup);
  311. info->ptr->comp_node().sync();
  312. }
  313. }
  314. }
  315. SmallVector<Handle> ChannelImpl::apply_op(
  316. std::shared_ptr<OpDef> op,
  317. const SmallVector<Handle>& inputs) {
  318. MGB_LOCK_GUARD(m_spin);
  319. mgb_assert(check_available(), "Channel already closed");
  320. return apply_op_impl(std::move(op), inputs);
  321. }
  322. SmallVector<Handle> ChannelImpl::apply_op_impl(
  323. std::shared_ptr<OpDef> op,
  324. const SmallVector<Handle>& inputs) {
  325. auto& state = get_channel_state();
  326. for (auto i : inputs) {
  327. mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
  328. "invalid handle: %p", i);
  329. }
  330. SmallVector<TensorInfo*> input_infos;
  331. input_infos.reserve(inputs.size());
  332. SmallVector<LogicalTensorDesc> input_descs;
  333. input_descs.reserve(inputs.size());
  334. {
  335. MGB_LOCK_GUARD(m_mutex);
  336. for (auto i : inputs) {
  337. auto info = reinterpret_cast<TensorInfo*>(i);
  338. mgb_assert(!info->invalid, "an input tensor is unusable due to previous error");
  339. input_infos.push_back(info);
  340. input_descs.push_back(info->desc);
  341. }
  342. }
  343. SmallVector<Handle> outputs;
  344. DispatchMode dispatch_mode = state.options.enable_host_compute
  345. ? OpDef::decide_dispatch_mode(*op, input_descs)
  346. : DispatchMode::KERNEL;
  347. switch (dispatch_mode) {
  348. case DEFAULT_CPU: {
  349. dispatch_default_cpu(op, input_infos, input_descs, &outputs);
  350. break;
  351. }
  352. case KERNEL: {
  353. dispatch_kernel(op, input_infos, input_descs, &outputs);
  354. break;
  355. }
  356. }
  357. return outputs;
  358. }
  359. HostTensorND ChannelImpl::get_value(Handle handle) {
  360. MGB_LOCK_GUARD(m_spin);
  361. mgb_assert(check_available(), "Channel already closed");
  362. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  363. "invalid handle: %p", handle);
  364. auto info = reinterpret_cast<TensorInfo*>(handle);
  365. // donnot use info->value_fetched, it's unsafe
  366. mgb_assert(!info->invalid, "tensor is unusable due to previous error");
  367. return wait_tensor(info, TensorProp::HostValue)->get_value();
  368. }
  369. TensorShape ChannelImpl::get_shape(Handle handle) {
  370. MGB_LOCK_GUARD(m_spin);
  371. mgb_assert(check_available(), "Channel already closed");
  372. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  373. "invalid handle: %p", handle);
  374. auto info = reinterpret_cast<TensorInfo*>(handle);
  375. if (info->desc.layout.ndim != 0) {
  376. return info->desc.layout;
  377. }
  378. TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
  379. mgb_assert(ret.ndim != 0);
  380. return ret;
  381. }
  382. DType ChannelImpl::get_dtype(Handle handle) {
  383. MGB_LOCK_GUARD(m_spin);
  384. mgb_assert(check_available(), "Channel already closed");
  385. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  386. "invalid handle: %p", handle);
  387. auto info = reinterpret_cast<TensorInfo*>(handle);
  388. MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
  389. auto ret = info->desc.layout.dtype;
  390. mgb_assert(ret.valid());
  391. return ret;
  392. }
  393. CompNode ChannelImpl::get_device(Handle handle) {
  394. MGB_LOCK_GUARD(m_spin);
  395. mgb_assert(check_available(), "Channel already closed");
  396. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  397. "invalid handle: %p", handle);
  398. auto info = reinterpret_cast<TensorInfo*>(handle);
  399. MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
  400. auto ret = info->desc.comp_node;
  401. mgb_assert(ret.valid());
  402. return ret;
  403. }
  404. DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
  405. MGB_LOCK_GUARD(m_spin);
  406. mgb_assert(check_available(), "Channel already closed");
  407. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  408. "invalid handle: %p", handle);
  409. auto info = reinterpret_cast<TensorInfo*>(handle);
  410. return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
  411. }
  412. void ChannelImpl::sync() {
  413. MGB_LOCK_GUARD(m_spin);
  414. mgb_assert(check_available(), "Channel already closed");
  415. sync_impl();
  416. }
  417. void ChannelImpl::sync_impl() {
  418. m_buffer.flush();
  419. m_worker.wait_all_task_finish();
  420. MGB_LOCK_GUARD(m_mutex);
  421. check_worker_exc_unsafe();
  422. }
  423. void ChannelImpl::close() {
  424. MGB_LOCK_GUARD(m_spin);
  425. if (!check_available()) {
  426. return;
  427. }
  428. std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
  429. for (auto* handle: valid_handles) {
  430. del_impl(handle);
  431. }
  432. mgb_assert(m_valid_handle.empty());
  433. mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
  434. sync_impl();
  435. m_closed = true;
  436. }
  437. size_t ChannelImpl::get_option(std::string name) {
  438. MGB_LOCK_GUARD(m_spin);
  439. mgb_assert(check_available(), "Channel already closed");
  440. auto& state = get_channel_state();
  441. return state.options.get_option(name);
  442. }
  443. void ChannelImpl::set_option(std::string name, size_t value) {
  444. MGB_LOCK_GUARD(m_spin);
  445. mgb_assert(check_available(), "Channel already closed");
  446. auto& state = get_channel_state();
  447. state.options.set_option(name, value);
  448. m_buffer.enqueue(SetOption{name, value});
  449. }
  450. TensorInfo* ChannelImpl::alloc() {
  451. auto& state = get_channel_state();
  452. auto info = [this]{
  453. MGB_LOCK_GUARD(m_mutex);
  454. return m_pool.alloc();
  455. }();
  456. info->id = Profiler::next_id();
  457. if (Profiler::is_profiling()) {
  458. size_t tensor_id = state.stack_manager.current()->next_id("tensor");
  459. info->name = state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
  460. }
  461. return info;
  462. }
  463. void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
  464. m_valid_handle.insert(reinterpret_cast<Handle>(info));
  465. MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
  466. info->status = TensorInfo::Allocated;
  467. info->desc = std::move(desc);
  468. info->mem_desc.layout = info->desc.layout;
  469. info->mem_desc.cn = info->desc.comp_node;
  470. info->mem_desc.offset = 0;
  471. }
  472. void ChannelImpl::do_drop(TensorInfo* ptr, bool user=false) {
  473. if (!ptr->producer) {
  474. if (user) {
  475. mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", ptr);
  476. }
  477. return;
  478. }
  479. if (ptr->evict_type != EvictType::NONE) {
  480. return;
  481. }
  482. ptr->evict_type = EvictType::DROP;
  483. ptr->status = TensorInfo::Dropped;
  484. release_tensor(ptr);
  485. }
  486. void ChannelImpl::free(TensorInfo* ptr) {
  487. auto& state = get_worker_state();
  488. if (state.options.enable_dtr_auto_drop) {
  489. // Evicting a tensor, rather than freeing it, can avoid pinning
  490. // potentially exploding amounts of memory and allow us to save
  491. // more memory.
  492. ptr->allow_delete = true;
  493. if (!ptr->ref_cnt) {
  494. recursive_free(ptr);
  495. } else {
  496. do_drop(ptr);
  497. }
  498. } else {
  499. real_free(ptr);
  500. }
  501. }
  502. void ChannelImpl::recursive_free(TensorInfo* ptr) {
  503. MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
  504. SmallVector<TensorInfo*> inps;
  505. if (ptr->producer) {
  506. for (auto i : ptr->producer->inputs) {
  507. if (i && --i->ref_cnt == 0) {
  508. inps.push_back(i);
  509. }
  510. }
  511. }
  512. real_free(ptr);
  513. for (auto i : inps) {
  514. if (i->allow_delete) {
  515. recursive_free(i);
  516. }
  517. }
  518. MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
  519. }
  520. void ChannelImpl::real_free(TensorInfo* ptr) {
  521. auto& state = get_worker_state();
  522. if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  523. m_dtr.erase_candidate(ptr);
  524. }
  525. detach_users(ptr);
  526. ptr->detach_producer();
  527. bool has_value = ptr->ptr != nullptr;
  528. if (has_value) {
  529. MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
  530. }
  531. MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
  532. ptr->status = TensorInfo::Deleted;
  533. MGB_LOCK_GUARD(m_mutex);
  534. m_pool.free(ptr);
  535. }
  536. ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){}
  537. ChannelImpl::~ChannelImpl() {
  538. close();
  539. }
  540. void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
  541. auto& state = get_worker_state();
  542. MGB_LOCK_GUARD(m_mutex);
  543. m_dtr.update_used_time(dest);
  544. MGB_RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr());
  545. // update tensor desc for static infer
  546. dest->desc.layout = ptr->layout();
  547. dest->desc.comp_node = ptr->comp_node();
  548. dest->memory = ptr->blob()->size();
  549. dest->ptr = std::move(ptr);
  550. dest->evict_type = EvictType::NONE;
  551. dest->status = TensorInfo::Produced;
  552. if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  553. m_dtr.insert_candidate(dest);
  554. }
  555. notify_tensor_unsafe(dest);
  556. }
  557. void ChannelImpl::release_tensor(TensorInfo* dest) {
  558. MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
  559. MGB_LOCK_GUARD(m_mutex);
  560. dest->ptr.reset();
  561. }
  562. void ChannelImpl::regenerate(TensorInfo* dest) {
  563. if (dest->evict_type == EvictType::DROP) {
  564. auto &&path = dest->producer;
  565. m_apply_stack.push({ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}, 0, dest, "dtr"});
  566. if (!m_applying) flush_apply_stack();
  567. } else if (dest->evict_type == EvictType::SWAP) {
  568. MGB_RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandKind::ReGen);
  569. produce_tensor(dest, Tensor::make(dest->h_value));
  570. MGB_RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandKind::ReGen);
  571. }
  572. }
  573. void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
  574. using namespace ranges;
  575. using namespace ranges::views;
  576. auto& state = get_worker_state();
  577. bool profiling_device = Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
  578. uint64_t apply_id = cmd.id;
  579. struct TensorWithDesc {
  580. TensorPtr tensor;
  581. MemoryDesc desc;
  582. };
  583. SmallVector<TensorWithDesc> inputs;
  584. inputs.reserve(cmd.inputs.size());
  585. // refcnt == 1, owners: [TensorInfo::ptr]
  586. for (auto i : cmd.inputs) {
  587. mgb_assert(i->ptr, "Invalid input tensor ptr!");
  588. // refcnt ++, owners: [i->ptr, tensor_inputs]
  589. // tensor_inputs.push_back(i->ptr);
  590. inputs.push_back({i->ptr, i->mem_desc});
  591. }
  592. if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
  593. auto_evict(0);
  594. }
  595. auto apply_on_physical_tensor = [&](auto&& self, const OpDef& def, SmallVector<TensorWithDesc> inputs) -> SmallVector<TensorWithDesc> {
  596. auto apply_functor = [&](std::shared_ptr<OpDef> op, SmallVector<TensorWithDesc> inputs, size_t nr_outputs) -> SmallVector<TensorWithDesc> {
  597. auto opname = op->trait()->make_name(*op);
  598. imperative_log_profile_begin(opname.c_str());
  599. auto outputs = self(self, *op, inputs);
  600. imperative_log_profile_end(opname.c_str());
  601. return outputs;
  602. };
  603. auto const_functor = [&](TensorPtr value) -> TensorWithDesc {
  604. return {value, MemoryDesc{value->layout(), 0, value->comp_node(), StorageIdentifier::make()}};
  605. };
  606. if (def.trait()->make_forward_graph) {
  607. // apply recursivily
  608. SmallVector<LogicalTensorDesc> input_descs;
  609. for (auto&& input: inputs) {
  610. input_descs.push_back({{{}, input.tensor->dtype()}, input.tensor->comp_node()});
  611. }
  612. auto forward_graph = OpDef::make_forward_graph(def, input_descs);
  613. auto outputs = forward_graph.apply(inputs, apply_functor, const_functor);
  614. return outputs;
  615. }
  616. SmallVector<TensorPtr> input_tensors;
  617. SmallVector<MemoryDesc> input_descs;
  618. for (auto&& input: inputs) {
  619. input_tensors.push_back(input.tensor);
  620. input_descs.push_back(input.desc);
  621. }
  622. auto [output_descs, output_tensors, workspaces] = init_output_and_workspace(def, input_tensors, input_descs);
  623. if (!output_descs.empty()) {
  624. OpDef::execute(def, input_tensors, output_tensors, workspaces);
  625. } else {
  626. output_tensors = OpDef::apply_on_physical_tensor(def, input_tensors);
  627. for (auto&& output_tensor: output_tensors) {
  628. output_descs.push_back(MemoryDesc{output_tensor->layout(), 0, output_tensor->comp_node(), StorageIdentifier::make()});
  629. }
  630. }
  631. SmallVector<TensorWithDesc> outputs;
  632. for (auto&& [output_tensor, output_desc]: ranges::zip_view(output_tensors, output_descs)) {
  633. outputs.push_back({output_tensor, output_desc});
  634. }
  635. return outputs;
  636. };
  637. MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
  638. // Begin profiling operator
  639. SmallVector<std::pair<CompNode, uint64_t>> kernels;
  640. if (profiling_device) {
  641. // Collecting devices
  642. SmallVector<CompNode> devices;
  643. for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
  644. if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
  645. devices.push_back(i->desc.comp_node);
  646. kernels.push_back({i->desc.comp_node, Profiler::next_id()});
  647. }
  648. }
  649. }
  650. for (auto* input: cmd.inputs) {
  651. auto input_id = input->id;
  652. MGB_RECORD_EVENT(OpInputEvent, input_id);
  653. MGB_RECORD_EVENT(TensorUsageEvent, input_id);
  654. MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
  655. }
  656. // Fused by command buffer. @see: CommandBuffer::fuse_del
  657. // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
  658. // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
  659. for (auto* del : cmd.dels) {
  660. // refcnt --, owners: [tensor_inputs]
  661. // if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor
  662. uint64_t del_id = del->id;
  663. MGB_RECORD_EVENT(TensorCommandEvent, del_id, TensorCommandKind::Del);
  664. free(del);
  665. MGB_RECORD_EVENT(TensorCommandFinishEvent, del_id, TensorCommandKind::Del);
  666. }
  667. // Before wait
  668. //TODO: split operator wait and execute so that OpWait could be corrected recorded.
  669. // Before execute
  670. for (auto&& [device, kernel_id]: kernels) {
  671. MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
  672. MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(device));
  673. }
  674. // Apply op
  675. // Here std::move is REQUIRED for removing duplicated references.
  676. auto outputs = apply_on_physical_tensor(apply_on_physical_tensor, *cmd.op, inputs);
  677. // After execute
  678. for (auto&& [device, kernel_id]: kernels) {
  679. MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(device));
  680. MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
  681. }
  682. // End profiling operator
  683. mgb_assert(outputs.size() == cmd.outputs.size());
  684. for (size_t i = 0; i < outputs.size(); ++i) {
  685. auto output = cmd.outputs[i];
  686. if (output == nullptr) {
  687. MGB_RECORD_EVENT(OpOutputEvent, 0);
  688. MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
  689. } else if (output->ptr != nullptr) {
  690. MGB_RECORD_EVENT(OpOutputEvent, output->id);
  691. MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
  692. } else {
  693. MGB_RECORD_EVENT(OpOutputEvent, output->id);
  694. produce_tensor(output, outputs[i].tensor);
  695. output->mem_desc = outputs[i].desc;
  696. MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
  697. sample_on_device(output->desc.comp_node, false);
  698. }
  699. }
  700. if (state.options.enable_dtr_auto_drop) {
  701. double estimate_compute_time = 0;
  702. for (auto i : cmd.inputs) {
  703. estimate_compute_time += i->memory;
  704. }
  705. for (auto i : outputs) {
  706. estimate_compute_time += i.tensor->blob()->size();
  707. }
  708. m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
  709. for (auto i : cmd.outputs) {
  710. if (i != nullptr) {
  711. i->compute_time = estimate_compute_time;
  712. }
  713. }
  714. m_dtr.unpin(cmd.inputs);
  715. }
  716. MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
  717. // End profiling operator
  718. }
  719. void ChannelImpl::flush_apply_stack() {
  720. m_applying = true;
  721. auto& state = get_worker_state();
  722. while (!m_apply_stack.empty()) {
  723. auto& [cmd, idx, recomp, reason] = m_apply_stack.top(); // cmd.inputs[0~idx-1] is in memory
  724. if (idx == 0) {
  725. if (state.options.enable_dtr_auto_drop) {
  726. m_dtr.pin(cmd.inputs);
  727. }
  728. if (recomp) {
  729. MGB_RECORD_EVENT(TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
  730. }
  731. }
  732. bool regen = false;
  733. for (size_t i = idx; i < cmd.inputs.size(); i ++) {
  734. auto&& p = cmd.inputs[i];
  735. if (state.options.enable_dtr_auto_drop) {
  736. m_dtr.update_used_time(p);
  737. }
  738. if (!p->ptr && p->evict_type != EvictType::NONE) {
  739. idx = i + 1;
  740. regenerate(p); // add ApplyOp to the stack
  741. regen = true;
  742. break;
  743. }
  744. }
  745. if (regen) continue;
  746. // the required input tensors are already in memory
  747. auto [cmd_backup, recomp_backup, reason_backup] = std::make_tuple(cmd, recomp, reason);
  748. m_apply_stack.pop();
  749. do_apply_op(cmd_backup, reason_backup);
  750. if (recomp_backup) {
  751. MGB_RECORD_EVENT(TensorCommandFinishEvent, recomp_backup->id, TensorCommandKind::ReGen);
  752. for (auto o : cmd_backup.outputs) {
  753. if (o) {
  754. m_dtr.update_dsu_after_recompute(o);
  755. }
  756. }
  757. }
  758. }
  759. m_applying = false;
  760. }
  761. bool ChannelImpl::auto_evict(size_t force_num) {
  762. auto& state = get_worker_state();
  763. if (!m_dtr.comp_node.valid()) {
  764. return false;
  765. }
  766. size_t current_memory = m_dtr.comp_node.get_used_memory();
  767. size_t flag = false;
  768. while ((state.options.dtr_eviction_threshold > 0 && current_memory > state.options.dtr_eviction_threshold) || force_num > 0) {
  769. MGB_RECORD_EVENT(AutoEvictEvent);
  770. sample_on_device(m_dtr.comp_node, false);
  771. auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling && !force_num);
  772. if (!best) {
  773. MGB_RECORD_EVENT(AutoEvictFinishEvent);
  774. break;
  775. }
  776. if (best->ptr.unique() && best->ptr->blob().unique()) {
  777. current_memory -= best->memory;
  778. if (force_num > 0) {
  779. force_num --;
  780. }
  781. flag = true;
  782. }
  783. do_drop(best);
  784. if (best->evict_type == EvictType::DROP) {
  785. m_dtr.update_dsu_after_evict(best);
  786. }
  787. sample_on_device(m_dtr.comp_node, false);
  788. MGB_RECORD_EVENT(AutoEvictFinishEvent);
  789. }
  790. return flag;
  791. }
  792. void ChannelImpl::detach_users(TensorInfo* dest) {
  793. SmallVector<TensorInfo::ComputePath*> users = dest->users;
  794. for (auto* user: users) {
  795. SmallVector<TensorInfo*> outputs = user->outputs;
  796. SmallVector<TensorInfo*> inputs = user->inputs;
  797. for (auto* output: outputs) {
  798. // When a `ComputePath` is detach from it's input,
  799. // there is no need to reserve it,
  800. // so we detach all output of this path
  801. // to decrease it's `ref_cnt` to zero.
  802. if (output == nullptr) {
  803. continue;
  804. }
  805. regenerate(output);
  806. output->detach_producer();
  807. for (auto* input: inputs) {
  808. input->ref_cnt --;
  809. }
  810. }
  811. // now user is dead
  812. }
  813. mgb_assert(dest->users.empty(), "ComputePath leaking");
  814. }
  815. bool ChannelImpl::check_available() {
  816. return !m_closed;
  817. }
  818. TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
  819. m_buffer.flush();
  820. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  821. mgb_assert(!m_waitee, "duplicate waitee");
  822. m_waitee = info;
  823. m_waitee_id = Profiler::next_id();
  824. MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
  825. bool require_host = prop == TensorProp::HostValue;
  826. auto host_available = [&]{
  827. return info->ptr && info->ptr->value_fetched();
  828. };
  829. if (require_host && !host_available()) {
  830. // avoid dead lock
  831. lock.unlock();
  832. m_buffer.enqueue(GetValue{info});
  833. m_buffer.flush();
  834. lock.lock();
  835. }
  836. m_cv.wait(lock, [&]() {
  837. check_worker_exc_unsafe();
  838. return require_host ? host_available() : static_cast<bool>(info->ptr);
  839. });
  840. MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
  841. m_waitee = nullptr;
  842. return info->ptr;
  843. }
  844. void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
  845. if (info == m_waitee) {
  846. MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
  847. m_cv.notify_all();
  848. }
  849. }
  850. std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
  851. std::unordered_set<TensorInfo*> valid_tensors;
  852. for (auto* handle: m_valid_handle) {
  853. auto* info = reinterpret_cast<TensorInfo*>(handle);
  854. valid_tensors.insert(info);
  855. }
  856. return valid_tensors;
  857. }
  858. void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
  859. auto reserve_size = [&](size_t size) {
  860. if (!m_dtr.comp_node.valid()) {
  861. return false;
  862. }
  863. while (size > m_dtr.comp_node.get_max_block_size_available()) {
  864. bool evict_suc = auto_evict(1);
  865. if (!evict_suc) return false;
  866. }
  867. return true;
  868. };
  869. auto pre_level = set_log_level(LogLevel::NO_LOG);
  870. reserve_size(x->size());
  871. MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
  872. MGB_CATCH(MemAllocError&, {
  873. bool suc = false;
  874. while (!suc) {
  875. if (!auto_evict(1)) {
  876. break;
  877. }
  878. MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
  879. MGB_CATCH(MemAllocError&, { continue; });
  880. suc = true;
  881. }
  882. if (!suc) {
  883. set_log_level(pre_level);
  884. mgb_log_warn("reallocating all cuda memory to alleviate fragmentation, the performance may be affected");
  885. set_log_level(LogLevel::NO_LOG);
  886. imperative_log_profile_begin("defrag");
  887. BlobManager::inst()->defrag(x->comp_node());
  888. imperative_log_profile_end("defrag");
  889. BlobManager::inst()->alloc_direct(x, x->size());
  890. }
  891. });
  892. set_log_level(pre_level);
  893. }
  894. std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> ChannelImpl::init_output_and_workspace(
  895. const OpDef& def,
  896. SmallVector<TensorPtr> inputs,
  897. SmallVector<MemoryDesc> inputs_mem_desc) {
  898. auto [outputs_desc, workspaces_desc] = OpDef::infer_output_mem_desc(def, inputs, inputs_mem_desc);
  899. if (!outputs_desc.size()) {
  900. // failed to infer memplan
  901. return {{}, {}, {}};
  902. }
  903. // refine storage id to make it unique
  904. for (auto&& desc : outputs_desc) {
  905. if (desc.id->is_sys_alloc()) {
  906. // TODO: there may be some outputs sharing the same storage id
  907. desc.id->id = ++ m_storage_id;
  908. }
  909. }
  910. auto& state = get_worker_state();
  911. auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) {
  912. SmallVector<TensorPtr> tensors;
  913. for (size_t i = 0; i < desc.size(); i ++) {
  914. if (desc[i].id->is_sys_alloc()) {
  915. tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn));
  916. if (state.options.enable_dtr_auto_drop && !desc[i].layout.is_empty()) {
  917. alloc_tensor_with_evict(tensors.back()->blob().get());
  918. }
  919. } else if (desc[i].id->is_from_other()) {
  920. for (size_t j = 0; j < inputs_mem_desc.size();j ++) {
  921. if (inputs_mem_desc[j].id->desc == desc[i].id->desc) {
  922. tensors.push_back(inputs[j]->sub(desc[i].offset, desc[i].layout));
  923. break;
  924. }
  925. }
  926. } else if (desc[i].id->is_device_ptr()) {
  927. tensors.push_back(desc[i].id->ptr);
  928. } else {
  929. mgb_assert(0, "not implemented");
  930. }
  931. }
  932. return tensors;
  933. };
  934. return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)};
  935. }
  936. void ChannelImpl::process_one_task(Command& icmd) {
  937. using namespace ranges;
  938. using namespace ranges::views;
  939. auto& state = get_worker_state();
  940. auto& options = state.options;
  941. //TODO: remove std::visit for support osx 10.12
  942. auto cmd_visitor = [&](const auto& cmd) {
  943. using T = std::decay_t<decltype(cmd)>;
  944. if constexpr (std::is_same_v<T, Put>) {
  945. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
  946. MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(cmd.value.comp_node()));
  947. auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value);
  948. MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(cmd.value.comp_node()));
  949. produce_tensor(cmd.dest, std::move(value));
  950. MGB_RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
  951. sample_on_device(cmd.dest->desc.comp_node, false);
  952. } else if constexpr (std::is_same_v<T, ApplyOp>) {
  953. for (auto& i : cmd.inputs) {
  954. if (i->invalid) {
  955. MGB_LOCK_GUARD(m_mutex);
  956. for (auto& i : cmd.outputs) {
  957. i->invalid = true;
  958. }
  959. return;
  960. }
  961. }
  962. m_apply_stack.push({cmd, 0, nullptr, "cmd"});
  963. flush_apply_stack();
  964. for (size_t i = 0; i < cmd.outputs.size(); ++i) {
  965. auto output = cmd.outputs[i];
  966. if (output == nullptr) {
  967. continue;
  968. }
  969. if (state.options.enable_dtr_auto_drop) {
  970. output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
  971. }
  972. }
  973. if (state.options.enable_drop && state.options.record_computing_path) {
  974. auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
  975. auto& input = std::get<0>(tuple2);
  976. auto& output = std::get<1>(tuple2);
  977. if (!input->ptr || !output->ptr) {
  978. return false;
  979. }
  980. return input->ptr->blob()->storage() == output->ptr->blob()->storage();
  981. };
  982. // FIXME: do not use opname as identifier
  983. auto get_name = [](const OpDef& opdef) {
  984. if (auto attr = opdef.try_cast_final<OprAttr>()) {
  985. return attr->type.c_str();
  986. }
  987. return opdef.dyn_typeinfo()->name;
  988. };
  989. auto is_cross_cn = [comp_node=m_dtr.comp_node](TensorInfo* info){
  990. return info->desc.comp_node != comp_node;
  991. };
  992. bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
  993. bool inplace = any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);
  994. if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
  995. TensorInfo::ComputePath::make(cmd.id, cmd.op, cmd.inputs, cmd.outputs);
  996. size_t detach_cnt = 0;
  997. if (!strcmp(get_name(*cmd.op), "BatchNorm") && cmd.outputs.size() == 5) {
  998. cmd.outputs[0]->detach_producer(); // detach running_mean
  999. cmd.outputs[1]->detach_producer(); // detach running_var
  1000. for (auto input : cmd.inputs) {
  1001. input->ref_cnt -= 2;
  1002. }
  1003. }
  1004. for (auto output : cmd.outputs) {
  1005. if (output->producer && !output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  1006. output->detach_producer();
  1007. detach_cnt ++;
  1008. }
  1009. }
  1010. for (auto input : cmd.inputs) {
  1011. input->ref_cnt -= detach_cnt;
  1012. }
  1013. }
  1014. }
  1015. } else if constexpr (std::is_same_v<T, Del>) {
  1016. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
  1017. CompNode device = cmd.dest->desc.comp_node;
  1018. uint64_t tensor_id = cmd.dest->id;
  1019. free(cmd.dest);
  1020. MGB_RECORD_EVENT(TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
  1021. sample_on_device(device, false);
  1022. } else if constexpr (std::is_same_v<T, GetValue>) {
  1023. if (cmd.dest->invalid) return;
  1024. imperative_log_profile_begin("GetValue");
  1025. if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
  1026. regenerate(cmd.dest);
  1027. }
  1028. cmd.dest->ptr->fetch_value();
  1029. MGB_LOCK_GUARD(m_mutex);
  1030. notify_tensor_unsafe(cmd.dest);
  1031. imperative_log_profile_end("GetValue");
  1032. } else if constexpr (std::is_same_v<T, SwapIn>) {
  1033. if (cmd.dest->invalid) return;
  1034. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::SwapIn);
  1035. produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
  1036. MGB_RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::SwapIn);
  1037. sample_on_device(cmd.dest->desc.comp_node, false);
  1038. } else if constexpr (std::is_same_v<T, SwapOut>) {
  1039. if (cmd.dest->invalid) return;
  1040. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::SwapOut);
  1041. cmd.dest->h_value = cmd.dest->ptr->get_value();
  1042. if (cmd.dest->evict_type == EvictType::NONE) {
  1043. cmd.dest->evict_type = EvictType::SWAP;
  1044. cmd.dest->status = TensorInfo::Swapped;
  1045. release_tensor(cmd.dest);
  1046. }
  1047. MGB_RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::SwapOut);
  1048. sample_on_device(cmd.dest->desc.comp_node, false);
  1049. } else if constexpr (std::is_same_v<T, Drop>) {
  1050. if (cmd.dest->invalid) return;
  1051. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
  1052. do_drop(cmd.dest, true);
  1053. MGB_RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
  1054. } else if constexpr (std::is_same_v<T, SetOption>) {
  1055. options.set_option(cmd.key, cmd.value);
  1056. } else if constexpr (std::is_same_v<T, StartProfile>) {
  1057. MGB_RECORD_EVENT(StartProfileEvent);
  1058. CompNode::sync_all();
  1059. for (auto* info: cmd.capture_tensors) {
  1060. MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
  1061. if (info->status == TensorInfo::Produced) {
  1062. // TODO: handle swap/drop
  1063. MGB_RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
  1064. }
  1065. }
  1066. CompNode::foreach([&](CompNode device){
  1067. sample_on_device(device, true);
  1068. MGB_RECORD_EVENT_IF((Profiler::get_option("profile_device", 0)), RecordDeviceEvent, Timer::record_device(device));
  1069. });
  1070. MGB_RECORD_EVENT(StartProfileFinishEvent);
  1071. } else if constexpr (std::is_same_v<T, StopProfile>) {
  1072. MGB_RECORD_EVENT(StopProfileEvent);
  1073. for (auto* info: cmd.escape_tensors) {
  1074. bool has_value = info->status == TensorInfo::Produced;
  1075. if (has_value) {
  1076. MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
  1077. }
  1078. MGB_RECORD_EVENT(TensorEraseEvent, info->id);
  1079. }
  1080. CompNode::foreach([&](CompNode device){
  1081. sample_on_device(device, true);
  1082. });
  1083. MGB_RECORD_EVENT(StopProfileFinishEvent);
  1084. } else if constexpr (std::is_same_v<T, PushScope>) {
  1085. MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
  1086. } else if constexpr (std::is_same_v<T, PopScope>) {
  1087. MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
  1088. } else {
  1089. static_assert(!std::is_same_v<T, T>);
  1090. }
  1091. };
  1092. std::visit([&](const auto& cmd){
  1093. using T = std::decay_t<decltype(cmd)>;
  1094. if (!options.catch_worker_execption) {
  1095. cmd_visitor(cmd);
  1096. return;
  1097. }
  1098. try {
  1099. cmd_visitor(cmd);
  1100. } catch (...) {
  1101. MGB_LOCK_GUARD(m_mutex);
  1102. if constexpr (std::is_same_v<T, ApplyOp>) {
  1103. for (auto oup : cmd.outputs) {
  1104. oup->invalid = true;
  1105. }
  1106. } else if constexpr (std::is_same_v<T, Put>) {
  1107. cmd.dest->invalid = true;
  1108. }
  1109. m_worker_exc = std::current_exception();
  1110. MGB_RECORD_EVENT(WorkerExceptionEvent);
  1111. if (m_waitee) {
  1112. notify_tensor_unsafe(m_waitee);
  1113. }
  1114. }
  1115. }, icmd.data);
  1116. }
  1117. void ChannelImpl::check_worker_exc_unsafe() {
  1118. if (m_worker_exc) {
  1119. // for reuse interpreter_for_py after some exception tests
  1120. m_waitee = nullptr;
  1121. std::exception_ptr exc;
  1122. std::swap(exc, m_worker_exc);
  1123. try {
  1124. std::rethrow_exception(exc);
  1125. } catch (...) {
  1126. throw AsyncError();
  1127. }
  1128. }
  1129. }
  1130. void ChannelImpl::CommandBuffer::enqueue(CommandData cmd) {
  1131. auto& state = m_owner->get_channel_state();
  1132. if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
  1133. return;
  1134. }
  1135. // mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
  1136. m_commands.push_back({Profiler::next_id(), std::move(cmd), state.stack_manager.dump()});
  1137. auto flush_pos = flush_pos_for(m_commands.back());
  1138. flush(flush_pos);
  1139. }
  1140. void ChannelImpl::CommandBuffer::flush() {
  1141. flush(m_commands.end());
  1142. }
  1143. void ChannelImpl::CommandBuffer::flush(Handle pos) {
  1144. for (auto iter = m_commands.begin(); iter != pos; ++iter) {
  1145. if (Profiler::is_profiling()) {
  1146. mgb_log_debug("%s Flushed", to_string(*iter).c_str());
  1147. }
  1148. m_owner->m_worker.add_task(std::move(*iter));
  1149. }
  1150. m_commands.erase(m_commands.begin(), pos);
  1151. }
  1152. auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
  1153. auto& state = m_owner->get_channel_state();
  1154. return std::visit([this, &state](const auto& cmd) {
  1155. using T = std::decay_t<decltype(cmd)>;
  1156. if constexpr (std::is_same_v<T, ApplyOp>) {
  1157. auto* op_type = cmd.op->dyn_typeinfo();
  1158. if (op_type == RemoteRecv::typeinfo() ||
  1159. op_type == RemoteSend::typeinfo() ||
  1160. op_type == CollectiveComm::typeinfo() ||
  1161. op_type == opr::InputCallback::typeinfo() ||
  1162. op_type == opr::OutputCallback::typeinfo()) {
  1163. return m_commands.end();
  1164. }
  1165. } else if constexpr (std::is_same_v<T, GetValue>) {
  1166. return m_commands.end();
  1167. }
  1168. size_t buffer_length = state.options.buffer_length;
  1169. if (m_commands.size() > buffer_length) {
  1170. return m_commands.begin() + (m_commands.size() - buffer_length);
  1171. }
  1172. return m_commands.begin();
  1173. }, cmd.data);
  1174. }
  1175. /**
  1176. * 1. Find ApplyOp(dest) in buffered commands
  1177. * 2. Check if there are other usages between ApplyOp and Del, return false if not
  1178. * 3. Fuse Del into ApplyOp, return true
  1179. */
  1180. bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
  1181. auto* dest = cmd.dest;
  1182. // TODO: eliminate Puts
  1183. auto begin = m_commands.begin(), end = m_commands.end();
  1184. auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd){
  1185. if (auto* apply = std::get_if<ApplyOp>(&cmd.data)) {
  1186. return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
  1187. }
  1188. return false;
  1189. });
  1190. if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
  1191. return false;
  1192. }
  1193. // mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
  1194. std::get<ApplyOp>(apply_iter->data).dels.push_back(dest);
  1195. return true;
  1196. }
  1197. auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
  1198. -> Handle {
  1199. auto found = range[1];
  1200. for (auto iter = range[0]; iter != range[1]; ++iter) {
  1201. std::visit([&](const auto& cmd) {
  1202. using T = std::decay_t<decltype(cmd)>;
  1203. if constexpr (std::is_same_v<T, ApplyOp>) {
  1204. if (std::count(cmd.inputs.begin(), cmd.inputs.end(),
  1205. dest) > 0) {
  1206. found = iter;
  1207. }
  1208. } else if constexpr (std::is_same_v<T, GetValue>) {
  1209. if (cmd.dest == dest) {
  1210. found = iter;
  1211. }
  1212. } else if constexpr (std::is_same_v<T, SwapIn> ||
  1213. std::is_same_v<T, SwapOut> ||
  1214. std::is_same_v<T, Drop>) {
  1215. //TODO: ignore swap-like commands, just remove them from buffer
  1216. if (cmd.dest == dest) {
  1217. found = iter;
  1218. }
  1219. }
  1220. }, iter->data);
  1221. };
  1222. return found;
  1223. }
  1224. auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
  1225. -> Handle {
  1226. return std::find_if(range[0], range[1], [dest](auto& cmd) {
  1227. return std::visit([dest](const auto& cmd){
  1228. using T = std::decay_t<decltype(cmd)>;
  1229. if constexpr (std::is_same_v<T, ApplyOp>) {
  1230. return std::count(cmd.outputs.begin(), cmd.outputs.end(), dest) > 0;
  1231. } else if constexpr (std::is_same_v<T, Put>) {
  1232. return cmd.dest == dest;
  1233. }
  1234. return false;
  1235. }, cmd.data);
  1236. });
  1237. }
  1238. void ChannelImpl::start_profile() {
  1239. MGB_LOCK_GUARD(m_spin);
  1240. mgb_assert(check_available(), "Channel already closed");
  1241. auto capture_tensors = collect_valid_tensors();
  1242. if (capture_tensors.size() > 0) {
  1243. m_buffer.enqueue(StartProfile{std::move(capture_tensors)});
  1244. }
  1245. }
  1246. void ChannelImpl::stop_profile() {
  1247. MGB_LOCK_GUARD(m_spin);
  1248. mgb_assert(check_available(), "Channel already closed");
  1249. m_buffer.flush();
  1250. auto escape_tensors = collect_valid_tensors();
  1251. if (escape_tensors.size() > 0) {
  1252. m_buffer.enqueue(StopProfile{std::move(escape_tensors)});
  1253. }
  1254. }
  1255. void ChannelImpl::push_scope(std::string name) {
  1256. MGB_LOCK_GUARD(m_spin);
  1257. mgb_assert(check_available(), "Channel already closed");
  1258. auto& state = get_channel_state();
  1259. state.stack_manager.enter(name);
  1260. MGB_RECORD_EVENT(ScopeEvent, name);
  1261. m_buffer.enqueue(PushScope{name});
  1262. }
  1263. void ChannelImpl::pop_scope(std::string name) {
  1264. MGB_LOCK_GUARD(m_spin);
  1265. mgb_assert(check_available(), "Channel already closed");
  1266. auto& state = get_channel_state();
  1267. state.stack_manager.exit(name);
  1268. MGB_RECORD_EVENT(ScopeFinishEvent, name);
  1269. m_buffer.enqueue(PopScope{name});
  1270. }
  1271. void ChannelImpl::assert_in_channel() {
  1272. mgb_assert(get_worker_tid() != std::this_thread::get_id(), "this method cannot be called in worker thread");
  1273. }
  1274. void ChannelImpl::assert_in_worker() {
  1275. mgb_assert(get_worker_tid() == std::this_thread::get_id(), "this method can only be called in worker thread");
  1276. }
  1277. void ChannelImpl::sample_on_device(CompNode device, bool force) {
  1278. if (!force) {
  1279. thread_local int last_sample_id = 0;
  1280. int sample_rate = Profiler::is_profiling() ? Profiler::get_option("sample_rate", 0) : 0;
  1281. if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
  1282. return;
  1283. }
  1284. }
  1285. MGB_RECORD_EVENT(SampleDeviceEvent, device);
  1286. auto [total, free] = device.get_mem_status_bytes();
  1287. MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
  1288. }
  1289. void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
  1290. for (auto i : vec) {
  1291. i->pin();
  1292. }
  1293. }
  1294. void ChannelImpl::DynamicSublinear::unpin(const SmallVector<TensorInfo*>& vec) {
  1295. for (auto i : vec) {
  1296. i->unpin();
  1297. }
  1298. }
  1299. void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
  1300. auto&& dsu_fa = find_father(ptr->dsu_ptr);
  1301. dsu_fa->t -= ptr->compute_time;
  1302. ptr->dsu_ptr->parent.reset();
  1303. ptr->dsu_ptr->t = ptr->compute_time;
  1304. }
  1305. void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
  1306. for (auto i : ptr->producer->inputs) {
  1307. if (i->evict_type == EvictType::DROP) {
  1308. merge(i->dsu_ptr, ptr->dsu_ptr);
  1309. }
  1310. }
  1311. for (auto i : ptr->producer->outputs) {
  1312. if (i && i->evict_type == EvictType::DROP) {
  1313. merge(ptr->dsu_ptr, i->dsu_ptr);
  1314. }
  1315. }
  1316. }
  1317. double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
  1318. double cost = 0;
  1319. for (auto i : ptr->producer->inputs) {
  1320. if (i->evict_type == EvictType::DROP) {
  1321. double t = find_father(i->dsu_ptr)->t;
  1322. if (t < i->compute_time) {
  1323. t = i->compute_time;
  1324. }
  1325. cost += t;
  1326. }
  1327. }
  1328. for (auto i : ptr->producer->outputs) {
  1329. if (i && i->evict_type == EvictType::DROP) {
  1330. double t = find_father(i->dsu_ptr)->t;
  1331. if (t < i->compute_time) {
  1332. t = i->compute_time;
  1333. }
  1334. cost += t;
  1335. }
  1336. }
  1337. return cost;
  1338. }
  1339. TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(bool enable_dtr_sqrt_sampling=false) {
  1340. double min_msps = -1;
  1341. TensorInfo* best = nullptr;
  1342. size_t sz = 1;
  1343. if (enable_dtr_sqrt_sampling) {
  1344. while (sz * sz <= candidates.size()) sz ++;
  1345. } else {
  1346. sz = candidates.size();
  1347. }
  1348. for (auto i : candidates) {
  1349. if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) {
  1350. double neighbor_cost = estimate_neighbor_cost(i);
  1351. size_t begin_ptr = reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
  1352. auto side_info = i->ptr->comp_node().get_free_left_and_right(begin_ptr, begin_ptr + i->ptr->blob()->size());
  1353. double free_mem = side_info.first + side_info.second;
  1354. double msps = i->eval_func(neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
  1355. if (min_msps < 0 || msps < min_msps) {
  1356. min_msps = msps;
  1357. best = i;
  1358. }
  1359. }
  1360. if (--sz == 0) break;
  1361. }
  1362. return best;
  1363. }
  1364. void ChannelImpl::DynamicSublinear::merge(std::shared_ptr<DsuNode> &x, std::shared_ptr<DsuNode> &y) {
  1365. auto&& f_x = find_father(x);
  1366. auto&& f_y = find_father(y);
  1367. if (f_x.get() == f_y.get()) {
  1368. return;
  1369. }
  1370. f_y->t += f_x->t;
  1371. f_x->parent = f_y;
  1372. }
  1373. std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(std::shared_ptr<DsuNode>& x) {
  1374. if (x->is_root()) {
  1375. return x;
  1376. } else {
  1377. auto&& fa = find_father(x->parent);
  1378. return x->parent = fa;
  1379. }
  1380. }
  1381. void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
  1382. candidates.insert(ptr);
  1383. if (!comp_node.valid()) {
  1384. comp_node = ptr->ptr->comp_node();
  1385. }
  1386. }
  1387. void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
  1388. candidates.erase(ptr);
  1389. }
  1390. void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
  1391. ptr->last_used_time = estimate_timestamp;
  1392. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台