From 1c2ffc3ed99f8d150f3b3c21cc15013ac54d5071 Mon Sep 17 00:00:00 2001 From: isaacxr Date: Tue, 8 Jun 2021 15:35:45 +0800 Subject: [PATCH] alltoall node executor --- ge/common/types.cc | 2 + ge/hybrid/node_executor/hccl/hccl_node_executor.cc | 124 ++++++++++++- ge/hybrid/node_executor/hccl/hccl_node_executor.h | 16 ++ inc/framework/common/types.h | 2 + tests/depends/hccl/src/hccl_stub.cc | 11 ++ tests/ut/ge/CMakeLists.txt | 1 + .../hccl/hccl_node_executor_unittest.cc | 199 +++++++++++++++++++++ third_party/fwkacllib/inc/hccl/base.h | 24 +++ third_party/fwkacllib/inc/hccl/hcom.h | 5 + 9 files changed, 383 insertions(+), 1 deletion(-) create mode 100644 tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc diff --git a/ge/common/types.cc b/ge/common/types.cc index 33b7f437..ab0b0379 100644 --- a/ge/common/types.cc +++ b/ge/common/types.cc @@ -391,6 +391,8 @@ REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); REGISTER_OPTYPE_DEFINE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); +REGISTER_OPTYPE_DEFINE(HCOMALLTOALLVDYNAMIC, "HcomAllToAllVDynamic"); +REGISTER_OPTYPE_DEFINE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV"); REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp"); diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc index c46d5080..4ff55ea1 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc @@ -23,8 +23,8 @@ #include "graph/runtime_inference_context.h" #include "graph/utils/type_utils.h" #include "graph/types.h" -#include "hccl/hcom.h" #include "hybrid/executor/hybrid_execution_context.h" +#include "hccl/hcom.h" namespace ge { namespace { @@ -32,9 +32,14 @@ constexpr size_t kVarTableDims = 2; constexpr size_t kVarTableRowCnt = 3; constexpr size_t kVarTableIdxAddr = 1; constexpr size_t kVarTableIdxLen = 2; +// input anchor nums according to IR +constexpr size_t kAllToAllVInputNums = 5; +constexpr size_t kGatherAllToAllVInputNums = 4; + const std::set kRdmaReadTypes = { HCOMREMOTEREAD, HCOMREMOTEREFREAD }; const std::set kRdmaWriteTypes = { HCOMREMOTEWRITE, HCOMREMOTESCATTERWRITE }; const std::set kRdmaScatterTypes = { HCOMREMOTEREFREAD, HCOMREMOTESCATTERWRITE }; +const std::set kAllToAllTypes = {HCOMALLTOALLVDYNAMIC, HCOMGATHERALLTOALLV}; } // namespace namespace hybrid { @@ -345,6 +350,121 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function do return SUCCESS; } +Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { + void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, + ¶ms.recvcounts, ¶ms.rdispls}; + for (size_t i = 0; i < kAllToAllVInputNums; ++i) { + auto addr = context.MutableInput(i); + GE_CHECK_NOTNULL(addr); + *input_addrs[i] = addr->MutableData(); + } + auto recv_tv = context.MutableOutput(0); + GE_CHECK_NOTNULL(recv_tv); + params.recvbuf = recv_tv->MutableData(); + + const NodeItem &node_item = context.GetNodeItem(); + const OpDescPtr op_desc = node_item.GetOpDesc(); + auto input_desc = node_item.MutableInputDesc(0); + GE_CHECK_NOTNULL(input_desc); + ge::DataType src_data_type = input_desc->GetDataType(); + auto iter = kConstOpHcclDataType.find(static_cast(src_data_type)); + if (iter == kConstOpHcclDataType.end()) { + REPORT_INNER_ERROR("E19999", "%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(), + TypeUtils::DataTypeToSerialString(src_data_type).c_str()); + GELOGE(PARAM_INVALID, "[Find][DataType]%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(), + TypeUtils::DataTypeToSerialString(src_data_type).c_str()); + return PARAM_INVALID; + } + params.sendtype = iter->second; + params.recvtype = iter->second; + + return SUCCESS; +} + +Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams ¶ms) { + void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, + ¶ms.recvcounts, ¶ms.rdispls}; + for (size_t i = 0; i < kGatherAllToAllVInputNums; ++i) { + auto addr = context.MutableInput(i); + GE_CHECK_NOTNULL(addr); + *input_addrs[i] = addr->MutableData(); + } + auto recv_tv = context.MutableOutput(0); + GE_CHECK_NOTNULL(recv_tv); + params.recvbuf = recv_tv->MutableData(); + auto gathered_tv = context.MutableOutput(1); + GE_CHECK_NOTNULL(gathered_tv); + params.gatheredbuf = gathered_tv->MutableData(); + + const NodeItem &node_item = context.GetNodeItem(); + const OpDescPtr op_desc = node_item.GetOpDesc(); + + ge::DataType data_type = ge::DT_FLOAT; + (void)ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, data_type); + auto iter = kConstOpHcclDataType.find(static_cast(data_type)); + if (iter == kConstOpHcclDataType.end()) { + REPORT_INNER_ERROR("E19999", "%s received datatype:%s not support.", op_desc->GetName().c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + GELOGE(PARAM_INVALID, "[Find][DataType]%s received datatype:%s not support.", op_desc->GetName().c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + return PARAM_INVALID; + } + params.recvtype = iter->second; + + int64_t addr_len; + (void) ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); + params.addrLength = static_cast(addr_len); + + return SUCCESS; +} + +Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GELOGI("[%s] AllToAllNodeTask::ExecuteAsync in.", context.GetNodeName()); + + TaskContext *p_ctx = &context; + auto callback = [p_ctx, done_callback](HcclResult status){ + if (status != HCCL_SUCCESS) { + GELOGE(HCCL_E_INTERNAL, "[%s] AllToAllNodeTask execute failed.", p_ctx->GetNodeName()); + p_ctx->SetStatus(FAILED); + } + done_callback(); + GELOGI("[%s] AllToAllNodeTask callback successfully.", p_ctx->GetNodeName()); + }; + + if (context.GetNodeItem().NodeType() == HCOMALLTOALLVDYNAMIC) { + auto HcomExecEnqueueAllToAllV = (HcclResult(*)(HcomAllToAllVParams, std::function))dlsym( + context.handle_, "HcomExecEnqueueAllToAllV"); + if (HcomExecEnqueueAllToAllV == nullptr) { + GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueAllToAllV] for node:%s.",context.GetNodeName()); + return FAILED; + } + HcomAllToAllVParams params; + GE_CHK_STATUS_RET(BuildAllToAllVparams(context, params)); + HcclResult hccl_ret = HcomExecEnqueueAllToAllV(params, callback); + if (hccl_ret != HCCL_SUCCESS) { + GELOGE(HCCL_E_INTERNAL, "AllToAllV teak enqueue failed for node [%s].", context.GetNodeName()); + return HCCL_E_INTERNAL; + } + } else { + auto HcomExecEnqueueGatherAllToAllV = + (HcclResult(*)(HcomGatherAllToAllVParams, std::function))dlsym( + context.handle_, "HcomExecEnqueueGatherAllToAllV"); + if (HcomExecEnqueueGatherAllToAllV == nullptr) { + GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueGatherAllToAllV] for node:%s.", context.GetNodeName()); + return FAILED; + } + HcomGatherAllToAllVParams params; + GE_CHK_STATUS_RET(BuildGatherAllToAllParams(context, params)); + HcclResult hccl_ret = HcomExecEnqueueGatherAllToAllV(params, callback); + if (hccl_ret != HCCL_SUCCESS) { + GELOGE(HCCL_E_INTERNAL, "GatherAllToAllV teak enqueue failed for node [%s].", context.GetNodeName()); + return HCCL_E_INTERNAL; + } + } + GELOGI("[%s] AllToAllNodeTask::ExecuteAsync success.", context.GetNodeName()); + return SUCCESS; +} + Status HcclNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } Status HcclNodeTask::Init(TaskContext &context) { @@ -375,6 +495,8 @@ Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, GE_CHECK_NOTNULL(node); if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) { task = MakeShared(); + } else if (kAllToAllTypes.count(node->GetType()) > 0) { + task = MakeShared(); } else { task = MakeShared(); } diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.h b/ge/hybrid/node_executor/hccl/hccl_node_executor.h index 873f259f..d42ae884 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.h +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.h @@ -62,6 +62,22 @@ class RdmaNodeTask : public NodeTask { bool skip_flag_; }; + +class AllToAllNodeTask : public NodeTask { + public: + AllToAllNodeTask() = default; + + ~AllToAllNodeTask() = default; + + Status UpdateArgs(TaskContext &context) override { return SUCCESS; } + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + Status Init(TaskContext &context) override { return SUCCESS; } + + private: + std::mutex hccl_mutex_; + std::condition_variable cond_; +}; + class HcclNodeExecutor : public NodeExecutor { public: Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const; diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index 91759b8f..eea1c2fd 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -440,6 +440,8 @@ REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); REGISTER_OPTYPE_DECLARE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); +REGISTER_OPTYPE_DECLARE(HCOMALLTOALLVDYNAMIC, "HcomAllToAllVDynamic"); +REGISTER_OPTYPE_DECLARE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV"); REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); diff --git a/tests/depends/hccl/src/hccl_stub.cc b/tests/depends/hccl/src/hccl_stub.cc index b9b9d4f6..5f5e513c 100644 --- a/tests/depends/hccl/src/hccl_stub.cc +++ b/tests/depends/hccl/src/hccl_stub.cc @@ -42,3 +42,14 @@ HcclResult hcom_reduce_scatter(const char *tag, void *input_ptr, void *output_pt HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) { return HCCL_SUCCESS; } + +HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function callback) { + return HCCL_SUCCESS; +} + +HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params, +std::function callback) { + return HCCL_SUCCESS; +} + + diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index ec0b146c..e3f93bf5 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -840,6 +840,7 @@ set(HYBRID_TEST_FILES "hybrid/model/hybrid_model_builder_unittest.cc" "hybrid/node_executor/rts/rts_node_task_unittest.cc" "hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" + "hybrid/node_executor/hccl/hccl_node_executor_unittest.cc" "hybrid/executor/hybrid_model_async_executor_unittest.cc" "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" diff --git a/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc new file mode 100644 index 00000000..c50a9f98 --- /dev/null +++ b/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc @@ -0,0 +1,199 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#define private public +#define protected public +#include "hybrid/executor/subgraph_context.h" +#include "hybrid/node_executor/hccl/hccl_node_executor.h" +#include "model/ge_root_model.h" + +using namespace std; +using namespace testing; + +namespace { +const string kHcclSoPath = "../build/tests/depends/hccl/libhccl_stub.so"; +} +namespace ge { +using namespace hybrid; + +class UtestHcclNodeExecutor : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + + +static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { + OpDescPtr op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_NHWC, DT_INT64); + TensorUtils::SetSize(tensor, 64); + vector input_offset; + for (int i = 0; i < in_num; i++) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(i * 64); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; i++) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(in_num * 64 + i * 64); + } + op_desc->SetOutputOffset(output_offset); + + op_desc->SetWorkspace({}); + op_desc->SetWorkspaceBytes({}); + + return graph.AddNode(op_desc); +} + +TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) { +ComputeGraphPtr graph = std::make_shared("test"); +GeModelPtr ge_sub_model = std::make_shared(); +GeRootModelPtr ge_root_model = std::make_shared(graph); +ge_root_model->SetModelName("test_name"); +ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); +HybridModel hybrid_model(ge_root_model); + + +NodePtr node = CreateNode(*graph, "gatheralltoallv", HCOMGATHERALLTOALLV, 4, 2); + +std::unique_ptr new_node; +ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); +NodeItem *node_item = new_node.get(); +hybrid_model.node_items_[node] = std::move(new_node); +node_item->input_start = 0; +node_item->output_start = 0; + +GraphItem graph_item; +graph_item.node_items_.emplace_back(node_item); +graph_item.total_inputs_ = 4; +graph_item.total_outputs_ = 2; + +GraphExecutionContext graph_context; +SubgraphContext subgraph_context(&graph_item, &graph_context); +ASSERT_EQ(subgraph_context.Init(), SUCCESS); +graph_context.callback_manager = std::unique_ptr(new CallbackManager()); + +auto node_state = subgraph_context.GetOrCreateNodeState(node_item); +ASSERT_NE(node_state, nullptr); + +auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); +ASSERT_NE(unique_task_context, nullptr); +auto shared_task_context = std::shared_ptr(unique_task_context.release()); +node_state->SetTaskContext(shared_task_context); + +for (int i=0; i<4; ++i) { +uint64_t value_0 = 512; +TensorValue in_tensor0(&value_0, sizeof(value_0)); +subgraph_context.SetInput(*node_item, 0, in_tensor0); +} + +uint64_t value_0 = 512; +TensorValue out_tensor0(&value_0, sizeof(value_0)); +subgraph_context.SetOutput(*node_item, 0, out_tensor0); + +uint64_t value_1 = 512; +TensorValue out_tensor1(&value_1, sizeof(value_1)); +subgraph_context.SetOutput(*node_item, 1, out_tensor1); + +NodeTaskPtr task = nullptr; +HcclNodeExecutor node_executor; +ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); +ASSERT_NE(task, nullptr); + +auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL); +ASSERT_NE(handle, nullptr); +node_state->GetTaskContext()->handle_ = handle; +std::function done = []() {}; +ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); + +if (handle = nullptr) { +dlclose(handle); +} +} + +TEST_F(UtestHcclNodeExecutor, alltoallv_execute) { +ComputeGraphPtr graph = std::make_shared("test"); +GeModelPtr ge_sub_model = std::make_shared(); +GeRootModelPtr ge_root_model = std::make_shared(graph); +ge_root_model->SetModelName("test_name"); +ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); +HybridModel hybrid_model(ge_root_model); + + +NodePtr node = CreateNode(*graph, "alltoallv", HCOMALLTOALLVDYNAMIC, 5, 1); + +std::unique_ptr new_node; +ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); +NodeItem *node_item = new_node.get(); +hybrid_model.node_items_[node] = std::move(new_node); +node_item->input_start = 0; +node_item->output_start = 0; + +GraphItem graph_item; +graph_item.node_items_.emplace_back(node_item); +graph_item.total_inputs_ = 5; +graph_item.total_outputs_ = 1; + +GraphExecutionContext graph_context; +SubgraphContext subgraph_context(&graph_item, &graph_context); +ASSERT_EQ(subgraph_context.Init(), SUCCESS); +graph_context.callback_manager = std::unique_ptr(new CallbackManager()); + +auto node_state = subgraph_context.GetOrCreateNodeState(node_item); +ASSERT_NE(node_state, nullptr); + +auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); +ASSERT_NE(unique_task_context, nullptr); +auto shared_task_context = std::shared_ptr(unique_task_context.release()); +node_state->SetTaskContext(shared_task_context); + +for (int i=0; i<5; ++i) { +uint64_t value_0 = 512; +TensorValue in_tensor0(&value_0, sizeof(value_0)); +subgraph_context.SetInput(*node_item, 0, in_tensor0); +} + +uint64_t value_1 = 512; +TensorValue out_tensor0(&value_1, sizeof(value_1)); +subgraph_context.SetOutput(*node_item, 0, out_tensor0); +NodeTaskPtr task = nullptr; +HcclNodeExecutor node_executor; +ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); +ASSERT_NE(task, nullptr); + +auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL); +ASSERT_NE(handle, nullptr); +node_state->GetTaskContext()->handle_ = handle; + +std::function done = []() {}; +ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); + +if (handle = nullptr) { + dlclose(handle); +} +} +} // namespace ge + diff --git a/third_party/fwkacllib/inc/hccl/base.h b/third_party/fwkacllib/inc/hccl/base.h index 9facd20c..e57563b3 100644 --- a/third_party/fwkacllib/inc/hccl/base.h +++ b/third_party/fwkacllib/inc/hccl/base.h @@ -123,6 +123,30 @@ struct HcomRemoteAccessAddrInfo { u64 length; // Memory Length in Bytes }; +struct HcomAllToAllVParams { + void *sendbuf; + void *sendcounts; + void *sdispls; + HcclDataType sendtype; + void *recvbuf; + void *recvcounts; + void *rdispls; + HcclDataType recvtype; + const char *group; +}; + +struct HcomGatherAllToAllVParams { + void *addrInfo; + void *addrInfoCountPerRank; + void *recvbuf; + void *recvcounts; + void *rdispls; + void *gatheredbuf; + s32 addrLength; + HcclDataType recvtype; + const char *group; +}; + #ifdef __cplusplus } #endif // __cplusplus diff --git a/third_party/fwkacllib/inc/hccl/hcom.h b/third_party/fwkacllib/inc/hccl/hcom.h index 972f470c..955764d6 100644 --- a/third_party/fwkacllib/inc/hccl/hcom.h +++ b/third_party/fwkacllib/inc/hccl/hcom.h @@ -164,6 +164,11 @@ HcclResult HcomExecEnqueueRemoteAccess(const std::string& remoteAccessType, const std::vector& addrInfos, std::function callback); +HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function callback); + +HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params, + std::function callback); + /** * @brief Register memories and init resources for remote access. *