You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

megray_helper.cpp 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. /**
  2. * \file src/opr-mm/impl/megray_helper.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/megray_helper.h"
  12. #include "megbrain/comp_node_env.h"
  13. using namespace mgb;
  14. using namespace opr;
  15. MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) {
  16. switch(dtype.enumv()) {
  17. case DTypeEnum::Int8:
  18. return MegRay::DType::MEGRAY_INT8;
  19. case DTypeEnum::Int32:
  20. return MegRay::DType::MEGRAY_INT32;
  21. case DTypeEnum::Float32:
  22. return MegRay::DType::MEGRAY_FLOAT32;
  23. #ifndef MEGDNN_DISABLE_FLOAT16
  24. case DTypeEnum::Float16:
  25. return MegRay::DType::MEGRAY_FLOAT16;
  26. #endif
  27. default:
  28. mgb_throw(MegBrainError, "bad CollectiveComm dtype");
  29. }
  30. }
  31. MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
  32. if (backend == "nccl") {
  33. return MegRay::MEGRAY_NCCL;
  34. } else if (backend == "rccl") {
  35. return MegRay::MEGRAY_RCCL;
  36. } else if (backend == "ucx") {
  37. return MegRay::MEGRAY_UCX;
  38. } else {
  39. mgb_throw(MegBrainError, "back CollectiveComm backend");
  40. }
  41. }
  42. std::shared_ptr<MegRay::Context> mgb::opr::get_megray_context(CompNode comp_node){
  43. #if MGB_CUDA
  44. return MegRay::CudaContext::make(CompNodeEnv::from_comp_node(comp_node).cuda_env().stream);
  45. #elif MGB_ROCM
  46. return MegRay::HipContext::make(CompNodeEnv::from_comp_node(comp_node).rocm_env().stream);
  47. #else
  48. #error "neither CUDA nor ROCm is enabled"
  49. #endif
  50. }
  51. bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
  52. std::unique_lock<std::mutex> lk(m_map_mtx);
  53. auto it = m_megray_comms.find(hash);
  54. if (it != m_megray_comms.end()) {
  55. comm = it->second;
  56. return true;
  57. }
  58. return false;
  59. }
  60. void MegRayCommBuilder::emplace(uint64_t hash,
  61. std::shared_ptr<MegRay::Communicator> comm) {
  62. std::unique_lock<std::mutex> lk(m_map_mtx);
  63. m_megray_comms.emplace(hash, comm);
  64. }
  65. std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
  66. uint64_t hash, std::string key, uint32_t size, uint32_t rank,
  67. MegRay::Backend backend,
  68. std::shared_ptr<mgb::opr::GroupClient> group_client) {
  69. {
  70. // singleton pattern
  71. std::unique_lock<std::mutex> lk(sm_instance_mtx);
  72. if (sm_instance == nullptr) {
  73. sm_instance = new MegRayCommBuilder();
  74. }
  75. }
  76. std::shared_ptr<MegRay::Communicator> comm;
  77. if (!sm_instance->find(hash, comm)) {
  78. uint32_t root = 0;
  79. std::string master_ip;
  80. int port = 0;
  81. if (rank == root) {
  82. char* c = MegRay::get_host_ip();
  83. master_ip = std::string(c);
  84. delete c;
  85. port = MegRay::get_free_port();
  86. auto ret = MegRay::create_server(size, port);
  87. mgb_assert(ret == MegRay::Status::MEGRAY_OK);
  88. }
  89. group_client->bcast_addr(master_ip, port, key, size, rank, root);
  90. comm = MegRay::get_communicator(size, rank, backend);
  91. auto ret = comm->init(master_ip.c_str(), port);
  92. mgb_assert(ret == MegRay::Status::MEGRAY_OK);
  93. sm_instance->emplace(hash, comm);
  94. }
  95. return comm;
  96. }
  97. MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr;
  98. std::mutex MegRayCommBuilder::sm_instance_mtx;
  99. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台