From 0ab8f7c6ab5c84e916c978c6eee6e18729fa79b3 Mon Sep 17 00:00:00 2001 From: chuxing Date: Mon, 2 Nov 2020 19:51:52 +0800 Subject: [PATCH 1/3] shape inference for fused graph with GetInputConst --- ge/hybrid/model/hybrid_model_builder.cc | 82 +++++++++++++++++++++++++++++++++ ge/hybrid/model/hybrid_model_builder.h | 1 + 2 files changed, 83 insertions(+) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 812d822f..4148966f 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -58,6 +58,41 @@ int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { } return var_size; } + +Status CollectDependenciesForFusedGraph(NodeItem &node_item, std::set &data_ops) { + for (const auto &node : node_item.fused_subgraph->nodes) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const auto &depends = op_desc->GetOpInferDepends(); + if (depends.empty()) { + continue; + } + + for (auto &input_name : depends) { + auto input_index = op_desc->GetInputIndexByName(input_name); + const auto &in_anchor = ge_node->GetInDataAnchor(input_index); + GE_CHECK_NOTNULL(in_anchor); + const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + const auto &src_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto src_op_desc = src_node->GetOpDesc(); + GE_CHECK_NOTNULL(src_op_desc); + if (src_node->GetType() != DATA_TYPE) { + GELOGE(UNSUPPORTED, + "[%s::%s] Node in fused subgraph can only depend on Data nodes, but depend on %s", + node_item.NodeName().c_str(), + node->GetName().c_str(), + src_node->GetType().c_str()); + return UNSUPPORTED; + } + + data_ops.emplace(src_op_desc.get()); + } + } + + return SUCCESS; +} } // namespace HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) : hybrid_model_(hybrid_model), runtime_param_(hybrid_model.root_runtime_param_) { @@ -272,6 +307,53 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s node_item.dependents_for_shape_inference.emplace_back(dep_node); } + GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(noe_item)); + return SUCCESS; +} + +Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { + if (node_item.fused_subgraph == nullptr) { + return SUCCESS; + } + + std::set data_ops; + GE_CHK_STATUS_RET_NOLOG(CollectDependenciesForFusedGraph(node_item, data_ops)); + for (auto &op_desc : data_ops) { + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(INTERNAL_ERROR, + "[%s] Failed to get attr [%s]", + data_op_desc->GetName().c_str(), + ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return INTERNAL_ERROR; + } + + const auto &in_anchor = node_item.node->GetInDataAnchor(parent_index); + GE_CHECK_NOTNULL(in_anchor); + const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + const auto &src_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + NodeItem *src_node_item = nullptr; + GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(src_node, &src_node_item)); + op_desc->SetId(src_node_item->op_desc->GetId()); + GELOGD("[%S::%S] Node id was set to that of outer src node's, src_node = %s", + node_item.NodeName().c_str(), + op_desc->GetName().c_str(), + src_node_item->NodeName().c_str()); + src_node_item->has_observer = true; + src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); + + auto &depends = node_item.dependents_for_shape_inference; + if (std::find(depends.begin(), depends.end(), src_node) == depends.end()) { + depends.emplace_back(src_node); + GELOGD("[%s] Dependent added from output of [%s:%d]", + node_item.NodeName().c_str(), + src_node_item->NodeName().c_str(), + peer_out_anchor->GetIdx()); + } + } + return SUCCESS; } diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index d522939e..d78d622b 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -62,6 +62,7 @@ class HybridModelBuilder { Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); + Status ParseDependentForFusedSubgraph(NodeItem &node_item); Status IndexTaskDefs(); Status IndexSpecialNodes(); Status InitRuntimeParams(); From 7f5a15375121fee86e6ccdb098058ef004ec1a12 Mon Sep 17 00:00:00 2001 From: chuxing Date: Mon, 2 Nov 2020 21:46:46 +0800 Subject: [PATCH 2/3] fix typo --- ge/hybrid/model/hybrid_model_builder.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 4148966f..c9401c38 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -15,6 +15,7 @@ */ #include "hybrid/model/hybrid_model_builder.h" +#include #include "common/math/math_util.h" #include "graph/ge_context.h" #include "graph/build/memory/var_mem_assign_util.h" @@ -70,7 +71,7 @@ Status CollectDependenciesForFusedGraph(NodeItem &node_item, std::set for (auto &input_name : depends) { auto input_index = op_desc->GetInputIndexByName(input_name); - const auto &in_anchor = ge_node->GetInDataAnchor(input_index); + const auto &in_anchor = node->GetInDataAnchor(input_index); GE_CHECK_NOTNULL(in_anchor); const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out_anchor); @@ -307,7 +308,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s node_item.dependents_for_shape_inference.emplace_back(dep_node); } - GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(noe_item)); + GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item)); return SUCCESS; } @@ -320,10 +321,10 @@ Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { GE_CHK_STATUS_RET_NOLOG(CollectDependenciesForFusedGraph(node_item, data_ops)); for (auto &op_desc : data_ops) { uint32_t parent_index = 0; - if (!AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + if (!AttrUtils::GetInt(*op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { GELOGE(INTERNAL_ERROR, "[%s] Failed to get attr [%s]", - data_op_desc->GetName().c_str(), + op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); return INTERNAL_ERROR; } @@ -337,7 +338,7 @@ Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { NodeItem *src_node_item = nullptr; GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(src_node, &src_node_item)); op_desc->SetId(src_node_item->op_desc->GetId()); - GELOGD("[%S::%S] Node id was set to that of outer src node's, src_node = %s", + GELOGD("[%s::%s] Node id was set to that of outer src node's, src_node = %s", node_item.NodeName().c_str(), op_desc->GetName().c_str(), src_node_item->NodeName().c_str()); From 1efd304527bd5b8db930b6b66b1936b00d206e2f Mon Sep 17 00:00:00 2001 From: chuxing Date: Wed, 4 Nov 2020 20:02:50 +0800 Subject: [PATCH 3/3] fixing code review --- ge/hybrid/model/hybrid_model_builder.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index c9401c38..f47a02fd 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -71,11 +71,7 @@ Status CollectDependenciesForFusedGraph(NodeItem &node_item, std::set for (auto &input_name : depends) { auto input_index = op_desc->GetInputIndexByName(input_name); - const auto &in_anchor = node->GetInDataAnchor(input_index); - GE_CHECK_NOTNULL(in_anchor); - const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - const auto &src_node = peer_out_anchor->GetOwnerNode(); + auto src_node = NodeUtils::GetInDataNodeByIndex(*node, input_index); GE_CHECK_NOTNULL(src_node); auto src_op_desc = src_node->GetOpDesc(); GE_CHECK_NOTNULL(src_op_desc);