@@ -1938,6 +1938,8 @@ Status GraphMemoryAssigner::AssignBufferPoolMemory() {
return SUCCESS;
}
// if producer and customers in the same stream, or customers on the same stream when producer not assign a stream,
// then return false.
static bool IsOutputVisitedByMultiStream(const NodePtr &peer_out_node, int64_t out_anchor_index) {
GE_IF_BOOL_EXEC(peer_out_node->GetOpDesc() == nullptr, return true);
int64_t unique_stream_id = peer_out_node->GetOpDesc()->GetStreamId();
@@ -1949,7 +1951,7 @@ static bool IsOutputVisitedByMultiStream(const NodePtr &peer_out_node, int64_t o
if (node->GetOpDesc()->GetStreamId() == kInvalidStream) {
continue;
}
if (unique_stream_id == kInvalidStream) {
if (unique_stream_id == kInvalidStream) { // peer_out_node not belong to any stream
unique_stream_id = node->GetOpDesc()->GetStreamId();
continue;
}
@@ -2001,22 +2003,23 @@ static void UpdatePrevNodeInputDesc(const NodePtr &prev_node,
prev_node->GetName().c_str(),
prev_node_input_index);
}
return;
}
static void UpdateCurNodeInputIndex (const NodePtr &cur_node, int64_t cur_node_input_index, int64_t distance) {
static void UpdateCurNodeInputDesc (const NodePtr &cur_node, int64_t cur_node_input_index, int64_t distance) {
GE_IF_BOOL_EXEC(cur_node == nullptr, return);
GE_IF_BOOL_EXEC(cur_node->GetOpDesc() == nullptr, return);
auto input_desc = cur_node->GetOpDesc()->GetInputDesc(cur_node_input_index);
vector<int64_t> prev_next_distances{distance, -1};
if (!ge::AttrUtils::SetListInt(input_desc, ATTR_NAME_DATA_VISIT_DISTANCE, prev_next_distances)) {
GELOGW("Update [%s] input[%lld] ATTR_NAME_DATA_VISIT_DISTANCE failed.",
GELOGW("Set [%s] input[%lld] ATTR_NAME_DATA_VISIT_DISTANCE failed.",
cur_node->GetOpDesc()->GetName().c_str(),
cur_node_input_index);
return;
}
if (cur_node->GetOpDesc()->UpdateInputDesc(cur_node_input_index, input_desc) != GRAPH_SUCCESS) {
GELOGW("Update[%s] input[%lld] ATTR_NAME_DATA_VISIT_DISTANCE failed.",
GELOGW("Update [%s] input[%lld] ATTR_NAME_DATA_VISIT_DISTANCE failed.",
cur_node->GetOpDesc()->GetName().c_str(),
cur_node_input_index);
return;
@@ -2028,93 +2031,76 @@ static void UpdateCurNodeInputIndex(const NodePtr &cur_node, int64_t cur_node_in
return;
}
void GraphMemoryAssigner::MarkNodeDistanceAttr(const ComputeGraphPtr &compute_graph,
NodePtr &node,
map<size_t, pair<NodePtr, vector<int64_t>>> &mem_block_visit_info,
const map<string, int64_t> &node_index_in_stream) {
GELOGD("cur node name is [%s]", node->GetName().c_str());
GELOGD("Begin to mark node distance attr, node name is [%s]", node->GetName().c_str());
GE_IF_BOOL_EXEC(node == nullptr, return);
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
auto peer_out_node = peer_out_anchor->GetOwnerNode();
GE_IF_BOOL_EXEC(peer_out_node == nullptr, continue);
GELOGD("cur node[%s], cur in_data_anchor[%d], peer_out_node[%s]",
node->GetName().c_str(),
in_data_anchor->GetIdx(),
peer_out_node->GetName().c_str());
// find cur peer_out_node's peer_out_anchor->GetIdx()-th mem_offset(start mem offset)
NodeIndexIO node_index_io(peer_out_node, peer_out_anchor->GetIdx(), kOut);
std::string symbol;
size_t matched_mem_offset = mem_assigner_->GetPriorityAssinger()->GetAnchorDataOffset(node_index_io, symbol);
if (matched_mem_offset != kInvalidOffset) {
GELOGD("matched_mem_offset is [%zu], and new method's matched_mem_offset is [%zu], peer_out_node is [%s]",
matched_mem_offset,
peer_out_node->GetOpDesc()->GetOutputOffset().at(peer_out_anchor->GetIdx()),
peer_out_node->GetName().c_str());
} else { // peer_out_anchor not assign MemoryBlock, we search the peer_out_anchor's data in continous memory
if (matched_mem_offset == kInvalidOffset) {
// peer_out_anchor not assign MemoryBlock, we search the peer_out_anchor's data in continous memory
matched_mem_offset = peer_out_node->GetOpDesc()->GetOutputOffset().at(peer_out_anchor->GetIdx());
}
GELOGD("final matched_mem_offset is [%zu]", matched_mem_offset);
auto iter = mem_block_visit_info.find(matched_mem_offset);
if (iter == mem_block_visit_info.end()) { // cannot find visit info, peer_out_node must be a producer and this data is the first time to be visited.
GELOGD("iter == mem_block_visit_info.end() !!!");
if (IsOutputVisitedByMultiStream(peer_out_node, peer_out_anchor->GetIdx())) { // 生产者的输出多流访问或者生产者在某条流上,但是不在cur_node所在的这条流上
// cannot find visit info, peer_out_node must be a producer and this data is the first time to be visited.
if (iter == mem_block_visit_info.end()) {
if (IsOutputVisitedByMultiStream(peer_out_node, peer_out_anchor->GetIdx())) {
vector<int64_t> temp;
mem_block_visit_info.insert(std::make_pair(matched_mem_offset, std::make_pair(nullptr, temp)));
GELOGD("IsOutputVistedByMultiStream is true, peer_out_node is [%s], stream_id [%d], cur node is [%s], in_data_anchor_index is[%zu], stream_id[%d]",
peer_out_node->GetName().c_str(),
peer_out_node->GetOpDesc()->GetStreamId(),
node->GetName().c_str(),
in_data_anchor->GetIdx(),
node->GetOpDesc()->GetStreamId());
continue;
} else { // 生产者和消费者都在同一条流上,或者所有消费者都在同一条流上且生产者未分配流
GELOGD("IsOutputVisitedByMultiStream is false.");
} else {
vector<int64_t> temp = {-1};
mem_block_visit_info.insert(std::make_pair(matched_mem_offset, std::make_pair(peer_out_node, temp))); //数据生产者的prev_node_index设置成-1
// producer's prev_node_index set to -1 as default
mem_block_visit_info.insert(std::make_pair(matched_mem_offset, std::make_pair(peer_out_node, temp)));
}
} else { // 如果在访问信息中找到了对应的offset,此时需要判断当前的输出节点和访问信息中的那个节点是否在同一个stream上,如果不在同一个stream上就不计算距离
if (mem_block_visit_info[matched_mem_offset].first == nullptr) {
} else {
if (mem_block_visit_info[matched_mem_offset].first == nullptr) { // multi-stream visit, no need to calculate
continue;
}
if (peer_out_node->GetOpDesc()->GetStreamId() != mem_block_visit_info[matched_mem_offset].first->GetOpDesc()->GetStreamId()) {
GELOGD("lckey, cur node[%s], peer_out_node[%s] not in the same stream with node[%s] in visit info.",
node->GetName().c_str(),
peer_out_node->GetName().c_str(),
mem_block_visit_info[matched_mem_offset].first->GetName().c_str());
if (peer_out_node->GetOpDesc()->GetStreamId() !=
mem_block_visit_info[matched_mem_offset].first->GetOpDesc()->GetStreamId()) {
// cur node and peer_out_node not in the same stream, no need to calculate
continue;
}
}
// now mem_block_visit_info must contains that memory_block
// 1. calculate distance, update visit info, update prev_node input desc, update cur node input desc
// steps: calculate distance, update visit info, update prev_node input desc, update cur node input desc
int64_t distance = -1;
auto prev_node = mem_block_visit_info[matched_mem_offset].first;
auto prev_node_input_index_vec = mem_block_visit_info[matched_mem_offset].second;
GE_IF_BOOL_EXEC(prev_node == nullptr, continue); // 生产者的输出多流访问或者生产者在某条流上,但是不在cur node所在的这条流上
if (prev_node_input_index_vec.size() == 1 && prev_node_input_index_vec[0] == -1) { // prev_node是生产者且数据刚刚被生产出来(还没被其他节点访问过)
GE_IF_BOOL_EXEC(prev_node == nullptr, continue);
if (prev_node_input_index_vec.size() == 1 && prev_node_input_index_vec[0] == -1) {
// prev_node is producer and the data is just be produced(not visited by other node)
GE_IF_BOOL_EXEC(prev_node->GetOpDesc() == nullptr, continue);
if (prev_node->GetOpDesc()->GetStreamId() == -1) { // 生产者未分配 stream
if (prev_node->GetOpDesc()->GetStreamId() == -1) { // producer not assigned a stream
distance = 0;
} else if (node_index_in_stream.find(prev_node->GetName()) == node_index_in_stream.end()) {
// prev_node is ge local op
distance = 0;
} else {
distance = node_index_in_stream.at(node->GetName()) - node_index_in_stream.at(prev_node->GetName()) - 1;
}
// 不需要更新prev_node的input desc(因为prev_node设定的input index=-1, 设也设不上去),只需要后面刷一把自己的input_desc即可
// 更新访问信息
mem_block_visit_info[matched_mem_offset].first = node;
mem_block_visit_info[matched_mem_offset].second.clear();
mem_block_visit_info[matched_mem_offset].second.push_back(in_data_anchor->GetIdx());
} else {
} else { // the data is visit by other customer just before.
if (prev_node_input_index_vec.empty()) {
GELOGW("Missing prev node[%s] input index.", prev_node->GetName().c_str());
continue;
}
if (prev_node == node) { // scene: multiple anchors of a node access the same data
vector<int64_t> prev_next_distances;
// 1. do not need to update prev_node's input desc, because it must be updated when the first anchor visit it.
@@ -2136,7 +2122,8 @@ void GraphMemoryAssigner::MarkNodeDistanceAttr(const ComputeGraphPtr &compute_gr
}
mem_block_visit_info[matched_mem_offset].second.push_back(in_data_anchor->GetIdx());
} else {
// 走到此处,prev_node必然是访问者,如果cur node和prev node这两个访问者不在同一条流上,那么之前在IsOutputVisitedByMultiStream是就应该continue了,而不会走到这
// now, prev_node must be customer because if cur node and prev node not on the same stream, then it will be
// continue at IsOutputVisitedByMultiStream.
distance = node_index_in_stream.at(node->GetName()) - node_index_in_stream.at(prev_node->GetName()) - 1;
UpdatePrevNodeInputDesc(prev_node, prev_node_input_index_vec, distance);
mem_block_visit_info[matched_mem_offset].first = node;
@@ -2144,7 +2131,7 @@ void GraphMemoryAssigner::MarkNodeDistanceAttr(const ComputeGraphPtr &compute_gr
mem_block_visit_info[matched_mem_offset].second.push_back(in_data_anchor->GetIdx());
}
}
UpdateCurNodeInputIndex (node, in_data_anchor->GetIdx(), distance);
UpdateCurNodeInputDesc (node, in_data_anchor->GetIdx(), distance);
auto input_desc = node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx());
bool is_end_of_inputmem_lifecycle = false;
@@ -2163,14 +2150,16 @@ void GraphMemoryAssigner::MarkNodeDistanceAttr(const ComputeGraphPtr &compute_gr
}
void GraphMemoryAssigner::MarkDistanceAttr() {
map<size_t, pair<NodePtr, vector<int64_t>>> mem_block_visit_info; // key: mem_offset of the memory which we visited. value: node we visited and input index of this node
map<string, int64_t> node_index_in_stream; // key: node name, value: topo order of node in it's belonged stream(exclude ge_local_op)
map<int64_t, int64_t> stream_nodes_num; // key: stream id, value: cur nodes num in that stream
// key: mem_offset of the memory which we visited. value: node we visited and input index of this node
map<size_t, pair<NodePtr, vector<int64_t>>> mem_block_visit_info;
// key: node name, value: topo order of node in it's belonged stream(exclude ge_local_op)
map<string, int64_t> node_index_in_stream;
// key: stream id, value: cur nodes num in that stream
map<int64_t, int64_t> stream_nodes_num;
for (auto &node : compute_graph_->GetAllNodes()) {
auto node_op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(node_op_desc == nullptr, return);
int64_t stream_id = node_op_desc->GetStreamId();
if (node_op_desc->GetOpKernelLibName() != kEngineNameGeLocal) {
if (stream_nodes_num.find(stream_id) == stream_nodes_num.end()) {
@@ -2182,7 +2171,7 @@ void GraphMemoryAssigner::MarkDistanceAttr() {
MarkNodeDistanceAttr(compute_graph_, node, mem_block_visit_info, node_index_in_stream);
} else {
GELOGD("in GraphMemoryAssigner::MarkDistanceAttr, node[%s] is ge_local_op", node->GetName().c_str());
GELOGD("node[%s] is ge_local_op, no need to calculate distance. ", node->GetName().c_str());
}
}
}