You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

async_releaser.h 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. /**
  2. * \file imperative/src/impl/async_releaser.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/comp_node.h"
  13. #include "megbrain/imperative/blob_manager.h"
  14. #include "megbrain/system.h"
  15. #include "./event_pool.h"
  16. namespace mgb {
  17. namespace imperative {
  18. class AsyncReleaser : public CompNodeDepedentObject {
  19. struct WaiterParam {
  20. CompNode cn;
  21. CompNode::Event* event;
  22. BlobPtr blob;
  23. HostTensorStorage::RawStorage storage;
  24. };
  25. class Waiter final : public AsyncQueueSC<WaiterParam, Waiter> {
  26. AsyncReleaser* m_par_releaser;
  27. public:
  28. // disable busy wait by set max_spin=0 to save CPU cycle
  29. Waiter(AsyncReleaser* releaser)
  30. : AsyncQueueSC<WaiterParam, Waiter>(0),
  31. m_par_releaser(releaser) {}
  32. void process_one_task(WaiterParam& param) {
  33. if (param.event->finished()) {
  34. param.blob.reset();
  35. param.storage.reset();
  36. EventPool::without_timer().free(param.event);
  37. return;
  38. }
  39. using namespace std::literals;
  40. std::this_thread::sleep_for(1us);
  41. add_task(std::move(param));
  42. }
  43. void on_async_queue_worker_thread_start() override {
  44. sys::set_thread_name("releaser");
  45. }
  46. };
  47. Waiter m_waiter{this};
  48. protected:
  49. std::shared_ptr<void> on_comp_node_finalize() override {
  50. m_waiter.wait_task_queue_empty();
  51. return {};
  52. }
  53. public:
  54. static AsyncReleaser* inst() {
  55. static AsyncReleaser releaser;
  56. return &releaser;
  57. }
  58. ~AsyncReleaser() {
  59. m_waiter.wait_task_queue_empty();
  60. }
  61. void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); }
  62. void add(const HostTensorND& hv) {
  63. add(hv.comp_node(), {}, hv.storage().raw_storage());
  64. }
  65. void add(CompNode cn, BlobPtr blob,
  66. HostTensorStorage::RawStorage storage = {}) {
  67. auto event = EventPool::without_timer().alloc(cn);
  68. event->record();
  69. m_waiter.add_task({cn, event, std::move(blob), std::move(storage)});
  70. }
  71. };
  72. }
  73. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台