Browse Source

Fix Label for dynamic graph.

tags/v1.1.0
unknown 4 years ago
parent
commit
aa6dcd9262
3 changed files with 29 additions and 13 deletions
  1. +21
    -10
      ge/graph/build/label_allocator.cc
  2. +2
    -2
      ge/graph/build/model_builder.cc
  3. +6
    -1
      ge/graph/passes/memcpy_addr_async_pass.cc

+ 21
- 10
ge/graph/build/label_allocator.cc View File

@@ -32,11 +32,6 @@ Status LabelAllocator::AssignFunctionalLabels() {
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


if (compute_graph_->GetGraphUnknownFlag()) {
GELOGD("Graph[%s] is unknown graph, skip label allocator.", compute_graph_->GetName().c_str());
return SUCCESS;
}

// Add label task for sub graph. // Add label task for sub graph.
GELOGI("AssignFunctionalLabels start: %s.", compute_graph_->GetName().c_str()); GELOGI("AssignFunctionalLabels start: %s.", compute_graph_->GetName().c_str());
std::set<NodePtr> functional_nodes; std::set<NodePtr> functional_nodes;
@@ -62,7 +57,7 @@ Status LabelAllocator::AssignFunctionalLabels() {
} }


(void)AttrUtils::SetInt(*compute_graph_, ATTR_MODEL_LABEL_NUM, label_index); (void)AttrUtils::SetInt(*compute_graph_, ATTR_MODEL_LABEL_NUM, label_index);
GELOGI("AssignFunctionalLabels success.");
GELOGI("AssignFunctionalLabels success, Num: %u.", label_index);
return SUCCESS; return SUCCESS;
} }


@@ -72,13 +67,29 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::set<Node
return false; return false;
} }


NodePtr parent = graph->GetParentNode();
if (parent == nullptr) {
GELOGE(INTERNAL_ERROR, "ComputeGraph owner not set: %s.", graph->GetName().c_str());
if (graph->GetGraphUnknownFlag()) {
GELOGD("Graph[%s] is unknown graph, skip label allocator.", graph->GetName().c_str());
return true;
}

NodePtr func_node = graph->GetParentNode();
if (func_node == nullptr) {
GELOGE(INTERNAL_ERROR, "Parent functional node not set: %s.", graph->GetName().c_str());
return false; return false;
} }


(void)functional_nodes.insert(parent); // unique functional node.
ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph();
if (owner_graph == nullptr) {
GELOGE(INTERNAL_ERROR, "ComputeGraph owner not set: %s.", func_node->GetName().c_str());
return false;
}

if (owner_graph->GetGraphUnknownFlag()) {
GELOGD("Graph[%s] is unknown graph, skip label allocator.", owner_graph->GetName().c_str());
return true;
}

(void)functional_nodes.insert(func_node); // unique functional node.
return true; return true;
} }
} // namespace ge } // namespace ge

+ 2
- 2
ge/graph/build/model_builder.cc View File

@@ -690,8 +690,8 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) {
GE_TIMESTAMP_END(AssignLogicalStreams, "GraphBuilder::AssignLogicalStreams"); GE_TIMESTAMP_END(AssignLogicalStreams, "GraphBuilder::AssignLogicalStreams");


// Assign functional op labels. // Assign functional op labels.
label_num_ = 0;
(void)AttrUtils::GetInt(*compute_graph_, ATTR_MODEL_LABEL_NUM, label_num_);
auto root_graph = GraphUtils::FindRootGraph(compute_graph_);
(void)AttrUtils::GetInt(*root_graph, ATTR_MODEL_LABEL_NUM, label_num_);


GE_TIMESTAMP_START(AssignMemory); GE_TIMESTAMP_START(AssignMemory);
MemoryAssigner mem_assigner(compute_graph_); MemoryAssigner mem_assigner(compute_graph_);


+ 6
- 1
ge/graph/passes/memcpy_addr_async_pass.cc View File

@@ -25,6 +25,10 @@
namespace ge { namespace ge {
Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(graph);
if (graph->GetGraphUnknownFlag()) {
GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str());
return SUCCESS;
}


int64_t value = 0; int64_t value = 0;
rtError_t rt_ret = rtGetRtCapability(FEATURE_TYPE_MEMCPY, MEMCPY_INFO_SUPPORT_ZEROCOPY, &value); rtError_t rt_ret = rtGetRtCapability(FEATURE_TYPE_MEMCPY, MEMCPY_INFO_SUPPORT_ZEROCOPY, &value);
@@ -201,9 +205,10 @@ NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &gr
const OutDataAnchorPtr &out_data_anchor, const OutDataAnchorPtr &out_data_anchor,
const NodePtr &out_of_user_data) { const NodePtr &out_of_user_data) {
GELOGD("Start CreateMemcpyAddrAsyncNode."); GELOGD("Start CreateMemcpyAddrAsyncNode.");
static uint32_t new_node_index = 0;
OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc();
GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid."); GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid.");
std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC;
std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC + "_" + std::to_string(new_node_index++);


OpDescPtr op_desc = MakeShared<OpDesc>(node_name, MEMCPYADDRASYNC); OpDescPtr op_desc = MakeShared<OpDesc>(node_name, MEMCPYADDRASYNC);
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);


Loading…
Cancel
Save