Browse Source

dts: ir get data index failed.

pull/1590/head
zhengyuanhua 4 years ago
parent
commit
f833ecdf6f
4 changed files with 12 additions and 14 deletions
  1. +3
    -1
      ge/ir_build/ge_ir_build.cc
  2. +6
    -11
      ge/ir_build/option_utils.cc
  3. +2
    -1
      ge/ir_build/option_utils.h
  4. +1
    -1
      tests/ut/ge/graph_ir/ge_ir_build_unittest.cc

+ 3
- 1
ge/ir_build/ge_ir_build.cc View File

@@ -315,6 +315,7 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) {
}
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
int64_t index = 0;
for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(input_node);
ge::OpDescPtr op = input_node->GetOpDesc();
@@ -328,10 +329,11 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) {
GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str());
return GRAPH_FAILED;
}
if (UpdateDataOpShapeRange(op, index_shape_range_map) != SUCCESS) {
if (UpdateDataOpShapeRange(op, index_shape_range_map, index) != SUCCESS) {
GELOGE(GRAPH_FAILED, "[Update][DataOpShapeRange] fail for op:%s.", op->GetName().c_str());
return GRAPH_FAILED;
}
index++;
}
}



+ 6
- 11
ge/ir_build/option_utils.cc View File

@@ -393,7 +393,7 @@ bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_
*/
Status ParseInputShapeRange(const std::string &shape_range,
std::map<string, std::vector<std::pair<int64_t, int64_t>>> &shape_range_map) {
GELOGD("Input shape range %s", shape_range.c_str());
GELOGD("Input shape range %s.", shape_range.c_str());

vector<string> shape_range_vec = StringUtils::Split(shape_range, ';');
const int DEFAULT_SHAPE_RANGE_PAIR_SIZE = 2;
@@ -432,7 +432,7 @@ Status ParseInputShapeRange(const std::string &shape_range,
*/
Status ParseInputShapeRange(const std::string &shape_range,
std::vector<std::vector<std::pair<int64_t, int64_t>>> &range) {
GELOGD("Input shape range %s", shape_range.c_str());
GELOGD("Input shape range %s.", shape_range.c_str());

if (shape_range.size() < 2) {
REPORT_INPUT_ERROR("E10048", std::vector<std::string>({"shape_range", "reason", "sample"}),
@@ -852,7 +852,7 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
tensor_input->SetShapeRange(cur_shape_range);
tensor_output->SetShape(origin_shape);
tensor_output->SetShapeRange(cur_shape_range);
GELOGI("Update input [%s] shape range info", data_op_name.c_str());
GELOGI("Update input [%s] shape range info success by name.", data_op_name.c_str());
} else {
GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str());
}
@@ -861,19 +861,14 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
}

Status UpdateDataOpShapeRange(const OpDescPtr &op,
const vector<vector<pair<int64_t, int64_t>>> &index_shape_range_map) {
const vector<vector<pair<int64_t, int64_t>>> &index_shape_range_map,
int64_t index) {
GE_CHECK_NOTNULL(op);
if (index_shape_range_map.empty()) {
GELOGI("Shape range index map of data op [%s] is empty.", op->GetName().c_str());
return SUCCESS;
}

GeAttrValue::INT index = 0;
if (!AttrUtils::GetInt(op, ATTR_NAME_INDEX, index)) {
GELOGW("[%s] Get index from data attr failed.", op->GetName().c_str());
return SUCCESS;
}

if ((index < 0) || (static_cast<size_t>(index) >= index_shape_range_map.size())) {
std::string situation = "data op index[" + std::to_string(index) + "]";
std::string reason = "it must less than user_input size[" + std::to_string(index_shape_range_map.size()) + "]";
@@ -905,7 +900,7 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op,
tensor_input->SetShapeRange(cur_shape_range);
tensor_output->SetShape(origin_shape);
tensor_output->SetShapeRange(cur_shape_range);
GELOGI("Update input [%s] shape range info success.", data_op_name.c_str());
GELOGI("Update input [%s] shape range info success by index.", data_op_name.c_str());

return SUCCESS;
}


+ 2
- 1
ge/ir_build/option_utils.h View File

@@ -85,7 +85,8 @@ Status UpdateDataOpShape(const OpDescPtr &op, std::map<std::string, std::vector<
Status UpdateDataOpShapeRange(
const OpDescPtr &op, const std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> &name_shape_range_map);
Status UpdateDataOpShapeRange(const OpDescPtr &op,
const std::vector<std::vector<std::pair<int64_t, int64_t>>> &index_shape_range_map);
const std::vector<std::vector<std::pair<int64_t, int64_t>>> &index_shape_range_map,
int64_t index);
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range);
}
#endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_

+ 1
- 1
tests/ut/ge/graph_ir/ge_ir_build_unittest.cc View File

@@ -79,7 +79,7 @@ TEST(UtestIrCommon, update_data_op_shape_range) {
index_shape_range_map.push_back(range_pair_tmp);

AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, 0);
Status ret = UpdateDataOpShapeRange(op_desc, index_shape_range_map);
Status ret = UpdateDataOpShapeRange(op_desc, index_shape_range_map, 0);
EXPECT_EQ(ret, ge::SUCCESS);
}



Loading…
Cancel
Save