diff --git a/ge/ge_runtime/task/hccl_task.cc b/ge/ge_runtime/task/hccl_task.cc index 2169f96a..65690683 100644 --- a/ge/ge_runtime/task/hccl_task.cc +++ b/ge/ge_runtime/task/hccl_task.cc @@ -154,10 +154,8 @@ bool HcclTask::SetSecondaryStream() { return false; } stream = std::make_shared(rt_model_handle_, new_stream); - if (stream == nullptr) { - GELOGE(FAILED, "MakeShared failed."); - return false; - } + GE_IF_BOOL_EXEC(stream == nullptr, return false); + secondary_stream_vec[index] = stream; } secondary_stream_list_.push_back(stream); diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc index 0aee5122..0c6d4eaf 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc @@ -20,7 +20,6 @@ #include "graph/attr_value.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/util/hcom_util.h" -#include "graph/runtime_inference_context.h" #include "graph/utils/type_utils.h" #include "graph/types.h" #include "hccl/hcom.h" @@ -177,61 +176,8 @@ Status RdmaNodeTask::Init(TaskContext &context) { return SUCCESS; } -Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector &addr_infos) { - RuntimeInferenceContext *ctx = nullptr; - GE_CHK_STATUS_RET( - RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); - - ge::Tensor remote_tensor; - GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); - auto data = reinterpret_cast(remote_tensor.GetData()); - if (data == nullptr) { - if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { - GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); - return SUCCESS; - } else { - REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", - context.GetNodeItem().NodeType().c_str()); - GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", - context.GetNodeItem().NodeType().c_str()); - return FAILED; - } - } - auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); - if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { - REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" - "and shape dims back:%zu not equal expect:%zu, node:%s(%s)", - dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, - context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); - GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," - "number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", - dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, - context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); - return PARAM_INVALID; - } - - if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { - size_t remote_size = 0; - for (auto idx = 0; idx < dims.front(); ++idx) { - FMK_INT64_MULCHECK(idx, kVarTableRowCnt); - auto line_idx = idx * kVarTableRowCnt; - remote_size += data[line_idx + kVarTableIdxLen]; - } - auto allocator = NpuMemoryAllocator::GetAllocator(); - GE_CHECK_NOTNULL(allocator); - AllocationAttr attr; - attr.SetMemType(RDMA_HBM); - for (auto i = 0; i < context.NumOutputs(); ++i) { - GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); - auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); - GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr(tensor_buffer.release())))); - } - } else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { - AllocationAttr attr; - attr.SetMemType(RDMA_HBM); - GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) - } - +Status RdmaNodeTask::SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, + vector &addr_infos) { TensorValue *tv; if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) { tv = context.MutableOutput(local_index_); @@ -239,7 +185,6 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vectorGetInputIndexByName("local_offset"); @@ -250,10 +195,10 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vectorGetTensor(offset_index_.first, offset_index_.second, offset_tensor)) if (static_cast(offset_tensor.GetSize() / GetSizeByDataType(data_type)) != row_num) { REPORT_INNER_ERROR("E19999", "num of offset and remote addr mismatch, check invalid" - "offset size=%zu, remote_addr size=%ld, dtype=%s", offset_tensor.GetSize(), row_num, + "offset size=%zu, remote_addr size=%ld, dtype=%s", offset_tensor.GetSize(), row_num, TypeUtils::DataTypeToSerialString(data_type).c_str()); GELOGE(PARAM_INVALID, "[Check][Size]num of offset and remote addr mismatch," - "offset size=%zu, remote_addr size=%ld, dtype=%s", + "offset size=%zu, remote_addr size=%ld, dtype=%s", offset_tensor.GetSize(), row_num, TypeUtils::DataTypeToSerialString(data_type).c_str()); return PARAM_INVALID; } @@ -294,6 +239,65 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector &addr_infos) { + RuntimeInferenceContext *ctx = nullptr; + GE_CHK_STATUS_RET( + RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); + + ge::Tensor remote_tensor; + GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); + auto data = reinterpret_cast(remote_tensor.GetData()); + if (data == nullptr) { + if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { + GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); + return SUCCESS; + } else { + REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", + context.GetNodeItem().NodeType().c_str()); + GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", + context.GetNodeItem().NodeType().c_str()); + return FAILED; + } + } + auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); + if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { + REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" + "and shape dims back:%zu not equal expect:%zu, node:%s(%s)", + dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, + context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); + GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," + "number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", + dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, + context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); + return PARAM_INVALID; + } + + if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { + size_t remote_size = 0; + for (auto idx = 0; idx < dims.front(); ++idx) { + FMK_INT64_MULCHECK(idx, kVarTableRowCnt); + auto line_idx = idx * kVarTableRowCnt; + remote_size += data[line_idx + kVarTableIdxLen]; + } + auto allocator = NpuMemoryAllocator::GetAllocator(); + GE_CHECK_NOTNULL(allocator); + AllocationAttr attr; + attr.SetMemType(RDMA_HBM); + for (auto i = 0; i < context.NumOutputs(); ++i) { + GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); + auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); + GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr(tensor_buffer.release())))); + } + } else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { + AllocationAttr attr; + attr.SetMemType(RDMA_HBM); + GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) + } + + auto row_num = dims.front(); + return SetAddrInfo(context, ctx, data, row_num, addr_infos); +} + Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { GELOGI("[%s] RdmaNodeTask::ExecuteAsync in.", context.GetNodeName()); auto HcomExecEnqueueRemoteAccess = diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.h b/ge/hybrid/node_executor/hccl/hccl_node_executor.h index 873f259f..9e6d41a4 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.h +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.h @@ -18,6 +18,7 @@ #define HYBRID_HCCL_NODE_EXECUTOR_H_ #include "common/opskernel/ge_task_info.h" #include "graph/op_desc.h" +#include "graph/runtime_inference_context.h" #include "hybrid/model/hybrid_model.h" #include "hybrid/node_executor/node_executor.h" @@ -53,6 +54,8 @@ class RdmaNodeTask : public NodeTask { Status Init(TaskContext &context) override; private: + Status SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, + vector &addr_infos); Status ExtractTensor(TaskContext &context, vector &addr_infos); std::pair remote_index_; std::pair offset_index_;