GitOrigin-RevId: 5e8c27ac81
release-1.1
@@ -46,7 +46,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||||
ssprintf("%s:%d", recv.addr.data(), recv.port)); | ssprintf("%s:%d", recv.addr.data(), recv.port)); | ||||
auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | ||||
recv.key, *graph, group_client, OperatorNodeConfig{recv.cn}, | |||||
recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn}, | |||||
recv.shape, recv.dtype)); | recv.shape, recv.dtype)); | ||||
} | } | ||||
@@ -60,6 +60,43 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | |||||
} // anonymous namespace | } // anonymous namespace | ||||
#endif // MGB_ENABLE_OPR_MM | #endif // MGB_ENABLE_OPR_MM | ||||
bool RemoteSend::is_same_st(const Hashable& another) const{ | |||||
return as_tuple() == another.cast_final<RemoteSend>().as_tuple(); | |||||
} | |||||
size_t RemoteSend::hash() const{ | |||||
XXHash xxhash; | |||||
auto append = [&xxhash](auto field){ | |||||
auto hash_val = HashTrait<decltype(field)>::eval(field); | |||||
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||||
}; | |||||
append(key); | |||||
append(addr); | |||||
append(port); | |||||
append(rank_to); | |||||
return xxhash.digest(); | |||||
} | |||||
bool RemoteRecv::is_same_st(const Hashable& another) const{ | |||||
return as_tuple() == another.cast_final<RemoteRecv>().as_tuple(); | |||||
} | |||||
size_t RemoteRecv::hash() const{ | |||||
XXHash xxhash; | |||||
auto append = [&xxhash](auto field){ | |||||
auto hash_val = HashTrait<decltype(field)>::eval(field); | |||||
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||||
}; | |||||
append(key); | |||||
append(addr); | |||||
append(port); | |||||
append(rank_from); | |||||
append(cn.to_string()); | |||||
append(dtype.handle()); | |||||
append(shape.to_string()); | |||||
return xxhash.digest(); | |||||
} | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | ||||
@@ -31,6 +31,13 @@ public: | |||||
std::string addr; | std::string addr; | ||||
uint32_t port; | uint32_t port; | ||||
uint32_t rank_to; | uint32_t rank_to; | ||||
size_t hash() const override; | |||||
bool is_same_st(const Hashable& another) const override; | |||||
auto as_tuple() const{ | |||||
return std::tuple(key, addr, port, rank_to); | |||||
} | |||||
}; | }; | ||||
class RemoteRecv : public OpDefImplBase<RemoteRecv> { | class RemoteRecv : public OpDefImplBase<RemoteRecv> { | ||||
@@ -55,6 +62,13 @@ public: | |||||
CompNode cn; | CompNode cn; | ||||
TensorShape shape; | TensorShape shape; | ||||
DType dtype; | DType dtype; | ||||
size_t hash() const override; | |||||
bool is_same_st(const Hashable& another) const override; | |||||
auto as_tuple() const{ | |||||
return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string()); | |||||
} | |||||
}; | }; | ||||
} // namespace imperative | } // namespace imperative | ||||
@@ -151,6 +151,23 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||||
add_equivalence_component<ScalarHash<void*>>(this); | add_equivalence_component<ScalarHash<void*>>(this); | ||||
} | } | ||||
RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||||
std::shared_ptr<GroupClient> group_client, | |||||
const OperatorNodeConfig& config, | |||||
const TensorShape& shape, DType dtype) : | |||||
Super(&graph, config, "remote_recv", {}), | |||||
m_shape(shape), m_dtype(dtype) { | |||||
m_key = key; | |||||
m_group_client = group_client; | |||||
add_input({var}); | |||||
add_output(None) | |||||
->dtype(dtype) | |||||
.add_flag(VarNode::Flag::NO_MEM_RECLAIM) | |||||
.add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC); | |||||
add_equivalence_component<ScalarHash<void*>>(this); | |||||
} | |||||
SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
@@ -160,6 +177,15 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||||
return opr->output(0); | return opr->output(0); | ||||
} | } | ||||
SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||||
std::shared_ptr<GroupClient> group_client, | |||||
const OperatorNodeConfig& config, | |||||
const TensorShape& shape, DType dtype) { | |||||
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | |||||
key, var.node(), graph, group_client, config, shape, dtype)); | |||||
return opr->output(0); | |||||
} | |||||
void RemoteRecv::scn_do_execute() { | void RemoteRecv::scn_do_execute() { | ||||
if (!m_init) { | if (!m_init) { | ||||
auto&& comp_node = output(0)->comp_node(); | auto&& comp_node = output(0)->comp_node(); | ||||
@@ -77,12 +77,23 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | |||||
const OperatorNodeConfig& config, const TensorShape& shape, | const OperatorNodeConfig& config, const TensorShape& shape, | ||||
DType dtype); | DType dtype); | ||||
RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||||
std::shared_ptr<GroupClient> group_client, | |||||
const OperatorNodeConfig& config, const TensorShape& shape, | |||||
DType dtype); | |||||
static SymbolVar make( | static SymbolVar make( | ||||
const std::string& key, cg::ComputingGraph& graph, | const std::string& key, cg::ComputingGraph& graph, | ||||
std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
const OperatorNodeConfig& config, const TensorShape& shape, | const OperatorNodeConfig& config, const TensorShape& shape, | ||||
DType dtype); | DType dtype); | ||||
static SymbolVar make( | |||||
const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||||
std::shared_ptr<GroupClient> group_client, | |||||
const OperatorNodeConfig& config, const TensorShape& shape, | |||||
DType dtype); | |||||
private: | private: | ||||
const TensorShape m_shape; | const TensorShape m_shape; | ||||
const DType m_dtype; | const DType m_dtype; | ||||