/** * \file src/custom/impl/manager.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/common.h" #if MGB_CUSTOM_OP #include "megbrain/custom/manager.h" #include #ifndef _WIN32 #include #endif using namespace mgb; namespace custom { CustomOpManager *CustomOpManager::inst(void) { static CustomOpManager op_manager; return &op_manager; } CustomOpManager::~CustomOpManager() { mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!"); LibManager::inst()->m_custom_libs.clear(); } std::shared_ptr CustomOpManager::insert(const std::string &name, uint32_t version) { MGB_LOCK_GUARD(m_mtx); auto iter = m_name2op.find(name); if (iter != m_name2op.end()) { mgb_log_warn("Register Custom Op Failed! Op %s has been registered", name.c_str()); return std::const_pointer_cast(iter->second); } std::shared_ptr op = std::make_shared(name, version); m_name2op[op->op_type()] = op; m_id2op[op->runtime_id()] = op; return std::const_pointer_cast(op); } bool CustomOpManager::erase(const std::string &name) { MGB_LOCK_GUARD(m_mtx); auto iter = m_name2op.find(name); if (iter == m_name2op.end()) { mgb_log_warn("Erase Custom Op Failed! %s has not been registered", name.c_str()); return false; } std::shared_ptr op = iter->second; m_id2op.erase(op->runtime_id()); m_name2op.erase(op->op_type()); return true; } bool CustomOpManager::erase(const RunTimeId &id) { MGB_LOCK_GUARD(m_mtx); auto iter = m_id2op.find(id); if (iter == m_id2op.end()) { mgb_log_warn("Erase Custom Op Failed! The Op has not been registered"); return false; } std::shared_ptr op = iter->second; m_id2op.erase(op->runtime_id()); m_name2op.erase(op->op_type()); return true; } std::shared_ptr CustomOpManager::find_or_reg(const std::string &name, uint32_t version) { auto iter = m_name2op.find(name); if (iter == m_name2op.end()) { return insert(name, version); } return std::const_pointer_cast(iter->second); } RunTimeId CustomOpManager::to_id(const std::string &name) const { std::shared_ptr op = find(name); return op->runtime_id(); } std::string CustomOpManager::to_name(const RunTimeId &id) const { std::shared_ptr op = find(id); return op->op_type(); } std::shared_ptr CustomOpManager::find(const std::string &name) const { auto ret = m_name2op.find(name); mgb_assert(ret != m_name2op.end(), "Find Custom Op Failed! Op %s has not been registered", name.c_str() ); return ret->second; } std::shared_ptr CustomOpManager::find(const RunTimeId &id) const { auto ret = m_id2op.find(id); mgb_assert(ret != m_id2op.end(), "Find Custom Op Failed! Op has not been registered"); return ret->second; } std::vector CustomOpManager::op_name_list(void) { std::vector ret; for (auto kv: m_name2op) { ret.emplace_back(kv.first); } return ret; } std::vector CustomOpManager::op_id_list(void) { std::vector ret; for (auto kv: m_id2op) { ret.emplace_back(kv.first); } return ret; } #ifndef _WIN32 CustomLib::CustomLib(const std::string &path, int mode = RTLD_LAZY) : m_handle(nullptr, [](void* handle) {dlclose(handle);}) { auto op_list_before_load = CustomOpManager::inst()->op_name_list(); std::unordered_set op_set_before_load( op_list_before_load.begin(), op_list_before_load.end()); m_handle.reset(dlopen(path.c_str(), mode)); mgb_assert(m_handle != nullptr, "open custom op lib failed, error type: %s", dlerror()); auto op_list_after_load = CustomOpManager::inst()->op_name_list(); for (auto &op: op_list_after_load) { if (op_set_before_load.find(op) == op_set_before_load.end()) { m_ops.emplace_back(op); } } } #else CustomLib::CustomLib(const std::string &path, int mode = 0) : m_handle(nullptr, [](void* handle) {}) { mgb_assert(false, "custom op is only supported on Linux now"); } #endif const std::vector &CustomLib::ops_in_lib(void) const { return m_ops; } CustomLib::~CustomLib() { for (auto &op: m_ops) { CustomOpManager::inst()->erase(op); } } bool CustomLib::valid() const { return m_handle != nullptr; } LibManager *LibManager::inst(void) { static LibManager custom_libs; return &custom_libs; } const std::vector &LibManager::install(const std::string &name, const std::string &path) { MGB_LOCK_GUARD(m_mtx);; LibHandle handle = std::make_shared(path); m_custom_libs.insert({name, handle}); return m_custom_libs[name]->ops_in_lib(); } bool LibManager::uninstall(const std::string &name) { MGB_LOCK_GUARD(m_mtx);; mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error"); return true; } std::shared_ptr op_insert(std::string opname, uint32_t version) { return CustomOpManager::inst()->insert(opname, version); } } #endif