Browse Source

fix(mgb/opr-mm): fix megray_helper thread safety

GitOrigin-RevId: f7b7c1d97f
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
4e0054f7b2
2 changed files with 28 additions and 6 deletions
  1. +22
    -6
      src/opr-mm/impl/megray_helper.cpp
  2. +6
    -0
      src/opr-mm/include/megbrain/opr/megray_helper.h

+ 22
- 6
src/opr-mm/impl/megray_helper.cpp View File

@@ -14,19 +14,35 @@
using namespace mgb;
using namespace opr;

bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
std::unique_lock<std::mutex> lk(m_mtx);
auto it = m_megray_comms.find(hash);
if (it != m_megray_comms.end()) {
comm = it->second;
return true;
}
return false;
}

void MegRayCommunicatorBuilder::emplace(uint64_t hash,
std::shared_ptr<MegRay::Communicator> comm) {
std::unique_lock<std::mutex> lk(m_mtx);
m_megray_comms.emplace(hash, comm);
}

std::shared_ptr<MegRay::Communicator> MegRayCommunicatorBuilder::get_megray_comm(
uint64_t hash, std::string key, uint32_t size, uint32_t rank,
MegRay::Backend backend,
std::shared_ptr<mgb::opr::GroupClient> group_client) {
auto it = m_megray_comms.find(hash);
if (it == m_megray_comms.end()) {
auto comm = MegRay::get_communicator(size, rank, backend);
std::shared_ptr<MegRay::Communicator> comm;
if (!find(hash, comm)) {
comm = MegRay::get_communicator(size, rank, backend);
auto uid = comm->get_uid();
auto uids = group_client->gather_uid(uid, key, size, rank);
comm->init(uids);
m_megray_comms.emplace(hash, std::move(comm));
mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK);
emplace(hash, comm);
}
return m_megray_comms[hash];
return comm;
}

MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder);


+ 6
- 0
src/opr-mm/include/megbrain/opr/megray_helper.h View File

@@ -11,6 +11,8 @@

#pragma once

#include <mutex>

#include "megbrain/utils/metahelper.h"
#include "megbrain/opr/group_manager.h"
#include "megray.h"
@@ -25,7 +27,11 @@ class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData
MGB_TYPEINFO_OBJ_DECL;

private:
bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm);
void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm);

std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms;
std::mutex m_mtx;

public:
std::shared_ptr<MegRay::Communicator> get_megray_comm(


Loading…
Cancel
Save