@@ -391,6 +391,8 @@ REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); | |||
REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | |||
REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | |||
REGISTER_OPTYPE_DEFINE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | |||
REGISTER_OPTYPE_DEFINE(HCOMALLTOALLVDYNAMIC, "HcomAllToAllVDynamic"); | |||
REGISTER_OPTYPE_DEFINE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV"); | |||
REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); | |||
REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | |||
@@ -23,8 +23,8 @@ | |||
#include "graph/runtime_inference_context.h" | |||
#include "graph/utils/type_utils.h" | |||
#include "graph/types.h" | |||
#include "hccl/hcom.h" | |||
#include "hybrid/executor/hybrid_execution_context.h" | |||
#include "hccl/hcom.h" | |||
namespace ge { | |||
namespace { | |||
@@ -32,9 +32,14 @@ constexpr size_t kVarTableDims = 2; | |||
constexpr size_t kVarTableRowCnt = 3; | |||
constexpr size_t kVarTableIdxAddr = 1; | |||
constexpr size_t kVarTableIdxLen = 2; | |||
// input anchor nums according to IR | |||
constexpr size_t kAllToAllVInputNums = 5; | |||
constexpr size_t kGatherAllToAllVInputNums = 4; | |||
const std::set<std::string> kRdmaReadTypes = { HCOMREMOTEREAD, HCOMREMOTEREFREAD }; | |||
const std::set<std::string> kRdmaWriteTypes = { HCOMREMOTEWRITE, HCOMREMOTESCATTERWRITE }; | |||
const std::set<std::string> kRdmaScatterTypes = { HCOMREMOTEREFREAD, HCOMREMOTESCATTERWRITE }; | |||
const std::set<std::string> kAllToAllTypes = {HCOMALLTOALLVDYNAMIC, HCOMGATHERALLTOALLV}; | |||
} // namespace | |||
namespace hybrid { | |||
@@ -345,6 +350,132 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||
return SUCCESS; | |||
} | |||
Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { | |||
void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, | |||
¶ms.recvcounts, ¶ms.rdispls}; | |||
for (size_t i = 0; i < kAllToAllVInputNums; ++i) { | |||
auto addr = context.MutableInput(i); | |||
GE_CHECK_NOTNULL(addr); | |||
*input_addrs[i] = addr->MutableData(); | |||
} | |||
auto recv_tv = context.MutableOutput(0); | |||
GE_CHECK_NOTNULL(recv_tv); | |||
params.recvbuf = recv_tv->MutableData(); | |||
const NodeItem &node_item = context.GetNodeItem(); | |||
const OpDescPtr op_desc = node_item.GetOpDesc(); | |||
auto input_desc = node_item.MutableInputDesc(0); | |||
GE_CHECK_NOTNULL(input_desc); | |||
ge::DataType src_data_type = input_desc->GetDataType(); | |||
auto iter = kConstOpHcclDataType.find(static_cast<int64_t>(src_data_type)); | |||
if (iter == kConstOpHcclDataType.end()) { | |||
REPORT_INNER_ERROR("E19999", "%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(), | |||
TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||
GELOGE(PARAM_INVALID, "[Find][DataType]%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(), | |||
TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||
return PARAM_INVALID; | |||
} | |||
params.sendtype = iter->second; | |||
params.recvtype = iter->second; | |||
string group; | |||
(void)ge::AttrUtils::GetStr(op_desc, HCOM_ATTR_GROUP, group); | |||
params.group = group.c_str(); | |||
return SUCCESS; | |||
} | |||
Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams ¶ms) { | |||
void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, | |||
¶ms.recvcounts, ¶ms.rdispls}; | |||
for (size_t i = 0; i < kGatherAllToAllVInputNums; ++i) { | |||
auto addr = context.MutableInput(i); | |||
GE_CHECK_NOTNULL(addr); | |||
*input_addrs[i] = addr->MutableData(); | |||
} | |||
auto recv_tv = context.MutableOutput(0); | |||
GE_CHECK_NOTNULL(recv_tv); | |||
params.recvbuf = recv_tv->MutableData(); | |||
auto gathered_tv = context.MutableOutput(1); | |||
GE_CHECK_NOTNULL(gathered_tv); | |||
params.gatheredbuf = gathered_tv->MutableData(); | |||
const NodeItem &node_item = context.GetNodeItem(); | |||
const OpDescPtr op_desc = node_item.GetOpDesc(); | |||
ge::DataType data_type = ge::DT_FLOAT; | |||
(void)ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, data_type); | |||
auto iter = kConstOpHcclDataType.find(static_cast<int64_t>(data_type)); | |||
if (iter == kConstOpHcclDataType.end()) { | |||
REPORT_INNER_ERROR("E19999", "%s received datatype:%s not support.", op_desc->GetName().c_str(), | |||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
GELOGE(PARAM_INVALID, "[Find][DataType]%s received datatype:%s not support.", op_desc->GetName().c_str(), | |||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
return PARAM_INVALID; | |||
} | |||
params.recvtype = iter->second; | |||
int64_t addr_len; | |||
(void) ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); | |||
params.addrLength = static_cast<int>(addr_len); | |||
string group; | |||
(void) ge::AttrUtils::GetStr(op_desc, HCOM_ATTR_GROUP, group); | |||
params.group = group.c_str(); | |||
return SUCCESS; | |||
} | |||
Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | |||
GELOGI("[%s] AllToAllNodeTask::ExecuteAsync in.", context.GetNodeName()); | |||
TaskContext *p_ctx = &context; | |||
auto callback = [p_ctx, done_callback](HcclResult status){ | |||
if (status != HCCL_SUCCESS) { | |||
GELOGE(HCCL_E_INTERNAL, "[%s] AllToAllNodeTask execute failed.", p_ctx->GetNodeName()); | |||
p_ctx->SetStatus(FAILED); | |||
} | |||
done_callback(); | |||
GELOGI("[%s] AllToAllNodeTask callback successfully.", p_ctx->GetNodeName()); | |||
}; | |||
if (context.GetNodeItem().NodeType() == HCOMALLTOALLVDYNAMIC) { | |||
auto HcomExecEnqueueAllToAllV = (HcclResult(*)(HcomAllToAllVParams, std::function<void(HcclResult status)>))dlsym( | |||
context.handle_, "HcomExecEnqueueAllToAllV"); | |||
if (HcomExecEnqueueAllToAllV == nullptr) { | |||
GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueAllToAllV] for node:%s.",context.GetNodeName()); | |||
if (dlclose(context.handle_) != 0) { | |||
GELOGW("Failed to close handle %s.", dlerror()); | |||
} | |||
return FAILED; | |||
} | |||
HcomAllToAllVParams params; | |||
GE_CHK_STATUS_RET(BuildAllToAllVparams(context, params)); | |||
HcclResult hccl_ret = HcomExecEnqueueAllToAllV(params, callback); | |||
if (hccl_ret != HCCL_SUCCESS) { | |||
GELOGE(HCCL_E_INTERNAL, "AllToAllV teak enqueue failed for node [%s].", context.GetNodeName()); | |||
return HCCL_E_INTERNAL; | |||
} | |||
} else { | |||
auto HcomExecEnqueueGatherAllToAllV = | |||
(HcclResult(*)(HcomGatherAllToAllVParams, std::function<void(HcclResult status)>))dlsym( | |||
context.handle_, "HcomExecEnqueueGatherAllToAllV"); | |||
if (HcomExecEnqueueGatherAllToAllV == nullptr) { | |||
GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueGatherAllToAllV] for node:%s.", context.GetNodeName()); | |||
if (dlclose(context.handle_) != 0) { | |||
GELOGW("Failed to close handle %s.", dlerror()); | |||
} | |||
return FAILED; | |||
} | |||
HcomGatherAllToAllVParams params; | |||
GE_CHK_STATUS_RET(BuildGatherAllToAllParams(context, params)); | |||
HcclResult hccl_ret = HcomExecEnqueueGatherAllToAllV(params, callback); | |||
if (hccl_ret != HCCL_SUCCESS) { | |||
GELOGE(HCCL_E_INTERNAL, "GatherAllToAllV teak enqueue failed for node [%s].", context.GetNodeName()); | |||
return HCCL_E_INTERNAL; | |||
} | |||
} | |||
GELOGI("[%s] AllToAllNodeTask::ExecuteAsync success.", context.GetNodeName()); | |||
return SUCCESS; | |||
} | |||
Status HcclNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } | |||
Status HcclNodeTask::Init(TaskContext &context) { | |||
@@ -375,6 +506,8 @@ Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, | |||
GE_CHECK_NOTNULL(node); | |||
if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) { | |||
task = MakeShared<RdmaNodeTask>(); | |||
} else if (kAllToAllTypes.count(node->GetType()) > 0) { | |||
task = MakeShared<AllToAllNodeTask>(); | |||
} else { | |||
task = MakeShared<HcclNodeTask>(); | |||
} | |||
@@ -62,6 +62,22 @@ class RdmaNodeTask : public NodeTask { | |||
bool skip_flag_; | |||
}; | |||
class AllToAllNodeTask : public NodeTask { | |||
public: | |||
AllToAllNodeTask() = default; | |||
~AllToAllNodeTask() = default; | |||
Status UpdateArgs(TaskContext &context) override { return SUCCESS; } | |||
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | |||
Status Init(TaskContext &context) override { return SUCCESS; } | |||
private: | |||
std::mutex hccl_mutex_; | |||
std::condition_variable cond_; | |||
}; | |||
class HcclNodeExecutor : public NodeExecutor { | |||
public: | |||
Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const; | |||
@@ -440,6 +440,8 @@ REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); | |||
REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | |||
REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | |||
REGISTER_OPTYPE_DECLARE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | |||
REGISTER_OPTYPE_DECLARE(HCOMALLTOALLVDYNAMIC, "HcomAllToAllVDynamic"); | |||
REGISTER_OPTYPE_DECLARE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV"); | |||
REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | |||
REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | |||
@@ -42,3 +42,14 @@ HcclResult hcom_reduce_scatter(const char *tag, void *input_ptr, void *output_pt | |||
HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) { | |||
return HCCL_SUCCESS; | |||
} | |||
HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function<void(HcclResult status)> callback) { | |||
return HCCL_SUCCESS; | |||
} | |||
HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params, | |||
std::function<void(HcclResult status)> callback) { | |||
return HCCL_SUCCESS; | |||
} | |||
@@ -840,6 +840,7 @@ set(HYBRID_TEST_FILES | |||
"hybrid/model/hybrid_model_builder_unittest.cc" | |||
"hybrid/node_executor/rts/rts_node_task_unittest.cc" | |||
"hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" | |||
"hybrid/node_executor/hccl/hccl_node_executor_unittest.cc" | |||
"hybrid/executor/hybrid_model_async_executor_unittest.cc" | |||
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | |||
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | |||
@@ -0,0 +1,199 @@ | |||
/** | |||
* Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include <gmock/gmock.h> | |||
#include <vector> | |||
#define private public | |||
#define protected public | |||
#include "hybrid/executor/subgraph_context.h" | |||
#include "hybrid/node_executor/hccl/hccl_node_executor.h" | |||
#include "model/ge_root_model.h" | |||
using namespace std; | |||
using namespace testing; | |||
namespace { | |||
const string kHcclSoPath = "../build/tests/depends/hccl/libhccl_stub.so"; | |||
} | |||
namespace ge { | |||
using namespace hybrid; | |||
class UtestHcclNodeExecutor : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||
op_desc->SetStreamId(0); | |||
static int32_t index = 0; | |||
op_desc->SetId(index++); | |||
GeTensorDesc tensor(GeShape(), FORMAT_NHWC, DT_INT64); | |||
TensorUtils::SetSize(tensor, 64); | |||
vector<int64_t> input_offset; | |||
for (int i = 0; i < in_num; i++) { | |||
op_desc->AddInputDesc(tensor); | |||
input_offset.emplace_back(i * 64); | |||
} | |||
op_desc->SetInputOffset(input_offset); | |||
vector<int64_t> output_offset; | |||
for (int i = 0; i < out_num; i++) { | |||
op_desc->AddOutputDesc(tensor); | |||
output_offset.emplace_back(in_num * 64 + i * 64); | |||
} | |||
op_desc->SetOutputOffset(output_offset); | |||
op_desc->SetWorkspace({}); | |||
op_desc->SetWorkspaceBytes({}); | |||
return graph.AddNode(op_desc); | |||
} | |||
TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||
ge_root_model->SetModelName("test_name"); | |||
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||
HybridModel hybrid_model(ge_root_model); | |||
NodePtr node = CreateNode(*graph, "gatheralltoallv", HCOMGATHERALLTOALLV, 4, 2); | |||
std::unique_ptr<NodeItem> new_node; | |||
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||
NodeItem *node_item = new_node.get(); | |||
hybrid_model.node_items_[node] = std::move(new_node); | |||
node_item->input_start = 0; | |||
node_item->output_start = 0; | |||
GraphItem graph_item; | |||
graph_item.node_items_.emplace_back(node_item); | |||
graph_item.total_inputs_ = 4; | |||
graph_item.total_outputs_ = 2; | |||
GraphExecutionContext graph_context; | |||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
for (int i=0; i<4; ++i) { | |||
uint64_t value_0 = 512; | |||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||
} | |||
uint64_t value_0 = 512; | |||
TensorValue out_tensor0(&value_0, sizeof(value_0)); | |||
subgraph_context.SetOutput(*node_item, 0, out_tensor0); | |||
uint64_t value_1 = 512; | |||
TensorValue out_tensor1(&value_1, sizeof(value_1)); | |||
subgraph_context.SetOutput(*node_item, 1, out_tensor1); | |||
NodeTaskPtr task = nullptr; | |||
HcclNodeExecutor node_executor; | |||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||
ASSERT_NE(task, nullptr); | |||
auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||
ASSERT_NE(handle, nullptr); | |||
node_state->GetTaskContext()->handle_ = handle; | |||
std::function<void()> done = []() {}; | |||
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||
if (handle = nullptr) { | |||
dlclose(handle); | |||
} | |||
} | |||
TEST_F(UtestHcclNodeExecutor, alltoallv_execute) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||
ge_root_model->SetModelName("test_name"); | |||
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||
HybridModel hybrid_model(ge_root_model); | |||
NodePtr node = CreateNode(*graph, "alltoallv", HCOMALLTOALLVDYNAMIC, 5, 1); | |||
std::unique_ptr<NodeItem> new_node; | |||
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||
NodeItem *node_item = new_node.get(); | |||
hybrid_model.node_items_[node] = std::move(new_node); | |||
node_item->input_start = 0; | |||
node_item->output_start = 0; | |||
GraphItem graph_item; | |||
graph_item.node_items_.emplace_back(node_item); | |||
graph_item.total_inputs_ = 5; | |||
graph_item.total_outputs_ = 1; | |||
GraphExecutionContext graph_context; | |||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
ASSERT_NE(unique_task_context, nullptr); | |||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
node_state->SetTaskContext(shared_task_context); | |||
for (int i=0; i<5; ++i) { | |||
uint64_t value_0 = 512; | |||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||
} | |||
uint64_t value_1 = 512; | |||
TensorValue out_tensor0(&value_1, sizeof(value_1)); | |||
subgraph_context.SetOutput(*node_item, 0, out_tensor0); | |||
NodeTaskPtr task = nullptr; | |||
HcclNodeExecutor node_executor; | |||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||
ASSERT_NE(task, nullptr); | |||
auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||
ASSERT_NE(handle, nullptr); | |||
node_state->GetTaskContext()->handle_ = handle; | |||
std::function<void()> done = []() {}; | |||
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||
if (handle = nullptr) { | |||
dlclose(handle); | |||
} | |||
} | |||
} // namespace ge | |||
@@ -123,6 +123,30 @@ struct HcomRemoteAccessAddrInfo { | |||
u64 length; // Memory Length in Bytes | |||
}; | |||
struct HcomAllToAllVParams { | |||
void *sendbuf; | |||
void *sendcounts; | |||
void *sdispls; | |||
HcclDataType sendtype; | |||
void *recvbuf; | |||
void *recvcounts; | |||
void *rdispls; | |||
HcclDataType recvtype; | |||
const char *group; | |||
}; | |||
struct HcomGatherAllToAllVParams { | |||
void *addrInfo; | |||
void *addrInfoCountPerRank; | |||
void *recvbuf; | |||
void *recvcounts; | |||
void *rdispls; | |||
void *gatheredbuf; | |||
s32 addrLength; | |||
HcclDataType recvtype; | |||
const char *group; | |||
}; | |||
#ifdef __cplusplus | |||
} | |||
#endif // __cplusplus | |||
@@ -164,6 +164,11 @@ HcclResult HcomExecEnqueueRemoteAccess(const std::string& remoteAccessType, | |||
const std::vector<HcomRemoteAccessAddrInfo>& addrInfos, | |||
std::function<void(HcclResult status)> callback); | |||
HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function<void(HcclResult status)> callback); | |||
HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params, | |||
std::function<void(HcclResult status)> callback); | |||
/** | |||
* @brief Register memories and init resources for remote access. | |||
* | |||