Browse Source

mark attr distance

pull/1573/head
lichun 4 years ago
parent
commit
3c36c8a526
5 changed files with 275 additions and 2 deletions
  1. +13
    -2
      ge/graph/build/memory/block_mem_assigner.cc
  2. +2
    -0
      ge/graph/build/memory/block_mem_assigner.h
  3. +251
    -0
      ge/graph/build/memory/graph_mem_assigner.cc
  4. +7
    -0
      ge/graph/build/memory/graph_mem_assigner.h
  5. +2
    -0
      ge/graph/build/memory/memory_assigner.cc

+ 13
- 2
ge/graph/build/memory/block_mem_assigner.cc View File

@@ -431,7 +431,7 @@ void SetLastUsedInputMemAttr(NodePtr &node, int input_index) {
auto node_op_desc = node->GetOpDesc(); auto node_op_desc = node->GetOpDesc();
if (node_op_desc != nullptr) { if (node_op_desc != nullptr) {
auto input_desc = node_op_desc->MutableInputDesc(input_index); auto input_desc = node_op_desc->MutableInputDesc(input_index);
if (!ge::AttrUtils::SetInt(*input_desc, ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE, true)) {
if (!ge::AttrUtils::SetBool(*input_desc, ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE, true)) {
GELOGW("Set %s input[%d] ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE to true failed.", node_op_desc->GetName().c_str(), GELOGW("Set %s input[%d] ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE to true failed.", node_op_desc->GetName().c_str(),
input_index); input_index);
return; return;
@@ -1493,8 +1493,8 @@ void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock
bool same_stream) { bool same_stream) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null.");
GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory"); GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory");
GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory");
--to_release->ref_count_; --to_release->ref_count_;
GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory");
if (!same_stream) { if (!same_stream) {
to_release->same_stream_ = false; to_release->same_stream_ = false;
} }
@@ -2160,6 +2160,17 @@ void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) {
} }
} }


size_t BlockMemAssigner::GetAnchorDataOffset(const NodeIndexIO &node_index_io, std::string &symbol) {
MemoryBlock *mem_block = nullptr;
if (IsSymbolExist(node_index_io, symbol)) {
mem_block = symbol_blocks_[symbol];
if (mem_block != nullptr) {
return mem_block->HeadOffset();
}
}
return kInvalidOffset;
}

