GitOrigin-RevId: 5e8c27ac81
release-1.1
@@ -46,7 +46,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||
ssprintf("%s:%d", recv.addr.data(), recv.port)); | |||
auto&& graph = inputs[0]->owner_graph(); | |||
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | |||
recv.key, *graph, group_client, OperatorNodeConfig{recv.cn}, | |||
recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn}, | |||
recv.shape, recv.dtype)); | |||
} | |||
@@ -60,6 +60,43 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | |||
} // anonymous namespace | |||
#endif // MGB_ENABLE_OPR_MM | |||
bool RemoteSend::is_same_st(const Hashable& another) const{ | |||
return as_tuple() == another.cast_final<RemoteSend>().as_tuple(); | |||
} | |||
size_t RemoteSend::hash() const{ | |||
XXHash xxhash; | |||
auto append = [&xxhash](auto field){ | |||
auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
}; | |||
append(key); | |||
append(addr); | |||
append(port); | |||
append(rank_to); | |||
return xxhash.digest(); | |||
} | |||
bool RemoteRecv::is_same_st(const Hashable& another) const{ | |||
return as_tuple() == another.cast_final<RemoteRecv>().as_tuple(); | |||
} | |||
size_t RemoteRecv::hash() const{ | |||
XXHash xxhash; | |||
auto append = [&xxhash](auto field){ | |||
auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
}; | |||
append(key); | |||
append(addr); | |||
append(port); | |||
append(rank_from); | |||
append(cn.to_string()); | |||
append(dtype.handle()); | |||
append(shape.to_string()); | |||
return xxhash.digest(); | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | |||
@@ -31,6 +31,13 @@ public: | |||
std::string addr; | |||
uint32_t port; | |||
uint32_t rank_to; | |||
size_t hash() const override; | |||
bool is_same_st(const Hashable& another) const override; | |||
auto as_tuple() const{ | |||
return std::tuple(key, addr, port, rank_to); | |||
} | |||
}; | |||
class RemoteRecv : public OpDefImplBase<RemoteRecv> { | |||
@@ -55,6 +62,13 @@ public: | |||
CompNode cn; | |||
TensorShape shape; | |||
DType dtype; | |||
size_t hash() const override; | |||
bool is_same_st(const Hashable& another) const override; | |||
auto as_tuple() const{ | |||
return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string()); | |||
} | |||
}; | |||
} // namespace imperative | |||
@@ -151,6 +151,23 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, | |||
add_equivalence_component<ScalarHash<void*>>(this); | |||
} | |||
RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, | |||
const TensorShape& shape, DType dtype) : | |||
Super(&graph, config, "remote_recv", {}), | |||
m_shape(shape), m_dtype(dtype) { | |||
m_key = key; | |||
m_group_client = group_client; | |||
add_input({var}); | |||
add_output(None) | |||
->dtype(dtype) | |||
.add_flag(VarNode::Flag::NO_MEM_RECLAIM) | |||
.add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC); | |||
add_equivalence_component<ScalarHash<void*>>(this); | |||
} | |||
SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, | |||
@@ -160,6 +177,15 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, | |||
return opr->output(0); | |||
} | |||
SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, | |||
const TensorShape& shape, DType dtype) { | |||
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>( | |||
key, var.node(), graph, group_client, config, shape, dtype)); | |||
return opr->output(0); | |||
} | |||
void RemoteRecv::scn_do_execute() { | |||
if (!m_init) { | |||
auto&& comp_node = output(0)->comp_node(); | |||
@@ -77,12 +77,23 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | |||
const OperatorNodeConfig& config, const TensorShape& shape, | |||
DType dtype); | |||
RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, const TensorShape& shape, | |||
DType dtype); | |||
static SymbolVar make( | |||
const std::string& key, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, const TensorShape& shape, | |||
DType dtype); | |||
static SymbolVar make( | |||
const std::string& key, SymbolVar var, cg::ComputingGraph& graph, | |||
std::shared_ptr<GroupClient> group_client, | |||
const OperatorNodeConfig& config, const TensorShape& shape, | |||
DType dtype); | |||
private: | |||
const TensorShape m_shape; | |||
const DType m_dtype; | |||