Browse Source

Feature: reset shape of dynamic single op

pull/565/head
l00444296 4 years ago
parent
commit
b0f6c5094f
2 changed files with 56 additions and 2 deletions
  1. +51
    -2
      ge/generator/ge_generator.cc
  2. +5
    -0
      ge/graph/passes/dynamic_single_op_reset_shape_pass.cc

+ 51
- 2
ge/generator/ge_generator.cc View File

@@ -47,6 +47,8 @@ const char *const kEngineNameDefault = "default";
const char *const kVectorEngine = "VectorEngine";
const char *const kAIcoreEngine = "AIcoreEngine";
const char *const kFileNameSuffix = "online";
const int kDynamicDimSize = 1;
const int64_t kDynamicDimValue = -2;

std::map<ge::OpEngineType, std::string> engine_type_map{
{ge::ENGINE_SYS, kEngineNameDefault}, {ge::ENGINE_AICORE, kAIcoreEngine}, {ge::ENGINE_VECTOR, kVectorEngine}};
@@ -231,6 +233,43 @@ static void GetOpsProtoPath(string &opsproto_path) {
opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
}

static Status CheckShapeReset(const OpDescPtr &op_desc, bool &change_shape_flag) {
GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
change_shape_flag = false;
for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) {
auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i));
GE_CHECK_NOTNULL(input_desc);
// pass scalar input desc
auto dims = input_desc->GetShape().GetDims();
if (dims.size() == kDynamicDimSize && dims[0] == kDynamicDimValue) {
change_shape_flag = true;
}
}
return SUCCESS;
}

static void ResetInputShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) {
for (auto input : inputs) {
auto input_desc = input.GetTensorDesc();
GeShape shape_ori = input_desc.GetShape();
Format format_ori = input_desc.GetFormat();
DataType type_ori = input_desc.GetDataType();

std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue};
GeShape dynamic_shape(dynamic_shape_dims);

ge::GeTensor inputTensor;
if (shape_ori.GetDims().size() == 0) {
ge::GeTensorDesc desc(shape_ori, format_ori, type_ori);
} else {
ge::GeTensorDesc desc(dynamic_shape, format_ori, type_ori);
}

inputTensor.SetTensorDesc(desc);
inputs_dynamic.push_back(inputTensor);
}
}

class GeGenerator::Impl {
public:
Impl(OmgContext &omg_context) : omg_context_(omg_context) {}
@@ -557,7 +596,9 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor>
Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline) {

if (is_offline) {
(void)AttrUtils::SetBool(data_op, ATTR_DYNAMIC_SHAPE_SINGLE_AICPU, true);
}
if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
GELOGE(PARAM_INVALID, "input param is invalid when build single op!");
return PARAM_INVALID;
@@ -634,7 +675,15 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
}
GeModelPtr &ge_model = name_to_ge_model.begin()->second;
GELOGD("The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str());
GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));

bool dynamic_flag = false;
if (CheckShapeReset(op_desc, dynamic_flag) == SUCCESS && dynamic_flag) {
vector<GeTensor> inputs_dynamic;
ResetInputShape(inputs, inputs_dynamic);
GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs));
} else {
GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
}
GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff));
return SUCCESS;
}


+ 5
- 0
ge/graph/passes/dynamic_single_op_reset_shape_pass.cc View File

@@ -71,6 +71,11 @@ Status DynamicSingleOpResetShapePass::Run(ComputeGraphPtr graph) {
for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) {
auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i));
GE_CHECK_NOTNULL(input_desc);
// pass scalar input desc
auto dims_ori = input_desc->GetShape().GetDims();
if (dims_ori.size() == 0) {
continue;
}
input_desc->SetShape(dynamic_shape);
}
GELOGD("Reset dynamic aicpu node [%s] shape success!", node->GetName().c_str());


Loading…
Cancel
Save