You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comp_node.h 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793
  1. /**
  2. * \file src/core/include/megbrain/comp_node.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/utils/hash.h"
  13. #include "megbrain/utils/metahelper.h"
  14. #include "megbrain/utils/thin/hash_table.h"
  15. #include "megbrain/utils/thread.h"
  16. #include "megbrain/utils/thin/function.h"
  17. #include "megdnn/thin/function.h"
  18. #include <cstddef>
  19. #include <string>
  20. #include <memory>
  21. namespace mgb {
  22. // forward declaration; defined in comp_node_env.h
  23. class CompNodeEnv;
  24. namespace cg {
  25. class ComputingGraph;
  26. }
  27. class CompNodeSeqRecorder;
  28. /*!
  29. * \brief identifier for a memory node
  30. *
  31. * MemNode is comparable. CompNodes with the same MemNode can access memory of
  32. * each other directly
  33. */
  34. class MemNode {
  35. const void* m_id = nullptr;
  36. public:
  37. MemNode() = default;
  38. explicit MemNode(const void *id):
  39. m_id{id}
  40. {}
  41. bool operator == (const MemNode &rhs) const {
  42. return m_id == rhs.m_id;
  43. }
  44. bool operator != (const MemNode &rhs) const {
  45. return m_id != rhs.m_id;
  46. }
  47. operator bool() const {
  48. return m_id != nullptr;
  49. }
  50. };
  51. /*!
  52. * \brief abstraction of a streaming computing resource on localhost (a
  53. * thread on CPU, a cuda stream, etc.)
  54. *
  55. * Note that most of the operations are asynchronous with respect to the caller
  56. * thread
  57. */
  58. class CompNode {
  59. public:
  60. //! computing device type
  61. enum class DeviceType {
  62. //! for "xpu" comp node that would mapped to available cn on
  63. //! current system
  64. UNSPEC = 0,
  65. CUDA = 1,
  66. CPU = 2,
  67. CAMBRICON = 3,
  68. ROCM = 8,
  69. ATLAS = 9,
  70. MULTITHREAD,
  71. MAX_DEVICE_ID,
  72. };
  73. static constexpr size_t NR_DEVICE_TYPE =
  74. static_cast<size_t>(DeviceType::MAX_DEVICE_ID);
  75. /*!
  76. * \brief an identifier to specify a computing node
  77. *
  78. * Note: logical locator is directly parsed from a string identifier
  79. * given by user; it should be translated to physical locator by calling
  80. * to_physical() before actual use.
  81. *
  82. * Unless explicitly specified otherwise, all locators are physical
  83. * locators.
  84. */
  85. struct Locator {
  86. /*!
  87. * \brief special device number for the "cpu default" comp node,
  88. * which dispatches all tasks in the caller thread
  89. */
  90. static constexpr int DEVICE_CPU_DEFAULT = -1024;
  91. /*!
  92. * \brief special device number for the "multithread_default"
  93. * comp node, which dispatches all tasks to thread pool and the
  94. * caller thread is the main thread of thread pool
  95. */
  96. static constexpr int DEVICE_MULTITHREAD_DEFAULT = -1025;
  97. DeviceType type = DeviceType::UNSPEC;
  98. /*!
  99. * corresponding to a physical computing device; memories between
  100. * different devices are not shared.
  101. *
  102. * device == -1 means logical default device (maps to 0 by default,
  103. * and can be changed by set_device_map)
  104. */
  105. int device = -1;
  106. //! multiple streams can execute on one computing device and share
  107. //! memory, when compnode type is multithread the field also stand
  108. //! for nr_threads
  109. union {
  110. int stream = 0;
  111. int nr_threads;
  112. };
  113. /*!
  114. * \brief parse a string identifier
  115. *
  116. * currently supported ID format: (gpu|cpu)<n>[:m] where n is the
  117. * device number, possibly with m as the stream id.
  118. */
  119. static Locator parse(const std::string& id);
  120. /*!
  121. * \brief set mapping between device numbers of a device type
  122. */
  123. static void set_device_map(DeviceType type, int from, int to);
  124. /*!
  125. * \brief set the actual device type to be used for
  126. * DeviceType::UNSPEC
  127. */
  128. static void set_unspec_device_type(DeviceType type);
  129. /*!
  130. * \brief get corresponding physical Locator
  131. *
  132. * DeviceType::UNSPEC would be resolved, and device map would be
  133. * applied on device number
  134. */
  135. Locator to_physical() const;
  136. /*!
  137. * \brief get string description of this locator that can be parsed
  138. * again
  139. */
  140. std::string to_string() const;
  141. bool operator == (const Locator &rhs) const {
  142. return type == rhs.type && device == rhs.device &&
  143. stream == rhs.stream;
  144. }
  145. };
  146. struct LocatorPairHashKey {
  147. Locator locator, locator_logical;
  148. bool operator==(const LocatorPairHashKey& rhs) const {
  149. return locator == rhs.locator && locator_logical == rhs.locator_logical;
  150. }
  151. struct Hash {
  152. size_t operator()(const LocatorPairHashKey& k) const {
  153. return hash_pair_combine(mgb::hash(k.locator),
  154. mgb::hash(k.locator_logical));
  155. }
  156. };
  157. };
  158. //! predefined special streams
  159. struct Stream {
  160. static constexpr int
  161. COPY = -1,
  162. REMOTE_SEND = -2,
  163. LOOP_SWAP = -3;
  164. };
  165. CompNode() = default;
  166. /*!
  167. * \brief manually destroy all comp node resources
  168. */
  169. static void finalize();
  170. /*!
  171. * \brief load a computing node from logical locator ID;
  172. * \see Locator::parse
  173. */
  174. static CompNode load(const std::string& id) {
  175. return load(Locator::parse(id));
  176. }
  177. /*!
  178. * \brief create a CompNode object from **logical** locator
  179. */
  180. static CompNode load(const Locator& locator) {
  181. return load(locator.to_physical(), locator);
  182. }
  183. static CompNode load(const Locator& locator_physical,
  184. const Locator& locator_logical);
  185. /* =================== memory management ======================== */
  186. /*!
  187. * \brief allocate memory on this computing node
  188. *
  189. * Note: allocation of device memory is synchronous with the host,
  190. * meaning that the memory can be used immediately; however deallocation
  191. * is asynchronous to ensure that the memory can be used by
  192. * already-launched kernels on the computing node.
  193. *
  194. * Exception should be raised if allocation fails.
  195. */
  196. void *alloc_device(size_t size) const;
  197. //! deallocate device buffer; see alloc_device() for more details
  198. void free_device(void *ptr) const;
  199. /*!
  200. * \brief allocate memory on host that is associated with the device,
  201. * which may accelerate I/O
  202. *
  203. * Both allocation and deallocation on host are synchronous.
  204. */
  205. void *alloc_host(size_t size) const;
  206. void free_host(void *ptr) const;
  207. //! copy from underlying device to host
  208. void copy_to_host(
  209. void *host_ptr, const void *device_ptr, size_t size) const {
  210. return m_impl->copy_to_host(host_ptr, device_ptr, size);
  211. }
  212. //! copy from host to underlying device
  213. void copy_to_device(
  214. void *device_ptr, const void *host_ptr, size_t size) const {
  215. return m_impl->copy_to_device(device_ptr, host_ptr, size);
  216. }
  217. /*!
  218. * \brief copy from this device to another device; would use the
  219. * computing resource on dest_node
  220. * \param src source memory that must be allocated on this device
  221. */
  222. void peer_copy_to(CompNode dest_node, void *dest,
  223. const void *src, size_t size) const {
  224. return m_impl->peer_copy_to(
  225. reinterpret_cast<Impl*>(dest_node.m_impl), dest, src, size);
  226. }
  227. //! get alignment requiement in bytes; guaranteed to be power of 2
  228. size_t get_mem_addr_alignment() const {
  229. return m_impl->get_mem_addr_alignment();
  230. }
  231. /*!
  232. * \brief get the size of the paddings which must be reserved at the
  233. * end of memory chunk; guaranteed to be power of 2
  234. */
  235. size_t get_mem_padding() const {
  236. size_t padding = m_impl->get_mem_padding();
  237. mgb_assert(!(padding & (padding - 1)),
  238. "mem padding should be power of 2");
  239. return padding;
  240. }
  241. /*!
  242. * \brief release consecutive free chunks on all devices to defragment;
  243. * see DevMemAlloc::try_coalesce_free
  244. */
  245. static void try_coalesce_all_free_memory();
  246. /*
  247. * \brief specifies how to pre-allocate from raw dev allocator
  248. *
  249. */
  250. static void set_prealloc_config(size_t alignment, size_t min_req,
  251. size_t max_overhead, double growth_factor,
  252. DeviceType device_type);
  253. /* =================== synchronization ======================== */
  254. class Event;
  255. class EventPool;
  256. std::unique_ptr<Event> create_event(size_t flags = 0) const {
  257. return m_impl->create_event(flags);
  258. }
  259. //! wait for an event created on another CompNode
  260. inline void device_wait_event(Event &event) const;
  261. /*!
  262. * \brief block host thread to wait for all previous operations on this
  263. * computing node to finish
  264. */
  265. void sync() const {
  266. return m_impl->sync();
  267. }
  268. /*!
  269. * \brief synchronize all computing nodes
  270. */
  271. static void sync_all();
  272. /* =================== misc ======================== */
  273. /*!
  274. * \brief get id of underlying memory node; comp nodes that share the
  275. * same mem node can access memory allocated by each other.
  276. */
  277. MemNode mem_node() const {
  278. return m_impl->mem_node();
  279. }
  280. bool operator == (const CompNode &rhs) const {
  281. return m_impl == rhs.m_impl;
  282. }
  283. bool operator != (const CompNode &rhs) const {
  284. return !this->operator==(rhs);
  285. }
  286. bool valid() const {
  287. return m_impl;
  288. }
  289. //! get total and free memory on the computing device in bytes
  290. std::pair<size_t, size_t> get_mem_status_bytes() const {
  291. return m_impl->get_mem_status_bytes();
  292. }
  293. //! change to another stream on the same memory node
  294. CompNode change_stream(int dest_stream) const;
  295. //! get string representation of physical device
  296. std::string to_string() const {
  297. return m_impl ? m_impl->locator().to_string() : "invalid";
  298. }
  299. //! get string representation of logical device
  300. std::string to_string_logical() const {
  301. return m_impl ? m_impl->locator_logical().to_string() : "invalid";
  302. }
  303. uint64_t get_uid() {
  304. return m_impl->get_uid();
  305. }
  306. //! get the physical locator that created this comp node
  307. Locator locator() const {
  308. return m_impl->locator();
  309. }
  310. //! get the logical locator that created this comp node
  311. Locator locator_logical() const {
  312. return m_impl->locator_logical();
  313. }
  314. //! see CompNodeEnv::activate
  315. void activate() const;
  316. //! get device type of this comp node
  317. DeviceType device_type() const;
  318. /*!
  319. * \brief check for error on the asynchronous computing stream
  320. *
  321. * This is used for devices with limited error handling such as CUDA.
  322. *
  323. * It will return MegBrainError with error messages rather than
  324. * directly throw exception; return nullptr if no error.
  325. */
  326. MGB_WARN_UNUSED_RESULT
  327. std::unique_ptr<MegBrainError> check_async_error() const;
  328. /*!
  329. * \brief create a CompNodeSeqRecorder associated with this computing
  330. * node
  331. *
  332. * Note: the implementation must be thread safe: simultaneous calls to
  333. * create_seq_recorder() must block until existing CompNodeSeqRecorder
  334. * objects are either destructed or stopped.
  335. *
  336. * \return the recorder object; nullptr is returned if recording is not
  337. * supported
  338. */
  339. std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(
  340. cg::ComputingGraph* cg) {
  341. return m_impl->create_seq_recorder(cg);
  342. }
  343. /*!
  344. * insert callback into current compute stream.
  345. * The callack is to be called after all currently enqueued
  346. * iterms in the stream have completed. And the later tasks
  347. * in the stream must wait for the callback to finish.
  348. */
  349. void add_callback(megdnn::thin_function<void()>&& cb) {
  350. return m_impl->add_callback(std::move(cb));
  351. }
  352. enum class Flag : uint32_t {
  353. //! Whether computing recorder is supported on this comp node (i.e.
  354. //! whether non-zero comp_node_seq_record_level is allowed)
  355. SUPPORT_RECORDER = 1 << 0,
  356. //! Whether dynamic memory allocation is supported in seq recorder.
  357. //! If this flag is not setted, ComputingSequence::do_execute()
  358. //! would skip the warm up and allow seq recorder to start
  359. //! immediately
  360. RECORDER_SUPPORT_DYNAMIC_ALLOC = 1 << 1,
  361. //! Whether the capacity of the asynchronous execution queue on this
  362. //! comp node is limited.
  363. //! If this flag is set, tasks on multiple comp nodes would be
  364. //! dispatched from multiple cpu threads.
  365. //! \see ComputingGraph::Options::async_exec_level
  366. QUEUE_LIMITED = 1 << 2,
  367. //! Whether this comp node supports copy stream, so computation and
  368. //! I/O can be parallelized
  369. HAS_COPY_STREAM = 1 << 3,
  370. //! Destructing an event is unsafe if the comp node is not
  371. //! synchronized; setting this flag would cause computing sequence
  372. //! to sync the comp node in its dtor.
  373. EVENT_DTOR_UNSAFE = 1 << 4,
  374. //! CompNode is available even there is no thread support, i.e.
  375. //! MGB_HAVE_THREAD=0. Usually this means that execution on the
  376. //! CompNode is synchronous, i.e. behaves like cpu:default
  377. SUPPORT_NO_THREAD = 1 << 5,
  378. //! Whether this comp node supports unified address. i.e. CPU and
  379. //! CUDA supports unified address.
  380. SUPPORT_UNIFIED_ADDRESS = 1 << 6,
  381. };
  382. bool contain_flag(Flag flag) {
  383. return contain_flag(device_type(), flag);
  384. }
  385. static bool contain_flag(DeviceType device_type, Flag flag);
  386. using UnorderedSet = ThinHashSet<CompNode>;
  387. template<typename T>
  388. using UnorderedMap = ThinHashMap<CompNode, T>;
  389. //! apply function to each initialized comp node
  390. static void foreach(thin_function<void(CompNode)> callback);
  391. //! get total number of specific devices on this system
  392. static size_t get_device_count(DeviceType type, bool warn=true);
  393. /* =================== specialized ======================== */
  394. //! get default CPU comp node
  395. // implemented in comp_node/cpu/comp_node.cpp
  396. static CompNode default_cpu();
  397. /*!
  398. * \brief set whether to enable affinity setting for CPU comp nodes
  399. *
  400. * If enabled, computation on cpux would be bound to the x'th CPU.
  401. *
  402. * This is disabled by default.
  403. *
  404. * (implemented in comp_node/cpu/comp_node.cpp)
  405. *
  406. * \return original setting
  407. */
  408. static bool enable_affinity_for_cpu(bool flag);
  409. protected:
  410. //! ImplBase with env(); defined in CompNodeEnv
  411. class Impl;
  412. class ImplBase: public NonCopyableObj, public DynTypeObj {
  413. public:
  414. typedef void (*free_func_t)(ImplBase* self, void* ptr);
  415. //! memory free might be called after finalize(); so we should
  416. //! not rely on virtual function for this
  417. const free_func_t free_device;
  418. const free_func_t free_host;
  419. virtual void* alloc_device(size_t size) = 0;
  420. virtual void *alloc_host(size_t size) = 0;
  421. virtual void copy_to_host(void *host_ptr,
  422. const void *device_ptr, size_t size) = 0;
  423. virtual void copy_to_device(void *device_ptr,
  424. const void *host_ptr, size_t size) = 0;
  425. virtual void peer_copy_to(
  426. Impl *dest_impl, void *dest,
  427. const void *src, size_t size) = 0;
  428. virtual size_t get_mem_addr_alignment() = 0;
  429. virtual size_t get_mem_padding();
  430. virtual std::unique_ptr<Event> create_event(size_t flags) = 0;
  431. virtual void sync() = 0;
  432. virtual MemNode mem_node() = 0;
  433. virtual std::pair<size_t, size_t> get_mem_status_bytes() = 0;
  434. virtual Locator locator() = 0;
  435. virtual Locator locator_logical() = 0;
  436. virtual std::unique_ptr<CompNodeSeqRecorder>
  437. create_seq_recorder(cg::ComputingGraph* cg);
  438. virtual void add_callback(megdnn::thin_function<void()>&&);
  439. virtual uint64_t get_uid() {
  440. mgb_throw(MegBrainError, "get_uid is not impl yet");
  441. };
  442. protected:
  443. ImplBase(free_func_t fd, free_func_t fh)
  444. : free_device{fd}, free_host{fh} {}
  445. ~ImplBase() = default;
  446. };
  447. //! implementations are allocated statically, so no memory management
  448. //! is needed
  449. ImplBase *m_impl = nullptr;
  450. friend class CompNodeEnv;
  451. friend struct HashTrait<CompNode>;
  452. friend struct HashTrait<CompNode::Locator>;
  453. friend class CompNodeImplHelper;
  454. public:
  455. CompNode(ImplBase* impl) : m_impl{impl} {}
  456. };
  457. MGB_DEF_ENUM_CLASS_BIT_OPR(CompNode::Flag)
  458. /*!
  459. * \brief record computation operations on a computing node
  460. *
  461. * This is used for fast execution of an identical computation sequence where
  462. * only input/output data differ.
  463. *
  464. * When this object is created from a comp node, recording starts immediately.
  465. * Call stop() when computation finishes, and call replay() when it needs to be
  466. * re-executed.
  467. *
  468. * Implementations should consider thread safe in comp_node, in order to support
  469. * multi threads reording in the same comp_node simultaneously, using thread
  470. * local recorder in comp_node.
  471. *
  472. * Note. When recording is over, the recorder is independent with comp_node, so
  473. * the task dispatched into recorder should not related to the comp_node
  474. * methord, and the thread of recorder replay is the user thread.
  475. */
  476. class CompNodeSeqRecorder {
  477. public:
  478. virtual ~CompNodeSeqRecorder() noexcept = default;
  479. /*!
  480. * \brief Enter fake-exec mode
  481. *
  482. * Memory allocation/free is only allowed in fake-exec mode, and kernels
  483. * should not be actually recorded in this mode.
  484. *
  485. * This should be paired with exit_fake_exec()
  486. */
  487. virtual void enter_fake_exec(const CompNode& comp_node) = 0;
  488. //! Exit fake-exec mode
  489. virtual void exit_fake_exec(const CompNode& comp_node) = 0;
  490. virtual void stop(const CompNode& comp_node) = 0;
  491. virtual void replay() = 0;
  492. };
  493. /*!
  494. * \brief event associated with a CompNode node, used for cross-device
  495. * synchronization
  496. */
  497. class CompNode::Event: public NonCopyableObj {
  498. protected:
  499. static int sm_cpu_sync_level;
  500. //! flags when this event is created
  501. size_t const m_create_flags;
  502. Event(size_t create_flags):
  503. m_create_flags{create_flags}
  504. {
  505. }
  506. public:
  507. enum Flags {
  508. NEED_TIMER = 1
  509. };
  510. virtual ~Event() = default;
  511. /*!
  512. * \brief record this event on the comp node that creates it
  513. *
  514. * Note that if a comp node is recorded multiple times, then subsequent
  515. * calls would overwrite its internal state and other methods that
  516. * examine the status would only examine the completion of the most
  517. * recent call to record().
  518. */
  519. virtual void record() = 0;
  520. //! whether this event has finished; it must has been recorded
  521. virtual bool finished() = 0;
  522. //! block the host thread (caller thread) to wait for this event
  523. virtual void host_wait() = 0;
  524. //! get elapsed time in seconds from this to another event; the events
  525. //! must be finished
  526. virtual double elapsed_time_until(Event &end) = 0;
  527. //! record an action on another comp node so it would wait for this
  528. //! event
  529. virtual void device_wait_by(CompNode cn) = 0;
  530. //! get the comp node to which this event is associated
  531. virtual CompNode comp_node() const = 0;
  532. //! flags when this event is created
  533. size_t create_flags() const {
  534. return m_create_flags;
  535. }
  536. /*!
  537. * \brief set CPU resource usage level when performing synchronization
  538. * \param level CPU waiting level:
  539. * 0. condition var (the default)
  540. * 1. busy wait with yield
  541. * 2. busy wait
  542. */
  543. static void set_cpu_sync_level(int level) {
  544. sm_cpu_sync_level = level;
  545. }
  546. };
  547. /*!
  548. * \brief pool of events that can be reused
  549. */
  550. class CompNode::EventPool {
  551. CompNode m_cn;
  552. std::vector<std::unique_ptr<CompNode::Event>> m_allocated;
  553. std::vector<CompNode::Event*> m_free;
  554. Spinlock m_lock;
  555. size_t m_flags;
  556. public:
  557. explicit EventPool(CompNode cn, size_t flags = 0);
  558. ~EventPool();
  559. CompNode::Event* alloc();
  560. void free(CompNode::Event *ev);
  561. //! assert that all allocated events have been freed
  562. void assert_all_freed();
  563. };
  564. void CompNode::device_wait_event(Event &event) const {
  565. event.device_wait_by(*this);
  566. }
  567. template<>
  568. struct HashTrait<CompNode> {
  569. static size_t eval(const CompNode &val) {
  570. static_assert(sizeof(size_t) == sizeof(void*), "bad hash type");
  571. return reinterpret_cast<size_t>(static_cast<void*>(val.m_impl));
  572. }
  573. };
  574. template<>
  575. struct HashTrait<CompNode::Locator> {
  576. static size_t eval(const CompNode::Locator &val) {
  577. return static_cast<size_t>(val.device)
  578. + (static_cast<size_t>(val.type) << 4)
  579. + (static_cast<size_t>(val.stream) << 8);
  580. }
  581. };
  582. namespace comp_node_detail {
  583. /*!
  584. * \brief an inplace doubly linked list for efficient inserting/deleting
  585. *
  586. * Note: do not use this directly; it is only for CompNodeDepedentObject
  587. */
  588. class DepedentObjList {
  589. class Sentinel;
  590. struct StaticInfo;
  591. static StaticInfo sm_info;
  592. DepedentObjList *m_prev = nullptr, *m_next = nullptr;
  593. static void link(DepedentObjList* a, DepedentObjList* b) {
  594. a->m_next = b;
  595. b->m_prev = a;
  596. }
  597. protected:
  598. virtual std::shared_ptr<void> callback() = 0;
  599. ~DepedentObjList() = default;
  600. static void add(DepedentObjList* ptr);
  601. static void remove(DepedentObjList* ptr);
  602. public:
  603. static void invoke_callback_and_clean();
  604. };
  605. } // namespace comp_node_detail
  606. /*!
  607. * \brief base class for objects that depend on CompNode
  608. *
  609. * There is a CompNode::finalize() method that destorys all global comp nodes.
  610. * Therefore objects that depend on CompNode should all be marked as invalid at
  611. * that time.
  612. *
  613. * CompNode::finalize() is called in atexit() because some external libraries
  614. * that CompNode depends on seems to be registering exit handlers. It is also
  615. * impractical to require a correct destruction order because, for example, in
  616. * python atexit() handlers are invoked before global python objects get
  617. * reclaimed.
  618. *
  619. * As a result we give up enforcing a correct destruction order, but rather
  620. * require all CompNode-dependent objects to derive from this class so they can
  621. * get notified possibly do most of the cleanup when CompNode is finalized.
  622. */
  623. class CompNodeDepedentObject : private comp_node_detail::DepedentObjList {
  624. //! 1: in on_comp_node_finalize(); 2: after on_comp_node_finalize()
  625. int m_state = 0;
  626. std::shared_ptr<void> callback() override final;
  627. protected:
  628. CompNodeDepedentObject() { add(this); }
  629. ~CompNodeDepedentObject() { remove(this); }
  630. /*!
  631. * \brief overwritten by subclasses to perform clean up jobs
  632. *
  633. * Note: in case the object has nested objects which hold a reference to the
  634. * object itself, a reference to this object must be kept so it would not be
  635. * released during the call of on_comp_node_finalize().
  636. */
  637. virtual std::shared_ptr<void> on_comp_node_finalize() = 0;
  638. //! exception would thrown if on_comp_node_finalize() has been called (do
  639. //! not raise if invoked from on_comp_node_finalize())
  640. void check_not_finalized() const;
  641. //! whether on_comp_node_finalize() has been called (true when invoked
  642. //! from on_comp_node_finalize())
  643. bool is_finalized() const { return m_state; }
  644. };
  645. } // namespace mgb
  646. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台