/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/node_utils.h" #include "graph/utils/graph_utils.h" #include "debug/ge_op_types.h" #include "debug/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" #include "graph/debug/ge_attr_define.h" #include "graph/types.h" #include "utils/tensor_utils.h" #include "utils/type_utils.h" namespace ge { std::map> NodeUtils::map_send_info_{}; std::map> NodeUtils::map_recv_info_{}; const std::set kConstOpTypes = {"Const", "Constant"}; const std::set kIfOpTypes = {"If", "_If", "StatelessIf"}; const std::set kWhileOpTypes = {"While", "_While", "StatelessWhile"}; const std::set kCaseOpTypes = {"Case"}; const std::set kForOpTypes = {"For"}; bool OpShapeIsUnknown(const OpDescPtr &desc) { for (const auto &ptr : desc->GetAllInputsDescPtr()) { auto ge_shape = ptr->GetShape(); for (const auto &dim : ge_shape.GetDims()) { if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { return true; } } } for (const auto &ptr : desc->GetAllOutputsDescPtr()) { auto ge_shape = ptr->GetShape(); for (const auto &dim : ge_shape.GetDims()) { if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { return true; } } } return false; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node, const uint32_t &event_id) { GE_CHECK_NOTNULL(node); map_send_info_[node].push_back(event_id); return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node, const uint32_t &event_id) { GE_CHECK_NOTNULL(node); map_recv_info_[node].push_back(event_id); return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector &vec_send) { GE_CHECK_NOTNULL(node); auto find = map_send_info_.find(node); if (find == map_send_info_.end()) { return GRAPH_FAILED; } else { vec_send = find->second; return GRAPH_SUCCESS; } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector &vec_recv) { GE_CHECK_NOTNULL(node); auto find = map_recv_info_.find(node); if (find == map_recv_info_.end()) { return GRAPH_FAILED; } else { vec_recv = find->second; return GRAPH_SUCCESS; } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() { map_send_info_.clear(); return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() { map_recv_info_.clear(); return GRAPH_SUCCESS; } graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) { GE_CHECK_NOTNULL(src); NodePtr cur_ptr; if (depth < 1) { return GRAPH_FAILED; } for (int i = 0; i < depth; i++) { if (src->GetOutDataNodes().size() != 1) { return GRAPH_FAILED; } cur_ptr = src->GetOutDataNodes().at(0); GE_CHECK_NOTNULL(cur_ptr); } dst = cur_ptr; return GRAPH_SUCCESS; } graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data, InControlAnchorPtr &in_control) { GE_CHECK_NOTNULL(node_ptr); for (const auto &p : node_ptr->GetAllOutDataAnchors()) { GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr"); for (const auto &p_in : p->GetPeerInControlAnchors()) { GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr"); out_data = p; in_control = p_in; return GRAPH_SUCCESS; } } return GRAPH_FAILED; } graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, "node or in_data_anchor is nullptr"); bool find_flag = false; uint32_t index = 0; vector::iterator it = node_ptr->in_data_anchors_.end(); for (const auto &tmp : node_ptr->in_data_anchors_) { if (tmp == in_data_anchor) { find_flag = true; auto iter = node_ptr->in_data_anchors_.begin() + index; if (iter != node_ptr->in_data_anchors_.end()) { it = node_ptr->in_data_anchors_.erase(iter); } break; } index++; } for (; it != node_ptr->in_data_anchors_.end(); ++it) { (*it)->SetIdx(index); index++; } if (!find_flag) { return GRAPH_FAILED; } return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) { GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr"); GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed"); return GRAPH_SUCCESS; } graphStatus NodeUtils::SetAllAnchorStatus(Node &node) { node.anchor_status_updated_ = true; return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) { GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr"); return IsAnchorStatusSet(*node_ptr); } bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; } graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) { if ((origin_node == nullptr) || (new_node == nullptr)) { return GRAPH_FAILED; } auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors(); auto new_out_data_anchors = new_node->GetAllOutDataAnchors(); if (origin_out_data_anchors.size() != new_out_data_anchors.size()) { return GRAPH_FAILED; } for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) { for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) { GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, "unlink peer_anchor failed"); GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, "linkto peer_anchor failed"); } for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) { GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, "unlink peer_anchor failed"); GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, "linkto peer_anchor failed"); } } auto origin_out_control_anchor = origin_node->GetOutControlAnchor(); GE_CHECK_NOTNULL(origin_out_control_anchor); auto new_out_control_anchor = new_node->GetOutControlAnchor(); GE_CHECK_NOTNULL(new_out_control_anchor); for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) { GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, "linkto peer_anchor failed"); } for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) { GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, "linkto peer_anchor failed"); } origin_out_control_anchor->UnlinkAll(); return GRAPH_SUCCESS; } bool NodeUtils::IsConst(const Node &node) { auto src_node_type = node.GetType(); bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP)); return is_const; } void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) { if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "node is null"); return; } UpdateIsInputConst(*node_ptr); } /// /// update is_input_const /// @param node /// @return void /// void NodeUtils::UpdateIsInputConst(Node &node) { std::vector is_input_const; size_t anchor_num = node.GetAllInDataAnchors().size(); for (size_t i = 0; i < anchor_num; i++) { auto in_anchor = node.GetInDataAnchor(static_cast(i)); if (in_anchor == nullptr) { is_input_const.push_back(false); continue; } auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { is_input_const.push_back(false); continue; } auto src_node = peer_out_anchor->GetOwnerNode(); if (src_node == nullptr) { is_input_const.push_back(false); continue; } if (IsConst(*(src_node))) { is_input_const.push_back(true); } else { is_input_const.push_back(false); } } if (node.GetOpDesc() == nullptr) { GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr"); return; } node.GetOpDesc()->SetIsInputConst(is_input_const); } void NodeUtils::UnlinkAll(const Node &node) { for (const auto &anchor : node.GetAllOutAnchors()) { anchor->UnlinkAll(); } for (const auto &anchor : node.GetAllInAnchors()) { anchor->UnlinkAll(); } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) { if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); return GRAPH_FAILED; } auto op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { return GRAPH_FAILED; } bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); if (is_unknown_graph) { return GRAPH_SUCCESS; } for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); output_tensor->SetOriginShape(output_tensor->GetShape()); output_tensor->SetOriginDataType(output_tensor->GetDataType()); GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); continue; } auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); if (peer_input_desc == nullptr) { GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); continue; } GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), output_tensor->GetDataType(), output_tensor->GetOriginDataType()); peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); peer_input_desc->SetShape(output_tensor->GetShape()); peer_input_desc->SetDataType(output_tensor->GetDataType()); peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); std::vector> shape_range; (void)output_tensor->GetShapeRange(shape_range); peer_input_desc->SetShapeRange(shape_range); ge::TensorUtils::SetRealDimCnt(*peer_input_desc, static_cast(output_tensor->GetShape().GetDims().size())); GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); } } return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, uint32_t index) { if (node == nullptr) { GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); return GRAPH_FAILED; } GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); OpDescPtr op_desc = node->op_; for (size_t i = op_desc->GetInputsSize(); i < index; ++i) { if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Add input desc failed"); return GRAPH_FAILED; } auto anchor = ComGraphMakeShared(node, i); if (anchor == nullptr) { GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed."); return GRAPH_FAILED; } node->in_data_anchors_.push_back(anchor); } return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, uint32_t index) { if (node == nullptr) { GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); return GRAPH_FAILED; } OpDescPtr op_desc = node->op_; op_desc->RemoveInputDesc(index); while (node->in_data_anchors_.size() > index) { node->in_data_anchors_.pop_back(); } return GRAPH_SUCCESS; } bool NodeUtils::IsInNodesEmpty(const Node &node) { for (const auto &in_anchor : node.in_data_anchors_) { if (in_anchor != nullptr) { auto out_anchor = in_anchor->GetPeerOutAnchor(); if (out_anchor != nullptr) { if (out_anchor->GetOwnerNode() != nullptr) { return false; } } } } if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) { auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors(); for (const auto &out_control_anchor : peer_out_control_anchors) { if (out_control_anchor != nullptr) { if (out_control_anchor->GetOwnerNode() != nullptr) { return false; } } } } return true; } GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) { auto desc = node.GetOpDesc(); if (desc == nullptr) { return GeTensorDesc(); } return desc->GetOutputDesc(index); } GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) { auto desc = node.GetOpDesc(); if (desc == nullptr) { return GeTensorDesc(); } return desc->GetInputDesc(index); } graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) { auto desc = node.GetOpDesc(); if (desc == nullptr) { return GRAPH_PARAM_INVALID; } auto output_desc = desc->MutableOutputDesc(index); if (output_desc == nullptr) { return GRAPH_PARAM_INVALID; } output_desc->SetShape(shape); return GRAPH_SUCCESS; } graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) { auto desc = node.GetOpDesc(); if (desc == nullptr) { return GRAPH_PARAM_INVALID; } auto input_desc = desc->MutableInputDesc(index); if (input_desc == nullptr) { return GRAPH_PARAM_INVALID; } input_desc->SetShape(shape); return GRAPH_SUCCESS; } graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { auto desc = node.GetOpDesc(); GE_CHECK_NOTNULL(desc); // check self is_unknow = OpShapeIsUnknown(desc); if (is_unknow) { return GRAPH_SUCCESS; } auto sub_graph_names = desc->GetSubgraphInstanceNames(); if (sub_graph_names.empty()) { return GRAPH_SUCCESS; } else { auto owner_graph = node.GetOwnerComputeGraph(); GE_CHECK_NOTNULL(owner_graph); auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); if (root_graph == nullptr) { GE_LOGE("Node %s gets null root graph", node.GetName().c_str()); return GRAPH_PARAM_INVALID; } for (auto &sub_graph_name : sub_graph_names) { auto sub_graph = root_graph->GetSubgraph(sub_graph_name); GE_CHECK_NOTNULL(sub_graph); for (const auto &node_ptr : sub_graph->GetDirectNode()) { auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow); if (status != GRAPH_SUCCESS) { GE_LOGE("get node unknown shape status failed!"); return status; } if (is_unknow) { return GRAPH_SUCCESS; } } } } return GRAPH_SUCCESS; } std::string NodeUtils::GetNodeType(const Node &node) { if (node.GetType() != FRAMEWORKOP) { return node.GetType(); } std::string type; (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); return type; } ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { auto op_desc = node.GetOpDesc(); if (op_desc == nullptr) { return nullptr; } auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); if (root_graph == nullptr) { return nullptr; } return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); } graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) { if (subgraph == nullptr) { GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index); return GRAPH_PARAM_INVALID; } auto op_desc = node.GetOpDesc(); if (op_desc == nullptr) { return GRAPH_PARAM_INVALID; } auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); if (root_graph == nullptr) { GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); return GRAPH_PARAM_INVALID; } auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); if (ret != GRAPH_SUCCESS) { GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); return ret; } subgraph->SetParentNode(node.shared_from_this()); subgraph->SetParentGraph(node.GetOwnerComputeGraph()); return root_graph->AddSubgraph(subgraph); } /// /// Check if node is input of subgraph /// @param [in] node /// @return bool /// bool NodeUtils::IsSubgraphInput(const NodePtr &node) { if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { return false; } auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); if (parent_op_desc == nullptr) { return false; } if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { bool is_unknown_shape = false; (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); if (is_unknown_shape) return false; } if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) && kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 && kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) { return false; } return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); } /// /// Check if node is output of subgraph /// @param [in] node /// @return bool /// bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) { return false; } auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); if (parent_op_desc == nullptr) { return false; } if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { bool is_unknown_shape = false; (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); if (is_unknown_shape) return false; } if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) && kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 && kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) { return false; } for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { return true; } } return false; } /// /// @brief Get subgraph original input node. /// @param [in] node /// @return Node /// NodePtr NodeUtils::GetParentInput(const NodePtr &node) { GE_CHECK_NOTNULL_EXEC(node, return nullptr); uint32_t parent_index = 0; if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { return nullptr; } // Subgraph Data Node, check for constant input. const ComputeGraphPtr &graph = node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL_EXEC(graph, return nullptr); const NodePtr &parent_node = graph->GetParentNode(); GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr); const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index); GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr); const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr); return peer_out_anchor->GetOwnerNode(); } /// /// @brief Check is varying_input for while node /// @param [in] node: Data node for subgraph /// @return bool /// bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { if (node == nullptr) { return false; } if (node->GetType() != DATA) { return false; // not input_node for subgraph } const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); if (parent_node == nullptr) { return false; // root graph } if (kWhileOpTypes.count(parent_node->GetType()) == 0) { return false; // not input_node for while subgraph } uint32_t index_i = 0; if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); return false; } bool varying_flag = true; for (const auto &item : node->GetOutDataNodesAndAnchors()) { if (item.first->GetType() != NETOUTPUT) { continue; } OpDescPtr op_desc = item.first->GetOpDesc(); uint32_t index_o = 0; if ((op_desc == nullptr) || !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { continue; // input for while-cond subgraph } if (index_i != index_o) { continue; // varying input for while-body subgraph } varying_flag = false; break; } return varying_flag; } /// /// @brief Get subgraph input is constant. /// @param [in] node /// @param [out] string /// @return bool /// bool NodeUtils::GetConstOpType(const NodePtr &in_node, std::string &op_type) { GE_CHECK_NOTNULL_EXEC(in_node, return false); if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { op_type = in_node->GetType(); return true; } if (in_node->GetType() == DATA) { std::string const_type; if (!AttrUtils::GetStr(in_node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) { return false; } if ((const_type == CONSTANT) || (const_type == CONSTANTOP)) { op_type = const_type; return true; } } return false; } /// /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. /// @param [in] node /// @return return GRAPH_SUCCESS if remove successfully, other for failed. /// Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); auto subgraph_names = op_desc->GetSubgraphInstanceNames(); if (subgraph_names.empty()) { return GRAPH_SUCCESS; } else { auto owner_graph = node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(owner_graph); auto root_graph = GraphUtils::FindRootGraph(owner_graph); GE_CHECK_NOTNULL(root_graph); std::unordered_set subgraph_to_remove; for (auto &subgraph_name : subgraph_names) { std::deque queue; queue.push_back(subgraph_name); subgraph_to_remove.insert(subgraph_name); op_desc->RemoveSubgraphInstanceName(subgraph_name); while (!queue.empty()) { auto graph_name = queue.front(); queue.pop_front(); auto subgraph = root_graph->GetSubgraph(graph_name); GE_CHECK_NOTNULL(subgraph); for (const auto &sub_node : subgraph->GetDirectNode()) { auto sub_op_desc = sub_node->GetOpDesc(); GE_CHECK_NOTNULL(sub_op_desc); auto sub_names = sub_op_desc->GetSubgraphInstanceNames(); // Subgraph and all nodes in it will be removed later, // no need to remove 'SubgraphInstanceName' in op desc here. for (auto &name : sub_names) { if (subgraph_to_remove.insert(name).second) { queue.push_back(name); } } } } } // Remove subgraph from root_graph for (const auto &name : subgraph_to_remove) { GELOGI("Remove subgraph:%s.", name.c_str()); root_graph->RemoveSubgraph(name); } } return GRAPH_SUCCESS; } /// /// @brief Get subgraph input data node by index. /// @param [in] node /// @return Node /// vector NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { vector in_data_node_vec; auto op_desc = node.GetOpDesc(); GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); auto subgraph_names = op_desc->GetSubgraphInstanceNames(); if (subgraph_names.empty()) { GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); return in_data_node_vec; } auto compute_graph = node.GetOwnerComputeGraph(); for (const std::string &instance_name : subgraph_names) { auto subgraph = compute_graph->GetSubgraph(instance_name); for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { int parent_index = -1; if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); if (parent_index == index) { in_data_node_vec.emplace_back(node_in_subgraph); } } } } return in_data_node_vec; } /// /// @brief Get subgraph input data node by index. /// @param [in] node /// @return Node /// vector NodeUtils::GetSubgraphOutputNodes(const Node &node) { vector out_data_node_vec; auto op_desc = node.GetOpDesc(); GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); auto subgraph_names = op_desc->GetSubgraphInstanceNames(); if (subgraph_names.empty()) { GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); return out_data_node_vec; } auto compute_graph = node.GetOwnerComputeGraph(); for (const std::string &instance_name : subgraph_names) { auto subgraph = compute_graph->GetSubgraph(instance_name); for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { out_data_node_vec.emplace_back(node_in_subgraph); } } } return out_data_node_vec; } NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) { if (node.GetInDataAnchor(index) == nullptr) { return nullptr; } if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { return nullptr; } return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); } vector NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) { vector out_data_nodes; auto out_data_anchor = node.GetOutDataAnchor(index); if (out_data_anchor == nullptr) { return out_data_nodes; } for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { if (peer_in_anchor == nullptr) { continue; } if (peer_in_anchor->GetOwnerNode() == nullptr) { continue; } out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); } return out_data_nodes; } } // namespace ge