From 064c774f05626ba0a9d2f7dfbf138dfa5df82ded Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 18 Nov 2020 11:38:06 +0800 Subject: [PATCH] feat(imperative): impl hashable for SendRecv and add virtual input for Recv GitOrigin-RevId: 5e8c27ac81b844bbf8720a108da7ad208060dad2 --- imperative/src/impl/ops/io_remote.cpp | 39 +++++++++++++++++++++- .../include/megbrain/imperative/ops/io_remote.h | 14 ++++++++ src/opr-mm/impl/io_remote.cpp | 26 +++++++++++++++ src/opr-mm/include/megbrain/opr/io_remote.h | 11 ++++++ 4 files changed, 89 insertions(+), 1 deletion(-) diff --git a/imperative/src/impl/ops/io_remote.cpp b/imperative/src/impl/ops/io_remote.cpp index 439dc1a3..09608292 100644 --- a/imperative/src/impl/ops/io_remote.cpp +++ b/imperative/src/impl/ops/io_remote.cpp @@ -46,7 +46,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( ssprintf("%s:%d", recv.addr.data(), recv.port)); auto&& graph = inputs[0]->owner_graph(); return graph->insert_opr(std::make_unique( - recv.key, *graph, group_client, OperatorNodeConfig{recv.cn}, + recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn}, recv.shape, recv.dtype)); } @@ -60,6 +60,43 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) } // anonymous namespace #endif // MGB_ENABLE_OPR_MM +bool RemoteSend::is_same_st(const Hashable& another) const{ + return as_tuple() == another.cast_final().as_tuple(); +} + +size_t RemoteSend::hash() const{ + XXHash xxhash; + auto append = [&xxhash](auto field){ + auto hash_val = HashTrait::eval(field); + xxhash.update(reinterpret_cast(&hash_val), sizeof(hash_val)); + }; + append(key); + append(addr); + append(port); + append(rank_to); + return xxhash.digest(); +} + +bool RemoteRecv::is_same_st(const Hashable& another) const{ + return as_tuple() == another.cast_final().as_tuple(); +} + +size_t RemoteRecv::hash() const{ + XXHash xxhash; + auto append = [&xxhash](auto field){ + auto hash_val = HashTrait::eval(field); + xxhash.update(reinterpret_cast(&hash_val), sizeof(hash_val)); + }; + append(key); + append(addr); + append(port); + append(rank_from); + append(cn.to_string()); + append(dtype.handle()); + append(shape.to_string()); + return xxhash.digest(); +} + MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); diff --git a/imperative/src/include/megbrain/imperative/ops/io_remote.h b/imperative/src/include/megbrain/imperative/ops/io_remote.h index 8de67fb4..9ec6e4f4 100644 --- a/imperative/src/include/megbrain/imperative/ops/io_remote.h +++ b/imperative/src/include/megbrain/imperative/ops/io_remote.h @@ -31,6 +31,13 @@ public: std::string addr; uint32_t port; uint32_t rank_to; + + size_t hash() const override; + bool is_same_st(const Hashable& another) const override; + + auto as_tuple() const{ + return std::tuple(key, addr, port, rank_to); + } }; class RemoteRecv : public OpDefImplBase { @@ -55,6 +62,13 @@ public: CompNode cn; TensorShape shape; DType dtype; + + size_t hash() const override; + bool is_same_st(const Hashable& another) const override; + + auto as_tuple() const{ + return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string()); + } }; } // namespace imperative diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index 53e5b1c3..81dd1a00 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -151,6 +151,23 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, add_equivalence_component>(this); } +RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, + std::shared_ptr group_client, + const OperatorNodeConfig& config, + const TensorShape& shape, DType dtype) : + Super(&graph, config, "remote_recv", {}), + m_shape(shape), m_dtype(dtype) { + m_key = key; + m_group_client = group_client; + + add_input({var}); + add_output(None) + ->dtype(dtype) + .add_flag(VarNode::Flag::NO_MEM_RECLAIM) + .add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC); + add_equivalence_component>(this); +} + SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, @@ -160,6 +177,15 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, return opr->output(0); } +SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph, + std::shared_ptr group_client, + const OperatorNodeConfig& config, + const TensorShape& shape, DType dtype) { + auto opr = graph.insert_opr(std::make_unique( + key, var.node(), graph, group_client, config, shape, dtype)); + return opr->output(0); +} + void RemoteRecv::scn_do_execute() { if (!m_init) { auto&& comp_node = output(0)->comp_node(); diff --git a/src/opr-mm/include/megbrain/opr/io_remote.h b/src/opr-mm/include/megbrain/opr/io_remote.h index b365e476..335e2648 100644 --- a/src/opr-mm/include/megbrain/opr/io_remote.h +++ b/src/opr-mm/include/megbrain/opr/io_remote.h @@ -77,12 +77,23 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { const OperatorNodeConfig& config, const TensorShape& shape, DType dtype); + RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, + std::shared_ptr group_client, + const OperatorNodeConfig& config, const TensorShape& shape, + DType dtype); + static SymbolVar make( const std::string& key, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, const TensorShape& shape, DType dtype); + static SymbolVar make( + const std::string& key, SymbolVar var, cg::ComputingGraph& graph, + std::shared_ptr group_client, + const OperatorNodeConfig& config, const TensorShape& shape, + DType dtype); + private: const TensorShape m_shape; const DType m_dtype;