/** * Copyright 2020-2021 Huawei Technologies Co., Ltd *+ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/infershape_pass.h" #include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "analyzer/analyzer.h" #include "framework/common/util.h" #include "graph/shape_refiner.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" #include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "external/graph/operator_factory.h" namespace ge { namespace { constexpr int kSwitchExitAnchorIndex = 0; constexpr int kSwitchPredAnchorIndex = 1; void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { desc_str += "["; std::vector> shape_range; (void)desc->GetShapeRange(shape_range); for (const auto &pair : shape_range) { desc_str += "{"; desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); desc_str += "},"; } desc_str += "]"; shape_range.clear(); (void)desc->GetOriginShapeRange(shape_range); for (const auto &pair : shape_range) { desc_str += ",{"; desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); desc_str += "},"; } } void UpdateShapeAndDType(const GeTensorDescPtr &src, GeTensorDescPtr &dst) { dst->SetOriginShape(src->GetOriginShape()); dst->SetShape(src->GetShape()); dst->SetDataType(src->GetDataType()); dst->SetOriginDataType(src->GetOriginDataType()); vector> src_shape_range; src->GetShapeRange(src_shape_range); dst->SetShapeRange(src_shape_range); dst->SetOriginShapeRange(src_shape_range); ge::TensorUtils::SetRealDimCnt(*dst, static_cast(src->GetShape().GetDims().size())); } } // namespace std::string InferShapePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const { std::stringstream ss; ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),"; ss << "(format:" << TypeUtils::FormatToSerialString(tensor_desc->GetFormat()) << "),"; ss << "(dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()) << "),"; ss << "(origin_shape:" << tensor_desc->GetOriginShape().ToString() << "),"; ss << "(origin_format:" << TypeUtils::FormatToSerialString(tensor_desc->GetOriginFormat()) << "),"; ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetOriginDataType()) << "),"; string range_str; SerialShapeRange(tensor_desc, range_str); ss << "(shape_range:" << range_str << ")"; return ss.str(); } Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) { if (node->GetType() != SWITCH) { return SUCCESS; } auto pred_node = NodeUtils::GetInDataNodeByIndex(*node, kSwitchPredAnchorIndex); GE_CHECK_NOTNULL(pred_node); if (pred_node->GetType() != LOOPCOND) { return SUCCESS; } for (const auto &anchor_2_node : NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kSwitchExitAnchorIndex)) { GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(), anchor_2_node.second->GetType().c_str()); auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()]; if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) { suspend_nodes.nodes.push(anchor_2_node.second); AddNodeSuspend(anchor_2_node.second); } } return SUCCESS; } Status InferShapePass::Infer(NodePtr &node) { auto ret = InferShapeAndType(node); if (ret != GRAPH_SUCCESS) { auto graph = node->GetOwnerComputeGraph(); GE_CHECK_NOTNULL(graph); auto root_graph = ge::GraphUtils::FindRootGraph(graph); GE_CHECK_NOTNULL(root_graph); analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), root_graph->GetGraphID()); REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed", node->GetName().c_str(), node->GetType().c_str()); GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed", node->GetName().c_str(), node->GetType().c_str()); return GE_GRAPH_INFERSHAPE_FAILED; } return SUCCESS; } graphStatus InferShapePass::InferShapeAndType(NodePtr &node) { auto ret = SuspendV1LoopExitNodes(node); if (ret != SUCCESS) { GELOGE(ret, "Suspend V1 loop exit nodes failed."); return ret; } bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); auto opdesc = node->GetOpDesc(); if (node->Verify() != GRAPH_SUCCESS) { REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str()); GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str()); return GRAPH_FAILED; } Operator op = OpDescUtils::CreateOperatorFromNode(node); if (!is_unknown_graph) { auto inference_context = ShapeRefiner::CreateInferenceContext(node); GE_CHECK_NOTNULL(inference_context); std::vector marks; inference_context->GetMarks(marks); GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), marks.size()); op.SetInferenceContext(inference_context); } graphStatus status = CallInferShapeFunc(node, op); if (status != GRAPH_NODE_NEED_REPASS && status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) { // node like netoutput return param_invalid, but valid ? return GE_GRAPH_INFERSHAPE_FAILED; } UpdateCurNodeOutputDesc(node); if (!is_unknown_graph) { auto ctx_after_infer = op.GetInferenceContext(); if (ctx_after_infer != nullptr) { std::vector marks; ctx_after_infer->GetMarks(marks); GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), marks.size()); if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !marks.empty()) { GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), marks.size()); ShapeRefiner::PushToContextMap(node, ctx_after_infer); } } } return (status == GRAPH_NODE_NEED_REPASS) ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS; } void InferShapePass::UpdateCurNodeOutputDesc(NodePtr &node) { auto op_desc = node->GetOpDesc(); for (const auto &out_anchor : node->GetAllOutDataAnchors()) { auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); GE_IF_BOOL_EXEC(output_tensor == nullptr, continue); GE_IF_BOOL_EXEC(output_tensor->MutableShape().GetDims().empty(), output_tensor->SetOriginShape(output_tensor->GetShape())); ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetOriginShape().GetDims() .size())); output_tensor->SetOriginDataType(output_tensor->GetDataType()); // set output origin shape range std::vector> range; (void)output_tensor->GetShapeRange(range); output_tensor->SetOriginShapeRange(range); GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); } } bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { // check shape range vector> src_shape_range; vector> dst_shape_range; src->GetShapeRange(src_shape_range); dst->GetShapeRange(dst_shape_range); if (src_shape_range.size() != dst_shape_range.size()) { GELOGI("Src shape range size is %zu, dst shape range size is %zu, not same.", src_shape_range.size(), dst_shape_range.size()); return false; } for (size_t i = 0; i < src_shape_range.size(); ++i) { if (src_shape_range[i].first != dst_shape_range[i].first || src_shape_range[i].second != dst_shape_range[i].second) { GELOGI("Current dim %zu. Src shape range is [%lu-%lu], dst shape range is [%lu-%lu], not same.", i, src_shape_range[i].first, src_shape_range[i].second, dst_shape_range[i].first, dst_shape_range[i].second); return false; } } // check shape auto src_shape = src->GetShape(); auto dst_shape = dst->GetShape(); if (src_shape.GetDims() != dst_shape.GetDims() || src->GetOriginShape().GetDims() != dst->GetOriginShape().GetDims() || src->GetDataType() != dst->GetDataType() || src->GetOriginDataType() != dst->GetOriginDataType()) { GELOGD( "Src shape is %s, origin_shape is %s, data_type is %s, origin data_type is %s; " "Dst shape is %s, origin_shape is %s, data_type is %s, original data_type is %s, not same.", src_shape.ToString().c_str(), src->GetOriginShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst_shape.ToString().c_str(), dst->GetOriginShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); return false; } return true; } graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { changed = !SameTensorDesc(src, dst); // refresh src itself src->SetOriginShape(src->GetShape()); src->SetOriginDataType(src->GetDataType()); TensorUtils::SetRealDimCnt(*src, static_cast(src->GetOriginShape().GetDims().size())); vector> src_shape_range; src->GetShapeRange(src_shape_range); src->SetOriginShapeRange(src_shape_range); if (!changed) { GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); return SUCCESS; } UpdateShapeAndDType(src, dst); GELOGD( "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s." "To dst Node: shape: [%s], datatype: %s, original datatype is %s.", src->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); return SUCCESS; } graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { auto op_desc = node->GetOpDesc(); const auto &op_type = op_desc->GetType(); auto ret = op_desc->CallInferFunc(op); if (ret == GRAPH_PARAM_INVALID) { // Op ir no infer func, try to get infer func from operator factory auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType().c_str()); if (node_op.IsEmpty()) { GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); return ret; } GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); node_op.BreakConnect(); if (temp_op_desc == nullptr) { REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr."); GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null"); return GRAPH_FAILED; } if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { GELOGW("InferShapeAndType UpdateInputName failed"); for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { break; } return GRAPH_SUCCESS; } } if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { GELOGW("InferShapeAndType UpdateOutputName failed"); } op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); ret = op_desc->CallInferFunc(op); GELOGI("op CallInferFunc second. ret: %u", ret); } return ret; } graphStatus InferShapePass::UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) { GELOGD("Enter update parent node shape for class branch op process"); // check sub_graph shape.If not same ,do unknown shape process auto ref_out_tensor = src.at(0); ge::GeShape &ref_out_tensor_shape = ref_out_tensor->MutableShape(); for (auto &tensor : src) { if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { REPORT_INNER_ERROR("E19999", "Does not support diff dtype among all ref output, shape:%s", ref_out_tensor_shape.ToString().c_str()); GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype output"); return GRAPH_FAILED; } auto shape = tensor->MutableShape(); if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { GELOGD("Shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); ref_out_tensor_shape = GeShape(UNKNOWN_RANK); break; } for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { continue; } GELOGD("j: %zu ,shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); } } UpdateShapeAndDType(ref_out_tensor, dst); return GRAPH_SUCCESS; } graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, GeTensorDescPtr &dst) { // check sub_graph shape. Get max for update. if (src.empty()) { GELOGI("Src subgraph shape is empty."); return SUCCESS; } int64_t max_size = 0; size_t max_shape_index = 0; auto &ref_out_tensor = src.at(0); for (size_t j = 0; j < src.size(); ++j) { auto &tensor = src.at(j); if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { REPORT_INNER_ERROR("E19999", "node does not support diff dtype among all ref output"); GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype among all ref output"); return GRAPH_FAILED; } auto shape = tensor->MutableShape(); int64_t size = 1; for (auto dim : shape.GetDims()) { if (dim != 0 && INT64_MAX / dim < size) { REPORT_INNER_ERROR("E19999", "The shape:%s size overflow", shape.ToString().c_str()); GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow"); return PARAM_INVALID; } size *= dim; } if (size > max_size) { max_size = size; max_shape_index = j; } } UpdateShapeAndDType(src.at(max_shape_index), dst); return GRAPH_SUCCESS; } Status InferShapePass::OnSuspendNodesLeaked() { auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); if (iter == graphs_2_suspend_nodes_.end()) { GELOGI("Current graph %s no suspend node.", GetCurrentGraphName().c_str()); return SUCCESS; } if (!iter->second.nodes.empty()) { AddNodeResume(iter->second.PopSuspendedNode()); } return SUCCESS; } } // namespace ge