You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comp_node.h 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799
  1. /**
  2. * \file src/core/include/megbrain/comp_node.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/utils/hash.h"
  13. #include "megbrain/utils/metahelper.h"
  14. #include "megbrain/utils/thin/function.h"
  15. #include "megbrain/utils/thin/hash_table.h"
  16. #include "megbrain/utils/thread.h"
  17. #include "megdnn/thin/function.h"
  18. #include <cstddef>
  19. #include <memory>
  20. #include <string>
  21. namespace mgb {
  22. // forward declaration; defined in comp_node_env.h
  23. class CompNodeEnv;
  24. namespace cg {
  25. class ComputingGraph;
  26. }
  27. class CompNodeSeqRecorder;
  28. /*!
  29. * \brief identifier for a memory node
  30. *
  31. * MemNode is comparable. CompNodes with the same MemNode can access memory of
  32. * each other directly
  33. */
  34. class MemNode {
  35. const void* m_id = nullptr;
  36. public:
  37. MemNode() = default;
  38. explicit MemNode(const void* id) : m_id{id} {}
  39. bool operator==(const MemNode& rhs) const { return m_id == rhs.m_id; }
  40. bool operator!=(const MemNode& rhs) const { return m_id != rhs.m_id; }
  41. operator bool() const { return m_id != nullptr; }
  42. };
  43. /*!
  44. * \brief abstraction of a streaming computing resource on localhost (a
  45. * thread on CPU, a cuda stream, etc.)
  46. *
  47. * Note that most of the operations are asynchronous with respect to the caller
  48. * thread
  49. */
  50. class CompNode {
  51. public:
  52. //! computing device type
  53. enum class DeviceType {
  54. //! for "xpu" comp node that would mapped to available cn on
  55. //! current system
  56. UNSPEC = 0,
  57. CUDA = 1,
  58. CPU = 2,
  59. CAMBRICON = 3,
  60. ROCM = 8,
  61. ATLAS = 9,
  62. MULTITHREAD = 11,
  63. MAX_DEVICE_ID,
  64. };
  65. static constexpr size_t NR_DEVICE_TYPE =
  66. static_cast<size_t>(DeviceType::MAX_DEVICE_ID);
  67. /*!
  68. * \brief an identifier to specify a computing node
  69. *
  70. * Note: logical locator is directly parsed from a string identifier
  71. * given by user; it should be translated to physical locator by calling
  72. * to_physical() before actual use.
  73. *
  74. * Unless explicitly specified otherwise, all locators are physical
  75. * locators.
  76. */
  77. struct Locator {
  78. /*!
  79. * \brief special device number for the "cpu default" comp node,
  80. * which dispatches all tasks in the caller thread
  81. */
  82. static constexpr int DEVICE_CPU_DEFAULT = -1024;
  83. /*!
  84. * \brief special device number for the "multithread_default"
  85. * comp node, which dispatches all tasks to thread pool and the
  86. * caller thread is the main thread of thread pool
  87. */
  88. static constexpr int DEVICE_MULTITHREAD_DEFAULT = -1025;
  89. DeviceType type = DeviceType::UNSPEC;
  90. /*!
  91. * corresponding to a physical computing device; memories between
  92. * different devices are not shared.
  93. *
  94. * device == -1 means logical default device (maps to 0 by default,
  95. * and can be changed by set_device_map)
  96. *
  97. */
  98. int device = -1;
  99. //! multiple streams can execute on one computing device and share
  100. //! memory, when compnode type is multithread the field also stand
  101. //! for nr_threads
  102. union {
  103. int stream = 0;
  104. int nr_threads;
  105. };
  106. /*!
  107. * \brief parse a string identifier
  108. *
  109. * currently supported ID format: (gpu|cpu)<n>[:m] where n is the
  110. * device number, possibly with m as the stream id.
  111. */
  112. MGE_WIN_DECLSPEC_FUC static Locator parse(const std::string& id);
  113. /*!
  114. * \brief set mapping between device numbers of a device type
  115. */
  116. MGE_WIN_DECLSPEC_FUC static void set_device_map(
  117. DeviceType type, int from, int to);
  118. /*!
  119. * \brief set the actual device type to be used for
  120. * DeviceType::UNSPEC
  121. */
  122. MGE_WIN_DECLSPEC_FUC static void set_unspec_device_type(DeviceType type);
  123. /*!
  124. * \brief get corresponding physical Locator
  125. *
  126. * DeviceType::UNSPEC would be resolved, and device map would be
  127. * applied on device number
  128. */
  129. MGE_WIN_DECLSPEC_FUC Locator to_physical() const;
  130. /*!
  131. * \brief get string description of this locator that can be parsed
  132. * again
  133. */
  134. MGE_WIN_DECLSPEC_FUC std::string to_string() const;
  135. bool operator==(const Locator& rhs) const {
  136. return type == rhs.type && device == rhs.device && stream == rhs.stream;
  137. }
  138. };
  139. struct LocatorPairHashKey {
  140. Locator locator, locator_logical;
  141. bool operator==(const LocatorPairHashKey& rhs) const {
  142. return locator == rhs.locator && locator_logical == rhs.locator_logical;
  143. }
  144. struct Hash {
  145. size_t operator()(const LocatorPairHashKey& k) const {
  146. return hash_pair_combine(
  147. mgb::hash(k.locator), mgb::hash(k.locator_logical));
  148. }
  149. };
  150. };
  151. //! predefined special streams
  152. struct Stream {
  153. static constexpr int COPY = -1, REMOTE_SEND = -2, LOOP_SWAP = -3;
  154. };
  155. CompNode() = default;
  156. /*!
  157. * \brief manually destroy all comp node resources
  158. */
  159. MGE_WIN_DECLSPEC_FUC static void finalize();
  160. /*!
  161. * \brief load a computing node from logical locator ID;
  162. * \see Locator::parse
  163. */
  164. static CompNode load(const std::string& id) { return load(Locator::parse(id)); }
  165. /*!
  166. * \brief create a CompNode object from **logical** locator
  167. */
  168. static CompNode load(const Locator& locator) {
  169. return load(locator.to_physical(), locator);
  170. }
  171. MGE_WIN_DECLSPEC_FUC static CompNode load(
  172. const Locator& locator_physical, const Locator& locator_logical);
  173. /* =================== memory management ======================== */
  174. /*!
  175. * \brief allocate memory on this computing node
  176. *
  177. * Note: allocation of device memory is synchronous with the host,
  178. * meaning that the memory can be used immediately; however deallocation
  179. * is asynchronous to ensure that the memory can be used by
  180. * already-launched kernels on the computing node.
  181. *
  182. * Exception should be raised if allocation fails.
  183. */
  184. MGE_WIN_DECLSPEC_FUC void* alloc_device(size_t size) const;
  185. //! deallocate device buffer; see alloc_device() for more details
  186. MGE_WIN_DECLSPEC_FUC void free_device(void* ptr) const;
  187. /*!
  188. * \brief allocate memory on host that is associated with the device,
  189. * which may accelerate I/O
  190. *
  191. * Both allocation and deallocation on host are synchronous.
  192. */
  193. MGE_WIN_DECLSPEC_FUC void* alloc_host(size_t size) const;
  194. MGE_WIN_DECLSPEC_FUC void free_host(void* ptr) const;
  195. //! copy from underlying device to host
  196. void copy_to_host(void* host_ptr, const void* device_ptr, size_t size) const {
  197. return m_impl->copy_to_host(host_ptr, device_ptr, size);
  198. }
  199. //! copy from host to underlying device
  200. void copy_to_device(void* device_ptr, const void* host_ptr, size_t size) const {
  201. return m_impl->copy_to_device(device_ptr, host_ptr, size);
  202. }
  203. /*!
  204. * \brief copy from this device to another device; would use the
  205. * computing resource on dest_node
  206. * \param src source memory that must be allocated on this device
  207. */
  208. void peer_copy_to(
  209. CompNode dest_node, void* dest, const void* src, size_t size) const {
  210. return m_impl->peer_copy_to(
  211. reinterpret_cast<Impl*>(dest_node.m_impl), dest, src, size);
  212. }
  213. //! get alignment requiement in bytes; guaranteed to be power of 2
  214. size_t get_mem_addr_alignment() const { return m_impl->get_mem_addr_alignment(); }
  215. /*!
  216. * \brief get the size of the paddings which must be reserved at the
  217. * end of memory chunk; guaranteed to be power of 2
  218. */
  219. size_t get_mem_padding() const {
  220. size_t padding = m_impl->get_mem_padding();
  221. mgb_assert(!(padding & (padding - 1)), "mem padding should be power of 2");
  222. return padding;
  223. }
  224. /*!
  225. * \brief release consecutive free chunks on all devices to defragment;
  226. * see DevMemAlloc::try_coalesce_free
  227. */
  228. MGE_WIN_DECLSPEC_FUC static void try_coalesce_all_free_memory();
  229. /*
  230. * \brief specifies how to pre-allocate from raw dev allocator
  231. *
  232. */
  233. MGE_WIN_DECLSPEC_FUC static void set_prealloc_config(
  234. size_t alignment, size_t min_req, size_t max_overhead, double growth_factor,
  235. DeviceType device_type);
  236. /*!
  237. * \brief get compute capability of the specified device
  238. */
  239. MGE_WIN_DECLSPEC_FUC static size_t get_compute_capability(
  240. int dev, DeviceType device_type);
  241. /* =================== synchronization ======================== */
  242. class Event;
  243. class EventPool;
  244. std::unique_ptr<Event> create_event(size_t flags = 0) const {
  245. return m_impl->create_event(flags);
  246. }
  247. //! wait for an event created on another CompNode
  248. inline void device_wait_event(Event& event) const;
  249. /*!
  250. * \brief block host thread to wait for all previous operations on this
  251. * computing node to finish
  252. */
  253. void sync() const { return m_impl->sync(); }
  254. /*!
  255. * \brief synchronize all computing nodes
  256. */
  257. MGE_WIN_DECLSPEC_FUC static void sync_all();
  258. /* =================== misc ======================== */
  259. /*!
  260. * \brief get id of underlying memory node; comp nodes that share the
  261. * same mem node can access memory allocated by each other.
  262. */
  263. MemNode mem_node() const { return m_impl->mem_node(); }
  264. bool operator==(const CompNode& rhs) const { return m_impl == rhs.m_impl; }
  265. bool operator!=(const CompNode& rhs) const { return !this->operator==(rhs); }
  266. bool valid() const { return m_impl; }
  267. //! get total and free memory on the computing device in bytes
  268. std::pair<size_t, size_t> get_mem_status_bytes() const {
  269. return m_impl->get_mem_status_bytes();
  270. }
  271. #if !MGB_BUILD_SLIM_SERVING
  272. std::pair<size_t, size_t> get_free_left_and_right(
  273. size_t begin_ptr, size_t end_ptr) {
  274. return m_impl->get_free_left_and_right(begin_ptr, end_ptr);
  275. }
  276. size_t get_used_memory() const { return m_impl->get_used_memory(); }
  277. size_t get_reserved_memory() const { return m_impl->get_reserved_memory(); }
  278. size_t get_max_reserved_memory() const { return m_impl->get_max_reserved_memory(); }
  279. size_t get_max_used_memory() const { return m_impl->get_max_used_memory(); }
  280. size_t get_max_block_size_available() const {
  281. return m_impl->get_max_block_size_available();
  282. }
  283. size_t get_free_mem() const { return m_impl->get_free_mem(); }
  284. void reset_max_reserved_memory() const {
  285. return m_impl->reset_max_reserved_memory();
  286. }
  287. void reset_max_used_memory() const { return m_impl->reset_max_used_memory(); }
  288. #endif
  289. //! change to another stream on the same memory node
  290. MGE_WIN_DECLSPEC_FUC CompNode change_stream(int dest_stream) const;
  291. //! get string representation
  292. std::string to_string() const {
  293. return m_impl ? mgb::ssprintf(
  294. "CompNode(\"%s\" from \"%s\")",
  295. to_string_physical().c_str(),
  296. to_string_logical().c_str())
  297. : "invalid";
  298. }
  299. //! get string representation of physical device
  300. std::string to_string_physical() const {
  301. return m_impl ? m_impl->locator().to_string() : "invalid";
  302. }
  303. //! get string representation of logical device
  304. std::string to_string_logical() const {
  305. return m_impl ? m_impl->locator_logical().to_string() : "invalid";
  306. }
  307. uint64_t get_uid() { return m_impl->get_uid(); }
  308. //! get the physical locator that created this comp node
  309. Locator locator() const { return m_impl->locator(); }
  310. //! get the logical locator that created this comp node
  311. Locator locator_logical() const { return m_impl->locator_logical(); }
  312. //! see CompNodeEnv::activate
  313. MGE_WIN_DECLSPEC_FUC void activate() const;
  314. //! get device type of this comp node
  315. MGE_WIN_DECLSPEC_FUC DeviceType device_type() const;
  316. /*!
  317. * \brief check for error on the asynchronous computing stream
  318. *
  319. * This is used for devices with limited error handling such as CUDA.
  320. *
  321. * It will return MegBrainError with error messages rather than
  322. * directly throw exception; return nullptr if no error.
  323. */
  324. MGB_WARN_UNUSED_RESULT
  325. MGE_WIN_DECLSPEC_FUC std::unique_ptr<MegBrainError> check_async_error() const;
  326. /*!
  327. * \brief create a CompNodeSeqRecorder associated with this computing
  328. * node
  329. *
  330. * Note: the implementation must be thread safe: simultaneous calls to
  331. * create_seq_recorder() must block until existing CompNodeSeqRecorder
  332. * objects are either destructed or stopped.
  333. *
  334. * \return the recorder object; nullptr is returned if recording is not
  335. * supported
  336. */
  337. std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(cg::ComputingGraph* cg) {
  338. return m_impl->create_seq_recorder(cg);
  339. }
  340. /*!
  341. * insert callback into current compute stream.
  342. * The callack is to be called after all currently enqueued
  343. * iterms in the stream have completed. And the later tasks
  344. * in the stream must wait for the callback to finish.
  345. */
  346. void add_callback(megdnn::thin_function<void()>&& cb) {
  347. return m_impl->add_callback(std::move(cb));
  348. }
  349. enum class Flag : uint32_t {
  350. //! Whether computing recorder is supported on this comp node (i.e.
  351. //! whether non-zero comp_node_seq_record_level is allowed)
  352. SUPPORT_RECORDER = 1 << 0,
  353. //! Whether dynamic memory allocation is supported in seq recorder.
  354. //! If this flag is not setted, ComputingSequence::do_execute()
  355. //! would skip the warm up and allow seq recorder to start
  356. //! immediately
  357. RECORDER_SUPPORT_DYNAMIC_ALLOC = 1 << 1,
  358. //! Whether the capacity of the asynchronous execution queue on this
  359. //! comp node is limited.
  360. //! If this flag is set, tasks on multiple comp nodes would be
  361. //! dispatched from multiple cpu threads.
  362. //! \see ComputingGraph::Options::async_exec_level
  363. QUEUE_LIMITED = 1 << 2,
  364. //! Whether this comp node supports copy stream, so computation and
  365. //! I/O can be parallelized
  366. HAS_COPY_STREAM = 1 << 3,
  367. //! Destructing an event is unsafe if the comp node is not
  368. //! synchronized; setting this flag would cause computing sequence
  369. //! to sync the comp node in its dtor.
  370. EVENT_DTOR_UNSAFE = 1 << 4,
  371. //! CompNode is available even there is no thread support, i.e.
  372. //! MGB_HAVE_THREAD=0. Usually this means that execution on the
  373. //! CompNode is synchronous, i.e. behaves like cpu:default
  374. SUPPORT_NO_THREAD = 1 << 5,
  375. //! Whether this comp node supports unified address. i.e. CPU and
  376. //! CUDA supports unified address.
  377. SUPPORT_UNIFIED_ADDRESS = 1 << 6,
  378. };
  379. bool contain_flag(Flag flag) { return contain_flag(device_type(), flag); }
  380. MGE_WIN_DECLSPEC_FUC static bool contain_flag(DeviceType device_type, Flag flag);
  381. using UnorderedSet = ThinHashSet<CompNode>;
  382. template <typename T>
  383. using UnorderedMap = ThinHashMap<CompNode, T>;
  384. //! apply function to each initialized comp node
  385. MGE_WIN_DECLSPEC_FUC static void foreach (thin_function<void(CompNode)> callback);
  386. //! get total number of specific devices on this system
  387. MGE_WIN_DECLSPEC_FUC static size_t get_device_count(
  388. DeviceType type, bool warn = true);
  389. /* =================== specialized ======================== */
  390. //! get default CPU comp node
  391. // implemented in comp_node/cpu/comp_node.cpp
  392. MGE_WIN_DECLSPEC_FUC static CompNode default_cpu();
  393. /*!
  394. * \brief set whether to enable affinity setting for CPU comp nodes
  395. *
  396. * If enabled, computation on cpux would be bound to the x'th CPU.
  397. *
  398. * This is disabled by default.
  399. *
  400. * (implemented in comp_node/cpu/comp_node.cpp)
  401. *
  402. * \return original setting
  403. */
  404. MGE_WIN_DECLSPEC_FUC static bool enable_affinity_for_cpu(bool flag);
  405. protected:
  406. //! ImplBase with env(); defined in CompNodeEnv
  407. class Impl;
  408. class ImplBase : public NonCopyableObj, public DynTypeObj {
  409. public:
  410. typedef void (*free_func_t)(ImplBase* self, void* ptr);
  411. //! memory free might be called after finalize(); so we should
  412. //! not rely on virtual function for this
  413. const free_func_t free_device;
  414. const free_func_t free_host;
  415. virtual void* alloc_device(size_t size) = 0;
  416. virtual void* alloc_host(size_t size) = 0;
  417. virtual void copy_to_host(
  418. void* host_ptr, const void* device_ptr, size_t size) = 0;
  419. virtual void copy_to_device(
  420. void* device_ptr, const void* host_ptr, size_t size) = 0;
  421. virtual void peer_copy_to(
  422. Impl* dest_impl, void* dest, const void* src, size_t size) = 0;
  423. virtual size_t get_mem_addr_alignment() = 0;
  424. virtual size_t get_mem_padding();
  425. virtual std::unique_ptr<Event> create_event(size_t flags) = 0;
  426. virtual void sync() = 0;
  427. virtual MemNode mem_node() = 0;
  428. virtual std::pair<size_t, size_t> get_mem_status_bytes() = 0;
  429. #if !MGB_BUILD_SLIM_SERVING
  430. virtual std::pair<size_t, size_t> get_free_left_and_right(size_t x, size_t y) {
  431. return {x - x, y - y};
  432. }
  433. virtual size_t get_used_memory() { return 0; }
  434. virtual size_t get_reserved_memory() { return 0; }
  435. virtual size_t get_max_reserved_memory() { return 0; }
  436. virtual size_t get_max_used_memory() { return 0; }
  437. virtual size_t get_max_block_size_available() { return 0; }
  438. virtual size_t get_free_mem() { return 0; }
  439. virtual void reset_max_reserved_memory() {}
  440. virtual void reset_max_used_memory() {}
  441. #endif
  442. virtual Locator locator() = 0;
  443. virtual Locator locator_logical() = 0;
  444. virtual std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(
  445. cg::ComputingGraph* cg);
  446. virtual void add_callback(megdnn::thin_function<void()>&&);
  447. virtual uint64_t get_uid() {
  448. mgb_throw(MegBrainError, "get_uid is not impl yet");
  449. };
  450. protected:
  451. ImplBase(free_func_t fd, free_func_t fh) : free_device{fd}, free_host{fh} {}
  452. ~ImplBase() = default;
  453. };
  454. //! implementations are allocated statically, so no memory management
  455. //! is needed
  456. ImplBase* m_impl = nullptr;
  457. friend class CompNodeEnv;
  458. friend struct HashTrait<CompNode>;
  459. friend struct HashTrait<CompNode::Locator>;
  460. friend class CompNodeImplHelper;
  461. public:
  462. CompNode(ImplBase* impl) : m_impl{impl} {}
  463. };
  464. MGB_DEF_ENUM_CLASS_BIT_OPR(CompNode::Flag)
  465. /*!
  466. * \brief record computation operations on a computing node
  467. *
  468. * This is used for fast execution of an identical computation sequence where
  469. * only input/output data differ.
  470. *
  471. * When this object is created from a comp node, recording starts immediately.
  472. * Call stop() when computation finishes, and call replay() when it needs to be
  473. * re-executed.
  474. *
  475. * Implementations should consider thread safe in comp_node, in order to support
  476. * multi threads reording in the same comp_node simultaneously, using thread
  477. * local recorder in comp_node.
  478. *
  479. * Note. When recording is over, the recorder is independent with comp_node, so
  480. * the task dispatched into recorder should not related to the comp_node
  481. * methord, and the thread of recorder replay is the user thread.
  482. */
  483. class CompNodeSeqRecorder {
  484. public:
  485. virtual ~CompNodeSeqRecorder() noexcept = default;
  486. /*!
  487. * \brief Enter fake-exec mode
  488. *
  489. * Memory allocation/free is only allowed in fake-exec mode, and kernels
  490. * should not be actually recorded in this mode.
  491. *
  492. * This should be paired with exit_fake_exec()
  493. */
  494. virtual void enter_fake_exec(const CompNode& comp_node) = 0;
  495. //! Exit fake-exec mode
  496. virtual void exit_fake_exec(const CompNode& comp_node) = 0;
  497. virtual void stop(const CompNode& comp_node) = 0;
  498. virtual void replay() = 0;
  499. };
  500. /*!
  501. * \brief event associated with a CompNode node, used for cross-device
  502. * synchronization
  503. */
  504. class CompNode::Event : public NonCopyableObj {
  505. protected:
  506. static int sm_cpu_sync_level;
  507. //! flags when this event is created
  508. size_t const m_create_flags;
  509. Event(size_t create_flags) : m_create_flags{create_flags} {}
  510. public:
  511. enum Flags { NEED_TIMER = 1 };
  512. virtual ~Event() = default;
  513. /*!
  514. * \brief record this event on the comp node that creates it
  515. *
  516. * Note that if a comp node is recorded multiple times, then subsequent
  517. * calls would overwrite its internal state and other methods that
  518. * examine the status would only examine the completion of the most
  519. * recent call to record().
  520. */
  521. virtual void record() = 0;
  522. //! whether this event has finished; it must has been recorded
  523. virtual bool finished() = 0;
  524. //! block the host thread (caller thread) to wait for this event
  525. virtual void host_wait() = 0;
  526. //! get elapsed time in seconds from this to another event; the events
  527. //! must be finished
  528. virtual double elapsed_time_until(Event& end) = 0;
  529. //! record an action on another comp node so it would wait for this
  530. //! event
  531. virtual void device_wait_by(CompNode cn) = 0;
  532. //! get the comp node to which this event is associated
  533. virtual CompNode comp_node() const = 0;
  534. //! flags when this event is created
  535. size_t create_flags() const { return m_create_flags; }
  536. /*!
  537. * \brief set CPU resource usage level when performing synchronization
  538. * \param level CPU waiting level:
  539. * 0. condition var (the default)
  540. * 1. busy wait with yield
  541. * 2. busy wait
  542. */
  543. static void set_cpu_sync_level(int level) { sm_cpu_sync_level = level; }
  544. };
  545. /*!
  546. * \brief pool of events that can be reused
  547. */
  548. class CompNode::EventPool {
  549. CompNode m_cn;
  550. std::vector<std::unique_ptr<CompNode::Event>> m_allocated;
  551. std::vector<CompNode::Event*> m_free;
  552. Spinlock m_lock;
  553. size_t m_flags;
  554. public:
  555. MGE_WIN_DECLSPEC_FUC explicit EventPool(CompNode cn, size_t flags = 0);
  556. MGE_WIN_DECLSPEC_FUC ~EventPool();
  557. MGE_WIN_DECLSPEC_FUC CompNode::Event* alloc();
  558. MGE_WIN_DECLSPEC_FUC void free(CompNode::Event* ev);
  559. //! assert that all allocated events have been freed
  560. MGE_WIN_DECLSPEC_FUC void assert_all_freed();
  561. };
  562. void CompNode::device_wait_event(Event& event) const {
  563. event.device_wait_by(*this);
  564. }
  565. template <>
  566. struct HashTrait<CompNode> {
  567. static size_t eval(const CompNode& val) {
  568. static_assert(sizeof(size_t) == sizeof(void*), "bad hash type");
  569. return reinterpret_cast<size_t>(static_cast<void*>(val.m_impl));
  570. }
  571. };
  572. template <>
  573. struct HashTrait<CompNode::Locator> {
  574. static size_t eval(const CompNode::Locator& val) {
  575. return static_cast<size_t>(val.device) + (static_cast<size_t>(val.type) << 4) +
  576. (static_cast<size_t>(val.stream) << 8);
  577. }
  578. };
  579. namespace comp_node_detail {
  580. /*!
  581. * \brief an inplace doubly linked list for efficient inserting/deleting
  582. *
  583. * Note: do not use this directly; it is only for CompNodeDepedentObject
  584. */
  585. class DepedentObjList {
  586. class Sentinel;
  587. struct StaticInfo;
  588. static StaticInfo sm_info;
  589. DepedentObjList *m_prev = nullptr, *m_next = nullptr;
  590. static void link(DepedentObjList* a, DepedentObjList* b) {
  591. a->m_next = b;
  592. b->m_prev = a;
  593. }
  594. protected:
  595. MGE_WIN_DECLSPEC_FUC virtual std::shared_ptr<void> callback() = 0;
  596. ~DepedentObjList() = default;
  597. MGE_WIN_DECLSPEC_FUC static void add(DepedentObjList* ptr);
  598. MGE_WIN_DECLSPEC_FUC static void remove(DepedentObjList* ptr);
  599. public:
  600. MGE_WIN_DECLSPEC_FUC static void invoke_callback_and_clean();
  601. };
  602. } // namespace comp_node_detail
  603. /*!
  604. * \brief base class for objects that depend on CompNode
  605. *
  606. * There is a CompNode::finalize() method that destorys all global comp nodes.
  607. * Therefore objects that depend on CompNode should all be marked as invalid at
  608. * that time.
  609. *
  610. * CompNode::finalize() is called in atexit() because some external libraries
  611. * that CompNode depends on seems to be registering exit handlers. It is also
  612. * impractical to require a correct destruction order because, for example, in
  613. * python atexit() handlers are invoked before global python objects get
  614. * reclaimed.
  615. *
  616. * As a result we give up enforcing a correct destruction order, but rather
  617. * require all CompNode-dependent objects to derive from this class so they can
  618. * get notified possibly do most of the cleanup when CompNode is finalized.
  619. */
  620. class CompNodeDepedentObject : private comp_node_detail::DepedentObjList {
  621. //! 1: in on_comp_node_finalize(); 2: after on_comp_node_finalize()
  622. int m_state = 0;
  623. MGE_WIN_DECLSPEC_FUC std::shared_ptr<void> callback() override final;
  624. protected:
  625. CompNodeDepedentObject() { add(this); }
  626. ~CompNodeDepedentObject() { remove(this); }
  627. /*!
  628. * \brief overwritten by subclasses to perform clean up jobs
  629. *
  630. * Note: in case the object has nested objects which hold a reference to the
  631. * object itself, a reference to this object must be kept so it would not be
  632. * released during the call of on_comp_node_finalize().
  633. */
  634. virtual std::shared_ptr<void> on_comp_node_finalize() = 0;
  635. //! exception would thrown if on_comp_node_finalize() has been called (do
  636. //! not raise if invoked from on_comp_node_finalize())
  637. void check_not_finalized() const;
  638. //! whether on_comp_node_finalize() has been called (true when invoked
  639. //! from on_comp_node_finalize())
  640. bool is_finalized() const { return m_state; }
  641. };
  642. } // namespace mgb
  643. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台