|
|
@@ -20,7 +20,6 @@ |
|
|
|
#include "graph/attr_value.h" |
|
|
|
#include "graph/debug/ge_attr_define.h" |
|
|
|
#include "graph/manager/util/hcom_util.h" |
|
|
|
#include "graph/runtime_inference_context.h" |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
#include "graph/types.h" |
|
|
|
#include "hccl/hcom.h" |
|
|
@@ -177,61 +176,8 @@ Status RdmaNodeTask::Init(TaskContext &context) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { |
|
|
|
RuntimeInferenceContext *ctx = nullptr; |
|
|
|
GE_CHK_STATUS_RET( |
|
|
|
RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); |
|
|
|
|
|
|
|
ge::Tensor remote_tensor; |
|
|
|
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); |
|
|
|
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); |
|
|
|
if (data == nullptr) { |
|
|
|
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { |
|
|
|
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); |
|
|
|
return SUCCESS; |
|
|
|
} else { |
|
|
|
REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", |
|
|
|
context.GetNodeItem().NodeType().c_str()); |
|
|
|
GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", |
|
|
|
context.GetNodeItem().NodeType().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); |
|
|
|
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" |
|
|
|
"and shape dims back:%zu not equal expect:%zu, node:%s(%s)", |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, |
|
|
|
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); |
|
|
|
GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," |
|
|
|
"number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, |
|
|
|
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { |
|
|
|
size_t remote_size = 0; |
|
|
|
for (auto idx = 0; idx < dims.front(); ++idx) { |
|
|
|
FMK_INT64_MULCHECK(idx, kVarTableRowCnt); |
|
|
|
auto line_idx = idx * kVarTableRowCnt; |
|
|
|
remote_size += data[line_idx + kVarTableIdxLen]; |
|
|
|
} |
|
|
|
auto allocator = NpuMemoryAllocator::GetAllocator(); |
|
|
|
GE_CHECK_NOTNULL(allocator); |
|
|
|
AllocationAttr attr; |
|
|
|
attr.SetMemType(RDMA_HBM); |
|
|
|
for (auto i = 0; i < context.NumOutputs(); ++i) { |
|
|
|
GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); |
|
|
|
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); |
|
|
|
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); |
|
|
|
} |
|
|
|
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { |
|
|
|
AllocationAttr attr; |
|
|
|
attr.SetMemType(RDMA_HBM); |
|
|
|
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) |
|
|
|
} |
|
|
|
|
|
|
|
Status RdmaNodeTask::SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, |
|
|
|
vector<HcomRemoteAccessAddrInfo> &addr_infos) { |
|
|
|
TensorValue *tv; |
|
|
|
if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) { |
|
|
|
tv = context.MutableOutput(local_index_); |
|
|
@@ -239,7 +185,6 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess |
|
|
|
tv = context.MutableInput(local_index_); |
|
|
|
} |
|
|
|
GE_CHECK_NOTNULL(tv); |
|
|
|
auto row_num = dims.front(); |
|
|
|
addr_infos.resize(row_num); |
|
|
|
if (skip_flag_) { |
|
|
|
int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset"); |
|
|
@@ -250,10 +195,10 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess |
|
|
|
GE_CHK_STATUS_RET(ctx->GetTensor(offset_index_.first, offset_index_.second, offset_tensor)) |
|
|
|
if (static_cast<int64_t>(offset_tensor.GetSize() / GetSizeByDataType(data_type)) != row_num) { |
|
|
|
REPORT_INNER_ERROR("E19999", "num of offset and remote addr mismatch, check invalid" |
|
|
|
"offset size=%zu, remote_addr size=%ld, dtype=%s", offset_tensor.GetSize(), row_num, |
|
|
|
"offset size=%zu, remote_addr size=%ld, dtype=%s", offset_tensor.GetSize(), row_num, |
|
|
|
TypeUtils::DataTypeToSerialString(data_type).c_str()); |
|
|
|
GELOGE(PARAM_INVALID, "[Check][Size]num of offset and remote addr mismatch," |
|
|
|
"offset size=%zu, remote_addr size=%ld, dtype=%s", |
|
|
|
"offset size=%zu, remote_addr size=%ld, dtype=%s", |
|
|
|
offset_tensor.GetSize(), row_num, TypeUtils::DataTypeToSerialString(data_type).c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
@@ -294,6 +239,65 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { |
|
|
|
RuntimeInferenceContext *ctx = nullptr; |
|
|
|
GE_CHK_STATUS_RET( |
|
|
|
RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); |
|
|
|
|
|
|
|
ge::Tensor remote_tensor; |
|
|
|
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); |
|
|
|
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); |
|
|
|
if (data == nullptr) { |
|
|
|
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { |
|
|
|
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); |
|
|
|
return SUCCESS; |
|
|
|
} else { |
|
|
|
REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", |
|
|
|
context.GetNodeItem().NodeType().c_str()); |
|
|
|
GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", |
|
|
|
context.GetNodeItem().NodeType().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); |
|
|
|
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" |
|
|
|
"and shape dims back:%zu not equal expect:%zu, node:%s(%s)", |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, |
|
|
|
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); |
|
|
|
GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," |
|
|
|
"number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, |
|
|
|
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { |
|
|
|
size_t remote_size = 0; |
|
|
|
for (auto idx = 0; idx < dims.front(); ++idx) { |
|
|
|
FMK_INT64_MULCHECK(idx, kVarTableRowCnt); |
|
|
|
auto line_idx = idx * kVarTableRowCnt; |
|
|
|
remote_size += data[line_idx + kVarTableIdxLen]; |
|
|
|
} |
|
|
|
auto allocator = NpuMemoryAllocator::GetAllocator(); |
|
|
|
GE_CHECK_NOTNULL(allocator); |
|
|
|
AllocationAttr attr; |
|
|
|
attr.SetMemType(RDMA_HBM); |
|
|
|
for (auto i = 0; i < context.NumOutputs(); ++i) { |
|
|
|
GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); |
|
|
|
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); |
|
|
|
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); |
|
|
|
} |
|
|
|
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { |
|
|
|
AllocationAttr attr; |
|
|
|
attr.SetMemType(RDMA_HBM); |
|
|
|
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) |
|
|
|
} |
|
|
|
|
|
|
|
auto row_num = dims.front(); |
|
|
|
return SetAddrInfo(context, ctx, data, row_num, addr_infos); |
|
|
|
} |
|
|
|
|
|
|
|
Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { |
|
|
|
GELOGI("[%s] RdmaNodeTask::ExecuteAsync in.", context.GetNodeName()); |
|
|
|
auto HcomExecEnqueueRemoteAccess = |
|
|
|