From e1eb148756b27dc87d836a60e73afdcc0d098c56 Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 20 Mar 2021 18:03:08 +0800 Subject: [PATCH] Fix bug of const input index. --- ge/generator/ge_generator.cc | 28 +++++++++++++++++++++------- inc/framework/generator/ge_generator.h | 1 + 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index e2426682..2ff0c327 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -154,7 +154,7 @@ static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_ty } static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index, - bool attr) { + bool attr, int32_t &data_index) { GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID); GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID); @@ -197,9 +197,10 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const "[Add][InputDesc]fail for node:%s", data_op->GetName().c_str()); GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "[Add][OutputDesc]fail for node:%s", data_op->GetName().c_str()); - if (attr) { - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, index), return FAILED, + if (attr && !is_const) { + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, data_index), return FAILED, "[Set][Attr:%s]fail for node:%s", ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str()); + ++data_index; } ge::NodePtr arg_node = graph->AddNode(data_op); @@ -709,6 +710,17 @@ bool GeGenerator::CheckNoAicore(const ComputeGraphPtr &graph) { return true; } +void GeGenerator::RemoveConst(const vector &inputs, vector &outputs) { + for (auto input : inputs) { + GeTensorDesc input_desc = input.GetTensorDesc(); + bool is_const = false; + (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const); + if (!is_const) { + outputs.emplace_back(input); + } + } +} + Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs) { GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); @@ -773,7 +785,9 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in GELOGI("ATC parser success in single op build."); GeRootModelPtr ge_root_model = nullptr; - GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, ge_root_model)); + vector data_inputs; + RemoveConst(inputs, data_inputs); + GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, data_inputs, ge_root_model)); map op_attrs = op_desc_tmp->GetAllAttrs(); GE_CHECK_NOTNULL(ge_root_model); GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); @@ -850,25 +864,25 @@ Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector(graph_name); GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR); - // 1. Add Node to ComputeGraph. NodePtr op_node = compute_graph->AddNode(op_desc); GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR); // 2. Create InputData node. int32_t arg_index = 0; + int32_t data_index = 0; if (inputs.empty()) { for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); if (!IsNeedConnectInputOpForSingleOp(*input_desc)) { continue; } - GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); + GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false, data_index)); arg_index++; } } else { for (const auto &in_desc : inputs) { - GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true)); + GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true, data_index)); arg_index++; } } diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index 4b8caa95..505c7146 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -99,6 +99,7 @@ class GE_FUNC_VISIBILITY GeGenerator { const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, bool is_offline = true); bool CheckNoAicore(const ComputeGraphPtr &graph); + void RemoveConst(const vector &inputs, vector &outputs); Status CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs); using GeRootModelPtr = std::shared_ptr;