diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index eaa6afb6..42e9fcd8 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -126,20 +126,6 @@ Status CollectDependenciesForFusedGraph(NodeItem &node_item, std::set return SUCCESS; } - -bool CheckHasHostMem(NodeItem &node_item) { - if (node_item.NodeType() == DATA) { - auto op_desc = node_item.GetOpDesc(); - if (op_desc == nullptr) { - return false; - } - auto tensor = op_desc->MutableInputDesc(0); - if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) { - return true; - } - } - return false; -} } // namespace HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) : hybrid_model_(hybrid_model), runtime_param_(hybrid_model.root_runtime_param_) { @@ -298,6 +284,47 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt return SUCCESS; } +Status HybridModelBuilder::ParseDependencies(NodeItem &node_item, const std::vector &dependencies, + std::set &dependent_for_shape_inference) { + auto &ge_node = node_item.node; + for (const auto &input_name : dependencies) { + int input_index = node_item.op_desc->GetInputIndexByName(input_name); + if (input_index < 0) { + GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.", + node_item.NodeName().c_str(), input_name.c_str()); + REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.", + node_item.NodeName().c_str(), input_name.c_str()); + return INTERNAL_ERROR; + } + + const auto &in_anchor = ge_node->GetInDataAnchor(input_index); + GE_CHECK_NOTNULL(in_anchor); + const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + const auto &src_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto src_node_item = MutableNodeItem(src_node); + GE_CHECK_NOTNULL(src_node_item); + if (src_node_item->NodeType() == DATA) { + auto op_desc = src_node_item->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto tensor = op_desc->MutableInputDesc(0); + if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) { + GELOGD("Skip d2h memcpy, get hostmem from node %s.", src_node_item->NodeName().c_str()); + continue; + } + } + src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); + dependent_for_shape_inference.emplace(src_node); + host_input_value_dependencies_[&node_item].emplace_back(peer_out_anchor->GetIdx(), src_node_item); + GELOGD("[%s] Dependent added from output of [%s:%d]", + node_item.NodeName().c_str(), + src_node_item->NodeName().c_str(), + peer_out_anchor->GetIdx()); + } + return SUCCESS; +} + Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies) { std::set dependent_for_shape_inference; std::set dependent_for_execution; @@ -361,35 +388,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s src_node_item->NodeName().c_str()); } - for (const auto &input_name : dependencies) { - int input_index = node_item.op_desc->GetInputIndexByName(input_name); - if (input_index < 0) { - GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.", - node_item.NodeName().c_str(), input_name.c_str()); - REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.", - node_item.NodeName().c_str(), input_name.c_str()); - return INTERNAL_ERROR; - } - - const auto &in_anchor = ge_node->GetInDataAnchor(input_index); - GE_CHECK_NOTNULL(in_anchor); - const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - const auto &src_node = peer_out_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(src_node); - auto src_node_item = MutableNodeItem(src_node); - GE_CHECK_NOTNULL(src_node_item); - GE_IF_BOOL_EXEC(CheckHasHostMem(*src_node_item), - GELOGD("Skip d2h memcpy, get hostmem from node %s.", src_node_item->NodeName().c_str()); - continue;) - src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); - dependent_for_shape_inference.emplace(src_node); - host_input_value_dependencies_[&node_item].emplace_back(peer_out_anchor->GetIdx(), src_node_item); - GELOGD("[%s] Dependent added from output of [%s:%d]", - node_item.NodeName().c_str(), - src_node_item->NodeName().c_str(), - peer_out_anchor->GetIdx()); - } + GE_CHK_STATUS_RET(ParseDependencies(node_item, dependencies, dependent_for_shape_inference)); GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item, dependent_for_shape_inference)); for (const auto &dep_node : dependent_for_shape_inference) { diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 3e467dc8..c612c713 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -66,6 +66,8 @@ class HybridModelBuilder { Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); Status CollectParallelGroups(NodeItem *node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); + Status ParseDependencies(NodeItem &node_item, const std::vector &dependencies, + std::set &dependent_for_shape_inference); Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set &dependencies); Status ParseDependentByParallelGroup(); Status IndexTaskDefs(); diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 4eae475d..0cf249ff 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -670,3 +670,30 @@ TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) { ASSERT_TRUE(model.GetNodeItem(node)->has_observer); ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); } + +TEST_F(UtestGeHybrid, TestParseDependencies) { + // make graph + ut::GraphBuilder graph_builder = ut::GraphBuilder("graph"); + auto data = graph_builder.AddNode("Data", "Data", 0, 1); + auto netoutput = graph_builder.AddNode("Netoutput", "NetOutput", 1, 0); + graph_builder.AddDataEdge(data, 0, netoutput, 0); + auto graph = graph_builder.GetGraph(); + + GeRootModelPtr root_model = MakeShared(graph); + HybridModel model(root_model); + HybridModelBuilder builder(model); + + std::unique_ptr node_item; + NodeItem::Create(netoutput, node_item); + + std::vector deps; + deps.push_back("data"); + auto op_desc = netoutput->GetOpDesc(); + op_desc->input_name_idx_["Data"] = 0; + auto tensor = std::make_shared(); + auto tensor_desc = op_desc->MutableInputDesc(0); + AttrUtils::SetTensor(tensor_desc, "_value", tensor); + + std::set dependent_for_shape_inference; + ASSERT_EQ(builder.ParseDependencies(*node_item, deps, dependent_for_shape_inference), SUCCESS); +} \ No newline at end of file