@@ -20,7 +20,6 @@ | |||
#include "graph/attr_value.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/manager/util/hcom_util.h" | |||
#include "graph/runtime_inference_context.h" | |||
#include "graph/utils/type_utils.h" | |||
#include "graph/types.h" | |||
#include "hccl/hcom.h" | |||
@@ -177,61 +176,8 @@ Status RdmaNodeTask::Init(TaskContext &context) { | |||
return SUCCESS; | |||
} | |||
Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||
RuntimeInferenceContext *ctx = nullptr; | |||
GE_CHK_STATUS_RET( | |||
RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); | |||
ge::Tensor remote_tensor; | |||
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | |||
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); | |||
if (data == nullptr) { | |||
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); | |||
return SUCCESS; | |||
} else { | |||
REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||
context.GetNodeItem().NodeType().c_str()); | |||
GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||
context.GetNodeItem().NodeType().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); | |||
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { | |||
REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" | |||
"and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||
GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," | |||
"number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||
return PARAM_INVALID; | |||
} | |||
if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { | |||
size_t remote_size = 0; | |||
for (auto idx = 0; idx < dims.front(); ++idx) { | |||
FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | |||
auto line_idx = idx * kVarTableRowCnt; | |||
remote_size += data[line_idx + kVarTableIdxLen]; | |||
} | |||
auto allocator = NpuMemoryAllocator::GetAllocator(); | |||
GE_CHECK_NOTNULL(allocator); | |||
AllocationAttr attr; | |||
attr.SetMemType(RDMA_HBM); | |||
for (auto i = 0; i < context.NumOutputs(); ++i) { | |||
GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); | |||
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); | |||
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); | |||
} | |||
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { | |||
AllocationAttr attr; | |||
attr.SetMemType(RDMA_HBM); | |||
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) | |||
} | |||
Status RdmaNodeTask::SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, | |||
vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||
TensorValue *tv; | |||
if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||
tv = context.MutableOutput(local_index_); | |||
@@ -239,7 +185,6 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||
tv = context.MutableInput(local_index_); | |||
} | |||
GE_CHECK_NOTNULL(tv); | |||
auto row_num = dims.front(); | |||
addr_infos.resize(row_num); | |||
if (skip_flag_) { | |||
int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset"); | |||
@@ -294,6 +239,65 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||
return SUCCESS; | |||
} | |||
Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||
RuntimeInferenceContext *ctx = nullptr; | |||
GE_CHK_STATUS_RET( | |||
RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); | |||
ge::Tensor remote_tensor; | |||
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | |||
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); | |||
if (data == nullptr) { | |||
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); | |||
return SUCCESS; | |||
} else { | |||
REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||
context.GetNodeItem().NodeType().c_str()); | |||
GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||
context.GetNodeItem().NodeType().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); | |||
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { | |||
REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" | |||
"and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||
GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," | |||
"number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||
return PARAM_INVALID; | |||
} | |||
if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { | |||
size_t remote_size = 0; | |||
for (auto idx = 0; idx < dims.front(); ++idx) { | |||
FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | |||
auto line_idx = idx * kVarTableRowCnt; | |||
remote_size += data[line_idx + kVarTableIdxLen]; | |||
} | |||
auto allocator = NpuMemoryAllocator::GetAllocator(); | |||
GE_CHECK_NOTNULL(allocator); | |||
AllocationAttr attr; | |||
attr.SetMemType(RDMA_HBM); | |||
for (auto i = 0; i < context.NumOutputs(); ++i) { | |||
GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); | |||
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); | |||
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); | |||
} | |||
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { | |||
AllocationAttr attr; | |||
attr.SetMemType(RDMA_HBM); | |||
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) | |||
} | |||
auto row_num = dims.front(); | |||
return SetAddrInfo(context, ctx, data, row_num, addr_infos); | |||
} | |||
Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | |||
GELOGI("[%s] RdmaNodeTask::ExecuteAsync in.", context.GetNodeName()); | |||
auto HcomExecEnqueueRemoteAccess = | |||
@@ -18,6 +18,7 @@ | |||
#define HYBRID_HCCL_NODE_EXECUTOR_H_ | |||
#include "common/opskernel/ge_task_info.h" | |||
#include "graph/op_desc.h" | |||
#include "graph/runtime_inference_context.h" | |||
#include "hybrid/model/hybrid_model.h" | |||
#include "hybrid/node_executor/node_executor.h" | |||
@@ -53,6 +54,8 @@ class RdmaNodeTask : public NodeTask { | |||
Status Init(TaskContext &context) override; | |||
private: | |||
Status SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, | |||
vector<HcomRemoteAccessAddrInfo> &addr_infos); | |||
Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos); | |||
std::pair<int64_t, int64_t> remote_index_; | |||
std::pair<int64_t, int64_t> offset_index_; | |||
@@ -710,6 +710,7 @@ set(PASS_TEST_FILES | |||
"graph/passes/infershape_pass_unittest.cc" | |||
"graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | |||
"graph/passes/multi_batch_clone_pass_unittest.cc" | |||
"graph/passes/subgraph_const_migration_pass_unittest.cc" | |||
"graph/passes/replace_with_empty_const_pass_unittest.cc" | |||
"graph/passes/link_gen_mask_nodes_pass_unittest.cc" | |||
"graph/passes/transpose_transdata_pass_unittest.cc" | |||
@@ -718,7 +719,7 @@ set(PASS_TEST_FILES | |||
"graph/passes/mark_node_unknown_shape_pass_unittest.cc" | |||
"graph/passes/reshape_recovery_pass_unittest.cc" | |||
"graph/passes/cast_remove_pass_unittest.cc" | |||
"graph/passes/memcpy_addr_async_unittest.cc" | |||
"graph/passes/memcpy_addr_async_unittest.cc" | |||
"graph/passes/hccl_continuous_pass_unittest.cc" | |||
"graph/passes/hccl_memcpy_pass_unittest.cc" | |||
@@ -843,6 +844,7 @@ set(HYBRID_TEST_FILES | |||
"hybrid/model/hybrid_model_builder_unittest.cc" | |||
"hybrid/node_executor/rts/rts_node_task_unittest.cc" | |||
"hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" | |||
"hybrid/node_executor/hccl/hccl_node_executor_unittest.cc" | |||
"hybrid/executor/hybrid_model_async_executor_unittest.cc" | |||
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | |||
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | |||
@@ -0,0 +1,125 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include <set> | |||
#include <string> | |||
#include "framework/omg/omg_inner_types.h" | |||
#include "graph/common/local_context.h" | |||
#include "graph/passes/subgraph_const_migration_pass.h" | |||
#include "inc/pass_manager.h" | |||
#include "register/op_registry.h" | |||
namespace ge { | |||
class UtestSubgraphConstMigrationPass : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
public: | |||
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||
for (auto i = 0; i < in_num; ++i) { | |||
op_desc->AddInputDesc(test_desc); | |||
} | |||
for (auto i = 0; i < out_num; ++i) { | |||
op_desc->AddOutputDesc(test_desc); | |||
} | |||
if (type == "Const") { | |||
uint64_t const_value = 101; | |||
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||
AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight); | |||
} | |||
return graph->AddNode(op_desc); | |||
} | |||
void make_original_graph(const ComputeGraphPtr &graph) { | |||
auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||
{ | |||
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||
} | |||
auto const1 = MakeNode(graph, 0, 1, "const1", "Const"); | |||
{ | |||
auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); | |||
GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor()); | |||
} | |||
auto const2 = MakeNode(graph, 0, 1, "const2", "Const"); | |||
{ | |||
auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3); | |||
GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor()); | |||
} | |||
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); | |||
GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); | |||
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); | |||
} | |||
void make_multibatch_graph(const ComputeGraphPtr &graph) { | |||
auto index = MakeNode(graph, 1, 1, "index", "Data"); | |||
auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||
auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||
auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||
auto case1 = MakeNode(graph, 4, 1, "case", "Case"); | |||
GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1)); | |||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2)); | |||
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3)); | |||
auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput"); | |||
GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||
AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); | |||
case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic); | |||
ComputeGraphPtr branch = std::make_shared<ComputeGraph>("test_branch"); | |||
make_original_graph(branch); | |||
for (int i = 0; i < 2; ++i) { | |||
std::string name("_ascend_mbatch_batch_" + std::to_string(i)); | |||
std::vector<NodePtr> input_nodes; | |||
std::vector<NodePtr> output_nodes; | |||
ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes); | |||
subgraph->SetName(name); | |||
subgraph->SetParentNode(case1); | |||
subgraph->SetParentGraph(graph); | |||
graph->AddSubgraph(subgraph->GetName(), subgraph); | |||
case1->GetOpDesc()->AddSubgraphName(name); | |||
case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName()); | |||
} | |||
} | |||
}; | |||
TEST_F(UtestSubgraphConstMigrationPass, graph_nullptr) { | |||
PassManager pass_manager; | |||
pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass); | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||
make_multibatch_graph(graph); | |||
pass_manager.Run(graph); | |||
} | |||
} // namespace ge |
@@ -0,0 +1,108 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gmock/gmock.h> | |||
#include <gtest/gtest.h> | |||
#include <vector> | |||
#define private public | |||
#define protected public | |||
#include "graph/runtime_inference_context.h" | |||
#include "hybrid/executor/subgraph_context.h" | |||
#include "hybrid/node_executor/hccl/hccl_node_executor.h" | |||
#undef protected | |||
#undef private | |||
using namespace std; | |||
using namespace testing; | |||
namespace ge { | |||
using namespace hybrid; | |||
class UtestHcclNodeExecutor : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||
op_desc->SetStreamId(0); | |||
static int32_t index = 0; | |||
op_desc->SetId(index++); | |||
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||
TensorUtils::SetSize(tensor, 64); | |||
vector<int64_t> input_offset; | |||
for (int i = 0; i < in_num; i++) { | |||
op_desc->AddInputDesc(tensor); | |||
input_offset.emplace_back(i * 64); | |||
} | |||
op_desc->SetInputOffset(input_offset); | |||
vector<int64_t> output_offset; | |||
for (int i = 0; i < out_num; i++) { | |||
op_desc->AddOutputDesc(tensor); | |||
output_offset.emplace_back(in_num * 64 + i * 64); | |||
} | |||
op_desc->SetOutputOffset(output_offset); | |||
return graph.AddNode(op_desc); | |||
} | |||
TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0); | |||
std::unique_ptr<NodeItem> new_node; | |||
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||
NodeItem *node_item = new_node.get(); | |||
node_item->input_start = 0; | |||
node_item->output_start = 0; | |||
GraphItem graph_item; | |||
GraphExecutionContext graph_context; | |||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
ASSERT_NE(node_state, nullptr); | |||
RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id)); | |||
RuntimeInferenceContext *ctx = nullptr; | |||
RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx); | |||
Shape s({1, 3}); | |||
TensorDesc tensor_desc(s); | |||
Tensor tensor(tensor_desc); | |||
std::vector<uint8_t> data = {1, 2, 3, 4}; | |||
tensor.SetData(data); | |||
ctx->SetTensor(1, 0, tensor.Clone()); | |||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
vector<HcomRemoteAccessAddrInfo> addr_infos; | |||
shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | |||
task->remote_index_ = {1, 0}; | |||
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||
Shape s2({1}); | |||
TensorDesc tensor_desc2(s2); | |||
Tensor tensor2(tensor_desc2); | |||
ctx->SetTensor(1, 0, tensor2.Clone()); | |||
task->ExtractTensor(*unique_task_context, addr_infos); | |||
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||
RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); | |||
} | |||
} // namespace ge |