|
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "infer_base_pass.h"
- #include "common/ge/ge_util.h"
- #include "common/util/error_manager/error_manager.h"
- #include "framework/common/debug/ge_log.h"
- #include "framework/common/util.h"
- #include "graph/debug/ge_attr_define.h"
- #include "graph/debug/ge_util.h"
- #include "graph/utils/graph_utils.h"
- #include "graph/utils/node_utils.h"
- #include "graph/utils/tensor_utils.h"
- #include "graph/utils/type_utils.h"
-
- namespace ge {
- namespace {
- string Serial(const vector<int64_t> &dims) {
- string serial_string;
- serial_string += "[";
- for (int64_t dim : dims) {
- serial_string += std::to_string(dim) + " ";
- }
- serial_string += "]";
- return serial_string;
- }
- void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
- desc_str += "[";
- std::vector<std::pair<int64_t, int64_t>> shape_range;
- (void)desc->GetShapeRange(shape_range);
- for (const auto &pair : shape_range) {
- desc_str += "{";
- desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
- desc_str += "},";
- }
- desc_str += "]";
- shape_range.clear();
- (void)desc->GetOriginShapeRange(shape_range);
- for (const auto &pair : shape_range) {
- desc_str += ",{";
- desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
- desc_str += "},";
- }
- }
- void SerialValueRange(const GeTensorDescPtr &desc, std::string &desc_str) {
- desc_str += "[";
- std::vector<std::pair<int64_t, int64_t>> value_range;
- (void)desc->GetValueRange(value_range);
- for (const auto &pair : value_range) {
- desc_str += "{";
- desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
- desc_str += "},";
- }
- desc_str += "]";
- }
- graphStatus FindSubgraphDataAndNetoutput(const ComputeGraphPtr &sub_graph, NodePtr &netoutput, const ConstNodePtr &node,
- std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
- auto sub_nodes = sub_graph->GetDirectNode();
- for (size_t i = sub_nodes.size(); i > 0; --i) {
- auto sub_node = sub_nodes.at(i - 1);
- if (sub_node->GetType() == NETOUTPUT) {
- netoutput = sub_node;
- }
- if (sub_node->GetType() == DATA) {
- if (sub_node->GetOpDesc() == nullptr) {
- return GRAPH_FAILED;
- }
-
- int ref_i;
- if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
- REPORT_INNER_ERROR("E19999", "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
- GELOGE(GRAPH_FAILED, "[Get][Int] subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
- return GRAPH_FAILED;
- }
- if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
- REPORT_INNER_ERROR("E19999", "data node[%s]'s ref index[%d] is not in range [0, %u)!",
- sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
- GELOGE(GRAPH_FAILED, "[Check][Param] data node[%s]'s ref index[%d] is not in range [0, %u)!",
- sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
- return GRAPH_FAILED;
- }
- ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
- }
- }
- return GRAPH_SUCCESS;
- }
- } // namespace
-
- Status InferBasePass::Run(NodePtr &node) {
- GE_CHECK_NOTNULL(node);
- GE_CHECK_NOTNULL(node->GetOpDesc());
-
- bool need_infer = NeedInfer(node);
- if (!need_infer) {
- GELOGD("Node %s does not need to infer.", node->GetName().c_str());
- return SUCCESS;
- }
-
- std::set<NodePtr> changed_nodes;
- auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes);
- if (ret != GRAPH_SUCCESS) {
- (void)AnalyzeFailedInfo(node);
- return GE_GRAPH_INFERSHAPE_FAILED;
- }
-
- /*
- * we will use changed nodes to do repass for control_ops.
- * AddChangedNodesImmediateRepass(changed_nodes);
- */
- auto status = DoRepassForLoopNode(node);
- if (status != SUCCESS) {
- GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "repass failed. node: %s", node->GetName().c_str());
- return GE_GRAPH_INFERSHAPE_FAILED;
- }
- return SUCCESS;
- }
-
- bool InferBasePass::NeedInfer(const NodePtr &node) { return true; }
- void InferBasePass::AnalyzeFailedInfo(const NodePtr &node) { /* Analyze and select failed info*/ }
- Status InferBasePass::DoRepassForLoopNode(NodePtr &node) { return SUCCESS; }
- graphStatus InferBasePass::UpdatePeerInputs(NodePtr &node) { return GRAPH_SUCCESS; }
- void InferBasePass::AddChangedNodesImmediateRepass(std::set<NodePtr> &changed_nodes) {
- for (const auto &node_ele : changed_nodes) {
- AddImmediateRePassNode(node_ele);
- }
- }
-
- graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) {
- auto ret = GRAPH_SUCCESS;
- bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
- auto opdesc = node->GetOpDesc();
- // some op can not infershape twice such as aipp
- bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified");
- if (need_update_input) {
- ret = UpdateCurOpInputDesc(node);
- if (ret != GRAPH_SUCCESS) {
- REPORT_CALL_ERROR("E19999", "update op input_desc failed! ret:%d, node:%s", ret, node->GetName().c_str());
- GELOGE(GRAPH_FAILED, "[Update][OpInputDesc] failed! ret:%d", ret);
- return ret;
- }
- }
-
- bool contain_subgraph = ContainsSubgraph(node);
- if (contain_subgraph && before_subgraph) {
- ret = UpdateTensorDescToSubgraphData(node, changed_nodes);
- if (ret != GRAPH_SUCCESS) {
- return ret;
- }
- }
- ret = Infer(node);
- if (ret != GRAPH_SUCCESS) {
- return ret;
- }
- if (contain_subgraph && !before_subgraph) {
- ret = UpdateTensorDescToParentNode(node, changed_nodes);
- if (ret != GRAPH_SUCCESS) {
- return ret;
- }
- }
-
- ret = UpdatePeerInputs(node);
- return ret;
- }
-
- graphStatus InferBasePass::UpdateCurOpInputDesc(const NodePtr &node_ptr) {
- for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) {
- auto in_idx = in_anchor->GetIdx();
- auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
- if (peer_out_data_anchor == nullptr) {
- continue;
- }
- auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
- if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
- continue;
- }
- int peer_out_idx = peer_out_data_anchor->GetIdx();
- auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx));
-
- // check shape and dtype continuity. do not stop process
- auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_idx));
- if (in_desc == nullptr) {
- continue;
- }
- auto in_shape = in_desc->MutableShape().GetDims();
- auto in_dtype = in_desc->GetDataType();
- auto peer_out_shape = peer_out_desc->MutableShape().GetDims();
- auto peer_out_dtype = peer_out_desc->GetDataType();
- if (peer_out_dtype != in_dtype) {
- GELOGW(
- "current node [%s] [%d]\'th in_dtype is [%s].peer output node [%s] [%d]\'th "
- "output_dtype is [%s].The two dtype should be same! Please check graph and fix it",
- node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(),
- peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str());
- } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) {
- string in_shape_str = Serial(in_shape);
- string peer_out_shape_str = Serial(peer_out_shape);
- GELOGW(
- "current node [%s] [%d]\'th in_shape is [%s].peer output node [%s] [%d]\'th "
- "output_shape is [%s].The two shape should be same! Please check graph and fix it",
- node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx,
- peer_out_shape_str.c_str());
- }
- // refresh current node input desc
- bool output_changed = false;
- (void)UpdateInputDescAttr(peer_out_desc, in_desc, output_changed);
- }
- return GRAPH_SUCCESS;
- }
-
- graphStatus InferBasePass::UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
- changed = false;
- return GRAPH_SUCCESS;
- }
-
- bool InferBasePass::ContainsSubgraph(const NodePtr &node) {
- auto op_desc = node->GetOpDesc();
- auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
- if (sub_graph_names.empty()) {
- return false;
- }
-
- auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
- if (root_graph == nullptr) {
- return false;
- }
- for (const auto &name : sub_graph_names) {
- if (name.empty()) {
- continue;
- }
- auto sub_graph = root_graph->GetSubgraph(name);
- if (sub_graph != nullptr) {
- return true;
- }
- }
- return false;
- }
-
- std::vector<ComputeGraphPtr> InferBasePass::GetCurNodeSubgraphs(const NodePtr &node) {
- std::vector<ComputeGraphPtr> cur_node_subgraph;
- auto op_desc = node->GetOpDesc();
- auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
- if (sub_graph_names.empty()) {
- return cur_node_subgraph;
- }
-
- auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
- for (const auto &name : sub_graph_names) {
- if (name.empty()) {
- GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
- continue;
- }
- auto sub_graph = root_graph->GetSubgraph(name);
- if (sub_graph == nullptr) {
- REPORT_INNER_ERROR("E19999", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
- GE_LOGE("[Get][Graph] can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
- continue;
- }
- cur_node_subgraph.emplace_back(sub_graph);
- }
- return cur_node_subgraph;
- }
-
- graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node, std::set<NodePtr> &changed_nodes) {
- // if infer again, update output of while into subgraph data node
- auto op_desc = node->GetOpDesc();
- for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
- for (const auto &node_sub : sub_graph->GetDirectNode()) {
- if (node_sub->GetType() != DATA) {
- continue;
- }
- auto name = sub_graph->GetName();
- int ref_i;
- auto data_opdesc = node_sub->GetOpDesc();
- if (data_opdesc == nullptr) {
- REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
- node->GetName().c_str());
- GE_LOGE("[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
- node->GetName().c_str());
- return GRAPH_FAILED;
- }
- if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
- REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute",
- name.c_str(), node->GetName().c_str());
- GE_LOGE("[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
- node->GetName().c_str());
- return GRAPH_FAILED;
- }
- if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) {
- continue;
- }
- auto input_desc = op_desc->MutableInputDesc(ref_i);
- if (input_desc == nullptr) {
- REPORT_INNER_ERROR("E19999",
- "The ref index(%d) on the data %s on the sub graph %s "
- "parent node %s are incompatible, inputs num %u",
- ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
- node->GetAllInDataAnchorsSize());
- GE_LOGE(
- "[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s "
- "parent node %s are incompatible, inputs num %u",
- ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllInDataAnchorsSize());
- return GRAPH_FAILED;
- }
- GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(),
- node->GetName().c_str());
-
- // if need infer again, refresh subgraph input with output
- bool is_infer_again = false;
- AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, is_infer_again);
- if (is_infer_again) {
- input_desc = op_desc->MutableOutputDesc(ref_i);
- if (input_desc == nullptr) {
- REPORT_INNER_ERROR("E19999",
- "The ref index(%d) on the data %s on the subgraph %s "
- "parent node %s are incompatible, outputs num %u.",
- ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
- node->GetAllOutDataAnchorsSize());
- GELOGE(PARAM_INVALID,
- "[Call][MutableOutputDesc] The ref index(%d) on the data %s on the subgraph %s "
- "parent node %s are incompatible, outputs num %u.",
- ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
- node->GetAllOutDataAnchorsSize());
- }
- GELOGD("Update input desc of data %s on the sub graph %s of node %s,output idx: %d from [%s] to [%s]",
- node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), ref_i,
- data_opdesc->GetInputDescPtr(0)->GetShape().ToString().c_str(),
- input_desc->GetShape().ToString().c_str());
- }
-
- auto data_input_desc = data_opdesc->MutableInputDesc(0);
- auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);
- if (ret != GRAPH_SUCCESS) {
- REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s",
- node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
- GE_LOGE("[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", node_sub->GetName().c_str(),
- name.c_str(), node->GetName().c_str());
- return ret;
- }
- bool input_changed = TensorDescChanged(input_desc, data_input_desc);
-
- auto data_output_desc = data_opdesc->MutableOutputDesc(0);
- ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
- if (ret != GRAPH_SUCCESS) {
- REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s",
- node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
- GE_LOGE("[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed",
- node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
- return ret;
- }
- bool output_changed = TensorDescChanged(input_desc, data_output_desc);
-
- if (input_changed || output_changed) {
- changed_nodes.insert(node_sub);
- }
- }
- }
- return GRAPH_SUCCESS;
- }
-
- graphStatus InferBasePass::UpdateTensorDescToParentNode(NodePtr &node, std::set<NodePtr> &changed_nodes) {
- std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
- std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
-
- for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
- auto name = sub_graph->GetName();
- NodePtr netoutput = nullptr;
- auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
- if (ret != GRAPH_SUCCESS) {
- return ret;
- }
- if (netoutput == nullptr) {
- REPORT_INNER_ERROR("E19999", "No NetOutput node on sub graph %s, parent node %s", name.c_str(),
- node->GetName().c_str());
- GE_LOGE("[Check][Param] No NetOutput node on sub graph %s, parent node %s", name.c_str(),
- node->GetName().c_str());
- return GRAPH_FAILED;
- }
- auto netoutput_opdesc = netoutput->GetOpDesc();
- if (netoutput_opdesc == nullptr) {
- REPORT_INNER_ERROR("E19999", "Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it",
- name.c_str(), node->GetName().c_str());
- GE_LOGE("[Get][OpDesc] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(),
- node->GetName().c_str());
- return GRAPH_FAILED;
- }
- for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
- auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
- if (edge_desc == nullptr) {
- REPORT_INNER_ERROR("E19999",
- "Invalid NetOutput node on sub graph %s, parent node %s, "
- "can not find input tensor %d",
- name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
- GE_LOGE("[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d",
- name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
- return GRAPH_FAILED;
- }
- GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", edge_anchor->GetIdx(),
- edge_desc->GetShape().GetDimNum());
- int ref_i;
- if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
- // if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
- continue;
- }
- GELOGI("Parent node index of edge desc is %d", ref_i);
- if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
- return GRAPH_FAILED;
- }
- ref_out_tensors[ref_i].emplace_back(*edge_desc);
- }
- }
-
- if (node->GetType() == WHILE) {
- return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors, changed_nodes);
- }
- return UpdateParentNodeForBranch(node, ref_out_tensors, changed_nodes);
- }
-
- graphStatus InferBasePass::UpdateParentNodeForWhile(NodePtr &node,
- std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
- std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
- std::set<NodePtr> &changed_nodes) {
- GELOGD("Enter update parent node shape for class while op process");
- if (ref_data_tensors.size() != ref_out_tensors.size()) {
- REPORT_INNER_ERROR("E19999", "op:%s(%s) input number[%zu] and output number[%zu] is not same!",
- node->GetName().c_str(), node->GetType().c_str(), ref_data_tensors.size(),
- ref_out_tensors.size());
- GELOGE(GRAPH_FAILED, "[Check][Param] while op [%s] input number[%zu] and output number[%zu] is not same!",
- node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size());
- return GRAPH_FAILED;
- }
- for (size_t i = 0; i < ref_data_tensors.size(); i++) {
- if (ref_out_tensors[i].size() != 1) {
- REPORT_INNER_ERROR("E19999", "while op, every output should only find one output tensor in all graph!");
- GELOGE(GRAPH_FAILED, "[Check][Param] while op, every output should only find one output tensor in all graph!");
- return GRAPH_FAILED;
- }
- }
- bool need_infer_again = false;
- // check input and output
- for (size_t i = 0; i < ref_out_tensors.size(); i++) {
- if (ref_out_tensors[i].empty()) {
- continue;
- }
- auto ref_out_tensor = ref_out_tensors[i].at(0);
- auto out_shape = ref_out_tensor.MutableShape();
- vector<std::pair<int64_t, int64_t>> data_shape_range;
- // ref_i's data and output tensor shape should be same
- for (auto &tensor : ref_data_tensors[i]) {
- if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
- REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype or format among all ref output",
- node->GetName().c_str());
- GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype or format output.",
- node->GetName().c_str());
- return GRAPH_FAILED;
- }
- auto data_shape = tensor.MutableShape();
- // input is dynamic, here use dim_num
- if (data_shape.GetDims() != out_shape.GetDims()) {
- GELOGI("After infer, While %s %zu output shape [%s] is not match with input shape [%s].Need infer again.",
- node->GetName().c_str(), i, out_shape.ToString().c_str(), data_shape.ToString().c_str());
- if (data_shape.GetDimNum() != out_shape.GetDimNum()) {
- ref_out_tensor.SetUnknownDimNumShape();
- } else {
- for (size_t j = 0; j < data_shape.GetDimNum(); ++j) {
- if (data_shape.GetDim(j) != out_shape.GetDim(j)) {
- if (data_shape.GetDim(j) != UNKNOWN_DIM) {
- // if input data is fix shape, output is different, need_infer_again
- need_infer_again = true;
- }
- data_shape.SetDim(j, UNKNOWN_DIM);
- }
- // set shape rang of while, if dim is unknown ,set shape range as {1,-1}
- if (data_shape.GetDim(j) == UNKNOWN_DIM) {
- data_shape_range.emplace_back(std::make_pair(1, UNKNOWN_DIM));
- } else {
- data_shape_range.emplace_back(std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j)));
- }
- }
- ref_out_tensor.SetShape(data_shape);
- ref_out_tensor.SetShapeRange(data_shape_range);
- }
- }
- }
-
- auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
- (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
- bool output_changed = TensorDescChanged(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc);
- if (output_changed) {
- changed_nodes.insert(node);
- }
- }
- AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_infer_again);
- return GRAPH_SUCCESS;
- }
-
- graphStatus InferBasePass::UpdateOutputForMultiBatch(NodePtr &node,
- std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
- std::set<NodePtr> &changed_nodes) {
- // check sub_graph shape. Get max for update.
- for (size_t i = 0; i < ref_out_tensors.size(); ++i) {
- if (ref_out_tensors[i].empty()) {
- continue;
- }
-
- int64_t max_size = 0;
- size_t max_shape_index = 0;
- auto &ref_out_tensor = ref_out_tensors[i].at(0);
- for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) {
- auto &tensor = ref_out_tensors[i].at(j);
- if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
- REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output",
- node->GetName().c_str());
- GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype among all ref output",
- node->GetName().c_str());
- return GRAPH_FAILED;
- }
-
- auto shape = tensor.MutableShape();
- int64_t size = 1;
- for (auto dim : shape.GetDims()) {
- if (dim != 0 && INT64_MAX / dim < size) {
- REPORT_INNER_ERROR("E19999", "The shape:%s size overflow, node:%s", shape.ToString().c_str(),
- node->GetName().c_str());
- GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow");
- return PARAM_INVALID;
- }
- size *= dim;
- }
-
- if (size > max_size) {
- max_size = size;
- max_shape_index = j;
- }
- }
-
- auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
- (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index));
- bool output_changed =
- TensorDescChanged(ComGraphMakeShared<GeTensorDesc>(ref_out_tensors[i].at(max_shape_index)), output_desc);
- if (output_changed) {
- changed_nodes.insert(node);
- }
- }
-
- return GRAPH_SUCCESS;
- }
-
- graphStatus InferBasePass::UpdateParentNodeForBranch(NodePtr &node,
- std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
- std::set<NodePtr> &changed_nodes) {
- GELOGD("Enter update parent node shape for class branch op process");
- if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) {
- return UpdateOutputForMultiBatch(node, ref_out_tensors, changed_nodes);
- }
-
- // check sub_graph shape.If not same ,do unknown shape process
- for (size_t i = 0; i < ref_out_tensors.size(); i++) {
- if (ref_out_tensors[i].empty()) {
- continue;
- }
- auto ref_out_tensor = ref_out_tensors[i].at(0);
- ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
- for (auto &tensor : ref_out_tensors[i]) {
- if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
- REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output, shape:%s",
- node->GetName().c_str(), ref_out_tensor_shape.ToString().c_str());
- GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype output", node->GetName().c_str());
- return GRAPH_FAILED;
- }
- auto shape = tensor.MutableShape();
- if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
- GELOGD("node is %s, i : %zu, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
- shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
- ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
- break;
- }
- for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
- if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
- continue;
- }
- GELOGD("node is %s, i : %zu, j: %zu ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(),
- i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
- (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
- }
- }
-
- auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
- (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
- bool output_changed =
- TensorDescChanged(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc);
- if (output_changed) {
- changed_nodes.insert(node);
- }
- }
- return GRAPH_SUCCESS;
- }
-
- void InferBasePass::PrintInOutTensorShape(const NodePtr &node, const std::string &phase) {
- if (!IsLogEnable(GE, DLOG_DEBUG)) {
- return;
- }
- if (node == nullptr) {
- REPORT_INNER_ERROR("E19999", "param node is nullprt, check invalid");
- GELOGE(GRAPH_FAILED, "[Check][Param] node is null");
- return;
- }
- ge::OpDescPtr op_desc = node->GetOpDesc();
- GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "node has no opdesc, check invalid");
- GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return );
- std::stringstream ss;
- ss << "{";
- int32_t in_idx = 0;
- int32_t out_idx = 0;
- for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
- if (input_desc == nullptr) {
- in_idx++;
- continue;
- }
- if (in_idx > 0) {
- ss << " ";
- }
- ss << "input_" << in_idx << " "
- << "tensor: [";
- ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),";
- ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),";
- ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),";
- ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),";
- ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),";
- ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),";
- string range_str;
- SerialShapeRange(input_desc, range_str);
- ss << "(shape_range:" << range_str << "),";
- string value_range_str;
- SerialValueRange(input_desc, value_range_str);
- ss << "(value_range:" << value_range_str << ")]";
- in_idx++;
- }
- for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
- if (output_desc == nullptr) {
- out_idx++;
- continue;
- }
- ss << " ";
- ss << "output_" << out_idx << " "
- << "tensor: [";
- ss << "(shape:[" << output_desc->MutableShape().ToString() << "]),";
- ss << "(format:" << TypeUtils::FormatToSerialString(output_desc->GetFormat()) << "),";
- ss << "(dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) << "),";
- ss << "(origin_shape:" << output_desc->GetOriginShape().ToString() << "),";
- ss << "(origin_format:" << TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) << "),";
- ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) << "),";
- string range_str;
- SerialShapeRange(output_desc, range_str);
- ss << "(shape_range:" << range_str << "),";
- string value_range_str;
- SerialValueRange(output_desc, value_range_str);
- ss << "(value_range:" << value_range_str << ")]";
- out_idx++;
- }
- ss << "}";
- GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str());
- }
- } // namespace ge
|