From bd3b9cb6d45a5bdbb4e8dabcb03dedc29dd33479 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 9 Sep 2020 14:39:15 +0800 Subject: [PATCH] fix(mge/oprmm): fix grad for collective comm GitOrigin-RevId: 8e28f46c905002857431bf725ef414afc4a48704 --- imperative/src/impl/ops/collective_comm.cpp | 18 +++++++++ .../include/megbrain/imperative/ops/tensor_manip.h | 47 ++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/imperative/src/impl/ops/collective_comm.cpp b/imperative/src/impl/ops/collective_comm.cpp index 7d23b764..4c23be99 100644 --- a/imperative/src/impl/ops/collective_comm.cpp +++ b/imperative/src/impl/ops/collective_comm.cpp @@ -49,8 +49,26 @@ cg::OperatorNodeBase* apply_on_var_node( dev_buffer_arr, config, disable)); } +std::tuple split_address(const std::string& address_and_port){ + auto index = address_and_port.find_last_of(':'); + mgb_assert(index != std::string::npos, "missing ':' in server address"); + return {address_and_port.substr(0, index), address_and_port.substr(index+1)}; +} + +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node) { + auto&& comm = node->cast_final_safe(); + auto&& group_client = comm.group_client(); + auto [addr, port] = split_address(group_client->get_addr()); + auto comp_node = node->config().get_single_comp_node().to_string_logical(); + return std::make_shared( + comm.key(), comm.nr_devices(), comm.rank(), comm.is_root(), + comm.local_grad(), addr, std::stoi(port), comm.param().mode, + comm.dtype(), comm.backend(), comp_node); +} + OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm) .apply_on_var_node(apply_on_var_node) + .make_from_op_node(make_from_op_node) .fallback(); } // anonymous namespace diff --git a/imperative/src/include/megbrain/imperative/ops/tensor_manip.h b/imperative/src/include/megbrain/imperative/ops/tensor_manip.h index d00918a0..23a5bc31 100644 --- a/imperative/src/include/megbrain/imperative/ops/tensor_manip.h +++ b/imperative/src/include/megbrain/imperative/ops/tensor_manip.h @@ -13,6 +13,8 @@ #include "megbrain/imperative/op_def.h" +#include "megbrain/utils/hash.h" + namespace mgb::imperative { class GetVarShape : public OpDefImplBase { @@ -41,6 +43,33 @@ public: std::vector offsets; std::vector> shapes; + + size_t hash() const override { + XXHash builder; + for (auto&& offset : offsets) { + builder.update(&offset, sizeof(offset)); + } + auto&& offset_cnt = offsets.size(); + builder.update(&offset_cnt, sizeof(offset_cnt)); + for (auto&& shape : shapes) { + for (auto&& dim_len : shape) { + builder.update(&dim_len, sizeof(dim_len)); + } + auto&& dim_cnt = shape.size(); + builder.update(&dim_cnt, sizeof(dim_cnt)); + } + auto&& shape_cnt = shapes.size(); + builder.update(&shape_cnt, sizeof(shape_cnt)); + return builder.digest(); + } + + bool is_same_st(const Hashable& rhs) const override { + auto* pps = rhs.try_cast_final(); + if(pps == nullptr){ + return false; + } + return offsets == pps->offsets && shapes == pps->shapes; + } }; class ParamPackConcat : public OpDefImplBase { @@ -53,6 +82,24 @@ public: : offsets(offsets_) {} std::vector offsets; + + size_t hash() const override { + XXHash builder; + for (auto&& offset : offsets) { + builder.update(&offset, sizeof(offset)); + } + auto&& offset_cnt = offsets.size(); + builder.update(&offset_cnt, sizeof(offset_cnt)); + return builder.digest(); + } + + bool is_same_st(const Hashable& rhs) const override { + auto* ppc = rhs.try_cast_final(); + if(ppc == nullptr){ + return false; + } + return offsets == ppc->offsets; + } }; } // namespace mgb::imperative