#include "megbrain/rdnn/management.h" #include "megbrain/comp_node_env.h" #include "megbrain/tensor.h" #include "megbrain/utils/metahelper.h" #include "megdnn/handle.h" #include "megdnn/oprs.h" /* ================== global functions ================== */ using namespace mgb; using namespace mgb::opr; namespace { template class MegDNNGlobalOprContainer final : public UserDataContainer::UserData { MGB_TYPEINFO_OBJ_DECL; std::shared_ptr m_megdnn_handle; std::unique_ptr m_opr; public: MegDNNGlobalOprContainer(CompNode cn) : m_megdnn_handle{intl::get_megdnn_handle_shared(cn)}, m_opr{m_megdnn_handle->create_operator()} { mgb_assert(m_opr->is_thread_safe()); } Opr* get() const { return m_opr.get(); } }; template MGB_TYPEINFO_OBJ_IMPL(MegDNNGlobalOprContainer); } // anonymous namespace std::shared_ptr intl::get_megdnn_handle_shared(CompNode comp_node) { auto& handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)); return {handle.shared_from_this(), handle.handle()}; } megdnn::Handle* intl::get_megdnn_handle(CompNode comp_node) { return MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)).handle(); } template Opr* intl::get_megdnn_global_opr(CompNode comp_node) { using T = MegDNNGlobalOprContainer; auto maker = [comp_node]() { return std::make_shared(comp_node); }; return CompNodeEnv::from_comp_node(comp_node).get_user_data(maker).get(); } namespace mgb { namespace opr { namespace intl { #define INST(o) template o* get_megdnn_global_opr(CompNode) INST(megdnn::AddUpdate); INST(megdnn::Relayout); INST(megdnn::Checksum); #undef INST } // namespace intl } // namespace opr } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}