Browse Source

fix more sc

tags/v1.3.0
wjm 4 years ago
parent
commit
d30e88ab79
3 changed files with 68 additions and 63 deletions
  1. +2
    -4
      ge/ge_runtime/task/hccl_task.cc
  2. +63
    -59
      ge/hybrid/node_executor/hccl/hccl_node_executor.cc
  3. +3
    -0
      ge/hybrid/node_executor/hccl/hccl_node_executor.h

+ 2
- 4
ge/ge_runtime/task/hccl_task.cc View File

@@ -154,10 +154,8 @@ bool HcclTask::SetSecondaryStream() {
return false;
}
stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream);
if (stream == nullptr) {
GELOGE(FAILED, "MakeShared failed.");
return false;
}
GE_IF_BOOL_EXEC(stream == nullptr, return false);

secondary_stream_vec[index] = stream;
}
secondary_stream_list_.push_back(stream);


+ 63
- 59
ge/hybrid/node_executor/hccl/hccl_node_executor.cc View File

@@ -20,7 +20,6 @@
#include "graph/attr_value.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/manager/util/hcom_util.h"
#include "graph/runtime_inference_context.h"
#include "graph/utils/type_utils.h"
#include "graph/types.h"
#include "hccl/hcom.h"
@@ -177,61 +176,8 @@ Status RdmaNodeTask::Init(TaskContext &context) {
return SUCCESS;
}

Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) {
RuntimeInferenceContext *ctx = nullptr;
GE_CHK_STATUS_RET(
RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx));

ge::Tensor remote_tensor;
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor));
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData());
if (data == nullptr) {
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) {
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName());
return SUCCESS;
} else {
REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s",
context.GetNodeItem().NodeType().c_str());
GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s",
context.GetNodeItem().NodeType().c_str());
return FAILED;
}
}
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims();
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) {
REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu"
"and shape dims back:%zu not equal expect:%zu, node:%s(%s)",
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt,
context.GetNodeName(), context.GetNodeItem().NodeType().c_str());
GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed,"
"number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)",
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt,
context.GetNodeName(), context.GetNodeItem().NodeType().c_str());
return PARAM_INVALID;
}

if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) {
size_t remote_size = 0;
for (auto idx = 0; idx < dims.front(); ++idx) {
FMK_INT64_MULCHECK(idx, kVarTableRowCnt);
auto line_idx = idx * kVarTableRowCnt;
remote_size += data[line_idx + kVarTableIdxLen];
}
auto allocator = NpuMemoryAllocator::GetAllocator();
GE_CHECK_NOTNULL(allocator);
AllocationAttr attr;
attr.SetMemType(RDMA_HBM);
for (auto i = 0; i < context.NumOutputs(); ++i) {
GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size);
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr);
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release()))));
}
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) {
AllocationAttr attr;
attr.SetMemType(RDMA_HBM);
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr))
}

Status RdmaNodeTask::SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num,
vector<HcomRemoteAccessAddrInfo> &addr_infos) {
TensorValue *tv;
if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) {
tv = context.MutableOutput(local_index_);
@@ -239,7 +185,6 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess
tv = context.MutableInput(local_index_);
}
GE_CHECK_NOTNULL(tv);
auto row_num = dims.front();
addr_infos.resize(row_num);
if (skip_flag_) {
int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset");
@@ -250,10 +195,10 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess
GE_CHK_STATUS_RET(ctx->GetTensor(offset_index_.first, offset_index_.second, offset_tensor))
if (static_cast<int64_t>(offset_tensor.GetSize() / GetSizeByDataType(data_type)) != row_num) {
REPORT_INNER_ERROR("E19999", "num of offset and remote addr mismatch, check invalid"
"offset size=%zu, remote_addr size=%ld, dtype=%s", offset_tensor.GetSize(), row_num,
"offset size=%zu, remote_addr size=%ld, dtype=%s", offset_tensor.GetSize(), row_num,
TypeUtils::DataTypeToSerialString(data_type).c_str());
GELOGE(PARAM_INVALID, "[Check][Size]num of offset and remote addr mismatch,"
"offset size=%zu, remote_addr size=%ld, dtype=%s",
"offset size=%zu, remote_addr size=%ld, dtype=%s",
offset_tensor.GetSize(), row_num, TypeUtils::DataTypeToSerialString(data_type).c_str());
return PARAM_INVALID;
}
@@ -294,6 +239,65 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess
return SUCCESS;
}

Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) {
RuntimeInferenceContext *ctx = nullptr;
GE_CHK_STATUS_RET(
RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx));

ge::Tensor remote_tensor;
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor));
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData());
if (data == nullptr) {
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) {
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName());
return SUCCESS;
} else {
REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s",
context.GetNodeItem().NodeType().c_str());
GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s",
context.GetNodeItem().NodeType().c_str());
return FAILED;
}
}
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims();
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) {
REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu"
"and shape dims back:%zu not equal expect:%zu, node:%s(%s)",
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt,
context.GetNodeName(), context.GetNodeItem().NodeType().c_str());
GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed,"
"number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)",
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt,
context.GetNodeName(), context.GetNodeItem().NodeType().c_str());
return PARAM_INVALID;
}

if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) {
size_t remote_size = 0;
for (auto idx = 0; idx < dims.front(); ++idx) {
FMK_INT64_MULCHECK(idx, kVarTableRowCnt);
auto line_idx = idx * kVarTableRowCnt;
remote_size += data[line_idx + kVarTableIdxLen];
}
auto allocator = NpuMemoryAllocator::GetAllocator();
GE_CHECK_NOTNULL(allocator);
AllocationAttr attr;
attr.SetMemType(RDMA_HBM);
for (auto i = 0; i < context.NumOutputs(); ++i) {
GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size);
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr);
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release()))));
}
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) {
AllocationAttr attr;
attr.SetMemType(RDMA_HBM);
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr))
}

auto row_num = dims.front();
return SetAddrInfo(context, ctx, data, row_num, addr_infos);
}

Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
GELOGI("[%s] RdmaNodeTask::ExecuteAsync in.", context.GetNodeName());
auto HcomExecEnqueueRemoteAccess =


+ 3
- 0
ge/hybrid/node_executor/hccl/hccl_node_executor.h View File

@@ -18,6 +18,7 @@
#define HYBRID_HCCL_NODE_EXECUTOR_H_
#include "common/opskernel/ge_task_info.h"
#include "graph/op_desc.h"
#include "graph/runtime_inference_context.h"
#include "hybrid/model/hybrid_model.h"
#include "hybrid/node_executor/node_executor.h"

@@ -53,6 +54,8 @@ class RdmaNodeTask : public NodeTask {
Status Init(TaskContext &context) override;

private:
Status SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num,
vector<HcomRemoteAccessAddrInfo> &addr_infos);
Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos);
std::pair<int64_t, int64_t> remote_index_;
std::pair<int64_t, int64_t> offset_index_;


Loading…
Cancel
Save