You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comp_node.h 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873
  1. /**
  2. * \file src/core/include/megbrain/comp_node.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/utils/hash.h"
  13. #include "megbrain/utils/metahelper.h"
  14. #include "megbrain/utils/thin/function.h"
  15. #include "megbrain/utils/thin/hash_table.h"
  16. #include "megbrain/utils/thread.h"
  17. #include "megdnn/thin/function.h"
  18. #include <cstddef>
  19. #include <memory>
  20. #include <string>
  21. namespace mgb {
  22. // forward declaration; defined in comp_node_env.h
  23. class CompNodeEnv;
  24. namespace cg {
  25. class ComputingGraph;
  26. }
  27. class CompNodeSeqRecorder;
  28. /*!
  29. * \brief identifier for a memory node
  30. *
  31. * MemNode is comparable. CompNodes with the same MemNode can access memory of
  32. * each other directly
  33. */
  34. class MemNode {
  35. const void* m_id = nullptr;
  36. public:
  37. MemNode() = default;
  38. explicit MemNode(const void* id) : m_id{id} {}
  39. bool operator==(const MemNode& rhs) const { return m_id == rhs.m_id; }
  40. bool operator!=(const MemNode& rhs) const { return m_id != rhs.m_id; }
  41. operator bool() const { return m_id != nullptr; }
  42. };
  43. /*!
  44. * \brief abstraction of a streaming computing resource on localhost (a
  45. * thread on CPU, a cuda stream, etc.)
  46. *
  47. * Note that most of the operations are asynchronous with respect to the caller
  48. * thread
  49. */
  50. class CompNode {
  51. public:
  52. //! computing device type
  53. enum class DeviceType {
  54. //! for "xpu" comp node that would mapped to available cn on
  55. //! current system
  56. UNSPEC = 0,
  57. CUDA = 1,
  58. CPU = 2,
  59. CAMBRICON = 3,
  60. ROCM = 8,
  61. ATLAS = 9,
  62. MULTITHREAD = 11,
  63. MAX_DEVICE_ID,
  64. };
  65. static constexpr size_t NR_DEVICE_TYPE =
  66. static_cast<size_t>(DeviceType::MAX_DEVICE_ID);
  67. struct DeviceProperties {
  68. DeviceProperties() {
  69. name = "unspec";
  70. total_memory = major = minor = 0;
  71. }
  72. std::string name;
  73. size_t total_memory;
  74. //! for cuda
  75. int major;
  76. int minor;
  77. };
  78. /*!
  79. * \brief an identifier to specify a computing node
  80. *
  81. * Note: logical locator is directly parsed from a string identifier
  82. * given by user; it should be translated to physical locator by calling
  83. * to_physical() before actual use.
  84. *
  85. * Unless explicitly specified otherwise, all locators are physical
  86. * locators.
  87. */
  88. struct Locator {
  89. /*!
  90. * \brief special device number for the "cpu default" comp node,
  91. * which dispatches all tasks in the caller thread
  92. */
  93. static constexpr int DEVICE_CPU_DEFAULT = -1024;
  94. /*!
  95. * \brief special device number for the "multithread_default"
  96. * comp node, which dispatches all tasks to thread pool and the
  97. * caller thread is the main thread of thread pool
  98. */
  99. static constexpr int DEVICE_MULTITHREAD_DEFAULT = -1025;
  100. DeviceType type = DeviceType::UNSPEC;
  101. /*!
  102. * corresponding to a physical computing device; memories between
  103. * different devices are not shared.
  104. *
  105. * device == -1 means logical default device (maps to 0 by default,
  106. * and can be changed by set_device_map)
  107. *
  108. */
  109. int device = -1;
  110. //! multiple streams can execute on one computing device and share
  111. //! memory, when compnode type is multithread the field also stand
  112. //! for nr_threads
  113. union {
  114. int stream = 0;
  115. int nr_threads;
  116. };
  117. /*!
  118. * \brief parse a string identifier
  119. *
  120. * currently supported ID format: (gpu|cpu)<n>[:m] where n is the
  121. * device number, possibly with m as the stream id.
  122. */
  123. MGE_WIN_DECLSPEC_FUC static Locator parse(const std::string& id);
  124. /*!
  125. * \brief set mapping between device numbers of a device type
  126. */
  127. MGE_WIN_DECLSPEC_FUC static void set_device_map(
  128. DeviceType type, int from, int to);
  129. /*!
  130. * \brief set the actual device type to be used for
  131. * DeviceType::UNSPEC
  132. */
  133. MGE_WIN_DECLSPEC_FUC static void set_unspec_device_type(DeviceType type);
  134. /*!
  135. * \brief get corresponding physical Locator
  136. *
  137. * DeviceType::UNSPEC would be resolved, and device map would be
  138. * applied on device number
  139. */
  140. MGE_WIN_DECLSPEC_FUC Locator to_physical() const;
  141. /*!
  142. * \brief get string description of this locator that can be parsed
  143. * again
  144. */
  145. MGE_WIN_DECLSPEC_FUC std::string to_string() const;
  146. bool operator==(const Locator& rhs) const {
  147. return type == rhs.type && device == rhs.device && stream == rhs.stream;
  148. }
  149. };
  150. struct LocatorPairHashKey {
  151. Locator locator, locator_logical;
  152. bool operator==(const LocatorPairHashKey& rhs) const {
  153. return locator == rhs.locator && locator_logical == rhs.locator_logical;
  154. }
  155. struct Hash {
  156. size_t operator()(const LocatorPairHashKey& k) const {
  157. return hash_pair_combine(
  158. mgb::hash(k.locator), mgb::hash(k.locator_logical));
  159. }
  160. };
  161. };
  162. //! predefined special streams
  163. struct Stream {
  164. static constexpr int COPY = -1, REMOTE_SEND = -2, LOOP_SWAP = -3;
  165. };
  166. CompNode() = default;
  167. /*!
  168. * \brief manually destroy all comp node resources
  169. */
  170. MGE_WIN_DECLSPEC_FUC static void finalize();
  171. /*!
  172. * \brief load a computing node from logical locator ID;
  173. * \see Locator::parse
  174. */
  175. static CompNode load(const std::string& id) { return load(Locator::parse(id)); }
  176. /*!
  177. * \brief create a CompNode object from **logical** locator
  178. */
  179. static CompNode load(const Locator& locator) {
  180. return load(locator.to_physical(), locator);
  181. }
  182. MGE_WIN_DECLSPEC_FUC static CompNode load(
  183. const Locator& locator_physical, const Locator& locator_logical);
  184. /* =================== memory management ======================== */
  185. /*!
  186. * \brief allocate memory on this computing node
  187. *
  188. * Note: allocation of device memory is synchronous with the host,
  189. * meaning that the memory can be used immediately; however deallocation
  190. * is asynchronous to ensure that the memory can be used by
  191. * already-launched kernels on the computing node.
  192. *
  193. * Exception should be raised if allocation fails.
  194. */
  195. MGE_WIN_DECLSPEC_FUC void* alloc_device(size_t size) const;
  196. //! deallocate device buffer; see alloc_device() for more details
  197. MGE_WIN_DECLSPEC_FUC void free_device(void* ptr) const;
  198. /*!
  199. * \brief allocate memory on host that is associated with the device,
  200. * which may accelerate I/O
  201. *
  202. * Both allocation and deallocation on host are synchronous.
  203. */
  204. MGE_WIN_DECLSPEC_FUC void* alloc_host(size_t size) const;
  205. MGE_WIN_DECLSPEC_FUC void free_host(void* ptr) const;
  206. //! copy from underlying device to host
  207. void copy_to_host(void* host_ptr, const void* device_ptr, size_t size) const {
  208. return m_impl->copy_to_host(host_ptr, device_ptr, size);
  209. }
  210. //! copy from host to underlying device
  211. void copy_to_device(void* device_ptr, const void* host_ptr, size_t size) const {
  212. return m_impl->copy_to_device(device_ptr, host_ptr, size);
  213. }
  214. //! copy from underlying device to host
  215. void copy_to_host_ref(
  216. megdnn::RefPtr& host_ref_ptr, megdnn::RefPtr& device_ref_ptr,
  217. size_t size) const {
  218. return m_impl->copy_to_host_ref(host_ref_ptr, device_ref_ptr, size);
  219. }
  220. //! copy from host to underlying device
  221. void copy_to_device_ref(
  222. megdnn::RefPtr& device_ref_ptr, megdnn::RefPtr& host_ref_ptr,
  223. size_t size) const {
  224. return m_impl->copy_to_device_ref(device_ref_ptr, host_ref_ptr, size);
  225. }
  226. /*!
  227. * \brief copy from this device to another device; would use the
  228. * computing resource on dest_node
  229. * \param src source memory that must be allocated on this device
  230. */
  231. void peer_copy_to(
  232. CompNode dest_node, void* dest, const void* src, size_t size) const {
  233. return m_impl->peer_copy_to(
  234. reinterpret_cast<Impl*>(dest_node.m_impl), dest, src, size);
  235. }
  236. void peer_copy_to_ref(
  237. CompNode dest_node, megdnn::RefPtr& dst_ref_ptr,
  238. megdnn::RefPtr& src_ref_ptr, size_t size) const {
  239. return m_impl->peer_copy_to_ref(
  240. reinterpret_cast<Impl*>(dest_node.m_impl), dst_ref_ptr, src_ref_ptr,
  241. size);
  242. }
  243. //! get alignment requiement in bytes; guaranteed to be power of 2
  244. size_t get_mem_addr_alignment() const { return m_impl->get_mem_addr_alignment(); }
  245. /*!
  246. * \brief get the size of the paddings which must be reserved at the
  247. * end of memory chunk; guaranteed to be power of 2
  248. */
  249. size_t get_mem_padding() const {
  250. size_t padding = m_impl->get_mem_padding();
  251. mgb_assert(!(padding & (padding - 1)), "mem padding should be power of 2");
  252. return padding;
  253. }
  254. /*!
  255. * \brief release consecutive free chunks on all devices to defragment;
  256. * see DevMemAlloc::try_coalesce_free
  257. */
  258. MGE_WIN_DECLSPEC_FUC static void try_coalesce_all_free_memory();
  259. /*
  260. * \brief specifies how to pre-allocate from raw dev allocator
  261. *
  262. */
  263. MGE_WIN_DECLSPEC_FUC static void set_prealloc_config(
  264. size_t alignment, size_t min_req, size_t max_overhead, double growth_factor,
  265. DeviceType device_type);
  266. /*!
  267. * \brief get device property of the specified device
  268. */
  269. MGE_WIN_DECLSPEC_FUC static DeviceProperties get_device_prop(
  270. int dev, DeviceType device_type);
  271. /*!
  272. * \brief get control of host ptr to user
  273. */
  274. MGE_WIN_DECLSPEC_FUC void map_to_cpu(void* ptr, size_t size, bool blocking = false);
  275. /*!
  276. * \brief release control of host ptr to system
  277. */
  278. MGE_WIN_DECLSPEC_FUC void unmap_to_gpu(void* ptr, size_t size);
  279. /*!
  280. * \brief get logical address by host ptr
  281. */
  282. MGE_WIN_DECLSPEC_FUC void* get_logical_addr_by_host_ptr(void* ptr, size_t size);
  283. /* =================== synchronization ======================== */
  284. class Event;
  285. class EventPool;
  286. std::unique_ptr<Event> create_event(size_t flags = 0) const {
  287. return m_impl->create_event(flags);
  288. }
  289. //! wait for an event created on another CompNode
  290. inline void device_wait_event(Event& event) const;
  291. /*!
  292. * \brief block host thread to wait for all previous operations on this
  293. * computing node to finish
  294. */
  295. void sync() const { return m_impl->sync(); }
  296. /*!
  297. * \brief synchronize all computing nodes
  298. */
  299. MGE_WIN_DECLSPEC_FUC static void sync_all();
  300. /* =================== misc ======================== */
  301. /*!
  302. * \brief get id of underlying memory node; comp nodes that share the
  303. * same mem node can access memory allocated by each other.
  304. */
  305. MemNode mem_node() const { return m_impl->mem_node(); }
  306. bool operator==(const CompNode& rhs) const { return m_impl == rhs.m_impl; }
  307. bool operator!=(const CompNode& rhs) const { return !this->operator==(rhs); }
  308. bool valid() const { return m_impl; }
  309. //! get total and free memory on the computing device in bytes
  310. std::pair<size_t, size_t> get_mem_status_bytes() const {
  311. return m_impl->get_mem_status_bytes();
  312. }
  313. #if !MGB_BUILD_SLIM_SERVING
  314. std::pair<size_t, size_t> get_free_left_and_right(
  315. size_t begin_ptr, size_t end_ptr) {
  316. return m_impl->get_free_left_and_right(begin_ptr, end_ptr);
  317. }
  318. size_t get_used_memory() const { return m_impl->get_used_memory(); }
  319. size_t get_reserved_memory() const { return m_impl->get_reserved_memory(); }
  320. size_t get_max_reserved_memory() const { return m_impl->get_max_reserved_memory(); }
  321. size_t get_max_used_memory() const { return m_impl->get_max_used_memory(); }
  322. size_t get_max_block_size_available() const {
  323. return m_impl->get_max_block_size_available();
  324. }
  325. size_t get_free_mem() const { return m_impl->get_free_mem(); }
  326. void reset_max_reserved_memory() const {
  327. return m_impl->reset_max_reserved_memory();
  328. }
  329. void reset_max_used_memory() const { return m_impl->reset_max_used_memory(); }
  330. #endif
  331. //! change to another stream on the same memory node
  332. MGE_WIN_DECLSPEC_FUC CompNode change_stream(int dest_stream) const;
  333. //! get string representation
  334. std::string to_string() const {
  335. return m_impl ? mgb::ssprintf(
  336. "CompNode(\"%s\" from \"%s\")",
  337. to_string_physical().c_str(),
  338. to_string_logical().c_str())
  339. : "invalid";
  340. }
  341. //! get string representation of physical device
  342. std::string to_string_physical() const {
  343. return m_impl ? m_impl->locator().to_string() : "invalid";
  344. }
  345. //! get string representation of logical device
  346. std::string to_string_logical() const {
  347. return m_impl ? m_impl->locator_logical().to_string() : "invalid";
  348. }
  349. uint64_t get_uid() { return m_impl->get_uid(); }
  350. //! get the physical locator that created this comp node
  351. Locator locator() const { return m_impl->locator(); }
  352. //! get the logical locator that created this comp node
  353. Locator locator_logical() const { return m_impl->locator_logical(); }
  354. //! see CompNodeEnv::activate
  355. MGE_WIN_DECLSPEC_FUC void activate() const;
  356. //! get device type of this comp node
  357. MGE_WIN_DECLSPEC_FUC DeviceType device_type() const;
  358. /*!
  359. * \brief check for error on the asynchronous computing stream
  360. *
  361. * This is used for devices with limited error handling such as CUDA.
  362. *
  363. * It will return MegBrainError with error messages rather than
  364. * directly throw exception; return nullptr if no error.
  365. */
  366. MGB_WARN_UNUSED_RESULT
  367. MGE_WIN_DECLSPEC_FUC std::unique_ptr<MegBrainError> check_async_error() const;
  368. /*!
  369. * \brief create a CompNodeSeqRecorder associated with this computing
  370. * node
  371. *
  372. * Note: the implementation must be thread safe: simultaneous calls to
  373. * create_seq_recorder() must block until existing CompNodeSeqRecorder
  374. * objects are either destructed or stopped.
  375. *
  376. * \return the recorder object; nullptr is returned if recording is not
  377. * supported
  378. */
  379. std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(cg::ComputingGraph* cg) {
  380. return m_impl->create_seq_recorder(cg);
  381. }
  382. /*!
  383. * insert callback into current compute stream.
  384. * The callack is to be called after all currently enqueued
  385. * iterms in the stream have completed. And the later tasks
  386. * in the stream must wait for the callback to finish.
  387. */
  388. void add_callback(megdnn::thin_function<void()>&& cb) {
  389. return m_impl->add_callback(std::move(cb));
  390. }
  391. enum class Flag : uint32_t {
  392. //! Whether computing recorder is supported on this comp node (i.e.
  393. //! whether non-zero comp_node_seq_record_level is allowed)
  394. SUPPORT_RECORDER = 1 << 0,
  395. //! Whether dynamic memory allocation is supported in seq recorder.
  396. //! If this flag is not setted, ComputingSequence::do_execute()
  397. //! would skip the warm up and allow seq recorder to start
  398. //! immediately
  399. RECORDER_SUPPORT_DYNAMIC_ALLOC = 1 << 1,
  400. //! Whether the capacity of the asynchronous execution queue on this
  401. //! comp node is limited.
  402. //! If this flag is set, tasks on multiple comp nodes would be
  403. //! dispatched from multiple cpu threads.
  404. //! \see ComputingGraph::Options::async_exec_level
  405. QUEUE_LIMITED = 1 << 2,
  406. //! Whether this comp node supports copy stream, so computation and
  407. //! I/O can be parallelized
  408. HAS_COPY_STREAM = 1 << 3,
  409. //! Destructing an event is unsafe if the comp node is not
  410. //! synchronized; setting this flag would cause computing sequence
  411. //! to sync the comp node in its dtor.
  412. EVENT_DTOR_UNSAFE = 1 << 4,
  413. //! CompNode is available even there is no thread support, i.e.
  414. //! MGB_HAVE_THREAD=0. Usually this means that execution on the
  415. //! CompNode is synchronous, i.e. behaves like cpu:default
  416. SUPPORT_NO_THREAD = 1 << 5,
  417. //! Whether this comp node supports unified address. i.e. CPU and
  418. //! CUDA supports unified address.
  419. SUPPORT_UNIFIED_ADDRESS = 1 << 6,
  420. };
  421. bool contain_flag(Flag flag) { return contain_flag(device_type(), flag); }
  422. MGE_WIN_DECLSPEC_FUC static bool contain_flag(DeviceType device_type, Flag flag);
  423. using UnorderedSet = ThinHashSet<CompNode>;
  424. template <typename T>
  425. using UnorderedMap = ThinHashMap<CompNode, T>;
  426. //! apply function to each initialized comp node
  427. MGE_WIN_DECLSPEC_FUC static void foreach (thin_function<void(CompNode)> callback);
  428. //! get total number of specific devices on this system
  429. MGE_WIN_DECLSPEC_FUC static size_t get_device_count(
  430. DeviceType type, bool warn = true);
  431. /* =================== specialized ======================== */
  432. //! get default CPU comp node
  433. // implemented in comp_node/cpu/comp_node.cpp
  434. MGE_WIN_DECLSPEC_FUC static CompNode default_cpu();
  435. /*!
  436. * \brief set whether to enable affinity setting for CPU comp nodes
  437. *
  438. * If enabled, computation on cpux would be bound to the x'th CPU.
  439. *
  440. * This is disabled by default.
  441. *
  442. * (implemented in comp_node/cpu/comp_node.cpp)
  443. *
  444. * \return original setting
  445. */
  446. MGE_WIN_DECLSPEC_FUC static bool enable_affinity_for_cpu(bool flag);
  447. protected:
  448. //! ImplBase with env(); defined in CompNodeEnv
  449. class Impl;
  450. class ImplBase : public NonCopyableObj, public DynTypeObj {
  451. public:
  452. typedef void (*free_func_t)(ImplBase* self, void* ptr);
  453. //! memory free might be called after finalize(); so we should
  454. //! not rely on virtual function for this
  455. const free_func_t free_device;
  456. const free_func_t free_host;
  457. virtual void* alloc_device(size_t size) = 0;
  458. virtual void* alloc_host(size_t size) = 0;
  459. virtual void copy_to_host(
  460. void* host_ptr, const void* device_ptr, size_t size) = 0;
  461. virtual void copy_to_device(
  462. void* device_ptr, const void* host_ptr, size_t size) = 0;
  463. virtual void copy_to_host_ref(
  464. megdnn::RefPtr& host_ref_ptr, megdnn::RefPtr& device_ref_ptr,
  465. size_t size) {
  466. copy_to_host(host_ref_ptr.get_ptr(), device_ref_ptr.get_ptr(), size);
  467. }
  468. virtual void copy_to_device_ref(
  469. megdnn::RefPtr& device_ref_ptr, megdnn::RefPtr& host_ref_ptr,
  470. size_t size) {
  471. copy_to_device(device_ref_ptr.get_ptr(), host_ref_ptr.get_ptr(), size);
  472. }
  473. virtual void peer_copy_to(
  474. Impl* dest_impl, void* dest, const void* src, size_t size) = 0;
  475. virtual void peer_copy_to_ref(
  476. Impl* dest_impl, megdnn::RefPtr& dest, megdnn::RefPtr& src,
  477. size_t size) {
  478. peer_copy_to(dest_impl, dest.get_ptr(), src.get_ptr(), size);
  479. }
  480. virtual void map_to_cpu(void* ptr, size_t size, bool blocking = false);
  481. virtual void unmap_to_gpu(void* ptr, size_t size);
  482. virtual void* get_logical_addr_by_host_ptr(void* ptr, size_t size);
  483. virtual size_t get_mem_addr_alignment() = 0;
  484. virtual size_t get_mem_padding();
  485. virtual std::unique_ptr<Event> create_event(size_t flags) = 0;
  486. virtual void sync() = 0;
  487. virtual MemNode mem_node() = 0;
  488. virtual std::pair<size_t, size_t> get_mem_status_bytes() = 0;
  489. #if !MGB_BUILD_SLIM_SERVING
  490. virtual std::pair<size_t, size_t> get_free_left_and_right(size_t x, size_t y) {
  491. return {x - x, y - y};
  492. }
  493. virtual size_t get_used_memory() { return 0; }
  494. virtual size_t get_reserved_memory() { return 0; }
  495. virtual size_t get_max_reserved_memory() { return 0; }
  496. virtual size_t get_max_used_memory() { return 0; }
  497. virtual size_t get_max_block_size_available() { return 0; }
  498. virtual size_t get_free_mem() { return get_mem_status_bytes().second; }
  499. virtual void reset_max_reserved_memory() {}
  500. virtual void reset_max_used_memory() {}
  501. #endif
  502. virtual Locator locator() = 0;
  503. virtual Locator locator_logical() = 0;
  504. virtual std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(
  505. cg::ComputingGraph* cg);
  506. virtual void add_callback(megdnn::thin_function<void()>&&);
  507. virtual uint64_t get_uid() {
  508. mgb_throw(MegBrainError, "get_uid is not impl yet");
  509. };
  510. protected:
  511. ImplBase(free_func_t fd, free_func_t fh) : free_device{fd}, free_host{fh} {}
  512. ~ImplBase() = default;
  513. };
  514. //! implementations are allocated statically, so no memory management
  515. //! is needed
  516. ImplBase* m_impl = nullptr;
  517. friend class CompNodeEnv;
  518. friend struct HashTrait<CompNode>;
  519. friend struct HashTrait<CompNode::Locator>;
  520. friend class CompNodeImplHelper;
  521. public:
  522. CompNode(ImplBase* impl) : m_impl{impl} {}
  523. };
  524. MGB_DEF_ENUM_CLASS_BIT_OPR(CompNode::Flag)
  525. /*!
  526. * \brief record computation operations on a computing node
  527. *
  528. * This is used for fast execution of an identical computation sequence where
  529. * only input/output data differ.
  530. *
  531. * When this object is created from a comp node, recording starts immediately.
  532. * Call stop() when computation finishes, and call replay() when it needs to be
  533. * re-executed.
  534. *
  535. * Implementations should consider thread safe in comp_node, in order to support
  536. * multi threads reording in the same comp_node simultaneously, using thread
  537. * local recorder in comp_node.
  538. *
  539. * Note. When recording is over, the recorder is independent with comp_node, so
  540. * the task dispatched into recorder should not related to the comp_node
  541. * methord, and the thread of recorder replay is the user thread.
  542. */
  543. class CompNodeSeqRecorder {
  544. public:
  545. virtual ~CompNodeSeqRecorder() noexcept = default;
  546. /*!
  547. * \brief Enter fake-exec mode
  548. *
  549. * Memory allocation/free is only allowed in fake-exec mode, and kernels
  550. * should not be actually recorded in this mode.
  551. *
  552. * This should be paired with exit_fake_exec()
  553. */
  554. virtual void enter_fake_exec(const CompNode& comp_node) = 0;
  555. //! Exit fake-exec mode
  556. virtual void exit_fake_exec(const CompNode& comp_node) = 0;
  557. virtual void stop(const CompNode& comp_node) = 0;
  558. virtual void replay() = 0;
  559. };
  560. /*!
  561. * \brief event associated with a CompNode node, used for cross-device
  562. * synchronization
  563. */
  564. class CompNode::Event : public NonCopyableObj {
  565. protected:
  566. static int sm_cpu_sync_level;
  567. //! flags when this event is created
  568. size_t const m_create_flags;
  569. Event(size_t create_flags) : m_create_flags{create_flags} {}
  570. public:
  571. enum Flags { NEED_TIMER = 1 };
  572. virtual ~Event() = default;
  573. /*!
  574. * \brief record this event on the comp node that creates it
  575. *
  576. * Note that if a comp node is recorded multiple times, then subsequent
  577. * calls would overwrite its internal state and other methods that
  578. * examine the status would only examine the completion of the most
  579. * recent call to record().
  580. */
  581. virtual void record() = 0;
  582. //! whether this event has finished; it must has been recorded
  583. virtual bool finished() = 0;
  584. //! block the host thread (caller thread) to wait for this event
  585. virtual void host_wait() = 0;
  586. //! get elapsed time in seconds from this to another event; the events
  587. //! must be finished
  588. virtual double elapsed_time_until(Event& end) = 0;
  589. //! record an action on another comp node so it would wait for this
  590. //! event
  591. virtual void device_wait_by(CompNode cn) = 0;
  592. //! get the comp node to which this event is associated
  593. virtual CompNode comp_node() const = 0;
  594. //! flags when this event is created
  595. size_t create_flags() const { return m_create_flags; }
  596. /*!
  597. * \brief set CPU resource usage level when performing synchronization
  598. * \param level CPU waiting level:
  599. * 0. condition var (the default)
  600. * 1. busy wait with yield
  601. * 2. busy wait
  602. */
  603. static void set_cpu_sync_level(int level) { sm_cpu_sync_level = level; }
  604. };
  605. /*!
  606. * \brief pool of events that can be reused
  607. */
  608. class CompNode::EventPool {
  609. CompNode m_cn;
  610. std::vector<std::unique_ptr<CompNode::Event>> m_allocated;
  611. std::vector<CompNode::Event*> m_free;
  612. Spinlock m_lock;
  613. size_t m_flags;
  614. public:
  615. MGE_WIN_DECLSPEC_FUC explicit EventPool(CompNode cn, size_t flags = 0);
  616. MGE_WIN_DECLSPEC_FUC ~EventPool();
  617. MGE_WIN_DECLSPEC_FUC CompNode::Event* alloc();
  618. MGE_WIN_DECLSPEC_FUC void free(CompNode::Event* ev);
  619. //! assert that all allocated events have been freed
  620. MGE_WIN_DECLSPEC_FUC void assert_all_freed();
  621. };
  622. void CompNode::device_wait_event(Event& event) const {
  623. event.device_wait_by(*this);
  624. }
  625. template <>
  626. struct HashTrait<CompNode> {
  627. static size_t eval(const CompNode& val) {
  628. static_assert(sizeof(size_t) == sizeof(void*), "bad hash type");
  629. return reinterpret_cast<size_t>(static_cast<void*>(val.m_impl));
  630. }
  631. };
  632. template <>
  633. struct HashTrait<CompNode::Locator> {
  634. static size_t eval(const CompNode::Locator& val) {
  635. return static_cast<size_t>(val.device) + (static_cast<size_t>(val.type) << 4) +
  636. (static_cast<size_t>(val.stream) << 8);
  637. }
  638. };
  639. namespace comp_node_detail {
  640. /*!
  641. * \brief an inplace doubly linked list for efficient inserting/deleting
  642. *
  643. * Note: do not use this directly; it is only for CompNodeDepedentObject
  644. */
  645. class DepedentObjList {
  646. class Sentinel;
  647. struct StaticInfo;
  648. static StaticInfo sm_info;
  649. DepedentObjList *m_prev = nullptr, *m_next = nullptr;
  650. static void link(DepedentObjList* a, DepedentObjList* b) {
  651. a->m_next = b;
  652. b->m_prev = a;
  653. }
  654. protected:
  655. MGE_WIN_DECLSPEC_FUC virtual std::shared_ptr<void> callback() = 0;
  656. ~DepedentObjList() = default;
  657. MGE_WIN_DECLSPEC_FUC static void add(DepedentObjList* ptr);
  658. MGE_WIN_DECLSPEC_FUC static void remove(DepedentObjList* ptr);
  659. public:
  660. MGE_WIN_DECLSPEC_FUC static void invoke_callback_and_clean();
  661. };
  662. } // namespace comp_node_detail
  663. /*!
  664. * \brief base class for objects that depend on CompNode
  665. *
  666. * There is a CompNode::finalize() method that destorys all global comp nodes.
  667. * Therefore objects that depend on CompNode should all be marked as invalid at
  668. * that time.
  669. *
  670. * CompNode::finalize() is called in atexit() because some external libraries
  671. * that CompNode depends on seems to be registering exit handlers. It is also
  672. * impractical to require a correct destruction order because, for example, in
  673. * python atexit() handlers are invoked before global python objects get
  674. * reclaimed.
  675. *
  676. * As a result we give up enforcing a correct destruction order, but rather
  677. * require all CompNode-dependent objects to derive from this class so they can
  678. * get notified possibly do most of the cleanup when CompNode is finalized.
  679. */
  680. class CompNodeDepedentObject : private comp_node_detail::DepedentObjList {
  681. //! 1: in on_comp_node_finalize(); 2: after on_comp_node_finalize()
  682. int m_state = 0;
  683. MGE_WIN_DECLSPEC_FUC std::shared_ptr<void> callback() override final;
  684. protected:
  685. CompNodeDepedentObject() { add(this); }
  686. ~CompNodeDepedentObject() { remove(this); }
  687. /*!
  688. * \brief overwritten by subclasses to perform clean up jobs
  689. *
  690. * Note: in case the object has nested objects which hold a reference to the
  691. * object itself, a reference to this object must be kept so it would not be
  692. * released during the call of on_comp_node_finalize().
  693. */
  694. virtual std::shared_ptr<void> on_comp_node_finalize() = 0;
  695. //! exception would thrown if on_comp_node_finalize() has been called (do
  696. //! not raise if invoked from on_comp_node_finalize())
  697. void check_not_finalized() const;
  698. //! whether on_comp_node_finalize() has been called (true when invoked
  699. //! from on_comp_node_finalize())
  700. bool is_finalized() const { return m_state; }
  701. };
  702. } // namespace mgb
  703. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}