GitOrigin-RevId: 55c59879f2
release-0.6
@@ -139,27 +139,29 @@ GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key, | |||
return ret; | |||
} | |||
std::vector<std::string> GroupManager::gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank) { | |||
std::unique_lock<std::mutex> lk{m_key2uids_mtx}; | |||
if (m_key2uids_size[key] == 0) | |||
m_key2uids[key].resize(size); | |||
m_key2uids[key][rank] = uid; | |||
m_key2uids_size[key]++; | |||
if (m_key2uids_size[key] == size) { | |||
m_key2uids_flag[key] = true; | |||
m_gather_uid_cv.notify_all(); | |||
void GroupManager::bcast_addr(std::string& master_ip, int& port, | |||
const std::string& key, uint32_t size, uint32_t rank, uint32_t root) { | |||
std::unique_lock<std::mutex> lk{m_key2addr_mtx}; | |||
if (rank == root) { | |||
m_key2master_ip[key] = master_ip; | |||
m_key2port[key] = port; | |||
} | |||
m_key2addr_size[key]++; | |||
if (m_key2addr_size[key] == size) { | |||
m_key2addr_flag[key] = true; | |||
m_bcast_cv.notify_all(); | |||
} else { | |||
m_gather_uid_cv.wait( | |||
lk, [&] { return m_key2uids_flag.count(key) > 0; }); | |||
m_bcast_cv.wait( | |||
lk, [&] { return m_key2addr_flag.count(key) > 0; }); | |||
} | |||
auto uids = m_key2uids[key]; | |||
m_key2uids_size[key]--; | |||
if (m_key2uids_size[key] == 0) { | |||
m_key2uids.erase(key); | |||
m_key2uids_flag.erase(key); | |||
master_ip = m_key2master_ip[key]; | |||
port = m_key2port[key]; | |||
m_key2addr_size[key]--; | |||
if (m_key2addr_size[key] == 0) { | |||
m_key2master_ip.erase(key); | |||
m_key2port.erase(key); | |||
m_key2addr_flag.erase(key); | |||
} | |||
return uids; | |||
} | |||
void GroupManager::set_output_shape(const std::string& key, | |||
@@ -44,10 +44,22 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||
std::shared_ptr<MegRay::Communicator> comm; | |||
if (!sm_instance->find(hash, comm)) { | |||
uint32_t root = 0; | |||
std::string master_ip; | |||
int port = 0; | |||
if (rank == root) { | |||
char* c = MegRay::get_host_ip(); | |||
master_ip = std::string(c); | |||
delete c; | |||
port = MegRay::get_free_port(); | |||
auto ret = MegRay::create_server(size, port); | |||
mgb_assert(ret == MegRay::Status::MEGRAY_OK); | |||
} | |||
group_client->bcast_addr(master_ip, port, key, size, rank, root); | |||
comm = MegRay::get_communicator(size, rank, backend); | |||
auto uid = comm->get_uid(); | |||
auto uids = group_client->gather_uid(uid, key, size, rank); | |||
mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); | |||
auto ret = comm->init(master_ip.c_str(), port); | |||
mgb_assert(ret == MegRay::Status::MEGRAY_OK); | |||
sm_instance->emplace(hash, comm); | |||
} | |||
return comm; | |||
@@ -41,7 +41,7 @@ public: | |||
RUNSERVER(opr_register); | |||
RUNSERVER(set_output_shape); | |||
RUNSERVER(get_output_shape); | |||
RUNSERVER(gather_uid); | |||
RUNSERVER(bcast_addr); | |||
RUNSERVER(group_barrier); | |||
mgb_assert(false, "invalid rpc request"); | |||
} | |||
@@ -49,7 +49,7 @@ private: | |||
void opr_register(void* input_ptr, size_t input_len, std::string *output); | |||
void set_output_shape(void* input_ptr, size_t input_len, std::string *output); | |||
void get_output_shape(void* input_ptr, size_t input_len, std::string *output); | |||
void gather_uid(void* input_ptr, size_t input_len, std::string *output); | |||
void bcast_addr(void* input_ptr, size_t input_len, std::string *output); | |||
void group_barrier(void* input_ptr, size_t input_len, std::string *output); | |||
private: | |||
@@ -101,15 +101,14 @@ void GroupServerProxy::get_output_shape(void* input_ptr, size_t input_len, | |||
rsp.SerializeToString(output); | |||
} | |||
void GroupServerProxy::gather_uid(void* input_ptr, size_t input_len, | |||
void GroupServerProxy::bcast_addr(void* input_ptr, size_t input_len, | |||
std::string *output) { | |||
INFO_INIT(mm_handler, GatherUid); | |||
auto uid = req.uid(); | |||
auto uids = m_mgr.gather_uid(uid, req.key(), req.size(), req.rank()); | |||
for (size_t i = 0;i < uids.size();i++) { | |||
rsp.add_uids(); | |||
rsp.set_uids(i, uids[i].data(), uids[i].size()); | |||
} | |||
INFO_INIT(mm_handler, BcastAddr); | |||
std::string master_ip = req.master_ip(); | |||
int port = req.port(); | |||
m_mgr.bcast_addr(master_ip, port, req.key(), req.size(), req.rank(), req.root()); | |||
rsp.set_master_ip(master_ip); | |||
rsp.set_port(port); | |||
rsp.SerializeToString(output); | |||
} | |||
@@ -184,19 +183,20 @@ TensorShape GroupClientProxy::get_output_shape(const std::string& key) { | |||
} | |||
return shape; | |||
} | |||
std::vector<std::string> GroupClientProxy::gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank) { | |||
INFO_INIT(mm_handler, gather_uid, GatherUid); | |||
req.set_uid(uid.data(), uid.size()); | |||
void GroupClientProxy::bcast_addr(std::string& master_ip, | |||
int& port, const std::string& key, uint32_t size, | |||
uint32_t rank, uint32_t root) { | |||
INFO_INIT(mm_handler, bcast_addr, BcastAddr); | |||
req.set_master_ip(master_ip.data(), master_ip.size()); | |||
req.set_port(port); | |||
req.set_key(key.data(), key.size()); | |||
req.set_size(size); | |||
req.set_rank(rank); | |||
req.set_root(root); | |||
SOLVE_REQUEST(func_name, req, rsp); | |||
std::vector<std::string> rst; | |||
for (size_t i = 0;i < size;i++) { | |||
rst.push_back(rsp.uids(i)); | |||
} | |||
return rst; | |||
master_ip = rsp.master_ip(); | |||
port = rsp.port(); | |||
} | |||
uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | |||
@@ -82,9 +82,9 @@ class GroupManager { | |||
RegisterInfo opr_register(const std::string& key, size_t nr_devices, | |||
bool is_root, int rank, uint64_t comp_node_hash); | |||
//! gather uids from all ranks | |||
std::vector<std::string> gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank); | |||
//! broadcast master_ip and port | |||
void bcast_addr(std::string& master_ip, int& port, | |||
const std::string& key, uint32_t size, uint32_t rank, uint32_t root); | |||
//! Set output shape of this key | |||
void set_output_shape(const std::string& key, const TensorShape& shape); | |||
@@ -102,12 +102,13 @@ class GroupManager { | |||
std::unordered_map<std::string, GroupInfo> m_key2group_info; | |||
std::mutex m_key2group_info_mtx; | |||
//! key -> uid | |||
std::unordered_map<std::string, std::vector<std::string>> m_key2uids; | |||
std::unordered_map<std::string, uint32_t> m_key2uids_size; | |||
std::unordered_map<std::string, bool> m_key2uids_flag; | |||
std::mutex m_key2uids_mtx; | |||
std::condition_variable m_gather_uid_cv; | |||
//! key -> addr | |||
std::unordered_map<std::string, std::string> m_key2master_ip; | |||
std::unordered_map<std::string, int> m_key2port; | |||
std::unordered_map<std::string, uint32_t> m_key2addr_size; | |||
std::unordered_map<std::string, bool> m_key2addr_flag; | |||
std::mutex m_key2addr_mtx; | |||
std::condition_variable m_bcast_cv; | |||
//! barrier | |||
uint32_t m_barrier_size; | |||
@@ -133,8 +134,8 @@ class GroupClient { | |||
bool is_root, int rank, | |||
uint64_t comp_node_hash) = 0; | |||
virtual std::vector<std::string> gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank) = 0; | |||
virtual void bcast_addr(std::string& master_ip, int& port, | |||
const std::string& key, uint32_t size, uint32_t rank, uint32_t root) = 0; | |||
virtual void set_output_shape(const std::string& key, | |||
const TensorShape& shape) = 0; | |||
@@ -37,8 +37,8 @@ public: | |||
int rank, | |||
uint64_t comp_node_hash) override; | |||
std::vector<std::string> gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank) override; | |||
void bcast_addr(std::string& master_ip, int& port, const std::string& key, | |||
uint32_t size, uint32_t rank, uint32_t root) override; | |||
void set_output_shape(const std::string& key, | |||
const TensorShape& shape) override; | |||
@@ -16,15 +16,18 @@ message OprRegisterResponse { | |||
int32 root_rank = 3; | |||
} | |||
message GatherUidRequest { | |||
bytes uid = 1; | |||
string key = 2; | |||
uint32 size = 3; | |||
uint32 rank = 4; | |||
} | |||
message GatherUidResponse { | |||
repeated bytes uids = 1; | |||
message BcastAddrRequest { | |||
string master_ip = 1; | |||
int32 port = 2; | |||
string key = 3; | |||
uint32 size = 4; | |||
uint32 rank = 5; | |||
uint32 root = 6; | |||
} | |||
message BcastAddrResponse { | |||
string master_ip = 1; | |||
int32 port = 2; | |||
} | |||
message SetOutputShapeRequest { | |||
@@ -29,13 +29,14 @@ class MockGroupClient final : public opr::GroupClient { | |||
} | |||
RegisterInfo opr_register(const std::string& key, size_t nr_devices, | |||
bool is_root, int rank, uint64_t comp_node_hash) { | |||
bool is_root, int rank, uint64_t comp_node_hash) override { | |||
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | |||
} | |||
std::vector<std::string> gather_uid(const std::string& uid, | |||
const std::string& key, uint32_t size, uint32_t rank) { | |||
return m_mgr.gather_uid(uid, key, size, rank); | |||
void bcast_addr(std::string& master_ip, int& port, | |||
const std::string& key, uint32_t size, | |||
uint32_t rank, uint32_t root) override { | |||
return m_mgr.bcast_addr(master_ip, port, key, size, rank, root); | |||
} | |||
void set_output_shape(const std::string& key, | |||
@@ -1 +1 @@ | |||
Subproject commit d06c215dc1425fa932e20ecfaab7b07c0343a5bc | |||
Subproject commit e14e4f84c1349598ba17c49923168db47a4e9642 |