GitOrigin-RevId: 55c59879f2
release-0.6
@@ -139,27 +139,29 @@ GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key, | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<std::string> GroupManager::gather_uid(const std::string& uid, | |||||
const std::string& key, uint32_t size, uint32_t rank) { | |||||
std::unique_lock<std::mutex> lk{m_key2uids_mtx}; | |||||
if (m_key2uids_size[key] == 0) | |||||
m_key2uids[key].resize(size); | |||||
m_key2uids[key][rank] = uid; | |||||
m_key2uids_size[key]++; | |||||
if (m_key2uids_size[key] == size) { | |||||
m_key2uids_flag[key] = true; | |||||
m_gather_uid_cv.notify_all(); | |||||
void GroupManager::bcast_addr(std::string& master_ip, int& port, | |||||
const std::string& key, uint32_t size, uint32_t rank, uint32_t root) { | |||||
std::unique_lock<std::mutex> lk{m_key2addr_mtx}; | |||||
if (rank == root) { | |||||
m_key2master_ip[key] = master_ip; | |||||
m_key2port[key] = port; | |||||
} | |||||
m_key2addr_size[key]++; | |||||
if (m_key2addr_size[key] == size) { | |||||
m_key2addr_flag[key] = true; | |||||
m_bcast_cv.notify_all(); | |||||
} else { | } else { | ||||
m_gather_uid_cv.wait( | |||||
lk, [&] { return m_key2uids_flag.count(key) > 0; }); | |||||
m_bcast_cv.wait( | |||||
lk, [&] { return m_key2addr_flag.count(key) > 0; }); | |||||
} | } | ||||
auto uids = m_key2uids[key]; | |||||
m_key2uids_size[key]--; | |||||
if (m_key2uids_size[key] == 0) { | |||||
m_key2uids.erase(key); | |||||
m_key2uids_flag.erase(key); | |||||
master_ip = m_key2master_ip[key]; | |||||
port = m_key2port[key]; | |||||
m_key2addr_size[key]--; | |||||
if (m_key2addr_size[key] == 0) { | |||||
m_key2master_ip.erase(key); | |||||
m_key2port.erase(key); | |||||
m_key2addr_flag.erase(key); | |||||
} | } | ||||
return uids; | |||||
} | } | ||||
void GroupManager::set_output_shape(const std::string& key, | void GroupManager::set_output_shape(const std::string& key, | ||||
@@ -44,10 +44,22 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||||
std::shared_ptr<MegRay::Communicator> comm; | std::shared_ptr<MegRay::Communicator> comm; | ||||
if (!sm_instance->find(hash, comm)) { | if (!sm_instance->find(hash, comm)) { | ||||
uint32_t root = 0; | |||||
std::string master_ip; | |||||
int port = 0; | |||||
if (rank == root) { | |||||
char* c = MegRay::get_host_ip(); | |||||
master_ip = std::string(c); | |||||
delete c; | |||||
port = MegRay::get_free_port(); | |||||
auto ret = MegRay::create_server(size, port); | |||||
mgb_assert(ret == MegRay::Status::MEGRAY_OK); | |||||
} | |||||
group_client->bcast_addr(master_ip, port, key, size, rank, root); | |||||
comm = MegRay::get_communicator(size, rank, backend); | comm = MegRay::get_communicator(size, rank, backend); | ||||
auto uid = comm->get_uid(); | |||||
auto uids = group_client->gather_uid(uid, key, size, rank); | |||||
mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); | |||||
auto ret = comm->init(master_ip.c_str(), port); | |||||
mgb_assert(ret == MegRay::Status::MEGRAY_OK); | |||||
sm_instance->emplace(hash, comm); | sm_instance->emplace(hash, comm); | ||||
} | } | ||||
return comm; | return comm; | ||||
@@ -41,7 +41,7 @@ public: | |||||
RUNSERVER(opr_register); | RUNSERVER(opr_register); | ||||
RUNSERVER(set_output_shape); | RUNSERVER(set_output_shape); | ||||
RUNSERVER(get_output_shape); | RUNSERVER(get_output_shape); | ||||
RUNSERVER(gather_uid); | |||||
RUNSERVER(bcast_addr); | |||||
RUNSERVER(group_barrier); | RUNSERVER(group_barrier); | ||||
mgb_assert(false, "invalid rpc request"); | mgb_assert(false, "invalid rpc request"); | ||||
} | } | ||||
@@ -49,7 +49,7 @@ private: | |||||
void opr_register(void* input_ptr, size_t input_len, std::string *output); | void opr_register(void* input_ptr, size_t input_len, std::string *output); | ||||
void set_output_shape(void* input_ptr, size_t input_len, std::string *output); | void set_output_shape(void* input_ptr, size_t input_len, std::string *output); | ||||
void get_output_shape(void* input_ptr, size_t input_len, std::string *output); | void get_output_shape(void* input_ptr, size_t input_len, std::string *output); | ||||
void gather_uid(void* input_ptr, size_t input_len, std::string *output); | |||||
void bcast_addr(void* input_ptr, size_t input_len, std::string *output); | |||||
void group_barrier(void* input_ptr, size_t input_len, std::string *output); | void group_barrier(void* input_ptr, size_t input_len, std::string *output); | ||||
private: | private: | ||||
@@ -101,15 +101,14 @@ void GroupServerProxy::get_output_shape(void* input_ptr, size_t input_len, | |||||
rsp.SerializeToString(output); | rsp.SerializeToString(output); | ||||
} | } | ||||
void GroupServerProxy::gather_uid(void* input_ptr, size_t input_len, | |||||
void GroupServerProxy::bcast_addr(void* input_ptr, size_t input_len, | |||||
std::string *output) { | std::string *output) { | ||||
INFO_INIT(mm_handler, GatherUid); | |||||
auto uid = req.uid(); | |||||
auto uids = m_mgr.gather_uid(uid, req.key(), req.size(), req.rank()); | |||||
for (size_t i = 0;i < uids.size();i++) { | |||||
rsp.add_uids(); | |||||
rsp.set_uids(i, uids[i].data(), uids[i].size()); | |||||
} | |||||
INFO_INIT(mm_handler, BcastAddr); | |||||
std::string master_ip = req.master_ip(); | |||||
int port = req.port(); | |||||
m_mgr.bcast_addr(master_ip, port, req.key(), req.size(), req.rank(), req.root()); | |||||
rsp.set_master_ip(master_ip); | |||||
rsp.set_port(port); | |||||
rsp.SerializeToString(output); | rsp.SerializeToString(output); | ||||
} | } | ||||
@@ -184,19 +183,20 @@ TensorShape GroupClientProxy::get_output_shape(const std::string& key) { | |||||
} | } | ||||
return shape; | return shape; | ||||
} | } | ||||
std::vector<std::string> GroupClientProxy::gather_uid(const std::string& uid, | |||||
const std::string& key, uint32_t size, uint32_t rank) { | |||||
INFO_INIT(mm_handler, gather_uid, GatherUid); | |||||
req.set_uid(uid.data(), uid.size()); | |||||
void GroupClientProxy::bcast_addr(std::string& master_ip, | |||||
int& port, const std::string& key, uint32_t size, | |||||
uint32_t rank, uint32_t root) { | |||||
INFO_INIT(mm_handler, bcast_addr, BcastAddr); | |||||
req.set_master_ip(master_ip.data(), master_ip.size()); | |||||
req.set_port(port); | |||||
req.set_key(key.data(), key.size()); | req.set_key(key.data(), key.size()); | ||||
req.set_size(size); | req.set_size(size); | ||||
req.set_rank(rank); | req.set_rank(rank); | ||||
req.set_root(root); | |||||
SOLVE_REQUEST(func_name, req, rsp); | SOLVE_REQUEST(func_name, req, rsp); | ||||
std::vector<std::string> rst; | |||||
for (size_t i = 0;i < size;i++) { | |||||
rst.push_back(rsp.uids(i)); | |||||
} | |||||
return rst; | |||||
master_ip = rsp.master_ip(); | |||||
port = rsp.port(); | |||||
} | } | ||||
uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { | ||||
@@ -82,9 +82,9 @@ class GroupManager { | |||||
RegisterInfo opr_register(const std::string& key, size_t nr_devices, | RegisterInfo opr_register(const std::string& key, size_t nr_devices, | ||||
bool is_root, int rank, uint64_t comp_node_hash); | bool is_root, int rank, uint64_t comp_node_hash); | ||||
//! gather uids from all ranks | |||||
std::vector<std::string> gather_uid(const std::string& uid, | |||||
const std::string& key, uint32_t size, uint32_t rank); | |||||
//! broadcast master_ip and port | |||||
void bcast_addr(std::string& master_ip, int& port, | |||||
const std::string& key, uint32_t size, uint32_t rank, uint32_t root); | |||||
//! Set output shape of this key | //! Set output shape of this key | ||||
void set_output_shape(const std::string& key, const TensorShape& shape); | void set_output_shape(const std::string& key, const TensorShape& shape); | ||||
@@ -102,12 +102,13 @@ class GroupManager { | |||||
std::unordered_map<std::string, GroupInfo> m_key2group_info; | std::unordered_map<std::string, GroupInfo> m_key2group_info; | ||||
std::mutex m_key2group_info_mtx; | std::mutex m_key2group_info_mtx; | ||||
//! key -> uid | |||||
std::unordered_map<std::string, std::vector<std::string>> m_key2uids; | |||||
std::unordered_map<std::string, uint32_t> m_key2uids_size; | |||||
std::unordered_map<std::string, bool> m_key2uids_flag; | |||||
std::mutex m_key2uids_mtx; | |||||
std::condition_variable m_gather_uid_cv; | |||||
//! key -> addr | |||||
std::unordered_map<std::string, std::string> m_key2master_ip; | |||||
std::unordered_map<std::string, int> m_key2port; | |||||
std::unordered_map<std::string, uint32_t> m_key2addr_size; | |||||
std::unordered_map<std::string, bool> m_key2addr_flag; | |||||
std::mutex m_key2addr_mtx; | |||||
std::condition_variable m_bcast_cv; | |||||
//! barrier | //! barrier | ||||
uint32_t m_barrier_size; | uint32_t m_barrier_size; | ||||
@@ -133,8 +134,8 @@ class GroupClient { | |||||
bool is_root, int rank, | bool is_root, int rank, | ||||
uint64_t comp_node_hash) = 0; | uint64_t comp_node_hash) = 0; | ||||
virtual std::vector<std::string> gather_uid(const std::string& uid, | |||||
const std::string& key, uint32_t size, uint32_t rank) = 0; | |||||
virtual void bcast_addr(std::string& master_ip, int& port, | |||||
const std::string& key, uint32_t size, uint32_t rank, uint32_t root) = 0; | |||||
virtual void set_output_shape(const std::string& key, | virtual void set_output_shape(const std::string& key, | ||||
const TensorShape& shape) = 0; | const TensorShape& shape) = 0; | ||||
@@ -37,8 +37,8 @@ public: | |||||
int rank, | int rank, | ||||
uint64_t comp_node_hash) override; | uint64_t comp_node_hash) override; | ||||
std::vector<std::string> gather_uid(const std::string& uid, | |||||
const std::string& key, uint32_t size, uint32_t rank) override; | |||||
void bcast_addr(std::string& master_ip, int& port, const std::string& key, | |||||
uint32_t size, uint32_t rank, uint32_t root) override; | |||||
void set_output_shape(const std::string& key, | void set_output_shape(const std::string& key, | ||||
const TensorShape& shape) override; | const TensorShape& shape) override; | ||||
@@ -16,15 +16,18 @@ message OprRegisterResponse { | |||||
int32 root_rank = 3; | int32 root_rank = 3; | ||||
} | } | ||||
message GatherUidRequest { | |||||
bytes uid = 1; | |||||
string key = 2; | |||||
uint32 size = 3; | |||||
uint32 rank = 4; | |||||
} | |||||
message GatherUidResponse { | |||||
repeated bytes uids = 1; | |||||
message BcastAddrRequest { | |||||
string master_ip = 1; | |||||
int32 port = 2; | |||||
string key = 3; | |||||
uint32 size = 4; | |||||
uint32 rank = 5; | |||||
uint32 root = 6; | |||||
} | |||||
message BcastAddrResponse { | |||||
string master_ip = 1; | |||||
int32 port = 2; | |||||
} | } | ||||
message SetOutputShapeRequest { | message SetOutputShapeRequest { | ||||
@@ -29,13 +29,14 @@ class MockGroupClient final : public opr::GroupClient { | |||||
} | } | ||||
RegisterInfo opr_register(const std::string& key, size_t nr_devices, | RegisterInfo opr_register(const std::string& key, size_t nr_devices, | ||||
bool is_root, int rank, uint64_t comp_node_hash) { | |||||
bool is_root, int rank, uint64_t comp_node_hash) override { | |||||
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); | ||||
} | } | ||||
std::vector<std::string> gather_uid(const std::string& uid, | |||||
const std::string& key, uint32_t size, uint32_t rank) { | |||||
return m_mgr.gather_uid(uid, key, size, rank); | |||||
void bcast_addr(std::string& master_ip, int& port, | |||||
const std::string& key, uint32_t size, | |||||
uint32_t rank, uint32_t root) override { | |||||
return m_mgr.bcast_addr(master_ip, port, key, size, rank, root); | |||||
} | } | ||||
void set_output_shape(const std::string& key, | void set_output_shape(const std::string& key, | ||||
@@ -1 +1 @@ | |||||
Subproject commit d06c215dc1425fa932e20ecfaab7b07c0343a5bc | |||||
Subproject commit e14e4f84c1349598ba17c49923168db47a4e9642 |