From 0ab8f7c6ab5c84e916c978c6eee6e18729fa79b3 Mon Sep 17 00:00:00 2001 From: chuxing Date: Mon, 2 Nov 2020 19:51:52 +0800 Subject: [PATCH] shape inference for fused graph with GetInputConst --- ge/hybrid/model/hybrid_model_builder.cc | 82 +++++++++++++++++++++++++++++++++ ge/hybrid/model/hybrid_model_builder.h | 1 + 2 files changed, 83 insertions(+) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 812d822f..4148966f 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -58,6 +58,41 @@ int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { } return var_size; } + +Status CollectDependenciesForFusedGraph(NodeItem &node_item, std::set &data_ops) { + for (const auto &node : node_item.fused_subgraph->nodes) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const auto &depends = op_desc->GetOpInferDepends(); + if (depends.empty()) { + continue; + } + + for (auto &input_name : depends) { + auto input_index = op_desc->GetInputIndexByName(input_name); + const auto &in_anchor = ge_node->GetInDataAnchor(input_index); + GE_CHECK_NOTNULL(in_anchor); + const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + const auto &src_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto src_op_desc = src_node->GetOpDesc(); + GE_CHECK_NOTNULL(src_op_desc); + if (src_node->GetType() != DATA_TYPE) { + GELOGE(UNSUPPORTED, + "[%s::%s] Node in fused subgraph can only depend on Data nodes, but depend on %s", + node_item.NodeName().c_str(), + node->GetName().c_str(), + src_node->GetType().c_str()); + return UNSUPPORTED; + } + + data_ops.emplace(src_op_desc.get()); + } + } + + return SUCCESS; +} } // namespace HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) : hybrid_model_(hybrid_model), runtime_param_(hybrid_model.root_runtime_param_) { @@ -272,6 +307,53 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s node_item.dependents_for_shape_inference.emplace_back(dep_node); } + GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(noe_item)); + return SUCCESS; +} + +Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { + if (node_item.fused_subgraph == nullptr) { + return SUCCESS; + } + + std::set data_ops; + GE_CHK_STATUS_RET_NOLOG(CollectDependenciesForFusedGraph(node_item, data_ops)); + for (auto &op_desc : data_ops) { + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(INTERNAL_ERROR, + "[%s] Failed to get attr [%s]", + data_op_desc->GetName().c_str(), + ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return INTERNAL_ERROR; + } + + const auto &in_anchor = node_item.node->GetInDataAnchor(parent_index); + GE_CHECK_NOTNULL(in_anchor); + const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + const auto &src_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + NodeItem *src_node_item = nullptr; + GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(src_node, &src_node_item)); + op_desc->SetId(src_node_item->op_desc->GetId()); + GELOGD("[%S::%S] Node id was set to that of outer src node's, src_node = %s", + node_item.NodeName().c_str(), + op_desc->GetName().c_str(), + src_node_item->NodeName().c_str()); + src_node_item->has_observer = true; + src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); + + auto &depends = node_item.dependents_for_shape_inference; + if (std::find(depends.begin(), depends.end(), src_node) == depends.end()) { + depends.emplace_back(src_node); + GELOGD("[%s] Dependent added from output of [%s:%d]", + node_item.NodeName().c_str(), + src_node_item->NodeName().c_str(), + peer_out_anchor->GetIdx()); + } + } + return SUCCESS; } diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index d522939e..d78d622b 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -62,6 +62,7 @@ class HybridModelBuilder { Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); + Status ParseDependentForFusedSubgraph(NodeItem &node_item); Status IndexTaskDefs(); Status IndexSpecialNodes(); Status InitRuntimeParams();