You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

management.cpp 1.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. #include "megbrain/rdnn/management.h"
  2. #include "megbrain/comp_node_env.h"
  3. #include "megbrain/tensor.h"
  4. #include "megbrain/utils/metahelper.h"
  5. #include "megdnn/handle.h"
  6. #include "megdnn/oprs.h"
  7. /* ================== global functions ================== */
  8. using namespace mgb;
  9. using namespace mgb::opr;
  10. namespace {
  11. template <class Opr>
  12. class MegDNNGlobalOprContainer final : public UserDataContainer::UserData {
  13. MGB_TYPEINFO_OBJ_DECL;
  14. std::shared_ptr<megdnn::Handle> m_megdnn_handle;
  15. std::unique_ptr<Opr> m_opr;
  16. public:
  17. MegDNNGlobalOprContainer(CompNode cn)
  18. : m_megdnn_handle{intl::get_megdnn_handle_shared(cn)},
  19. m_opr{m_megdnn_handle->create_operator<Opr>()} {
  20. mgb_assert(m_opr->is_thread_safe());
  21. }
  22. Opr* get() const { return m_opr.get(); }
  23. };
  24. template <class Opr>
  25. MGB_TYPEINFO_OBJ_IMPL(MegDNNGlobalOprContainer<Opr>);
  26. } // anonymous namespace
  27. std::shared_ptr<megdnn::Handle> intl::get_megdnn_handle_shared(CompNode comp_node) {
  28. auto& handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node));
  29. return {handle.shared_from_this(), handle.handle()};
  30. }
  31. megdnn::Handle* intl::get_megdnn_handle(CompNode comp_node) {
  32. return MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)).handle();
  33. }
  34. template <typename Opr>
  35. Opr* intl::get_megdnn_global_opr(CompNode comp_node) {
  36. using T = MegDNNGlobalOprContainer<Opr>;
  37. auto maker = [comp_node]() { return std::make_shared<T>(comp_node); };
  38. return CompNodeEnv::from_comp_node(comp_node).get_user_data<T>(maker).get();
  39. }
  40. namespace mgb {
  41. namespace opr {
  42. namespace intl {
  43. #define INST(o) template o* get_megdnn_global_opr<o>(CompNode)
  44. INST(megdnn::AddUpdate);
  45. INST(megdnn::Relayout);
  46. INST(megdnn::Checksum);
  47. #undef INST
  48. } // namespace intl
  49. } // namespace opr
  50. } // namespace mgb
  51. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}