From b0f6c5094f7bf3801b9fee98f3ec378dbe4f36e8 Mon Sep 17 00:00:00 2001 From: l00444296 Date: Thu, 10 Dec 2020 14:39:48 +0800 Subject: [PATCH] Feature: reset shape of dynamic single op --- ge/generator/ge_generator.cc | 53 +++++++++++++++++++++- .../passes/dynamic_single_op_reset_shape_pass.cc | 5 ++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index dc64aac1..614b69ef 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -47,6 +47,8 @@ const char *const kEngineNameDefault = "default"; const char *const kVectorEngine = "VectorEngine"; const char *const kAIcoreEngine = "AIcoreEngine"; const char *const kFileNameSuffix = "online"; +const int kDynamicDimSize = 1; +const int64_t kDynamicDimValue = -2; std::map engine_type_map{ {ge::ENGINE_SYS, kEngineNameDefault}, {ge::ENGINE_AICORE, kAIcoreEngine}, {ge::ENGINE_VECTOR, kVectorEngine}}; @@ -231,6 +233,43 @@ static void GetOpsProtoPath(string &opsproto_path) { opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); } +static Status CheckShapeReset(const OpDescPtr &op_desc, bool &change_shape_flag) { + GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); + change_shape_flag = false; + for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) { + auto input_desc = op_desc->MutableInputDesc(static_cast(i)); + GE_CHECK_NOTNULL(input_desc); + // pass scalar input desc + auto dims = input_desc->GetShape().GetDims(); + if (dims.size() == kDynamicDimSize && dims[0] == kDynamicDimValue) { + change_shape_flag = true; + } + } + return SUCCESS; +} + +static void ResetInputShape(const vector &inputs, vector &inputs_dynamic) { + for (auto input : inputs) { + auto input_desc = input.GetTensorDesc(); + GeShape shape_ori = input_desc.GetShape(); + Format format_ori = input_desc.GetFormat(); + DataType type_ori = input_desc.GetDataType(); + + std::vector dynamic_shape_dims = {kDynamicDimValue}; + GeShape dynamic_shape(dynamic_shape_dims); + + ge::GeTensor inputTensor; + if (shape_ori.GetDims().size() == 0) { + ge::GeTensorDesc desc(shape_ori, format_ori, type_ori); + } else { + ge::GeTensorDesc desc(dynamic_shape, format_ori, type_ori); + } + + inputTensor.SetTensorDesc(desc); + inputs_dynamic.push_back(inputTensor); + } +} + class GeGenerator::Impl { public: Impl(OmgContext &omg_context) : omg_context_(omg_context) {} @@ -557,7 +596,9 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, bool is_offline) { - + if (is_offline) { + (void)AttrUtils::SetBool(data_op, ATTR_DYNAMIC_SHAPE_SINGLE_AICPU, true); + } if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) { GELOGE(PARAM_INVALID, "input param is invalid when build single op!"); return PARAM_INVALID; @@ -634,7 +675,15 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in } GeModelPtr &ge_model = name_to_ge_model.begin()->second; GELOGD("The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str()); - GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs)); + + bool dynamic_flag = false; + if (CheckShapeReset(op_desc, dynamic_flag) == SUCCESS && dynamic_flag) { + vector inputs_dynamic; + ResetInputShape(inputs, inputs_dynamic); + GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs)); + } else { + GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs)); + } GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff)); return SUCCESS; } diff --git a/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc b/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc index 1d1d3add..e1384571 100644 --- a/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc +++ b/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc @@ -71,6 +71,11 @@ Status DynamicSingleOpResetShapePass::Run(ComputeGraphPtr graph) { for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) { auto input_desc = op_desc->MutableInputDesc(static_cast(i)); GE_CHECK_NOTNULL(input_desc); + // pass scalar input desc + auto dims_ori = input_desc->GetShape().GetDims(); + if (dims_ori.size() == 0) { + continue; + } input_desc->SetShape(dynamic_shape); } GELOGD("Reset dynamic aicpu node [%s] shape success!", node->GetName().c_str());