/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "hybrid/model/node_item.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "hybrid/executor/worker/shape_inference_engine.h" #include "hybrid/node_executor/node_executor.h" namespace ge { namespace hybrid { namespace { const uint8_t kMaxTransCount = 3; const uint32_t kTransOpIoSize = 1; const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; const char *const kNodeTypeRetVal = "_RetVal"; const std::set kControlOpTypes{ IF, STATELESSIF, CASE, WHILE, STATELESSWHILE }; const std::set kControlFlowOpTypes{ STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX }; const std::set kMergeOpTypes{ MERGE, REFMERGE, STREAMMERGE }; bool IsEnterFeedNode(NodePtr node) { // For: Enter -> node // For: Enter -> Cast -> node // For: Enter -> TransData -> Cast -> node for (uint8_t i = 0; i < kMaxTransCount; ++i) { if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { GELOGD("Node[%s] is Enter feed node.", node->GetName().c_str()); return true; } const auto all_nodes = node->GetInDataNodes(); if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) { return false; } node = all_nodes.at(0); } return false; } Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { uint32_t parent_index = 0; if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { GELOGE(FAILED, "[Invoke][GetInt][%s] Failed to get attr [%s]", op_desc.GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); REPORT_CALL_ERROR("E19999", "[%s] Failed to get attr [%s]", op_desc.GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); return FAILED; } for (auto &node_and_anchor : node.GetOutDataNodesAndAnchors()) { auto dst_op_desc = node_and_anchor.first->GetOpDesc(); GE_CHECK_NOTNULL(dst_op_desc); auto in_idx = node_and_anchor.second->GetIdx(); auto tensor_desc = dst_op_desc->MutableInputDesc(in_idx); fused_subgraph.input_mapping[static_cast(parent_index)].emplace_back(tensor_desc); GELOGD("Input[%u] mapped to [%s:%u]", parent_index, dst_op_desc->GetName().c_str(), in_idx); } return SUCCESS; } Status ParseOutputMapping(const OpDescPtr &op_desc, FusedSubgraph &fused_subgraph) { uint32_t parent_index = 0; if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { GELOGE(FAILED, "[Invoke][GetInt][%s] Failed to get attr [%s]", op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); REPORT_CALL_ERROR("E19999", "[%s] Failed to get attr [%s].", op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); return FAILED; } fused_subgraph.output_mapping.emplace(static_cast(parent_index), op_desc); return SUCCESS; } Status ParseFusedSubgraph(NodeItem &node_item) { if (!node_item.op_desc->HasAttr(kAttrNameOriginalFusionGraph)) { return SUCCESS; } GELOGI("[%s] Start to parse fused subgraph.", node_item.node_name.c_str()); auto fused_subgraph = std::unique_ptr(new(std::nothrow)FusedSubgraph()); GE_CHECK_NOTNULL(fused_subgraph); ComputeGraphPtr fused_graph; (void) AttrUtils::GetGraph(*node_item.op_desc, kAttrNameOriginalFusionGraph, fused_graph); GE_CHECK_NOTNULL(fused_graph); fused_graph->SetGraphUnknownFlag(true); fused_subgraph->graph = fused_graph; GE_CHK_GRAPH_STATUS_RET(fused_graph->TopologicalSorting()); for (auto &node : fused_graph->GetAllNodes()) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); const std::string node_type = NodeUtils::GetNodeType(node); if (node_type == DATA) { GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); } else if (node_type == kNodeTypeRetVal) { GE_CHK_GRAPH_STATUS_RET(ParseOutputMapping(op_desc, *fused_subgraph)); } else { fused_subgraph->nodes.emplace_back(node); } } node_item.fused_subgraph = std::move(fused_subgraph); GELOGI("[%s] Done parsing fused subgraph successfully.", node_item.NodeName().c_str()); return SUCCESS; } } // namespace bool IsControlFlowV2Op(const std::string &op_type) { return kControlOpTypes.count(op_type) > 0; } NodeItem::NodeItem(NodePtr node) : node(std::move(node)) { this->op_desc = this->node->GetOpDesc().get(); this->node_name = this->node->GetName(); this->node_type = this->node->GetType(); } Status NodeItem::Create(const NodePtr &node, std::unique_ptr &node_item) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); std::unique_ptr instance(new(std::nothrow)NodeItem(node)); GE_CHECK_NOTNULL(instance); GE_CHK_STATUS_RET(instance->Init(), "[Invoke][Init]Failed to init NodeItem [%s] .", node->GetName().c_str()); node_item = std::move(instance); return SUCCESS; } void NodeItem::ResolveOptionalInputs() { if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) { has_optional_inputs = true; for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { const auto &input_desc = op_desc->MutableInputDesc(i); if (input_desc == nullptr) { GELOGD("[%s] Input[%zu] is optional and invalid", NodeName().c_str(), i); } else { input_desc_indices_.emplace_back(static_cast(i)); } } } } Status NodeItem::InitInputsAndOutputs() { GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX); GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX); num_inputs = static_cast(op_desc->GetInputsSize()); num_outputs = static_cast(op_desc->GetOutputsSize()); if (AttrUtils::GetInt(op_desc, ::ge::ATTR_STAGE_LEVEL, group)) { GELOGD("[%s] Got stage level from op_desc = %d", op_desc->GetName().c_str(), group); } else { if (node->GetOwnerComputeGraph() != nullptr) { if (AttrUtils::GetInt(node->GetOwnerComputeGraph(), ::ge::ATTR_STAGE_LEVEL, group)) { GELOGD("[%s] Got stage level from parent graph = %d", op_desc->GetName().c_str(), group); } else { auto parent_node = node->GetOwnerComputeGraph()->GetParentNode(); if ((parent_node != nullptr) && (AttrUtils::GetInt(parent_node->GetOpDesc(), ::ge::ATTR_STAGE_LEVEL, group))) { GELOGD("[%s] Got stage level from parent node = %d", op_desc->GetName().c_str(), group); } else { GELOGD("[%s] Node do not set stage level", op_desc->GetName().c_str()); } } } } ResolveOptionalInputs(); return SUCCESS; } Status NodeItem::ResolveDynamicState() { (void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic); GELOGD("Node name is %s, dynamic state is %d.", this->node_name.c_str(), is_dynamic); if (!is_dynamic) { GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_dynamic), "[Invoke][GetNodeUnknownShapeStatus][%s] Failed to get shape status.", node->GetName().c_str()); } return SUCCESS; } Status NodeItem::ResolveStaticInputsAndOutputs() { for (int i = 0; i < num_inputs; ++i) { // Data has unconnected input but set by framework if (node_type != DATA) { int origin_index = i; if (has_optional_inputs) { origin_index = input_desc_indices_[i]; } auto in_data_anchor = node->GetInDataAnchor(origin_index); GE_CHECK_NOTNULL(in_data_anchor); // If no node was connected to the current input anchor // increase num_static_input_shapes in case dead wait in ShapeInferenceState::AwaitShapesReady if (in_data_anchor->GetPeerOutAnchor() == nullptr || in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() == nullptr) { num_static_input_shapes++; is_input_shape_static_.push_back(true); GELOGW("[%s] Peer node of input[%d] is empty", NodeName().c_str(), i); continue; } } const auto &input_desc = MutableInputDesc(i); GE_CHECK_NOTNULL(input_desc); if (input_desc->MutableShape().IsUnknownShape()) { is_input_shape_static_.push_back(false); } else { num_static_input_shapes++; is_input_shape_static_.push_back(true); GELOGD("[%s] The shape of input[%d] is static. shape = [%s]", NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str()); } } for (int i = 0; i < num_outputs; ++i) { const auto &output_desc = op_desc->MutableOutputDesc(i); GE_CHECK_NOTNULL(output_desc); if (output_desc->MutableShape().IsUnknownShape()) { is_output_shape_static = false; break; } } if (is_output_shape_static) { GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(*this)); } return SUCCESS; } void NodeItem::ResolveUnknownShapeType() { if (IsControlFlowV2Op() || (is_dynamic && node_type == PARTITIONEDCALL)) { shape_inference_type = DEPEND_COMPUTE; } else { int32_t unknown_shape_type_val = 0; (void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val); shape_inference_type = static_cast(unknown_shape_type_val); } } Status NodeItem::Init() { is_ctrl_flow_v2_op_ = ge::hybrid::IsControlFlowV2Op(node_type); is_ctrl_flow_op_ = kControlFlowOpTypes.count(node_type) > 0; is_merge_op_ = kMergeOpTypes.count(node_type) > 0; is_root_node_ = node->GetInAllNodes().empty(); GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs()); GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState()); ResolveUnknownShapeType(); if (is_dynamic) { GE_CHK_STATUS_RET_NOLOG(ResolveStaticInputsAndOutputs()); GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[Invoke][ParseFusedSubgraph][%s] Failed to parse fused subgraph", node_name.c_str()); } copy_mu_ = MakeShared(); GE_CHECK_NOTNULL(copy_mu_); return SUCCESS; } bool NodeItem::IsHcclOp() const { return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL; } std::string NodeItem::DebugString() const { std::stringstream ss; ss << "Node: "; ss << "id = " << node_id; ss << ", name = [" << node->GetName(); ss << "], type = " << node->GetType(); ss << ", is_dynamic = " << (is_dynamic ? "True" : "False"); ss << ", is_output_static = " << (is_output_shape_static ? "True" : "False"); ss << ", unknown_shape_op_type = " << shape_inference_type; ss << ", stage = " << group; ss << ", input_start = " << input_start; ss << ", num_inputs = " << num_inputs; ss << ", output_start = " << output_start; ss << ", num_outputs = " << num_outputs; ss << ", dependent_nodes = ["; for (const auto &dep_node : dependents_for_shape_inference) { ss << dep_node->GetName() << ", "; } ss << "]"; int index = 0; for (auto &items : outputs) { ss << ", output[" << index++ << "]: "; for (auto &item : items) { ss << "(" << item.second->NodeName() << ":" << item.first << "), "; } } return ss.str(); } void NodeItem::SetToDynamic() { num_static_input_shapes = 0; is_dynamic = true; for (size_t i = 0; i < is_input_shape_static_.size(); ++i) { is_input_shape_static_[i] = false; } if (kernel_task != nullptr && !kernel_task->IsSupportDynamicShape()) { GELOGD("[%s] Dynamic shape is not supported, clear node task.", node_name.c_str()); kernel_task = nullptr; } } GeTensorDescPtr NodeItem::DoGetInputDesc(int index) const { if (!has_optional_inputs) { return op_desc->MutableInputDesc(static_cast(index)); } if (index < 0 || index >= num_inputs) { GELOGE(PARAM_INVALID, "[Check][Param:index][%s] Invalid input index, num inputs = %d, index = %d", node_name.c_str(), num_inputs, index); REPORT_INNER_ERROR("E19999", "Invalid input index, node:%s num inputs = %d, index = %d", node_name.c_str(), num_inputs, index); return nullptr; } return op_desc->MutableInputDesc(input_desc_indices_[index]); } GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { std::lock_guard lk(mu_); return DoGetInputDesc(index); } Status NodeItem::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { std::lock_guard lk(mu_); auto input_desc = DoGetInputDesc(index); GE_CHECK_NOTNULL(input_desc); tensor_desc = *input_desc; return SUCCESS; } Status NodeItem::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { std::lock_guard lk(mu_); auto output_desc = op_desc->MutableOutputDesc(static_cast(index)); GE_CHECK_NOTNULL(output_desc); tensor_desc = *output_desc; return SUCCESS; } GeTensorDescPtr NodeItem::MutableOutputDesc(int index) const { std::lock_guard lk(mu_); return op_desc->MutableOutputDesc(static_cast(index)); } Status NodeItem::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { std::lock_guard lk(mu_); auto input_desc = DoGetInputDesc(index); GE_CHECK_NOTNULL(input_desc); *input_desc = tensor_desc; return SUCCESS; } Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const { if (!has_optional_inputs) { canonical_index = index; return SUCCESS; } auto iter = std::find(input_desc_indices_.begin(), input_desc_indices_.end(), index); if (iter == input_desc_indices_.end()) { GELOGE(INTERNAL_ERROR, "[Check][Param:index]input index:%u not in input_desc_indices_, check Invalid, node:%s", index, node_name.c_str()); REPORT_INNER_ERROR("E19999", "input index:%u not in input_desc_indices_, check Invalid, node:%s", index, node_name.c_str()); return INTERNAL_ERROR; } canonical_index = static_cast(iter - input_desc_indices_.begin()); GELOGD("[%s] Canonicalize input index from [%u] to [%d]", node_name.c_str(), index, canonical_index); return SUCCESS; } bool NodeItem::IsInputShapeStatic(int index) const { if (!is_dynamic) { return true; } if (static_cast(index) >= is_input_shape_static_.size()) { GELOGE(PARAM_INVALID, "[Check][Param:index]Input index(%d) out of range: [0, %zu)", index, is_input_shape_static_.size()); REPORT_INNER_ERROR("E19999", "Input index(%d) out of range: [0, %zu).", index, is_input_shape_static_.size()); return false; } return is_input_shape_static_[index]; } void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { data_send_.emplace(node_item); node_item->data_recv_[this] = anchor_index; if (is_root_node_) { auto &data_anchors = node_item->root_data_[this]; data_anchors.emplace(anchor_index); } // If Enter feed Not Merge, take as root Node. if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) { auto &data_anchors = node_item->enter_data_[this]; data_anchors.emplace(anchor_index); } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); } void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { if (switch_index < switch_groups_.size()) { auto &switch_group = switch_groups_[switch_index]; switch_group.emplace(node_item); } else { ctrl_send_.insert(node_item); } node_item->ctrl_recv_.emplace(this); if (is_root_node_) { node_item->root_ctrl_.emplace(this); } // If Enter feed control signal, take as root Node. if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { node_item->enter_ctrl_.emplace(this); } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); } void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { if (merge_index >= switch_groups_.size()) { GELOGE(FAILED, "[%s] group size: %zu, merge index: %u", NodeName().c_str(), switch_groups_.size(), merge_index); return; } // this is StreamMerge node, node_item is StreamActive node. auto &switch_group = switch_groups_[merge_index]; switch_group.emplace(node_item); node_item->ctrl_send_.emplace(this); GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); } size_t NodeItem::GetMergeCtrl(uint32_t merge_index) const { return ((node_type == STREAMMERGE) && (merge_index < switch_groups_.size())) ? switch_groups_[merge_index].size() : 0; } OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) { if (mu_ != nullptr) { GELOGD("lock for %s", name_.c_str()); mu_->lock(); } } OptionalMutexGuard::~OptionalMutexGuard() { if (mu_ != nullptr) { GELOGD("unlock for %s", name_.c_str()); mu_->unlock(); mu_ = nullptr; } } } // namespace hybrid } // namespace ge