diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index c9dfac07..2064ff3f 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -315,6 +315,7 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { } auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); + int64_t index = 0; for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { GE_CHECK_NOTNULL(input_node); ge::OpDescPtr op = input_node->GetOpDesc(); @@ -328,10 +329,11 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str()); return GRAPH_FAILED; } - if (UpdateDataOpShapeRange(op, index_shape_range_map) != SUCCESS) { + if (UpdateDataOpShapeRange(op, index_shape_range_map, index) != SUCCESS) { GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str()); return GRAPH_FAILED; } + index++; } } diff --git a/ge/ir_build/option_utils.cc b/ge/ir_build/option_utils.cc index 1be996b2..a9755748 100755 --- a/ge/ir_build/option_utils.cc +++ b/ge/ir_build/option_utils.cc @@ -393,7 +393,7 @@ bool ParseSingleShapeRange(std::string &shape_range, vector>> &shape_range_map) { - GELOGD("Input shape range %s", shape_range.c_str()); + GELOGD("Input shape range %s.", shape_range.c_str()); vector shape_range_vec = StringUtils::Split(shape_range, ';'); const int DEFAULT_SHAPE_RANGE_PAIR_SIZE = 2; @@ -432,7 +432,7 @@ Status ParseInputShapeRange(const std::string &shape_range, */ Status ParseInputShapeRange(const std::string &shape_range, std::vector>> &range) { - GELOGD("Input shape range %s", shape_range.c_str()); + GELOGD("Input shape range %s.", shape_range.c_str()); if (shape_range.size() < 2) { REPORT_INPUT_ERROR("E10048", std::vector({"shape_range", "reason", "sample"}), @@ -857,7 +857,7 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, tensor_input->SetShapeRange(cur_shape_range); tensor_output->SetShape(origin_shape); tensor_output->SetShapeRange(cur_shape_range); - GELOGI("Update input [%s] shape range info", data_op_name.c_str()); + GELOGI("Update input [%s] shape range info success by name.", data_op_name.c_str()); } else { GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str()); } @@ -866,19 +866,14 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, } Status UpdateDataOpShapeRange(const OpDescPtr &op, - const vector>> &index_shape_range_map) { + const vector>> &index_shape_range_map, + int64_t index) { GE_CHECK_NOTNULL(op); if (index_shape_range_map.empty()) { GELOGI("Shape range index map of data op [%s] is empty.", op->GetName().c_str()); return SUCCESS; } - GeAttrValue::INT index = 0; - if (!AttrUtils::GetInt(op, ATTR_NAME_INDEX, index)) { - GELOGW("[%s] Get index from data attr failed.", op->GetName().c_str()); - return SUCCESS; - } - if ((index < 0) || (static_cast(index) >= index_shape_range_map.size())) { std::string situation = "data op index[" + std::to_string(index) + "]"; std::string reason = "it must less than user_input size[" + std::to_string(index_shape_range_map.size()) + "]"; @@ -910,7 +905,7 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, tensor_input->SetShapeRange(cur_shape_range); tensor_output->SetShape(origin_shape); tensor_output->SetShapeRange(cur_shape_range); - GELOGI("Update input [%s] shape range info success.", data_op_name.c_str()); + GELOGI("Update input [%s] shape range info success by index.", data_op_name.c_str()); return SUCCESS; } diff --git a/ge/ir_build/option_utils.h b/ge/ir_build/option_utils.h index 44504e35..9f09c98f 100644 --- a/ge/ir_build/option_utils.h +++ b/ge/ir_build/option_utils.h @@ -85,7 +85,8 @@ Status UpdateDataOpShape(const OpDescPtr &op, std::map>> &name_shape_range_map); Status UpdateDataOpShapeRange(const OpDescPtr &op, - const std::vector>> &index_shape_range_map); + const std::vector>> &index_shape_range_map, + int64_t index); Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range); } #endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ diff --git a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc index ec7b9488..ac57bbd7 100644 --- a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc +++ b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc @@ -79,7 +79,7 @@ TEST(UtestIrCommon, update_data_op_shape_range) { index_shape_range_map.push_back(range_pair_tmp); AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, 0); - Status ret = UpdateDataOpShapeRange(op_desc, index_shape_range_map); + Status ret = UpdateDataOpShapeRange(op_desc, index_shape_range_map, 0); EXPECT_EQ(ret, ge::SUCCESS); }