Status BlockMemAssigner::Assign() { Status BlockMemAssigner::Assign() {
vector<int64_t> ranges; vector<int64_t> ranges;
if (GetMemoryRanges(ranges) != SUCCESS) { if (GetMemoryRanges(ranges) != SUCCESS) {


+ 2
- 0
ge/graph/build/memory/block_mem_assigner.h View File

@@ -239,6 +239,8 @@ class BlockMemAssigner : public MemAssigner {
void SetOpMemOffset(bool is_zero_copy); void SetOpMemOffset(bool is_zero_copy);


std::string GetMaxBatchLabel() const { return max_batch_label_; } std::string GetMaxBatchLabel() const { return max_batch_label_; }

size_t GetAnchorDataOffset(const NodeIndexIO &node_index_io, std::string &symbol);
protected: protected:
/// ///
/// @ingroup domi /// @ingroup domi


+ 251
- 0
ge/graph/build/memory/graph_mem_assigner.cc View File

@@ -36,6 +36,8 @@ namespace {
const int kAllInputAddrIsAtomic = -1; const int kAllInputAddrIsAtomic = -1;
const int kVirtualInputNodeMemoryReuse = 0; const int kVirtualInputNodeMemoryReuse = 0;
const int kVirtualOutputNodeMemoryReuse = 1; const int kVirtualOutputNodeMemoryReuse = 1;
const int64_t kInvalidStream = -1;
const char *const kEngineNameGeLocal = "DNN_VM_GE_LOCAL_OP_STORE";
// One state per bit cannot be repeated // One state per bit cannot be repeated
enum ContinuousType { kTypeInput = 1, kTypeInputNoPadding = 2, kTypeOutput = 4, kTypeOutputNoPadding = 8 }; enum ContinuousType { kTypeInput = 1, kTypeInputNoPadding = 2, kTypeOutput = 4, kTypeOutputNoPadding = 8 };


@@ -1935,4 +1937,253 @@ Status GraphMemoryAssigner::AssignBufferPoolMemory() {
compute_graph_->GetName().c_str(), mem_type, buffer_pool_mem_assigner.GetMemOffset()); compute_graph_->GetName().c_str(), mem_type, buffer_pool_mem_assigner.GetMemOffset());
return SUCCESS; return SUCCESS;
} }

static bool IsOutputVisitedByMultiStream(const NodePtr &peer_out_node, int64_t out_anchor_index) {
GE_IF_BOOL_EXEC(peer_out_node->GetOpDesc() == nullptr, return true);
int64_t unique_stream_id = peer_out_node->GetOpDesc()->GetStreamId();

GE_IF_BOOL_EXEC(peer_out_node->GetOutDataAnchor(out_anchor_index) == nullptr, return true);
for (const auto &in_data_anchor : peer_out_node->GetOutDataAnchor(out_anchor_index)->GetPeerInDataAnchors()) {
auto node = in_data_anchor->GetOwnerNode();
GE_IF_BOOL_EXEC(node == nullptr || node->GetOpDesc() == nullptr, continue);
if (node->GetOpDesc()->GetStreamId() == kInvalidStream) {
continue;
}
if (unique_stream_id == kInvalidStream) {
unique_stream_id = node->GetOpDesc()->GetStreamId();
continue;
}
if (node->GetOpDesc()->GetStreamId() != unique_stream_id) {
return true;
}
}
return false;
}

static void UpdatePrevNodeInputDesc(const NodePtr &prev_node,
const vector<int64_t> &prev_node_input_index_vec,
int64_t distance) {
GE_IF_BOOL_EXEC(prev_node == nullptr, return);
auto prev_node_op_desc = prev_node->GetOpDesc();
GE_IF_BOOL_EXEC(prev_node_op_desc == nullptr, return);

for (const auto prev_node_input_index : prev_node_input_index_vec) {
auto input_desc = prev_node_op_desc->GetInputDesc(prev_node_input_index);
vector<int64_t> prev_next_distances;
if (!ge::AttrUtils::GetListInt(input_desc, ATTR_NAME_DATA_VISIT_DISTANCE, prev_next_distances)) {
GELOGW("Get [%s] input [%lld] ATTR_NAME_DATA_VISIT_DISTANCE failed",
prev_node_op_desc->GetName().c_str(),
prev_node_input_index);
continue;
}

if (prev_next_distances.size() == 2) {
prev_next_distances[1] = distance;
} else {
GELOGW("Size of prev_next_distances is not 2.");
continue;
}
if (!ge::AttrUtils::SetListInt(input_desc, ATTR_NAME_DATA_VISIT_DISTANCE, prev_next_distances)) {
GELOGW("Set [%s] input [%lld] ATTR_NAME_DATA_VISIT_DISTANCE failed.",
prev_node_op_desc->GetName().c_str(),
prev_node_input_index);
continue;
}

if (prev_node_op_desc->UpdateInputDesc(prev_node_input_index, input_desc) != GRAPH_SUCCESS) {
GELOGW("Update [%s] input [%lld] ATTR_NAME_DATA_VISIT_DISTANCE failed.",
prev_node_op_desc->GetName().c_str(),
prev_node_input_index);
continue;
}
GELOGD("Set the next distance[%lld] to node[%s], input index[%lld]",
distance,
prev_node->GetName().c_str(),
prev_node_input_index);
}
}

static void UpdateCurNodeInputIndex(const NodePtr &cur_node, int64_t cur_node_input_index, int64_t distance) {
GE_IF_BOOL_EXEC(cur_node == nullptr, return);
GE_IF_BOOL_EXEC(cur_node->GetOpDesc() == nullptr, return);
auto input_desc = cur_node->GetOpDesc()->GetInputDesc(cur_node_input_index);
vector<int64_t> prev_next_distances{distance, -1};

if (!ge::AttrUtils::SetListInt(input_desc, ATTR_NAME_DATA_VISIT_DISTANCE, prev_next_distances)) {
GELOGW("Update[%s] input[%lld] ATTR_NAME_DATA_VISIT_DISTANCE failed.",
cur_node->GetOpDesc()->GetName().c_str(),
cur_node_input_index);
return;
}
if (cur_node->GetOpDesc()->UpdateInputDesc(cur_node_input_index, input_desc) != GRAPH_SUCCESS) {
GELOGW("Update[%s] input[%lld] ATTR_NAME_DATA_VISIT_DISTANCE failed.",
cur_node->GetOpDesc()->GetName().c_str(),
cur_node_input_index);
return;
}
GELOGD("Set the prev distance[%lld] to node[%s], input index[%lld]",
distance,
cur_node->GetName().c_str(),
cur_node_input_index);
return;
}

void GraphMemoryAssigner::MarkNodeDistanceAttr(const ComputeGraphPtr &compute_graph,
NodePtr &node,
map<size_t, pair<NodePtr, vector<int64_t>>> &mem_block_visit_info,
const map<string, int64_t> &node_index_in_stream) {
GELOGD("cur node name is [%s]", node->GetName().c_str());
GE_IF_BOOL_EXEC(node == nullptr, return);
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
auto peer_out_node = peer_out_anchor->GetOwnerNode();
GE_IF_BOOL_EXEC(peer_out_node == nullptr, continue);
GELOGD("cur node[%s], cur in_data_anchor[%d], peer_out_node[%s]",
node->GetName().c_str(),
in_data_anchor->GetIdx(),
peer_out_node->GetName().c_str());

// find cur peer_out_node's peer_out_anchor->GetIdx()-th mem_offset(start mem offset)
NodeIndexIO node_index_io(peer_out_node, peer_out_anchor->GetIdx(), kOut);
std::string symbol;
size_t matched_mem_offset = mem_assigner_->GetPriorityAssinger()->GetAnchorDataOffset(node_index_io, symbol);
if (matched_mem_offset != kInvalidOffset) {
GELOGD("matched_mem_offset is [%zu], and new method's matched_mem_offset is [%zu], peer_out_node is [%s]",
matched_mem_offset,
peer_out_node->GetOpDesc()->GetOutputOffset().at(peer_out_anchor->GetIdx()),
peer_out_node->GetName().c_str());
} else { // peer_out_anchor not assign MemoryBlock, we search the peer_out_anchor's data in continous memory
matched_mem_offset = peer_out_node->GetOpDesc()->GetOutputOffset().at(peer_out_anchor->GetIdx());
}
GELOGD("final matched_mem_offset is [%zu]", matched_mem_offset);

auto iter = mem_block_visit_info.find(matched_mem_offset);
if (iter == mem_block_visit_info.end()) { // cannot find visit info, peer_out_node must be a producer and this data is the first time to be visited.
GELOGD("iter == mem_block_visit_info.end() !!!");
if (IsOutputVisitedByMultiStream(peer_out_node, peer_out_anchor->GetIdx())) { // 生产者的输出多流访问或者生产者在某条流上,但是不在cur_node所在的这条流上
vector<int64_t> temp;
mem_block_visit_info.insert(std::make_pair(matched_mem_offset, std::make_pair(nullptr, temp)));
GELOGD("IsOutputVistedByMultiStream is true, peer_out_node is [%s], stream_id [%d], cur node is [%s], in_data_anchor_index is[%zu], stream_id[%d]",
peer_out_node->GetName().c_str(),
peer_out_node->GetOpDesc()->GetStreamId(),
node->GetName().c_str(),
in_data_anchor->GetIdx(),
node->GetOpDesc()->GetStreamId());
continue;
} else { // 生产者和消费者都在同一条流上,或者所有消费者都在同一条流上且生产者未分配流
GELOGD("IsOutputVisitedByMultiStream is false.");
vector<int64_t> temp = {-1};
mem_block_visit_info.insert(std::make_pair(matched_mem_offset, std::make_pair(peer_out_node, temp))); //数据生产者的prev_node_index设置成-1
}
} else { // 如果在访问信息中找到了对应的offset,此时需要判断当前的输出节点和访问信息中的那个节点是否在同一个stream上,如果不在同一个stream上就不计算距离
if (mem_block_visit_info[matched_mem_offset].first == nullptr) {
continue;
}
if (peer_out_node->GetOpDesc()->GetStreamId() != mem_block_visit_info[matched_mem_offset].first->GetOpDesc()->GetStreamId()) {
GELOGD("lckey, cur node[%s], peer_out_node[%s] not in the same stream with node[%s] in visit info.",
node->GetName().c_str(),
peer_out_node->GetName().c_str(),
mem_block_visit_info[matched_mem_offset].first->GetName().c_str());
continue;
}
}

// now mem_block_visit_info must contains that memory_block
// 1. calculate distance, update visit info, update prev_node input desc, update cur node input desc
int64_t distance = -1;
auto prev_node = mem_block_visit_info[matched_mem_offset].first;
auto prev_node_input_index_vec = mem_block_visit_info[matched_mem_offset].second;
GE_IF_BOOL_EXEC(prev_node == nullptr, continue); // 生产者的输出多流访问或者生产者在某条流上,但是不在cur node所在的这条流上
if (prev_node_input_index_vec.size() == 1 && prev_node_input_index_vec[0] == -1) { // prev_node是生产者且数据刚刚被生产出来(还没被其他节点访问过)
GE_IF_BOOL_EXEC(prev_node->GetOpDesc() == nullptr, continue);
if (prev_node->GetOpDesc()->GetStreamId() == -1) { // 生产者未分配stream
distance = 0;
} else if (node_index_in_stream.find(prev_node->GetName()) == node_index_in_stream.end()) {
distance = 0;
} else {
distance = node_index_in_stream.at(node->GetName()) - node_index_in_stream.at(prev_node->GetName()) - 1;
}
// 不需要更新prev_node的input desc(因为prev_node设定的input index=-1, 设也设不上去),只需要后面刷一把自己的input_desc即可
// 更新访问信息
mem_block_visit_info[matched_mem_offset].first = node;
mem_block_visit_info[matched_mem_offset].second.clear();
mem_block_visit_info[matched_mem_offset].second.push_back(in_data_anchor->GetIdx());
} else {
if (prev_node_input_index_vec.empty()) {
GELOGW("Missing prev node[%s] input index.", prev_node->GetName().c_str());
continue;
}

if (prev_node == node) { // scene: multiple anchors of a node access the same data
vector<int64_t> prev_next_distances;
// 1. do not need to update prev_node's input desc, because it must be updated when the first anchor visit it.
// 2. we use the same distance of previous anchor to keep this value the same.
// 3. update visit info's prev_node_input_index_vec
if (prev_node->GetOpDesc() == nullptr) {
continue;
}
auto input_desc = prev_node->GetOpDesc()->GetInputDesc(prev_node_input_index_vec[0]);
if (!ge::AttrUtils::GetListInt(input_desc, ATTR_NAME_DATA_VISIT_DISTANCE, prev_next_distances)) {
GELOGW("Get ATTR_NAME_DATA_VISIT_DISTANCE failed.");
continue;
}
if (prev_next_distances.size() != 2) {
GELOGW("Size of prev_next_distance is not 2.");
continue;
} else {
distance = prev_next_distances[0]; //use the same prev_distance of previous anchor
}
mem_block_visit_info[matched_mem_offset].second.push_back(in_data_anchor->GetIdx());
} else {
// 走到此处,prev_node必然是访问者,如果cur node和prev node这两个访问者不在同一条流上,那么之前在IsOutputVisitedByMultiStream是就应该continue了,而不会走到这
distance = node_index_in_stream.at(node->GetName()) - node_index_in_stream.at(prev_node->GetName()) - 1;
UpdatePrevNodeInputDesc(prev_node, prev_node_input_index_vec, distance);
mem_block_visit_info[matched_mem_offset].first = node;
mem_block_visit_info[matched_mem_offset].second.clear();
mem_block_visit_info[matched_mem_offset].second.push_back(in_data_anchor->GetIdx());
}
}
UpdateCurNodeInputIndex(node, in_data_anchor->GetIdx(), distance);

auto input_desc = node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx());
bool is_end_of_inputmem_lifecycle = false;

// if is_end_of_inputmem_lifecycle is true, indicating that cur node is the last customer of this data,
// then we need to delete the visit info of the block in case that the memblock be reused and visited.
if (ge::AttrUtils::GetBool(input_desc, ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE, is_end_of_inputmem_lifecycle) &&
is_end_of_inputmem_lifecycle) {
GELOGD("ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE is true");
auto iter2 = mem_block_visit_info.find(matched_mem_offset);
if (iter2 != mem_block_visit_info.end()) {
mem_block_visit_info.erase(iter2);
}
}
}
}

void GraphMemoryAssigner::MarkDistanceAttr() {
map<size_t, pair<NodePtr, vector<int64_t>>> mem_block_visit_info; // key: mem_offset of the memory which we visited. value: node we visited and input index of this node
map<string, int64_t> node_index_in_stream; // key: node name, value: topo order of node in it's belonged stream(exclude ge_local_op)
map<int64_t, int64_t> stream_nodes_num; // key: stream id, value: cur nodes num in that stream

for (auto &node : compute_graph_->GetAllNodes()) {
auto node_op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(node_op_desc == nullptr, return);

int64_t stream_id = node_op_desc->GetStreamId();
if (node_op_desc->GetOpKernelLibName() != kEngineNameGeLocal) {
if (stream_nodes_num.find(stream_id) == stream_nodes_num.end()) {
stream_nodes_num.insert(std::make_pair(stream_id, 1));
} else {
++stream_nodes_num[stream_id];
}
node_index_in_stream.insert(std::make_pair(node->GetName(), stream_nodes_num[stream_id] - 1));

MarkNodeDistanceAttr(compute_graph_, node, mem_block_visit_info, node_index_in_stream);
} else {
GELOGD("in GraphMemoryAssigner::MarkDistanceAttr, node[%s] is ge_local_op", node->GetName().c_str());
}
}
}
} // namespace ge } // namespace ge

+ 7
- 0
ge/graph/build/memory/graph_mem_assigner.h View File

@@ -118,6 +118,13 @@ class GraphMemoryAssigner {


ge::Status AssignReferenceMemory(); ge::Status AssignReferenceMemory();


void MarkDistanceAttr();

void MarkNodeDistanceAttr(const ComputeGraphPtr &compute_graph,
NodePtr &node,
map<size_t, pair<NodePtr, vector<int64_t>>> &mem_block_visit_info,
const map<string, int64_t> &node_index_in_stream);

private: private:
/// ///
/// @ingroup ge_graph /// @ingroup ge_graph


+ 2
- 0
ge/graph/build/memory/memory_assigner.cc View File

@@ -65,6 +65,8 @@ Status MemoryAssigner::AssignMemory(bool is_loop_graph, map<int64_t, size_t> &me
GELOGE(FAILED, "CheckOffset Fail!"); GELOGE(FAILED, "CheckOffset Fail!");
return FAILED; return FAILED;
} }

graph_mem_assigner.MarkDistanceAttr();
return SUCCESS; return SUCCESS;
} }
} // namespace ge } // namespace ge

Loading…
Cancel
Save