You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

manager.cpp 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. /**
  2. * \file src/custom/impl/manager.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/common.h"
  12. #if MGB_CUSTOM_OP
  13. #include "megbrain/custom/manager.h"
  14. #include <unordered_set>
  15. #ifndef _WIN32
  16. #include <dlfcn.h>
  17. #endif
  18. using namespace mgb;
  19. namespace custom {
  20. CustomOpManager *CustomOpManager::inst(void) {
  21. static CustomOpManager op_manager;
  22. return &op_manager;
  23. }
  24. CustomOpManager::~CustomOpManager() {
  25. mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!");
  26. LibManager::inst()->m_custom_libs.clear();
  27. }
  28. std::shared_ptr<CustomOp> CustomOpManager::insert(const std::string &name, uint32_t version) {
  29. MGB_LOCK_GUARD(m_mtx);
  30. auto iter = m_name2op.find(name);
  31. if (iter != m_name2op.end()) {
  32. mgb_log_warn("Register Custom Op Failed! Op %s has been registered", name.c_str());
  33. return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second);
  34. }
  35. std::shared_ptr<const CustomOp> op = std::make_shared<const CustomOp>(name, version);
  36. m_name2op[op->op_type()] = op;
  37. m_id2op[op->runtime_id()] = op;
  38. return std::const_pointer_cast<CustomOp, const CustomOp>(op);
  39. }
  40. bool CustomOpManager::erase(const std::string &name) {
  41. MGB_LOCK_GUARD(m_mtx);
  42. auto iter = m_name2op.find(name);
  43. if (iter == m_name2op.end()) {
  44. mgb_log_warn("Erase Custom Op Failed! %s has not been registered", name.c_str());
  45. return false;
  46. }
  47. std::shared_ptr<const CustomOp> op = iter->second;
  48. m_id2op.erase(op->runtime_id());
  49. m_name2op.erase(op->op_type());
  50. return true;
  51. }
  52. bool CustomOpManager::erase(const RunTimeId &id) {
  53. MGB_LOCK_GUARD(m_mtx);
  54. auto iter = m_id2op.find(id);
  55. if (iter == m_id2op.end()) {
  56. mgb_log_warn("Erase Custom Op Failed! The Op has not been registered");
  57. return false;
  58. }
  59. std::shared_ptr<const CustomOp> op = iter->second;
  60. m_id2op.erase(op->runtime_id());
  61. m_name2op.erase(op->op_type());
  62. return true;
  63. }
  64. std::shared_ptr<CustomOp> CustomOpManager::find_or_reg(const std::string &name, uint32_t version) {
  65. auto iter = m_name2op.find(name);
  66. if (iter == m_name2op.end()) {
  67. return insert(name, version);
  68. }
  69. return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second);
  70. }
  71. RunTimeId CustomOpManager::to_id(const std::string &name) const {
  72. std::shared_ptr<const CustomOp> op = find(name);
  73. return op->runtime_id();
  74. }
  75. std::string CustomOpManager::to_name(const RunTimeId &id) const {
  76. std::shared_ptr<const CustomOp> op = find(id);
  77. return op->op_type();
  78. }
  79. std::shared_ptr<const CustomOp> CustomOpManager::find(const std::string &name) const {
  80. auto ret = m_name2op.find(name);
  81. mgb_assert(ret != m_name2op.end(),
  82. "Find Custom Op Failed! Op %s has not been registered", name.c_str()
  83. );
  84. return ret->second;
  85. }
  86. std::shared_ptr<const CustomOp> CustomOpManager::find(const RunTimeId &id) const {
  87. auto ret = m_id2op.find(id);
  88. mgb_assert(ret != m_id2op.end(), "Find Custom Op Failed! Op has not been registered");
  89. return ret->second;
  90. }
  91. std::vector<std::string> CustomOpManager::op_name_list(void) {
  92. std::vector<std::string> ret;
  93. for (auto kv: m_name2op) {
  94. ret.emplace_back(kv.first);
  95. }
  96. return ret;
  97. }
  98. std::vector<RunTimeId> CustomOpManager::op_id_list(void) {
  99. std::vector<RunTimeId> ret;
  100. for (auto kv: m_id2op) {
  101. ret.emplace_back(kv.first);
  102. }
  103. return ret;
  104. }
  105. #ifndef _WIN32
  106. CustomLib::CustomLib(const std::string &path, int mode = RTLD_LAZY)
  107. : m_handle(nullptr, [](void* handle) {dlclose(handle);}) {
  108. auto op_list_before_load = CustomOpManager::inst()->op_name_list();
  109. std::unordered_set<std::string> op_set_before_load(
  110. op_list_before_load.begin(), op_list_before_load.end());
  111. m_handle.reset(dlopen(path.c_str(), mode));
  112. mgb_assert(m_handle != nullptr, "open custom op lib failed, error type: %s", dlerror());
  113. auto op_list_after_load = CustomOpManager::inst()->op_name_list();
  114. for (auto &op: op_list_after_load) {
  115. if (op_set_before_load.find(op) == op_set_before_load.end()) {
  116. m_ops.emplace_back(op);
  117. }
  118. }
  119. }
  120. #else
  121. CustomLib::CustomLib(const std::string &path, int mode = 0)
  122. : m_handle(nullptr, [](void* handle) {}) {
  123. mgb_assert(false, "custom op is only supported on Linux now");
  124. }
  125. #endif
  126. const std::vector<std::string> &CustomLib::ops_in_lib(void) const {
  127. return m_ops;
  128. }
  129. CustomLib::~CustomLib() {
  130. for (auto &op: m_ops) {
  131. CustomOpManager::inst()->erase(op);
  132. }
  133. }
  134. bool CustomLib::valid() const {
  135. return m_handle != nullptr;
  136. }
  137. LibManager *LibManager::inst(void) {
  138. static LibManager custom_libs;
  139. return &custom_libs;
  140. }
  141. const std::vector<std::string> &LibManager::install(const std::string &name, const std::string &path) {
  142. MGB_LOCK_GUARD(m_mtx);;
  143. LibHandle handle = std::make_shared<CustomLib>(path);
  144. m_custom_libs.insert({name, handle});
  145. return m_custom_libs[name]->ops_in_lib();
  146. }
  147. bool LibManager::uninstall(const std::string &name) {
  148. MGB_LOCK_GUARD(m_mtx);;
  149. mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error");
  150. return true;
  151. }
  152. std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version) {
  153. return CustomOpManager::inst()->insert(opname, version);
  154. }
  155. }
  156. #endif

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台