/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/infershape_pass.h" #include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "analyzer/analyzer.h" #include "framework/common/util.h" #include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_util.h" #include "graph/operator_factory_impl.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" namespace ge { namespace { const char *const kPreOpInputShapeRange = "_pre_op_in_range"; thread_local std::unordered_map context_map; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void InferShapePass::ClearContextMap() { context_map.clear(); } InferenceContextPtr CreateInferenceContextPtr(const std::unordered_map &context_map, const NodePtr &node) { if (node == nullptr) { GELOGE(GRAPH_FAILED, "node is null"); return nullptr; } InferenceContextPtr inference_context = std::shared_ptr(InferenceContext::Create()); if (inference_context == nullptr) { REPORT_CALL_ERROR("E19999", "Failed to alloc InferenceContext, node:%s", node->GetName().c_str()); GELOGE(GRAPH_FAILED, "[Alloc][InferenceContext] failed."); return nullptr; } auto all_in_data_anchors = node->GetAllInDataAnchors(); std::vector> input_shapes_and_types(all_in_data_anchors.size()); std::vector marks; bool has_input_shapes_and_types = false; for (const auto &in_anchor : all_in_data_anchors) { const auto &out_anchor = in_anchor->GetPeerOutAnchor(); if (out_anchor == nullptr) { continue; } auto input_node = out_anchor->GetOwnerNode(); if (input_node == nullptr) { continue; } auto iter = context_map.find(input_node); if (iter != context_map.end()) { const auto &src_context = iter->second; GE_IF_BOOL_EXEC(src_context == nullptr, REPORT_INNER_ERROR("E19999", "src_context is null."); GELOGE(GRAPH_FAILED, "[Check][Param] src_context is null."); return nullptr); GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(), input_node->GetName().c_str()); for (auto mark : src_context->GetMarks()) { marks.push_back(mark); } auto output_idx = out_anchor->GetIdx(); auto input_idx = in_anchor->GetIdx(); auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes(); if (output_idx < static_cast(output_shape_and_type.size())) { GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx, node->GetName().c_str(), input_idx); input_shapes_and_types[input_idx] = output_shape_and_type[output_idx]; has_input_shapes_and_types = true; } else { GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx, output_shape_and_type.size()); } } } if (has_input_shapes_and_types) { inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types)); } inference_context->SetMarks(marks); return inference_context; } void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { desc_str += "["; std::vector> shape_range; (void)desc->GetShapeRange(shape_range); for (const auto &pair : shape_range) { desc_str += "{"; desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); desc_str += "},"; } desc_str += "]"; shape_range.clear(); (void)desc->GetOriginShapeRange(shape_range); for (const auto &pair : shape_range) { desc_str += ",{"; desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); desc_str += "},"; } } std::string GetInTensorInfoWithString(const ge::NodePtr &node) { ge::OpDescPtr op_desc = node->GetOpDesc(); std::stringstream ss; ss << "{"; int32_t in_idx = 0; for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { if (input_desc == nullptr) { in_idx++; continue; } if (in_idx > 0) { ss << " "; } ss << "input_" << in_idx << " " << "tensor: ["; ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),"; ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),"; ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),"; string range_str; SerialShapeRange(input_desc, range_str); ss << "(shape_range:" << range_str << ")]"; in_idx++; } return ss.str(); } void InferShapePass::AnalyzeFailedInfo(const NodePtr &node) { auto graph = node->GetOwnerComputeGraph(); if (graph == nullptr) { GELOGW("Owner compute graph of node %s is nullptr", node->GetName().c_str()); } auto root_graph = ge::GraphUtils::FindRootGraph(graph); if (root_graph == nullptr) { GELOGW("Root compute graph of node %s is nullptr", node->GetName().c_str()); } analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), root_graph->GetGraphID()); REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s", node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); } bool InferShapePass::TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { bool changed = false; const auto &dst_dims = dst->GetShape().GetDims(); const auto &src_dims = src->GetShape().GetDims(); if (dst_dims != src_dims) { changed = true; } return changed; } graphStatus InferShapePass::UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { dst->SetOriginShape(src->GetOriginShape()); dst->SetShape(src->MutableShape()); dst->SetDataType(src->GetDataType()); dst->SetOriginDataType(src->GetOriginDataType()); if (src->MutableShape().GetDims() != UNKNOWN_RANK) { std::vector> shape_range; (void)src->GetShapeRange(shape_range); dst->SetShapeRange(shape_range); } std::vector pre_op_in_range; if (ge::AttrUtils::GetListInt(*src, kPreOpInputShapeRange, pre_op_in_range)) { (void)ge::AttrUtils::SetListInt(*dst, kPreOpInputShapeRange, pre_op_in_range); } ge::TensorUtils::SetRealDimCnt(*dst, static_cast(src->MutableShape().GetDims().size())); return GRAPH_SUCCESS; } graphStatus InferShapePass::Infer(NodePtr &node) { bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); auto opdesc = node->GetOpDesc(); if (node->Verify() != GRAPH_SUCCESS) { REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str()); GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str()); return GRAPH_FAILED; } PrintInOutTensorShape(node, "before_infershape"); Operator op = OpDescUtils::CreateOperatorFromNode(node); if (!is_unknown_graph) { auto inference_context = CreateInferenceContextPtr(context_map, node); GE_CHECK_NOTNULL(inference_context); GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); op.SetInferenceContext(inference_context); } graphStatus status = CallInferShapeFunc(node, op); if (status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) { REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str()); GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str()); return GRAPH_FAILED; } if (!is_unknown_graph) { auto ctx_after_infer = op.GetInferenceContext(); if (ctx_after_infer != nullptr) { GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); (void)context_map.emplace(node, ctx_after_infer); } } } return GRAPH_SUCCESS; } graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { auto op_desc = node->GetOpDesc(); const auto &op_type = op_desc->GetType(); auto ret = op_desc->CallInferFunc(op); if (ret == GRAPH_PARAM_INVALID) { // Op ir no infer func, try to get infer func from operator factory auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); if (node_op.IsEmpty()) { GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); return ret; } GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); node_op.BreakConnect(); if (temp_op_desc == nullptr) { REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr."); GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null"); return GRAPH_FAILED; } if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { GELOGW("InferShapeAndType UpdateInputName failed"); for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { break; } return GRAPH_SUCCESS; } } if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { GELOGW("InferShapeAndType UpdateOutputName failed"); } op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); ret = op_desc->CallInferFunc(op); GELOGI("op CallInferFunc second. ret: %u", ret); } return ret; } graphStatus InferShapePass::UpdatePeerInputs(NodePtr &node) { bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); if (is_unknown_graph) { PrintInOutTensorShape(node, "after_infershape when running"); return GRAPH_SUCCESS; } UpdateInputOutputOriginAttr(node); if (NodeUtils::UpdatePeerNodeInputDesc(node) != SUCCESS) { return GRAPH_FAILED; } PrintInOutTensorShape(node, "after_infershape"); return GRAPH_SUCCESS; } void InferShapePass::UpdateInputOutputOriginAttr(NodePtr &node) { auto op_desc = node->GetOpDesc(); for (const auto &out_anchor : node->GetAllOutDataAnchors()) { auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); if (output_tensor == nullptr) { continue; } if (output_tensor->MutableShape().GetDims().empty()) { output_tensor->SetOriginShape(output_tensor->GetShape()); } ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetOriginShape().GetDims().size())); output_tensor->SetOriginDataType(output_tensor->GetDataType()); // set output origin shape range std::vector> range; (void)output_tensor->GetShapeRange(range); output_tensor->SetOriginShapeRange(range); GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); } for (const auto &in_anchor : node->GetAllInDataAnchors()) { auto input_tensor = op_desc->MutableInputDesc(in_anchor->GetIdx()); if (input_tensor == nullptr) { continue; } // set input origin shape range std::vector> range; (void)input_tensor->GetShapeRange(range); input_tensor->SetOriginShapeRange(range); } } Status InferShapePass::DoRepassForLoopNode(NodePtr &node) { GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node)); bool need_repass = false; auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass); if (has_attr) { if (!OptionExists(kOptimizeAfterSubGraph)) { return SUCCESS; } if (need_repass) { AddImmediateRePassNode(node); GELOGD("Node %s need repass immediately.", node->GetName().c_str()); } else { // clear attr on while node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); } } return SUCCESS; } Status InferShapePass::RePassLoopNode(const NodePtr &node) { const auto RePassNode = [&](const std::set &re_pass_types) { for (auto &n : node->GetOutDataNodes()) { GE_CHECK_NOTNULL(n); std::string node_type; GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); if (re_pass_types.count(node_type) > 0) { AddImmediateRePassNode(n); (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); } } return SUCCESS; }; const auto ExProcNode = [&](const std::set &proc_types, const std::function &proc_func, const std::string &info) { for (auto &n : node->GetOutDataNodes()) { GE_CHECK_NOTNULL(n); std::string node_type; GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); if (proc_types.count(node_type) > 0) { proc_func(this, n); GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); } } return SUCCESS; }; std::string node_type; GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "[Get][OriginalType] of node:%s failed.", node->GetName().c_str()); if (kNextIterationOpTypes.count(node_type) > 0) { return RePassNode(kMergeOpTypes); // Re-Pass Merge } if (kMergeOpTypes.count(node_type) > 0) { if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); return RePassNode(kSwitchOpTypes); // Re-Pass Switch } return SUCCESS; } if (kSwitchOpTypes.count(node_type) > 0) { if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit } else { return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit } } return SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus InferShapePass::InferShapeAndType(NodePtr &node) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); InferShapePass pass; std::set unused_changed_nodes; return pass.InferAndUpdate(node, true, unused_changed_nodes); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus InferShapePass::InferShapeAndType(NodePtr &node, bool before_subgraph) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); InferShapePass pass; std::set unused_changed_nodes; return pass.InferAndUpdate(node, before_subgraph, unused_changed_nodes); } graphStatus InferShapeForRunning::Infer(NodePtr &node) { auto opdesc = node->GetOpDesc(); vector temp_dtype; for (auto &tensor_desc : opdesc->GetAllOutputsDescPtr()) { temp_dtype.emplace_back(tensor_desc->GetDataType()); } PrintInOutTensorShape(node, "before_infershape when running"); Operator op = OpDescUtils::CreateOperatorFromNode(node); graphStatus status = CallInferShapeFuncForRunning(node, op); if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { // ensure the dtype is not changed after infershape in running auto after_opdesc = node->GetOpDesc(); GE_IF_BOOL_EXEC(after_opdesc == nullptr, REPORT_INNER_ERROR("E19999", "param node has no opdesc, check invalid."); GELOGE(GRAPH_FAILED, "[Get][OpDesc] after_opdesc is null."); return GRAPH_FAILED); auto all_output_tensor = after_opdesc->GetAllOutputsDescPtr(); for (size_t i = 0; i < all_output_tensor.size(); ++i) { if (all_output_tensor.at(i)->GetDataType() != temp_dtype[i]) { GELOGD("Op %s output %zu need reset dtype,original dtype is %s, new dtype is %s", node->GetName().c_str(), i, TypeUtils::DataTypeToSerialString(all_output_tensor.at(i)->GetDataType()).c_str(), TypeUtils::DataTypeToSerialString(temp_dtype[i]).c_str()); all_output_tensor.at(i)->SetDataType(temp_dtype[i]); } } PrintInOutTensorShape(node, "after_infershape when running"); return GRAPH_SUCCESS; } else { REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str()); GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str()); return GRAPH_FAILED; } } graphStatus InferShapeForRunning::CallInferShapeFuncForRunning(NodePtr &node, Operator &op) { auto op_desc = node->GetOpDesc(); const auto &op_type = op_desc->GetType(); // Create InferenceContext to avoid null pointer access. const static std::set force_context_op_types{"Enter", "Switch", "RefSwitch"}; if (force_context_op_types.count(op_type) > 0) { GELOGD("Set InferenceContext for node [%s]", op_desc->GetName().c_str()); op.SetInferenceContext(std::shared_ptr(InferenceContext::Create())); } // Get infer func and execute auto ret = op_desc->CallInferFunc(op); if (ret == GRAPH_PARAM_INVALID) { GELOGD("NodeUtils::GetNodeType return value is: [%s]", NodeUtils::GetNodeType(*node).c_str()); auto origin_type = NodeUtils::GetNodeType(*node); auto infer_func = ge::OperatorFactoryImpl::GetInferShapeFunc(origin_type); if (infer_func == nullptr) { REPORT_INNER_ERROR("E19999", "Failed to Get InferFunc. type is %s", origin_type.c_str()); GELOGE(GRAPH_FAILED, "[Get][InferFunc] failed. type is %s", origin_type.c_str()); return GRAPH_FAILED; } op_desc->AddInferFunc(infer_func); ret = op_desc->CallInferFunc(op); GELOGI("op CallInferFunc second. ret: %u", ret); } return ret; } bool InferShapeForRunning::TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { bool changed = false; const auto &dst_dims = dst->GetShape().GetDims(); const auto &src_dims = src->GetShape().GetDims(); if (dst_dims != src_dims) { changed = true; } return changed; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus InferShapeForRunning::InferShapeAndTypeForRunning(NodePtr &node, bool before_subgraph) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); InferShapeForRunning pass; std::set unused_changed_nodes; return pass.InferAndUpdate(node, before_subgraph, unused_changed_nodes); } } // namespace ge