Browse Source

subgraph netoutput memory bugfix

tags/v1.1.0
isaactalx 4 years ago
parent
commit
a9412057df
6 changed files with 72 additions and 14 deletions
  1. +1
    -1
      ge/graph/manager/rdma_pool_allocator.cc
  2. +43
    -0
      ge/hybrid/model/hybrid_model_builder.cc
  3. +1
    -0
      ge/hybrid/model/hybrid_model_builder.h
  4. +1
    -0
      ge/hybrid/model/node_item.h
  5. +11
    -4
      ge/hybrid/node_executor/hccl/hccl_node_executor.cc
  6. +15
    -9
      ge/hybrid/node_executor/task_context.cc

+ 1
- 1
ge/graph/manager/rdma_pool_allocator.cc View File

@@ -202,7 +202,7 @@ Status RdmaPoolAllocator::GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size) {
GELOGE(INTERNAL_ERROR, "Rdma base addr is nullptr."); GELOGE(INTERNAL_ERROR, "Rdma base addr is nullptr.");
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(rdma_base_addr_));
base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(rdma_base_addr_));
mem_size = rdma_mem_size_; mem_size = rdma_mem_size_;
return SUCCESS; return SUCCESS;
} }


+ 43
- 0
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -701,6 +701,9 @@ Status HybridModelBuilder::LoadGraph() {
GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item),
"[%s] Failed to identify ref outputs.", "[%s] Failed to identify ref outputs.",
parent_node_item->NodeName().c_str()); parent_node_item->NodeName().c_str());
GE_CHK_STATUS_RET(IdentifySameInputs(*parent_node_item),
"[%s] Failed to identify same outputs.",
parent_node_item->NodeName().c_str());


// if parent is function control op. need add a virtual partitioned call // if parent is function control op. need add a virtual partitioned call
if (parent_node_item->IsControlOp()) { if (parent_node_item->IsControlOp()) {
@@ -1162,6 +1165,46 @@ Status HybridModelBuilder::InitRuntimeParams() {
return SUCCESS; return SUCCESS;
} }


Status HybridModelBuilder::IdentifySameInputs(NodeItem &node_item) {
GELOGD("Start to parse same inputs on net output: %s", node_item.NodeName().c_str());
auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex);
GE_CHECK_NOTNULL(subgraph);
auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT);
if (net_output_node == nullptr) {
GELOGD("Subgraph [%s] does not have net output", subgraph->GetName().c_str());
return SUCCESS;
}

auto net_output_desc = net_output_node->GetOpDesc();
GE_CHECK_NOTNULL(net_output_desc);

std::map<std::string, int> connected_inputs;
for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) {
auto out_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (out_data_anchor == nullptr) {
continue;
}
auto src_node = out_data_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
auto op_desc = src_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);

std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx());
auto it = connected_inputs.find(input_key);
if (it == connected_inputs.end()) {
connected_inputs.emplace(input_key, in_data_anchor->GetIdx());
} else {
GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(),
in_data_anchor->GetIdx(),
it->second,
src_node->GetName().c_str(),
out_data_anchor->GetIdx());
node_item.reuse_outputs.emplace(in_data_anchor->GetIdx(), it->second);
}
}
return SUCCESS;
}

Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) {
GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str());
auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex);


+ 1
- 0
ge/hybrid/model/hybrid_model_builder.h View File

@@ -59,6 +59,7 @@ class HybridModelBuilder {
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model);
Status LoadTasks(); Status LoadTasks();
Status IdentifyVariableOutputs(NodeItem &node_item); Status IdentifyVariableOutputs(NodeItem &node_item);
Status IdentifySameInputs(NodeItem &node_item);
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status BuildNodeItem(const NodePtr &node, NodeItem &node_item);
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item);
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies);


+ 1
- 0
ge/hybrid/model/node_item.h View File

@@ -83,6 +83,7 @@ struct NodeItem {
const NodeExecutor *node_executor = nullptr; const NodeExecutor *node_executor = nullptr;
std::map<int, ge::NodePtr> ref_outputs; std::map<int, ge::NodePtr> ref_outputs;
std::map<int, int> reuse_inputs; std::map<int, int> reuse_inputs;
std::map<int, int> reuse_outputs;


std::vector<bool> is_input_shape_static; std::vector<bool> is_input_shape_static;
bool is_output_shape_static = true; bool is_output_shape_static = true;


+ 11
- 4
ge/hybrid/node_executor/hccl/hccl_node_executor.cc View File

@@ -189,13 +189,20 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess
} }
GE_CHECK_NOTNULL(tv); GE_CHECK_NOTNULL(tv);
auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData())); auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData()));
addr_infos.resize(dims.front());
for (auto idx = 0; idx < dims.front(); ++idx) {
auto row_num = dims.front();
addr_infos.resize(row_num);
auto device_len = tv->GetSize() / row_num;
if (device_len <= 0 || device_len > data[kVarTableIdxLen]) {
GELOGE(FAILED, "Local embedding length is out of range.");
return FAILED;
}

for (auto idx = 0; idx < row_num; ++idx) {
FMK_INT64_MULCHECK(idx, kVarTableRowCnt); FMK_INT64_MULCHECK(idx, kVarTableRowCnt);
auto line_idx = idx * kVarTableRowCnt; auto line_idx = idx * kVarTableRowCnt;
addr_infos[idx] = {static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr, addr_infos[idx] = {static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr,
data[line_idx + kVarTableIdxLen]};
local_addr += data[line_idx + kVarTableIdxLen];
device_len};
local_addr += device_len;
} }


return SUCCESS; return SUCCESS;


+ 15
- 9
ge/hybrid/node_executor/task_context.cc View File

@@ -221,16 +221,22 @@ Status TaskContext::AllocateOutput(int index,
GE_CHECK_NOTNULL(ref_tensor); GE_CHECK_NOTNULL(ref_tensor);
outputs_start_[index] = *ref_tensor; outputs_start_[index] = *ref_tensor;
} else { } else {
auto reuse_input = node_item_->reuse_inputs.find(index);
if (reuse_input != node_item_->reuse_inputs.end()) {
GELOGD("[%s] Output[%d] is referenced to input[%d]", GetNodeName(), index, reuse_input->second);
outputs_start_[index] = inputs_start_[reuse_input->second];
auto reuse_output_it = node_item_->reuse_outputs.find(index);
if (reuse_output_it != node_item_->reuse_outputs.end()) {
GELOGD("[%s] reuse output [%d] with output [%d]", GetNodeName(), index, reuse_output_it->second);
outputs_start_[index] = outputs_start_[reuse_output_it->second];
} else { } else {
GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index], attr));
GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu",
node_item_->NodeName().c_str(),
index,
outputs_start_[index].GetSize());
auto reuse_input = node_item_->reuse_inputs.find(index);
if (reuse_input != node_item_->reuse_inputs.end()) {
GELOGD("[%s] Output[%d] is referenced to input[%d]", GetNodeName(), index, reuse_input->second);
outputs_start_[index] = inputs_start_[reuse_input->second];
} else {
GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index], attr));
GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu",
node_item_->NodeName().c_str(),
index,
outputs_start_[index].GetSize());
}
} }
} }




Loading…
Cancel
Save