|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495 |
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "graph/passes/infershape_pass.h"
- #include "common/util/error_manager/error_manager.h"
- #include "framework/common/debug/ge_log.h"
- #include "analyzer/analyzer.h"
- #include "framework/common/util.h"
- #include "graph/common/omg_util.h"
- #include "graph/debug/ge_attr_define.h"
- #include "graph/debug/ge_util.h"
- #include "graph/operator_factory_impl.h"
- #include "graph/utils/graph_utils.h"
- #include "graph/utils/node_utils.h"
- #include "graph/utils/tensor_utils.h"
- #include "graph/utils/type_utils.h"
-
- namespace ge {
- namespace {
- const char *const kPreOpInputShapeRange = "_pre_op_in_range";
- thread_local std::unordered_map<NodePtr, InferenceContextPtr> context_map;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void InferShapePass::ClearContextMap() { context_map.clear(); }
-
- InferenceContextPtr CreateInferenceContextPtr(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map,
- const NodePtr &node) {
- if (node == nullptr) {
- GELOGE(GRAPH_FAILED, "node is null");
- return nullptr;
- }
- InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(InferenceContext::Create());
- if (inference_context == nullptr) {
- REPORT_CALL_ERROR("E19999", "Failed to alloc InferenceContext, node:%s", node->GetName().c_str());
- GELOGE(GRAPH_FAILED, "[Alloc][InferenceContext] failed.");
- return nullptr;
- }
-
- auto all_in_data_anchors = node->GetAllInDataAnchors();
- std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size());
- std::vector<std::string> marks;
-
- bool has_input_shapes_and_types = false;
- for (const auto &in_anchor : all_in_data_anchors) {
- const auto &out_anchor = in_anchor->GetPeerOutAnchor();
- if (out_anchor == nullptr) {
- continue;
- }
-
- auto input_node = out_anchor->GetOwnerNode();
- if (input_node == nullptr) {
- continue;
- }
-
- auto iter = context_map.find(input_node);
- if (iter != context_map.end()) {
- const auto &src_context = iter->second;
- GE_IF_BOOL_EXEC(src_context == nullptr, REPORT_INNER_ERROR("E19999", "src_context is null.");
- GELOGE(GRAPH_FAILED, "[Check][Param] src_context is null."); return nullptr);
- GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(),
- input_node->GetName().c_str());
- for (auto mark : src_context->GetMarks()) {
- marks.push_back(mark);
- }
- auto output_idx = out_anchor->GetIdx();
- auto input_idx = in_anchor->GetIdx();
- auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes();
- if (output_idx < static_cast<int>(output_shape_and_type.size())) {
- GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx,
- node->GetName().c_str(), input_idx);
- input_shapes_and_types[input_idx] = output_shape_and_type[output_idx];
- has_input_shapes_and_types = true;
- } else {
- GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx,
- output_shape_and_type.size());
- }
- }
- }
-
- if (has_input_shapes_and_types) {
- inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types));
- }
- inference_context->SetMarks(marks);
-
- return inference_context;
- }
-
- void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
- desc_str += "[";
- std::vector<std::pair<int64_t, int64_t>> shape_range;
- (void)desc->GetShapeRange(shape_range);
- for (const auto &pair : shape_range) {
- desc_str += "{";
- desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
- desc_str += "},";
- }
- desc_str += "]";
- shape_range.clear();
- (void)desc->GetOriginShapeRange(shape_range);
- for (const auto &pair : shape_range) {
- desc_str += ",{";
- desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
- desc_str += "},";
- }
- }
-
- std::string GetInTensorInfoWithString(const ge::NodePtr &node) {
- ge::OpDescPtr op_desc = node->GetOpDesc();
- std::stringstream ss;
- ss << "{";
- int32_t in_idx = 0;
- for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
- if (input_desc == nullptr) {
- in_idx++;
- continue;
- }
- if (in_idx > 0) {
- ss << " ";
- }
- ss << "input_" << in_idx << " "
- << "tensor: [";
- ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),";
- ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),";
- ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),";
- ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),";
- ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),";
- ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),";
- string range_str;
- SerialShapeRange(input_desc, range_str);
- ss << "(shape_range:" << range_str << ")]";
- in_idx++;
- }
- return ss.str();
- }
-
- void InferShapePass::AnalyzeFailedInfo(const NodePtr &node) {
- auto graph = node->GetOwnerComputeGraph();
- if (graph == nullptr) {
- GELOGW("Owner compute graph of node %s is nullptr", node->GetName().c_str());
- }
- auto root_graph = ge::GraphUtils::FindRootGraph(graph);
- if (root_graph == nullptr) {
- GELOGW("Root compute graph of node %s is nullptr", node->GetName().c_str());
- }
- analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), analyzer::INFER_SHAPE, node,
- "InferShapeFailed!"};
- (void)Analyzer::GetInstance()->DoAnalyze(analyze_info);
- (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), root_graph->GetGraphID());
- REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", node->GetName().c_str(),
- node->GetType().c_str(), GetInTensorInfoWithString(node).c_str());
- GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s",
- node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str());
- }
-
- bool InferShapePass::TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) {
- bool changed = false;
- const auto &dst_dims = dst->GetShape().GetDims();
- const auto &src_dims = src->GetShape().GetDims();
- if (dst_dims != src_dims) {
- changed = true;
- }
- return changed;
- }
-
- graphStatus InferShapePass::UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
- dst->SetOriginShape(src->GetOriginShape());
- dst->SetShape(src->MutableShape());
- dst->SetDataType(src->GetDataType());
- dst->SetOriginDataType(src->GetOriginDataType());
- if (src->MutableShape().GetDims() != UNKNOWN_RANK) {
- std::vector<std::pair<int64_t, int64_t>> shape_range;
- (void)src->GetShapeRange(shape_range);
- dst->SetShapeRange(shape_range);
- }
- std::vector<int64_t> pre_op_in_range;
- if (ge::AttrUtils::GetListInt(*src, kPreOpInputShapeRange, pre_op_in_range)) {
- (void)ge::AttrUtils::SetListInt(*dst, kPreOpInputShapeRange, pre_op_in_range);
- }
- ge::TensorUtils::SetRealDimCnt(*dst, static_cast<uint32_t>(src->MutableShape().GetDims().size()));
- return GRAPH_SUCCESS;
- }
-
- graphStatus InferShapePass::Infer(NodePtr &node) {
- bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
- auto opdesc = node->GetOpDesc();
- if (node->Verify() != GRAPH_SUCCESS) {
- REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str());
- GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str());
- return GRAPH_FAILED;
- }
- PrintInOutTensorShape(node, "before_infershape");
- Operator op = OpDescUtils::CreateOperatorFromNode(node);
-
- if (!is_unknown_graph) {
- auto inference_context = CreateInferenceContextPtr(context_map, node);
- GE_CHECK_NOTNULL(inference_context);
- GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
- op.SetInferenceContext(inference_context);
- }
-
- graphStatus status = CallInferShapeFunc(node, op);
- if (status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) {
- REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str());
- GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str());
- return GRAPH_FAILED;
- }
- if (!is_unknown_graph) {
- auto ctx_after_infer = op.GetInferenceContext();
- if (ctx_after_infer != nullptr) {
- GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
- if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
- GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
- ctx_after_infer->GetMarks().size());
- (void)context_map.emplace(node, ctx_after_infer);
- }
- }
- }
-
- return GRAPH_SUCCESS;
- }
-
- graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) {
- auto op_desc = node->GetOpDesc();
- const auto &op_type = op_desc->GetType();
- auto ret = op_desc->CallInferFunc(op);
- if (ret == GRAPH_PARAM_INVALID) {
- // Op ir no infer func, try to get infer func from operator factory
- auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
- if (node_op.IsEmpty()) {
- GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
- return ret;
- }
-
- GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str());
- auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
- node_op.BreakConnect();
- if (temp_op_desc == nullptr) {
- REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr.");
- GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null");
- return GRAPH_FAILED;
- }
- if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
- GELOGW("InferShapeAndType UpdateInputName failed");
- for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) {
- if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) {
- break;
- }
- return GRAPH_SUCCESS;
- }
- }
- if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
- GELOGW("InferShapeAndType UpdateOutputName failed");
- }
- op_desc->AddInferFunc(temp_op_desc->GetInferFunc());
- ret = op_desc->CallInferFunc(op);
- GELOGI("op CallInferFunc second. ret: %u", ret);
- }
- return ret;
- }
-
- graphStatus InferShapePass::UpdatePeerInputs(NodePtr &node) {
- bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
- if (is_unknown_graph) {
- PrintInOutTensorShape(node, "after_infershape when running");
- return GRAPH_SUCCESS;
- }
- UpdateInputOutputOriginAttr(node);
- if (NodeUtils::UpdatePeerNodeInputDesc(node) != SUCCESS) {
- return GRAPH_FAILED;
- }
- PrintInOutTensorShape(node, "after_infershape");
- return GRAPH_SUCCESS;
- }
-
- void InferShapePass::UpdateInputOutputOriginAttr(NodePtr &node) {
- auto op_desc = node->GetOpDesc();
- for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
- auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
- if (output_tensor == nullptr) {
- continue;
- }
- if (output_tensor->MutableShape().GetDims().empty()) {
- output_tensor->SetOriginShape(output_tensor->GetShape());
- }
- ge::TensorUtils::SetRealDimCnt(*output_tensor,
- static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims().size()));
- output_tensor->SetOriginDataType(output_tensor->GetDataType());
- // set output origin shape range
- std::vector<std::pair<int64_t, int64_t>> range;
- (void)output_tensor->GetShapeRange(range);
- output_tensor->SetOriginShapeRange(range);
- GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", node->GetName().c_str(),
- output_tensor->GetOriginShape().GetShapeSize(),
- TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
- TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
- }
- for (const auto &in_anchor : node->GetAllInDataAnchors()) {
- auto input_tensor = op_desc->MutableInputDesc(in_anchor->GetIdx());
- if (input_tensor == nullptr) {
- continue;
- }
- // set input origin shape range
- std::vector<std::pair<int64_t, int64_t>> range;
- (void)input_tensor->GetShapeRange(range);
- input_tensor->SetOriginShapeRange(range);
- }
- }
-
- Status InferShapePass::DoRepassForLoopNode(NodePtr &node) {
- GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node));
- bool need_repass = false;
- auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass);
- if (has_attr) {
- if (!OptionExists(kOptimizeAfterSubGraph)) {
- return SUCCESS;
- }
- if (need_repass) {
- AddImmediateRePassNode(node);
- GELOGD("Node %s need repass immediately.", node->GetName().c_str());
- } else {
- // clear attr on while
- node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
- }
- }
- return SUCCESS;
- }
-
- Status InferShapePass::RePassLoopNode(const NodePtr &node) {
- const auto RePassNode = [&](const std::set<std::string> &re_pass_types) {
- for (auto &n : node->GetOutDataNodes()) {
- GE_CHECK_NOTNULL(n);
- std::string node_type;
- GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str());
- if (re_pass_types.count(node_type) > 0) {
- AddImmediateRePassNode(n);
- (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false);
- GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str());
- }
- }
- return SUCCESS;
- };
-
- const auto ExProcNode = [&](const std::set<std::string> &proc_types,
- const std::function<void(InferShapePass *, NodePtr)> &proc_func,
- const std::string &info) {
- for (auto &n : node->GetOutDataNodes()) {
- GE_CHECK_NOTNULL(n);
- std::string node_type;
- GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str());
- if (proc_types.count(node_type) > 0) {
- proc_func(this, n);
- GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str());
- }
- }
- return SUCCESS;
- };
-
- std::string node_type;
- GE_CHK_STATUS_RET(GetOriginalType(node, node_type),
- "[Get][OriginalType] of node:%s failed.", node->GetName().c_str());
- if (kNextIterationOpTypes.count(node_type) > 0) {
- return RePassNode(kMergeOpTypes); // Re-Pass Merge
- }
-
- if (kMergeOpTypes.count(node_type) > 0) {
- if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
- node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
- return RePassNode(kSwitchOpTypes); // Re-Pass Switch
- }
- return SUCCESS;
- }
-
- if (kSwitchOpTypes.count(node_type) > 0) {
- if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
- node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
- return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit
- } else {
- return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit
- }
- }
-
- return SUCCESS;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
- graphStatus InferShapePass::InferShapeAndType(NodePtr &node) {
- GE_CHECK_NOTNULL(node);
- GE_CHECK_NOTNULL(node->GetOpDesc());
- InferShapePass pass;
- std::set<NodePtr> unused_changed_nodes;
- return pass.InferAndUpdate(node, true, unused_changed_nodes);
- }
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
- graphStatus InferShapePass::InferShapeAndType(NodePtr &node, bool before_subgraph) {
- GE_CHECK_NOTNULL(node);
- GE_CHECK_NOTNULL(node->GetOpDesc());
- InferShapePass pass;
- std::set<NodePtr> unused_changed_nodes;
- return pass.InferAndUpdate(node, before_subgraph, unused_changed_nodes);
- }
-
-
- graphStatus InferShapeForRunning::Infer(NodePtr &node) {
- auto opdesc = node->GetOpDesc();
- vector<ge::DataType> temp_dtype;
- for (auto &tensor_desc : opdesc->GetAllOutputsDescPtr()) {
- temp_dtype.emplace_back(tensor_desc->GetDataType());
- }
- PrintInOutTensorShape(node, "before_infershape when running");
-
- Operator op = OpDescUtils::CreateOperatorFromNode(node);
- graphStatus status = CallInferShapeFuncForRunning(node, op);
- if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) {
- // ensure the dtype is not changed after infershape in running
- auto after_opdesc = node->GetOpDesc();
- GE_IF_BOOL_EXEC(after_opdesc == nullptr, REPORT_INNER_ERROR("E19999", "param node has no opdesc, check invalid.");
- GELOGE(GRAPH_FAILED, "[Get][OpDesc] after_opdesc is null."); return GRAPH_FAILED);
- auto all_output_tensor = after_opdesc->GetAllOutputsDescPtr();
- for (size_t i = 0; i < all_output_tensor.size(); ++i) {
- if (all_output_tensor.at(i)->GetDataType() != temp_dtype[i]) {
- GELOGD("Op %s output %zu need reset dtype,original dtype is %s, new dtype is %s", node->GetName().c_str(), i,
- TypeUtils::DataTypeToSerialString(all_output_tensor.at(i)->GetDataType()).c_str(),
- TypeUtils::DataTypeToSerialString(temp_dtype[i]).c_str());
- all_output_tensor.at(i)->SetDataType(temp_dtype[i]);
- }
- }
- PrintInOutTensorShape(node, "after_infershape when running");
- return GRAPH_SUCCESS;
- } else {
- REPORT_CALL_ERROR("E19999", "%s call infer function failed.", node->GetName().c_str());
- GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s.", node->GetName().c_str());
- return GRAPH_FAILED;
- }
- }
-
- graphStatus InferShapeForRunning::CallInferShapeFuncForRunning(NodePtr &node, Operator &op) {
- auto op_desc = node->GetOpDesc();
- const auto &op_type = op_desc->GetType();
-
- // Create InferenceContext to avoid null pointer access.
- const static std::set<std::string> force_context_op_types{"Enter", "Switch", "RefSwitch"};
- if (force_context_op_types.count(op_type) > 0) {
- GELOGD("Set InferenceContext for node [%s]", op_desc->GetName().c_str());
- op.SetInferenceContext(std::shared_ptr<InferenceContext>(InferenceContext::Create()));
- }
-
- // Get infer func and execute
- auto ret = op_desc->CallInferFunc(op);
- if (ret == GRAPH_PARAM_INVALID) {
- GELOGD("NodeUtils::GetNodeType return value is: [%s]", NodeUtils::GetNodeType(*node).c_str());
- auto origin_type = NodeUtils::GetNodeType(*node);
- auto infer_func = ge::OperatorFactoryImpl::GetInferShapeFunc(origin_type);
- if (infer_func == nullptr) {
- REPORT_INNER_ERROR("E19999", "Failed to Get InferFunc. type is %s", origin_type.c_str());
- GELOGE(GRAPH_FAILED, "[Get][InferFunc] failed. type is %s", origin_type.c_str());
- return GRAPH_FAILED;
- }
- op_desc->AddInferFunc(infer_func);
- ret = op_desc->CallInferFunc(op);
- GELOGI("op CallInferFunc second. ret: %u", ret);
- }
- return ret;
- }
- bool InferShapeForRunning::TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) {
- bool changed = false;
- const auto &dst_dims = dst->GetShape().GetDims();
- const auto &src_dims = src->GetShape().GetDims();
- if (dst_dims != src_dims) {
- changed = true;
- }
- return changed;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
- graphStatus InferShapeForRunning::InferShapeAndTypeForRunning(NodePtr &node, bool before_subgraph) {
- GE_CHECK_NOTNULL(node);
- GE_CHECK_NOTNULL(node->GetOpDesc());
- InferShapeForRunning pass;
- std::set<NodePtr> unused_changed_nodes;
- return pass.InferAndUpdate(node, before_subgraph, unused_changed_nodes);
- }
- } // namespace ge
|