|
|
@@ -458,13 +458,16 @@ void CollectiveComm::opr_register() { |
|
|
|
auto hash = m_group_client->opr_register(m_key, m_nr_devices, m_rank, |
|
|
|
reinterpret_cast<uintptr_t>(cuda_env.stream)); |
|
|
|
|
|
|
|
auto megray_comm_builder = |
|
|
|
owner_graph() |
|
|
|
->options() |
|
|
|
.user_data |
|
|
|
.get_user_data_or_create<MegRayCommunicatorBuilder>(); |
|
|
|
MegRayCommunicatorBuilder* builder; |
|
|
|
|
|
|
|
m_megray_comm = megray_comm_builder->get_megray_comm( |
|
|
|
{ |
|
|
|
static std::mutex user_data_mtx; |
|
|
|
std::unique_lock<std::mutex> lk(user_data_mtx); |
|
|
|
builder = owner_graph()->options().user_data |
|
|
|
.get_user_data_or_create<MegRayCommunicatorBuilder>(); |
|
|
|
} |
|
|
|
|
|
|
|
m_megray_comm = builder->get_megray_comm( |
|
|
|
hash, m_key, m_nr_devices, m_rank, |
|
|
|
get_megray_backend(m_backend), m_group_client); |
|
|
|
|
|
|
|