|
|
@@ -1983,7 +1983,7 @@ void GraphMemoryAssigner::UpdatePrevNodeInputDesc(const NodePtr &prev_node, |
|
|
|
if (prev_next_distances.size() == kPrevNextDistanceNum) { |
|
|
|
prev_next_distances[1] = distance; |
|
|
|
} else { |
|
|
|
GELOGW("Size of prev_next_distances is not 2."); |
|
|
|
GELOGW("Size of prev_next_distances is not %d.", kPrevNextDistanceNum); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!ge::AttrUtils::SetListInt(input_desc, ATTR_NAME_DATA_VISIT_DISTANCE, prev_next_distances)) { |
|
|
@@ -2034,19 +2034,6 @@ void GraphMemoryAssigner::UpdateCurNodeInputDesc(const NodePtr &cur_node, |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
size_t GraphMemoryAssigner::GetMemoryOffset(const HybridMemAssignerPtr &mem_assigner, |
|
|
|
const NodePtr &peer_out_node, |
|
|
|
const OutDataAnchorPtr &peer_out_anchor) { |
|
|
|
NodeIndexIO node_index_io(peer_out_node, peer_out_anchor->GetIdx(), kOut); |
|
|
|
string symbol; |
|
|
|
size_t matched_mem_offset = mem_assigner->GetPriorityAssinger()->GetAnchorDataOffset(node_index_io, symbol); |
|
|
|
if (matched_mem_offset == kInvalidOffset) { |
|
|
|
// peer_out_anchor not assign MemoryBlock, we search the peer_out_anchor's data in continous memory |
|
|
|
matched_mem_offset = peer_out_node->GetOpDesc()->GetOutputOffset().at(peer_out_anchor->GetIdx()); |
|
|
|
} |
|
|
|
return matched_mem_offset; |
|
|
|
} |
|
|
|
|
|
|
|
void GraphMemoryAssigner::CheckNeedCalcDistAndUpdateVisitInfo( |
|
|
|
map<size_t, pair<NodePtr, vector<int64_t>>> &mem_block_visit_info, |
|
|
|
const size_t matched_mem_offset, |
|
|
@@ -2101,10 +2088,13 @@ void GraphMemoryAssigner::CalcDistanceAndUpdateDesc(map<size_t, pair<NodePtr, ve |
|
|
|
GE_IF_BOOL_EXEC(prev_node->GetOpDesc() == nullptr, is_need_skip = true; return); |
|
|
|
if (prev_node->GetOpDesc()->GetStreamId() == -1) { // producer not assigned a stream |
|
|
|
distance = 0; |
|
|
|
} else if (node_index_in_stream.find(prev_node->GetName()) == node_index_in_stream.end()) { |
|
|
|
distance = 0; |
|
|
|
} else { |
|
|
|
distance = node_index_in_stream.at(node->GetName()) - node_index_in_stream.at(prev_node->GetName()) - 1; |
|
|
|
auto iter = node_index_in_stream.find(prev_node->GetName()); |
|
|
|
if (iter == node_index_in_stream.end()) { |
|
|
|
distance = 0; |
|
|
|
} else { |
|
|
|
distance = node_index_in_stream.at(node->GetName()) - iter->second - 1; |
|
|
|
} |
|
|
|
} |
|
|
|
mem_block_visit_info[matched_mem_offset].first = node; |
|
|
|
mem_block_visit_info[matched_mem_offset].second.clear(); |
|
|
@@ -2125,7 +2115,7 @@ void GraphMemoryAssigner::CalcDistanceAndUpdateDesc(map<size_t, pair<NodePtr, ve |
|
|
|
return; |
|
|
|
} |
|
|
|
if (prev_next_distances.size() != kPrevNextDistanceNum) { |
|
|
|
GELOGW("Size of prev_next_distance is not 2."); |
|
|
|
GELOGW("Size of prev_next_distance is not %d.", kPrevNextDistanceNum); |
|
|
|
is_need_skip = true; |
|
|
|
return; |
|
|
|
} else { |
|
|
@@ -2148,6 +2138,7 @@ void GraphMemoryAssigner::DeleteVisitInfoWhenLifecycleEnded( |
|
|
|
const size_t matched_mem_offset, |
|
|
|
const NodePtr &node, |
|
|
|
const InDataAnchorPtr &in_data_anchor) { |
|
|
|
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, return); |
|
|
|
auto input_desc = node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); |
|
|
|
bool is_end_of_inputmem_lifecycle = false; |
|
|
|
// if is_end_of_inputmem_lifecycle is true, indicating that cur node is the last customer of this data, |
|
|
@@ -2177,7 +2168,8 @@ void GraphMemoryAssigner::MarkNodeDistanceAttr(const ComputeGraphPtr &compute_gr |
|
|
|
auto peer_out_node = peer_out_anchor->GetOwnerNode(); |
|
|
|
GE_IF_BOOL_EXEC(peer_out_node == nullptr, continue); |
|
|
|
|
|
|
|
auto matched_mem_offset = GetMemoryOffset(mem_assigner_, peer_out_node, peer_out_anchor); |
|
|
|
GE_IF_BOOL_EXEC(peer_out_node->GetOpDesc() == nullptr, continue); |
|
|
|
auto matched_mem_offset = peer_out_node->GetOpDesc()->GetOutputOffset().at(peer_out_anchor->GetIdx()); |
|
|
|
|
|
|
|
bool is_need_calc_distance = false; |
|
|
|
CheckNeedCalcDistAndUpdateVisitInfo(mem_block_visit_info, matched_mem_offset, peer_out_node, |
|
|
|