diff --git a/metadef b/metadef index 5d062a3..f5c1b6d 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 5d062a35640733026457c91966a558769570b0f8 +Subproject commit f5c1b6d1b6b6e97d0cfcf2efd52ec8da12d32c86 diff --git a/parser/CMakeLists.txt b/parser/CMakeLists.txt index e203166..ce8a6e0 100644 --- a/parser/CMakeLists.txt +++ b/parser/CMakeLists.txt @@ -22,18 +22,18 @@ set(SRC_LIST "caffe/caffe_custom_parser_adapter.cc" "caffe/caffe_op_parser.cc" "tensorflow/scope/scope_pass_manager.cc" - "tensorflow/graph_functiondef.cc" - "tensorflow/graph_optimizer.cc" + "tensorflow/graph_to_function_def.cc" + "tensorflow/parser_graph_optimizer.cc" "tensorflow/iterator_fusion_pass.cc" - "common/op_def/arg_op.cc" - "common/op_def/constant_op.cc" - "common/op_def/fill_op.cc" - "common/op_def/frameworkop_op.cc" - "common/op_def/no_op_op.cc" - "common/op_def/ref_switch_op.cc" - "common/op_def/shape_n_op.cc" - "common/op_def/var_is_initialized_op_op.cc" - "common/op_def/variable_op.cc" + "common/op_def/arg_op_operator.cc" + "common/op_def/constant_operator.cc" + "common/op_def/fill_operator.cc" + "common/op_def/framework_op_operator.cc" + "common/op_def/no_op_operator.cc" + "common/op_def/ref_switch_operator.cc" + "common/op_def/shape_n_operator.cc" + "common/op_def/var_is_initialized_op_operator.cc" + "common/op_def/variable_operator.cc" ) ############ libfmk_parser.so ############ diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index 6e28a8e..0ea52ce 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -236,14 +236,8 @@ const char *const kFieldInnerPro = "inner_product_param"; const char *const kFieldDim = "dim"; const char *const kFieldBiasTerm = "bias_term"; const char *const kDevNull = "/dev/null"; -const std::string kMessage = "message"; -const std::string kLayerParameter = "LayerParameter"; -const std::string kCloseBrace = "}"; -const std::string kOptional = "optional"; -const std::string kRepeated = "repeated"; -const std::string kRequired = "required"; -const std::string kCustom = "custom"; -const std::string kBuiltin = "built-in"; +const char *const kCustom = "custom"; +const char *const kBuiltin = "built-in"; std::vector kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT, ge::parser::NETOUTPUT}; const std::set kCustomProtoLayerCommonField = {"name", "type"}; @@ -284,104 +278,104 @@ const set CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxW "Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"}; Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const { - if (proto_message.input_size() > 0) { - GELOGI("This net exsit input."); + if (proto_message.input_size() <= 0) { + return SUCCESS; + } + GELOGI("This net exsit input."); + if (proto_message.input_dim_size() > 0) { + if (proto_message.input_shape_size() > 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E11001"); + GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); + return FAILED; + } - if (proto_message.input_dim_size() > 0) { - if (proto_message.input_shape_size() > 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E11001"); - GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); - return FAILED; - } + const int32_t input_dim_size = proto_message.input_dim_size(); + const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || + ((input_dim_size % proto_message.input_size()) != 0)); + if (is_input_invalid) { + ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, + {std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); + GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", + input_dim_size, proto_message.input_size()); + return FAILED; + } - const int32_t input_dim_size = proto_message.input_dim_size(); - const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || - ((input_dim_size % proto_message.input_size()) != 0)); - if (is_input_invalid) { - ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, - {std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); - GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", - input_dim_size, proto_message.input_size()); - return FAILED; - } + for (int i = 0; i < proto_message.input_size(); i++) { + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(proto_message.input(i)); + layer->set_type(ge::parser::INPUT_TYPE); + layer->add_top(proto_message.input(i)); - for (int i = 0; i < proto_message.input_size(); i++) { - domi::caffe::LayerParameter *layer = proto_message.add_layer(); - GE_CHECK_NOTNULL(layer); - layer->set_name(proto_message.input(i)); - layer->set_type(ge::parser::INPUT_TYPE); - layer->add_top(proto_message.input(i)); - - domi::caffe::InputParameter *input_param = layer->mutable_input_param(); - GE_CHECK_NOTNULL(input_param); - domi::caffe::BlobShape *shape = input_param->add_shape(); - GE_CHECK_NOTNULL(shape); - - for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { - // Can guarantee that it will not cross the border - shape->add_dim(static_cast(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); - } - input_data_flag = true; - } - } else if (proto_message.input_shape_size() > 0) { - if (proto_message.input_shape_size() != proto_message.input_size()) { - ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, - {std::to_string(proto_message.input_shape_size()), - std::to_string(proto_message.input_size())}); - GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", - proto_message.input_shape_size(), proto_message.input_size()); - return FAILED; + domi::caffe::InputParameter *input_param = layer->mutable_input_param(); + GE_CHECK_NOTNULL(input_param); + domi::caffe::BlobShape *shape = input_param->add_shape(); + GE_CHECK_NOTNULL(shape); + + for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { + // Can guarantee that it will not cross the border + shape->add_dim(static_cast(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); } + input_data_flag = true; + } + } else if (proto_message.input_shape_size() > 0) { + if (proto_message.input_shape_size() != proto_message.input_size()) { + ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, + {std::to_string(proto_message.input_shape_size()), + std::to_string(proto_message.input_size())}); + GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", + proto_message.input_shape_size(), proto_message.input_size()); + return FAILED; + } - for (int i = 0; i < proto_message.input_size(); i++) { - int dim_size = proto_message.input_shape(i).dim_size(); + for (int i = 0; i < proto_message.input_size(); i++) { + int dim_size = proto_message.input_shape(i).dim_size(); - domi::caffe::LayerParameter *layer = proto_message.add_layer(); - GE_CHECK_NOTNULL(layer); - layer->set_name(proto_message.input(i)); - layer->set_type(ge::parser::INPUT_TYPE); - layer->add_top(proto_message.input(i)); + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(proto_message.input(i)); + layer->set_type(ge::parser::INPUT_TYPE); + layer->add_top(proto_message.input(i)); - domi::caffe::InputParameter *input_param = layer->mutable_input_param(); - GE_CHECK_NOTNULL(input_param); - domi::caffe::BlobShape *shape = input_param->add_shape(); - GE_CHECK_NOTNULL(shape); + domi::caffe::InputParameter *input_param = layer->mutable_input_param(); + GE_CHECK_NOTNULL(input_param); + domi::caffe::BlobShape *shape = input_param->add_shape(); + GE_CHECK_NOTNULL(shape); - for (int j = 0; j < dim_size; j++) { - // Can guarantee that it will not cross the border - shape->add_dim(static_cast(proto_message.input_shape(i).dim(j))); - } - input_data_flag = true; + for (int j = 0; j < dim_size; j++) { + // Can guarantee that it will not cross the border + shape->add_dim(static_cast(proto_message.input_shape(i).dim(j))); } - } else { - const ge::ParserContext &ctx = ge::GetParserContext(); - std::map> input_dims = ctx.input_dims; - for (int i = 0; i < proto_message.input_size(); i++) { - string name = proto_message.input(i); - if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input - REPORT_INPUT_ERROR("E11005", std::vector({"input"}), std::vector({name})); - GELOGE(FAILED, "[Find][Dim]Model has no input shape."); - return FAILED; - } - std::vector dims = input_dims.at(name); - size_t dim_size = dims.size(); - - domi::caffe::LayerParameter *layer = proto_message.add_layer(); - GE_CHECK_NOTNULL(layer); - layer->set_name(name); - layer->set_type(ge::parser::INPUT_TYPE); - layer->add_top(proto_message.input(i)); - - domi::caffe::InputParameter *input_param = layer->mutable_input_param(); - GE_CHECK_NOTNULL(input_param); - domi::caffe::BlobShape *shape = input_param->add_shape(); - GE_CHECK_NOTNULL(shape); - - for (size_t j = 0; j < dim_size; j++) { - shape->add_dim(dims.at(j)); - } - input_data_flag = true; + input_data_flag = true; + } + } else { + const ge::ParserContext &ctx = ge::GetParserContext(); + std::map> input_dims = ctx.input_dims; + for (int i = 0; i < proto_message.input_size(); i++) { + string name = proto_message.input(i); + if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input + REPORT_INPUT_ERROR("E11005", std::vector({"input"}), std::vector({name})); + GELOGE(FAILED, "[Find][Dim]Model has no input shape."); + return FAILED; + } + std::vector dims = input_dims.at(name); + size_t dim_size = dims.size(); + + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(name); + layer->set_type(ge::parser::INPUT_TYPE); + layer->add_top(proto_message.input(i)); + + domi::caffe::InputParameter *input_param = layer->mutable_input_param(); + GE_CHECK_NOTNULL(input_param); + domi::caffe::BlobShape *shape = input_param->add_shape(); + GE_CHECK_NOTNULL(shape); + + for (size_t j = 0; j < dim_size; j++) { + shape->add_dim(dims.at(j)); } + input_data_flag = true; } } return SUCCESS; @@ -423,7 +417,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons return FAILED; } - if (ParseLayerParameter(layer_descriptor, message, operators) != SUCCESS) { + if (ParseLayerParameter(*layer_descriptor, *message, operators) != SUCCESS) { delete message; GELOGE(FAILED, "[Parse][LayerParameter] failed, model path:%s.", model_path); return FAILED; @@ -536,18 +530,18 @@ Status CaffeModelParser::ReadCaffeModelFromText(const char *model_path, google:: return SUCCESS; } -Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, - const google::protobuf::Message *message, +Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, + const google::protobuf::Message &message, vector &operators) const { - auto field_name = layer_descriptor->FindFieldByName(kFieldName); + auto field_name = layer_descriptor.FindFieldByName(kFieldName); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); - auto field_type = layer_descriptor->FindFieldByName(kFieldType); + auto field_type = layer_descriptor.FindFieldByName(kFieldType); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); - const google::protobuf::Reflection *reflection = message->GetReflection(); + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); for (auto &field : field_desc) { CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message"); // Only care about layers @@ -561,10 +555,10 @@ Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor return FAILED; } - int field_size = reflection->FieldSize(*message, field); + int field_size = reflection->FieldSize(message, field); GELOGI("Total Layer num of model file is %d", field_size); for (int i = 0; i < field_size; ++i) { - const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); + const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); const google::protobuf::Reflection *layer_reflection = layer_message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); GE_CHECK_NOTNULL(layer_reflection); @@ -1316,7 +1310,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co layer_name_map[layer.name()]++; // Set the name in proto and layer domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); - duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) + duplicate_name_layer->set_name(new_name); + layer.set_name(new_name);) // Insert the new operator name, the number of times of duplicate name is recorded as 1 layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); @@ -1539,7 +1534,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap layer_name_map[layer.name()]++; // Set the name in proto and layer domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); - duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) + duplicate_name_layer->set_name(new_name); + layer.set_name(new_name);) // Insert the new operator name, the number of times of duplicate name is recorded as 1 layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); @@ -1832,13 +1828,13 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con return FAILED; } - if (CheckLayersSize(message) != SUCCESS) { + if (CheckLayersSize(*message) != SUCCESS) { delete message; message = nullptr; return FAILED; } - if (ParseLayerParameter(layer_descriptor, message, graph) != SUCCESS) { + if (ParseLayerParameter(*layer_descriptor, *message, graph) != SUCCESS) { delete message; message = nullptr; REPORT_CALL_ERROR("E19999", "ParseLayerParameter failed failed from weight file:%s.", weight_path); @@ -1852,18 +1848,18 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con return SUCCESS; } -Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, - const google::protobuf::Message *message, +Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, + const google::protobuf::Message &message, ge::ComputeGraphPtr &graph) { - auto field_name = layer_descriptor->FindFieldByName(kFieldName); + auto field_name = layer_descriptor.FindFieldByName(kFieldName); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); - auto field_type = layer_descriptor->FindFieldByName(kFieldType); + auto field_type = layer_descriptor.FindFieldByName(kFieldType); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); - const google::protobuf::Reflection *reflection = message->GetReflection(); + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); NetParameter tmp_net; for (auto &field : field_desc) { @@ -1880,13 +1876,13 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto return FAILED; } - int field_size = reflection->FieldSize(*message, field); + int field_size = reflection->FieldSize(message, field); GELOGI("Total Layer num of model file is %d", field_size); for (int i = 0; i < field_size; ++i) { - const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); + const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); LayerParameter *layer = tmp_net.add_layer(); - if (ConvertLayerProto(&layer_message, layer) != SUCCESS) { + if (ConvertLayerProto(layer_message, layer) != SUCCESS) { GELOGE(FAILED, "[Invoke][ConvertLayerProto] Convert message to layer proto failed."); return FAILED; } @@ -1907,16 +1903,16 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto return SUCCESS; } -Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *message, +Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message &message, google::protobuf::Message *layer) { - const google::protobuf::Reflection *layer_reflection = message->GetReflection(); + const google::protobuf::Reflection *layer_reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - layer_reflection->ListFields(*message, &field_desc); + layer_reflection->ListFields(message, &field_desc); for (auto &field : field_desc) { GE_CHECK_NOTNULL(field); - if (ParseLayerField(layer_reflection, message, field, layer) != SUCCESS) { + if (ParseLayerField(*layer_reflection, message, *field, layer) != SUCCESS) { GELOGE(FAILED, "[Invoke][ParseLayerField] Parse field %s failed.", field->name().c_str()); return FAILED; } @@ -1924,114 +1920,114 @@ Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *me return SUCCESS; } -Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, +Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection &reflection, + const google::protobuf::Message &message, + const google::protobuf::FieldDescriptor &field, google::protobuf::Message *layer) const { - GELOGD("Start to parse field: %s.", field->name().c_str()); + GELOGD("Start to parse field: %s.", field.name().c_str()); domi::caffe::LayerParameter *layer_proto = PtrToPtr(layer); - string filed_name = field->name(); -#define CASE_FIELD_NAME(kName, method) \ + string filed_name = field.name(); +#define CASE_FIELD_NAME(kName, method, inner_message, field_ptr) \ if (filed_name == kField##kName) { \ - string value = reflection->GetString(*message, field); \ + string value = reflection.GetString(inner_message, field_ptr); \ GELOGD("Parse res: (%s : %s)", filed_name.c_str(), value.c_str()); \ layer_proto->set_##method(value); \ return SUCCESS; \ } - CASE_FIELD_NAME(Name, name); - CASE_FIELD_NAME(Type, type); + CASE_FIELD_NAME(Name, name, message, &field); + CASE_FIELD_NAME(Type, type, message, &field); #undef CASE_FIELD_NAME -#define CASE_FIELD_NAME_REPEATED(kName, method) \ - if (filed_name == kField##kName) { \ - int field_size = reflection->FieldSize(*message, field); \ - for (int i = 0; i < field_size; ++i) { \ - auto value = reflection->GetRepeatedString(*message, field, i); \ - layer_proto->add_##method(value); \ - } \ - return SUCCESS; \ - } - CASE_FIELD_NAME_REPEATED(Bottom, bottom); - CASE_FIELD_NAME_REPEATED(Top, top); +#define CASE_FIELD_NAME_REPEATED(kName, method, inner_message, field_ptr) \ + if (filed_name == kField##kName) { \ + int field_size = reflection.FieldSize(inner_message, field_ptr); \ + for (int i = 0; i < field_size; ++i) { \ + auto value = reflection.GetRepeatedString(inner_message, field_ptr, i); \ + layer_proto->add_##method(value); \ + } \ + return SUCCESS; \ + } + CASE_FIELD_NAME_REPEATED(Bottom, bottom, message, &field); + CASE_FIELD_NAME_REPEATED(Top, top, message, &field); #undef CASE_FIELD_NAME_REPEATED if (filed_name == kFieldBlobs) { - int field_size = reflection->FieldSize(*message, field); + int field_size = reflection.FieldSize(message, &field); for (int i = 0; i < field_size; ++i) { domi::caffe::BlobProto *item_message = layer_proto->add_blobs(); - const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); - if (ConvertBlobsProto(&sub_message, item_message) != SUCCESS) { - GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field->name().c_str()); + const google::protobuf::Message &sub_message = reflection.GetRepeatedMessage(message, &field, i); + if (ConvertBlobsProto(sub_message, item_message) != SUCCESS) { + GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field.name().c_str()); return FAILED; } } return SUCCESS; } if (filed_name == kFieldConvParam) { - const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); + const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); ConvolutionParameter *conv_param = layer_proto->mutable_convolution_param(); - ConvertConvParamProto(&sub_message, conv_param); + ConvertConvParamProto(sub_message, conv_param); } if (filed_name == kFieldInnerPro) { - const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); + const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); InnerProductParameter *inner_product = layer_proto->mutable_inner_product_param(); - ConvertInnerProdcutProto(&sub_message, inner_product); + ConvertInnerProdcutProto(sub_message, inner_product); } return SUCCESS; } -Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message, +Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message &message, google::protobuf::Message *blobs) const { - const google::protobuf::Reflection *blobs_reflection = message->GetReflection(); + const google::protobuf::Reflection *blobs_reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - blobs_reflection->ListFields(*message, &field_desc); + blobs_reflection->ListFields(message, &field_desc); domi::caffe::BlobProto *blobs_proto = PtrToPtr(blobs); for (auto &field : field_desc) { GE_CHECK_NOTNULL(field); string feild_name = field->name(); -#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name) \ - if (feild_name == #kName) { \ - int field_size = blobs_reflection->FieldSize(*message, field); \ - for (int i = 0; i < field_size; ++i) { \ - valuetype value = blobs_reflection->GetRepeated##method(*message, field, i); \ - blobs_proto->add_##name(value); \ - } \ - continue; \ - } - CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data); - CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff); - CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data); - CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff); - CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data); - CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data); +#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name, inner_message, inner_field) \ + if (feild_name == #kName) { \ + int field_size = blobs_reflection->FieldSize(inner_message, inner_field); \ + for (int i = 0; i < field_size; ++i) { \ + valuetype value = blobs_reflection->GetRepeated##method(inner_message, inner_field, i); \ + blobs_proto->add_##name(value); \ + } \ + continue; \ + } + CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data, message, field); + CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff, message, field); + CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data, message, field); + CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff, message, field); + CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data, message, field); + CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data, message, field); #undef CASE_BLOBS_FIELD_NAME_REPEATED -#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name) \ - if (feild_name == #kName) { \ - valuetype value = blobs_reflection->Get##method(*message, field); \ - blobs_proto->set_##name(value); \ - continue; \ - } - CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data); - CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num); - CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels); - CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height); - CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width); +#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name, inner_message, inner_field) \ + if (feild_name == #kName) { \ + valuetype value = blobs_reflection->Get##method(inner_message, inner_field); \ + blobs_proto->set_##name(value); \ + continue; \ + } + CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data, message, field); + CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num, message, field); + CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels, message, field); + CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height, message, field); + CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width, message, field); #undef CASE_BLOBS_FIELD_NAME if (feild_name == kFieldShape) { - const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(*message, field); + const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(message, field); domi::caffe::BlobShape *blob_shape = blobs_proto->mutable_shape(); - ConvertBlobShapeProto(&sub_message, blob_shape); + ConvertBlobShapeProto(sub_message, blob_shape); } } return SUCCESS; } -Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message *message, +Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const { - const google::protobuf::Reflection *reflection = message->GetReflection(); + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); domi::caffe::BlobShape *shape_proto = PtrToPtr(dest_message); @@ -2039,21 +2035,21 @@ Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message if (field->name() != kFieldDim) { continue; } - int field_size = reflection->FieldSize(*message, field); + int field_size = reflection->FieldSize(message, field); for (int i = 0; i < field_size; ++i) { - int64_t value = reflection->GetRepeatedInt64(*message, field, i); + int64_t value = reflection->GetRepeatedInt64(message, field, i); shape_proto->add_dim(value); } } return SUCCESS; } -Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message *message, +Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const { - const google::protobuf::Reflection *reflection = message->GetReflection(); + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); domi::caffe::ConvolutionParameter *conv_param_proto = PtrToPtr(dest_message); @@ -2062,18 +2058,18 @@ Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message if (field->name() != kFieldBiasTerm) { continue; } - bool value = reflection->GetBool(*message, field); + bool value = reflection->GetBool(message, field); conv_param_proto->set_bias_term(value); } return SUCCESS; } -Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message *message, +Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const { - const google::protobuf::Reflection *reflection = message->GetReflection(); + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); domi::caffe::InnerProductParameter *inner_product_proto = PtrToPtr(dest_message); @@ -2082,17 +2078,17 @@ Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Mess if (field->name() != kFieldBiasTerm) { continue; } - bool value = reflection->GetBool(*message, field); + bool value = reflection->GetBool(message, field); inner_product_proto->set_bias_term(value); } return SUCCESS; } -Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *message) const { - const google::protobuf::Reflection *reflection = message->GetReflection(); +Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message &message) const { + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); int num_layer = 0; int num_layers = 0; @@ -2110,7 +2106,7 @@ Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *mess return FAILED; } - int field_size = reflection->FieldSize(*message, field); + int field_size = reflection->FieldSize(message, field); if (field->name() == kLayerName) { num_layer = field_size; } else { diff --git a/parser/caffe/caffe_parser.h b/parser/caffe/caffe_parser.h index 542aca6..08fed80 100644 --- a/parser/caffe/caffe_parser.h +++ b/parser/caffe/caffe_parser.h @@ -212,8 +212,8 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { * @return SUCCESS parse layer successfully * @return FAILED parse layer failed */ - Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, - const google::protobuf::Message *message, std::vector &operators) const; + Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, + const google::protobuf::Message &message, std::vector &operators) const; /* * @ingroup domi_omg @@ -386,33 +386,33 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser { Status ParseWeightByFusionProto(const char *weight_path, const string &fusion_proto_path, const string &fusion_proto_name, ge::ComputeGraphPtr &graph); - Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, - const google::protobuf::Message *message, + Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, + const google::protobuf::Message &message, ge::ComputeGraphPtr &graph); Status ConvertLayerParameter(const google::protobuf::Message *layer_message, ge::ComputeGraphPtr &graph); - Status CheckLayersSize(const google::protobuf::Message *message) const; + Status CheckLayersSize(const google::protobuf::Message &message) const; - Status ConvertLayerProto(const google::protobuf::Message *message, + Status ConvertLayerProto(const google::protobuf::Message &message, google::protobuf::Message *layer); - Status ParseLayerField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, + Status ParseLayerField(const google::protobuf::Reflection &reflection, + const google::protobuf::Message &message, + const google::protobuf::FieldDescriptor &field, google::protobuf::Message *layer) const; - Status ConvertBlobsProto(const google::protobuf::Message *message, + Status ConvertBlobsProto(const google::protobuf::Message &message, google::protobuf::Message *blobs) const; - Status ConvertBlobShapeProto(const google::protobuf::Message *message, + Status ConvertBlobShapeProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const; - Status ConvertInnerProdcutProto(const google::protobuf::Message *message, + Status ConvertInnerProdcutProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const; - Status ConvertConvParamProto(const google::protobuf::Message *message, + Status ConvertConvParamProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const; /** * @ingroup domi_omg diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc index f29904e..8a4b261 100644 --- a/parser/common/acl_graph_parser_util.cc +++ b/parser/common/acl_graph_parser_util.cc @@ -431,6 +431,41 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra return SUCCESS; } +domi::Status AclGrphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, + const std::string &input_data_names) const { + std::vector input_names = StringUtils::Split(input_data_names, ','); + std::unordered_map name_to_index; + for (auto &input_name : input_names) { + if (!name_to_index.emplace(input_name, name_to_index.size()).second) { + GELOGE(PARAM_INVALID, "[Check][Param] Duplicate input name[%s].", input_name.c_str()); + return FAILED; + } + } + + for (const NodePtr &node : graph->GetDirectNode()) { + if (node->GetType() != ge::parser::DATA) { + continue; + } + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto iter = name_to_index.find(node->GetName()); + if (iter== name_to_index.cend()) { + GELOGE(PARAM_INVALID, "[Check][Param] Input name[%s] is not in input_data_names", + node->GetName().c_str()); + return FAILED; + } + GELOGI("[SetSpecifyIndexAttr] set node(%s) index attr, index is %ld", + op_desc->GetName().c_str(), iter->second); + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, iter->second)) { + REPORT_CALL_ERROR("E19999", "set attr %s failed for node:%s", + ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str()); + GELOGE(FAILED, "set attr %s failed for node:%s", ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str()); + return FAILED; + } + } + return SUCCESS; +} + void AclGrphParseUtil::CreateOutputNodesInfo(std::vector> &output_nodes_info, std::vector &output_nodes_name) const { output_nodes_name.clear(); @@ -670,6 +705,16 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, return PARAM_INVALID; } + string input_data_names; + GetAclParams(parser_params, ge::ir_option::INPUT_DATA_NAMES, input_data_names); + if (!input_data_names.empty()) { + if (SetSpecifyIndexAttrByInputNames(compute_graph, input_data_names) != SUCCESS) { + GELOGE(FAILED, "[Invoke][SetIndexAttr] set index attr failed, graph:%s", + compute_graph->GetName().c_str()); + return PARAM_INVALID; + } + } + return SUCCESS; } diff --git a/parser/common/acl_graph_parser_util.h b/parser/common/acl_graph_parser_util.h index 8af1d27..4ff649f 100644 --- a/parser/common/acl_graph_parser_util.h +++ b/parser/common/acl_graph_parser_util.h @@ -61,6 +61,7 @@ class AclGrphParseUtil { size_t index, OpDescPtr &op_desc); domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, const string &is_input_adjust_hw_layout) const; + domi::Status SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, const std::string &input_data_names) const; domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, std::vector> &output_nodes_info) const; }; diff --git a/parser/common/convert/message2operator.cc b/parser/common/convert/message2operator.cc index d58762f..d235ac2 100644 --- a/parser/common/convert/message2operator.cc +++ b/parser/common/convert/message2operator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/convert/message2operator.h b/parser/common/convert/message2operator.h index b8610e9..b247112 100644 --- a/parser/common/convert/message2operator.h +++ b/parser/common/convert/message2operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/convert/pb2json.cc b/parser/common/convert/pb2json.cc index 66097b8..0b3dd2b 100644 --- a/parser/common/convert/pb2json.cc +++ b/parser/common/convert/pb2json.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -82,7 +82,7 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr switch (field->type()) { case ProtobufFieldDescriptor::TYPE_MESSAGE: { const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); - if (0UL != tmp_message.ByteSizeLong()) { + if (tmp_message.ByteSizeLong() != 0UL) { Message2Json(tmp_message, black_fields, json[field->name()], enum2str, depth + 1); } break; @@ -122,7 +122,7 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr case ProtobufFieldDescriptor::TYPE_FLOAT: char str[kSignificantDigits]; - if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1){ + if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) { json[field->name()] = str; } else { json[field->name()] = reflection->GetFloat(message, field); @@ -155,10 +155,8 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { } string result = ""; for (char temp_value : type_bytes) { - uint8_t *value = 0; - value = reinterpret_cast(&temp_value); char str[kSignificantDigits]; - if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1){ + if (sprintf_s(str, kSignificantDigits, "%c", temp_value) == -1) { GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); continue; } diff --git a/parser/common/convert/pb2json.h b/parser/common/convert/pb2json.h index 28e796d..9e48d06 100644 --- a/parser/common/convert/pb2json.h +++ b/parser/common/convert/pb2json.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/arg_op.cc b/parser/common/op_def/arg_op_operator.cc similarity index 89% rename from parser/common/op_def/arg_op.cc rename to parser/common/op_def/arg_op_operator.cc index 2eb0ff6..5e06525 100644 --- a/parser/common/op_def/arg_op.cc +++ b/parser/common/op_def/arg_op_operator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "parser/common/op_def/arg_op.h" +#include "parser/common/op_def/arg_op_operator.h" #include #include "framework/common/fmk_types.h" diff --git a/parser/common/op_def/arg_op.h b/parser/common/op_def/arg_op_operator.h similarity index 89% rename from parser/common/op_def/arg_op.h rename to parser/common/op_def/arg_op_operator.h index 7258422..3746471 100644 --- a/parser/common/op_def/arg_op.h +++ b/parser/common/op_def/arg_op_operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/constant_op.cc b/parser/common/op_def/constant_operator.cc similarity index 91% rename from parser/common/op_def/constant_op.cc rename to parser/common/op_def/constant_operator.cc index ce9d249..db17752 100644 --- a/parser/common/op_def/constant_op.cc +++ b/parser/common/op_def/constant_operator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "common/op_def/constant_op.h" +#include "common/op_def/constant_operator.h" #include #include diff --git a/parser/common/op_def/constant_op.h b/parser/common/op_def/constant_operator.h similarity index 93% rename from parser/common/op_def/constant_op.h rename to parser/common/op_def/constant_operator.h index e329ac1..6f809b3 100644 --- a/parser/common/op_def/constant_op.h +++ b/parser/common/op_def/constant_operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/fill_op.cc b/parser/common/op_def/fill_operator.cc similarity index 92% rename from parser/common/op_def/fill_op.cc rename to parser/common/op_def/fill_operator.cc index 2228d26..9b2ee5d 100644 --- a/parser/common/op_def/fill_op.cc +++ b/parser/common/op_def/fill_operator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "common/op_def/fill_op.h" +#include "common/op_def/fill_operator.h" #include "framework/common/fmk_types.h" namespace ge { diff --git a/parser/common/op_def/fill_op.h b/parser/common/op_def/fill_operator.h similarity index 93% rename from parser/common/op_def/fill_op.h rename to parser/common/op_def/fill_operator.h index 4040b49..b556067 100644 --- a/parser/common/op_def/fill_op.h +++ b/parser/common/op_def/fill_operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/frameworkop_op.cc b/parser/common/op_def/framework_op_operator.cc similarity index 95% rename from parser/common/op_def/frameworkop_op.cc rename to parser/common/op_def/framework_op_operator.cc index 7810147..305fac0 100644 --- a/parser/common/op_def/frameworkop_op.cc +++ b/parser/common/op_def/framework_op_operator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "common/op_def/frameworkop_op.h" +#include "common/op_def/framework_op_operator.h" #include #include "framework/common/fmk_types.h" diff --git a/parser/common/op_def/frameworkop_op.h b/parser/common/op_def/framework_op_operator.h similarity index 94% rename from parser/common/op_def/frameworkop_op.h rename to parser/common/op_def/framework_op_operator.h index 9a1a54c..eafc1a5 100644 --- a/parser/common/op_def/frameworkop_op.h +++ b/parser/common/op_def/framework_op_operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/ir_pb_converter.cc b/parser/common/op_def/ir_pb_converter.cc index 9dd7f64..87c56ac 100644 --- a/parser/common/op_def/ir_pb_converter.cc +++ b/parser/common/op_def/ir_pb_converter.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/ir_pb_converter.h b/parser/common/op_def/ir_pb_converter.h index bbc8e03..fea9dd2 100644 --- a/parser/common/op_def/ir_pb_converter.h +++ b/parser/common/op_def/ir_pb_converter.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/no_op_op.cc b/parser/common/op_def/no_op_operator.cc similarity index 87% rename from parser/common/op_def/no_op_op.cc rename to parser/common/op_def/no_op_operator.cc index d9706ef..d630beb 100644 --- a/parser/common/op_def/no_op_op.cc +++ b/parser/common/op_def/no_op_operator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ // AUTO GEN PLEASE DO NOT MODIFY IT -#include "common/op_def/no_op_op.h" +#include "common/op_def/no_op_operator.h" #include namespace ge { diff --git a/parser/common/op_def/no_op_op.h b/parser/common/op_def/no_op_operator.h similarity index 90% rename from parser/common/op_def/no_op_op.h rename to parser/common/op_def/no_op_operator.h index f3bfea8..56338c2 100644 --- a/parser/common/op_def/no_op_op.h +++ b/parser/common/op_def/no_op_operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,6 @@ #ifndef DOMI_OP_NO_OP_OP_H_ #define DOMI_OP_NO_OP_OP_H_ #include "parser/common/op_def/operator.h" -#include "framework/omg/parser/parser_types.h" namespace ge { class NoOpOperator : public ParserOperator { diff --git a/parser/common/op_def/operator.cc b/parser/common/op_def/operator.cc index 8db5a75..15296e9 100644 --- a/parser/common/op_def/operator.cc +++ b/parser/common/op_def/operator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/operator.h b/parser/common/op_def/operator.h index a97580f..28fce69 100644 --- a/parser/common/op_def/operator.h +++ b/parser/common/op_def/operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/ref_switch_op.cc b/parser/common/op_def/ref_switch_operator.cc similarity index 89% rename from parser/common/op_def/ref_switch_op.cc rename to parser/common/op_def/ref_switch_operator.cc index 5331676..cc02f81 100644 --- a/parser/common/op_def/ref_switch_op.cc +++ b/parser/common/op_def/ref_switch_operator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ // AUTO GEN PLEASE DO NOT MODIFY IT -#include "common/op_def/ref_switch_op.h" +#include "common/op_def/ref_switch_operator.h" namespace ge { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::RefSwitchOperator() : ParserOperator("RefSwitch") {} diff --git a/parser/common/op_def/ref_switch_op.h b/parser/common/op_def/ref_switch_operator.h similarity index 93% rename from parser/common/op_def/ref_switch_op.h rename to parser/common/op_def/ref_switch_operator.h index 6a09bea..becbf28 100644 --- a/parser/common/op_def/ref_switch_op.h +++ b/parser/common/op_def/ref_switch_operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/shape_n_op.cc b/parser/common/op_def/shape_n_operator.cc similarity index 93% rename from parser/common/op_def/shape_n_op.cc rename to parser/common/op_def/shape_n_operator.cc index d5d64dc..df2a7d9 100644 --- a/parser/common/op_def/shape_n_op.cc +++ b/parser/common/op_def/shape_n_operator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ // AUTO GEN PLEASE DO NOT MODIFY IT -#include "common/op_def/shape_n_op.h" +#include "common/op_def/shape_n_operator.h" #include "graph/debug/ge_attr_define.h" #include "framework/omg/parser/parser_types.h" diff --git a/parser/common/op_def/shape_n_op.h b/parser/common/op_def/shape_n_operator.h similarity index 94% rename from parser/common/op_def/shape_n_op.h rename to parser/common/op_def/shape_n_operator.h index e60c70c..ce9aa2e 100644 --- a/parser/common/op_def/shape_n_op.h +++ b/parser/common/op_def/shape_n_operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/var_is_initialized_op_op.cc b/parser/common/op_def/var_is_initialized_op_operator.cc similarity index 89% rename from parser/common/op_def/var_is_initialized_op_op.cc rename to parser/common/op_def/var_is_initialized_op_operator.cc index ad3013d..81f7f4e 100644 --- a/parser/common/op_def/var_is_initialized_op_op.cc +++ b/parser/common/op_def/var_is_initialized_op_operator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ // AUTO GEN PLEASE DO NOT MODIFY IT -#include "common/op_def/var_is_initialized_op_op.h" +#include "common/op_def/var_is_initialized_op_operator.h" #include #include diff --git a/parser/common/op_def/var_is_initialized_op_op.h b/parser/common/op_def/var_is_initialized_op_operator.h similarity index 93% rename from parser/common/op_def/var_is_initialized_op_op.h rename to parser/common/op_def/var_is_initialized_op_operator.h index 1c8ac5e..a0c2a32 100644 --- a/parser/common/op_def/var_is_initialized_op_op.h +++ b/parser/common/op_def/var_is_initialized_op_operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/common/op_def/variable_op.cc b/parser/common/op_def/variable_operator.cc similarity index 92% rename from parser/common/op_def/variable_op.cc rename to parser/common/op_def/variable_operator.cc index 5be9600..8d200ec 100644 --- a/parser/common/op_def/variable_op.cc +++ b/parser/common/op_def/variable_operator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "parser/common/op_def/variable_op.h" +#include "parser/common/op_def/variable_operator.h" #include "graph/debug/ge_attr_define.h" diff --git a/parser/common/op_def/variable_op.h b/parser/common/op_def/variable_operator.h similarity index 94% rename from parser/common/op_def/variable_op.h rename to parser/common/op_def/variable_operator.h index 3326657..76d60a1 100644 --- a/parser/common/op_def/variable_op.h +++ b/parser/common/op_def/variable_operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/module.mk b/parser/module.mk index 678179c..524b6a9 100644 --- a/parser/module.mk +++ b/parser/module.mk @@ -92,18 +92,18 @@ PARSER_SCOPE_SRC_FILES := \ tensorflow/scope/scope_pass_manager.cc \ FMK_COMMON_SRC_FILES := \ - tensorflow/graph_functiondef.cc \ - tensorflow/graph_optimizer.cc \ + tensorflow/graph_to_function_def.cc \ + tensorflow/parser_graph_optimizer.cc \ tensorflow/iterator_fusion_pass.cc \ - common/op_def/arg_op.cc \ - common/op_def/constant_op.cc \ - common/op_def/fill_op.cc \ - common/op_def/frameworkop_op.cc \ - common/op_def/no_op_op.cc \ - common/op_def/ref_switch_op.cc \ - common/op_def/shape_n_op.cc \ - common/op_def/var_is_initialized_op_op.cc \ - common/op_def/variable_op.cc \ + common/op_def/arg_op_operator.cc \ + common/op_def/constant_operator.cc \ + common/op_def/fill_operator.cc \ + common/op_def/framework_op_operator.cc \ + common/op_def/no_op_operator.cc \ + common/op_def/ref_switch_operator.cc \ + common/op_def/shape_n_operator.cc \ + common/op_def/var_is_initialized_op_operator.cc \ + common/op_def/variable_operator.cc \ LOCAL_SRC_FILES := $(PARSER_TENSORFLOW_SRC_FILES) LOCAL_SRC_FILES += $(PARSER_SCOPE_SRC_FILES) diff --git a/parser/onnx/onnx_constant_parser.h b/parser/onnx/onnx_constant_parser.h index 0178787..628e832 100644 --- a/parser/onnx/onnx_constant_parser.h +++ b/parser/onnx/onnx_constant_parser.h @@ -69,16 +69,14 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { break; } #define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ - case dt_type: \ - { \ + case dt_type: { \ unique_ptr addr_trans(new(std::nothrow) value_type[count]()); \ GE_CHECK_NOTNULL(addr_trans); \ for (int32_t i = 0; i < (count); i++) { \ *(addr_trans.get() + i) = static_cast(*((addr).get() + i)); \ } \ (tensor).SetData(reinterpret_cast(addr_trans.get()), (count) * sizeof(value_type)); \ - break; \ - } \ + break; } \ CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor) CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor) @@ -89,7 +87,7 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { #undef CASE_SET_DATA default: { - tensor.SetData(reinterpret_cast(addr.get()), count * sizeof(T)); + tensor.SetData(PtrToPtr(addr.get()), count * sizeof(T)); break; } } diff --git a/parser/onnx/onnx_file_constant_parser.cc b/parser/onnx/onnx_file_constant_parser.cc index f12539f..1b6717d 100644 --- a/parser/onnx/onnx_file_constant_parser.cc +++ b/parser/onnx/onnx_file_constant_parser.cc @@ -31,11 +31,11 @@ using GeTensorDesc = ge::GeTensorDesc; using namespace ge::parser; namespace { -const std::string kAttrShape = "shape"; -const std::string kAttrDataType = "dtype"; -const std::string kFileConstantPath = "file_constant_path"; -const std::string kLocation = "location"; -const std::string kOffset = "offset"; +const char *const kAttrShape = "shape"; +const char *const kAttrDataType = "dtype"; +const char *const kFileConstantPath = "file_constant_path"; +const char *const kLocation = "location"; +const char *const kOffset = "offset"; const int64_t kOffsetCoefficient = 4096; const char *const kFileConstant = "FileConstant"; } @@ -46,7 +46,7 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator & GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); ge::onnx::TensorProto tensor_proto; - if (GetTensorProto(node, tensor_proto) != SUCCESS) { + if (GetTensorProto(*node, tensor_proto) != SUCCESS) { REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str()); GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str()); return FAILED; @@ -65,29 +65,29 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator & return SUCCESS; } -Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto *node_proto, - ge::onnx::TensorProto &tensor_proto) { - for (const auto &it : node_proto->attribute()) { +Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto &node_proto, + ge::onnx::TensorProto &tensor_proto) const { + for (const auto &it : node_proto.attribute()) { if (it.name() != ge::kAttrNameValue) { continue; } tensor_proto = it.t(); return SUCCESS; } - REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto->name().c_str()); - GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto->name().c_str()); + REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto.name().c_str()); + GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto.name().c_str()); return FAILED; } -void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { +void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { std::vector tmp_shape; for (int i = 0; i < tensor_proto.dims_size(); i++) { tmp_shape.push_back(tensor_proto.dims(i)); } - op_def.SetAttr(kAttrShape.c_str(), tmp_shape); + op_def.SetAttr(kAttrShape, tmp_shape); } -Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { +Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { int64_t data_type = tensor_proto.data_type(); ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type); if (type >= ge::DataType::DT_UNDEFINED) { @@ -96,11 +96,11 @@ Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor return FAILED; } - op_def.SetAttr(kAttrDataType.c_str(), type); + op_def.SetAttr(kAttrDataType, type); return SUCCESS; } -Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { +Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { ge::NamedAttrs attrs; for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) { const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i); @@ -116,12 +116,12 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); return FAILED; } - op_def.SetAttr(kFileConstantPath.c_str(), attrs); + op_def.SetAttr(kFileConstantPath, attrs); return SUCCESS; } Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, - ge::NamedAttrs &attrs) { + ge::NamedAttrs &attrs) const { if (string_proto.key() == kLocation) { AttrUtils::SetStr(attrs, kLocation, string_proto.value()); } else { @@ -134,7 +134,7 @@ Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProt return FAILED; } if (string_proto.key() == kOffset) { - if (std::numeric_limits::max() / kOffsetCoefficient < value) { + if (value > (std::numeric_limits::max() / kOffsetCoefficient)) { REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); return FAILED; diff --git a/parser/onnx/onnx_file_constant_parser.h b/parser/onnx/onnx_file_constant_parser.h index e46ce0f..0c523c5 100644 --- a/parser/onnx/onnx_file_constant_parser.h +++ b/parser/onnx/onnx_file_constant_parser.h @@ -26,11 +26,11 @@ class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser { Status ParseParams(const Message *op_src, ge::Operator &op_def) override; private: - Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); - Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); - void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); - Status GetTensorProto(const ge::onnx::NodeProto *node_proto, ge::onnx::TensorProto &tensor_proto); - Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs); + Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; + Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; + void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; + Status GetTensorProto(const ge::onnx::NodeProto &node_proto, ge::onnx::TensorProto &tensor_proto) const; + Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs) const; }; } // namespace ge diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 75472a2..a3ba4ca 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -232,7 +232,8 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type); if (post_func == nullptr) { GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str()); - if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) { + if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || + parse_func_v2 == nullptr) { GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str()); return SUCCESS; } @@ -522,9 +523,9 @@ Status OnnxModelParser::SetOperatorInputs() { auto src_op = output_op_iter->second; int dst_index = input_node_index.second; int src_index = out_node_index.second; - GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", src_index, - ParserUtils::GetOperatorName(src_op).c_str(), dst_index, - ParserUtils::GetOperatorName(dst_op).c_str()); + GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", + src_index, ParserUtils::GetOperatorName(src_op).c_str(), + dst_index, ParserUtils::GetOperatorName(dst_op).c_str()); auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op); GE_CHECK_NOTNULL(dst_op_desc); auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op); @@ -689,7 +690,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve return PARAM_INVALID; } input_ops.emplace_back(in_op->second); - GELOGI("Model assigned input node name: %s", ParserUtils::GetOperatorName(in_op->second).c_str()); + GELOGI("Model assigned input node name: %s", + ParserUtils::GetOperatorName(in_op->second).c_str()); } return SUCCESS; } @@ -717,7 +719,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vectorsecond, vector{static_cast(index)}); out_tensor_to_nodes[output_name] = std::make_pair(node_name, index); - GELOGI("out node index %d, node:%s", index, node_name.c_str()); + GELOGI("Out node index %d, node:%s", index, node_name.c_str()); } } return SUCCESS; @@ -934,16 +936,13 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model cur_compute_graph->GetName().c_str()); return ret; } - } UpdateDataFormat(root_graph); return SUCCESS; } Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { - ClearMembers(); - GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), "Run ProtoType Pass Failed"); // 1. Get all inializer. @@ -1174,7 +1173,8 @@ Status OnnxModelParser::SetOutputsInfo(const ParserUtils::OutputMapping &final_o default_out_nodes.emplace_back(output_node_info); output_tensor_names.emplace_back(tensor_name); GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].", - output_node_info.first.c_str(), output_node_info.second, tensor_name.c_str()); + output_node_info.first.c_str(), + output_node_info.second, tensor_name.c_str()); } return SUCCESS; } diff --git a/parser/onnx/onnx_util.h b/parser/onnx/onnx_util.h index 52fc2b3..cdb61b6 100644 --- a/parser/onnx/onnx_util.h +++ b/parser/onnx/onnx_util.h @@ -17,6 +17,8 @@ #ifndef PARSER_ONNX_ONNX_UTIL_PARSER_H_ #define PARSER_ONNX_ONNX_UTIL_PARSER_H_ +#include +#include #include "external/graph/types.h" namespace OnnxDataType { @@ -59,4 +61,4 @@ class OnnxUtil { }; } // namespace ge -#endif //PARSER_ONNX_ONNX_UTIL_PARSER_H_ +#endif // PARSER_ONNX_ONNX_UTIL_PARSER_H_ diff --git a/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc b/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc index 3b3eaf1..bcf01d8 100644 --- a/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc +++ b/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ using parser::IF; namespace { const std::map kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}}; const int kIfNodeAttrSize = 2; +const char *kIf = "If"; } // namespace domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( ge::onnx::NodeProto *parent_node, std::vector &onnx_graphs, @@ -33,7 +34,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), parent_node->op_type().c_str()); - auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); + auto ret = ParseIfNodeSubgraphs(*parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); if (ret != SUCCESS) { GELOGE(ret, "[Parse][Node] Parse if node failed."); REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); @@ -44,19 +45,19 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( } domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( - ge::onnx::NodeProto *parent_node, std::vector &onnx_graphs, + ge::onnx::NodeProto &parent_node, std::vector &onnx_graphs, std::map &name_to_onnx_graph, const std::string &parent_graph_name) const { - if (parent_node->attribute_size() != kIfNodeAttrSize) { - GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); - REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); + if (parent_node.attribute_size() != kIfNodeAttrSize) { + GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size()); + REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size()); return FAILED; } - GELOGD("node attribute size:%d.", parent_node->attribute_size()); + GELOGD("node attribute size:%d.", parent_node.attribute_size()); std::set all_inputs; // for onnx graph, the first attribute may be else branch and the second attribute may be then branch - for (int i = 0; i < parent_node->attribute_size(); i++) { - ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i); + for (int i = 0; i < parent_node.attribute_size(); i++) { + ge::onnx::AttributeProto *attribute = parent_node.mutable_attribute(i); GE_CHECK_NOTNULL(attribute); std::string attr_name = attribute->name(); auto itr = kAttrNameToIndex.find(attr_name); @@ -68,7 +69,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( return FAILED; } std::string unique_subgraph_name; - std::string node_name = parent_node->name(); + std::string node_name = parent_node.name(); if (!parent_graph_name.empty()) { node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name); } @@ -90,7 +91,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( AddInputNodeForGraph(all_inputs, *onnx_graph); } - AddInputForParentNode(all_inputs, *parent_node); + AddInputForParentNode(all_inputs, parent_node); return SUCCESS; } @@ -135,5 +136,5 @@ void IfSubgraphAdapter::AddInputForParentNode(const std::set &all_i parent_node.add_input(input_name); } } -REGISTER_SUBGRAPH_ADAPTER_CREATOR(IF, IfSubgraphAdapter); +REGISTER_SUBGRAPH_ADAPTER_CREATOR(kIf, IfSubgraphAdapter); } // namespace ge diff --git a/parser/onnx/subgraph_adapter/if_subgraph_adapter.h b/parser/onnx/subgraph_adapter/if_subgraph_adapter.h index eb6f492..936b9fd 100644 --- a/parser/onnx/subgraph_adapter/if_subgraph_adapter.h +++ b/parser/onnx/subgraph_adapter/if_subgraph_adapter.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #include #include #include "subgraph_adapter.h" +#include "parser/onnx/onnx_util.h" namespace ge { class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { @@ -30,7 +31,7 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { const std::string &parent_graph_name = "") override; private: - domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector &onnx_graphs, + domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto &parent_node, std::vector &onnx_graphs, std::map &name_to_onnx_graph, const std::string &parent_graph_name) const; domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set &all_inputs) const; diff --git a/parser/onnx/subgraph_adapter/subgraph_adapter.h b/parser/onnx/subgraph_adapter/subgraph_adapter.h index ad9eb1a..be84ee4 100644 --- a/parser/onnx/subgraph_adapter/subgraph_adapter.h +++ b/parser/onnx/subgraph_adapter/subgraph_adapter.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,8 +36,6 @@ #include "proto/onnx/ge_onnx.pb.h" #include "external/register/register_error_codes.h" #include "framework/omg/parser/parser_types.h" -#include "parser/onnx/onnx_util.h" - namespace ge { class PARSER_FUNC_VISIBILITY SubgraphAdapter { public: diff --git a/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc index cb45db0..b8bf478 100644 --- a/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc +++ b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/onnx/subgraph_adapter/subgraph_adapter_factory.h b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.h index 26f18c3..debe623 100644 --- a/parser/onnx/subgraph_adapter/subgraph_adapter_factory.h +++ b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,7 +62,6 @@ protected: * @brief SubgraphAdapter creation function * @return Created SubgraphAdapter */ - // typedef shared_ptr (*CREATOR_FUN)(void); using CREATOR_FUN = std::function(void)>; /** @@ -105,7 +104,7 @@ public: * @param [in] op_type Op type * @param [in] clazz SubgraphAdapter implementation class */ -#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ +#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ std::shared_ptr Creator_##op_type##_Subgraph_Adapter() { \ std::shared_ptr ptr(new (std::nothrow) clazz()); \ if (ptr == nullptr) { \ diff --git a/parser/stub/gen_stubapi.py b/parser/stub/gen_stubapi.py index 6aa1d18..95a02ee 100644 --- a/parser/stub/gen_stubapi.py +++ b/parser/stub/gen_stubapi.py @@ -167,6 +167,33 @@ class H2CC(object): del self.stack_template del self.func_list_exist + @staticmethod + def implement_function(func): + function_def = '' + function_def += '{\n' + + all_items = func.split() + start = 0 + return_type = all_items[start] + if return_type == "const": + start += 1 + return_type = all_items[start] + if return_type.startswith(('std::map', 'std::set', 'std::vector')): + return_type = "std::map" + if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): + return_type = "Ptr" + if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): + return_type += "&" + if RETURN_STATEMENTS.__contains__(return_type): + function_def += RETURN_STATEMENTS[return_type] + else: + logging.warning("Unhandled return type[%s]", return_type) + + function_def += '\n' + function_def += '}\n' + function_def += '\n' + return function_def + def just_skip(self): # skip blank line or comment if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( @@ -263,6 +290,7 @@ class H2CC(object): logging.info('Added %s functions', len(self.func_list_exist)) logging.info('Successfully converted,please see ' + self.output_file) + def handle_func1(self, line): """ :param line: @@ -461,12 +489,6 @@ class H2CC(object): logging.info("func_name[%s]", func_name) return line, func_name - def write_func_content(self, content, func_name, need_generate): - if not (func_name in self.func_list_exist) and need_generate: - self.output_fd.write(content) - self.func_list_exist.append(func_name) - logging.info('add func:[%s]', func_name) - def gen_comment(self, start_i): comment_line = '' # Function comments are on top of function declarations, copy them over @@ -488,32 +510,11 @@ class H2CC(object): break return comment_line - @staticmethod - def implement_function(func): - function_def = '' - function_def += '{\n' - - all_items = func.split() - start = 0 - return_type = all_items[start] - if return_type == "const": - start += 1 - return_type = all_items[start] - if return_type.startswith(('std::map', 'std::set', 'std::vector')): - return_type = "std::map" - if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): - return_type = "Ptr" - if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): - return_type += "&" - if RETURN_STATEMENTS.__contains__(return_type): - function_def += RETURN_STATEMENTS[return_type] - else: - logging.warning("Unhandled return type[%s]", return_type) - - function_def += '\n' - function_def += '}\n' - function_def += '\n' - return function_def + def write_func_content(self, content, func_name, need_generate): + if not (func_name in self.func_list_exist) and need_generate: + self.output_fd.write(content) + self.func_list_exist.append(func_name) + logging.info('add func:[%s]', func_name) def collect_header_files(path): diff --git a/parser/tensorflow/graph_functiondef.cc b/parser/tensorflow/graph_to_function_def.cc similarity index 99% rename from parser/tensorflow/graph_functiondef.cc rename to parser/tensorflow/graph_to_function_def.cc index 983d8a6..9629cfc 100644 --- a/parser/tensorflow/graph_functiondef.cc +++ b/parser/tensorflow/graph_to_function_def.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd +* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph_functiondef.h" +#include "graph_to_function_def.h" #include #include "common/fmk_error_codes.h" #include "graph/debug/ge_attr_define.h" diff --git a/parser/tensorflow/graph_functiondef.h b/parser/tensorflow/graph_to_function_def.h similarity index 97% rename from parser/tensorflow/graph_functiondef.h rename to parser/tensorflow/graph_to_function_def.h index ae27885..081a022 100644 --- a/parser/tensorflow/graph_functiondef.h +++ b/parser/tensorflow/graph_to_function_def.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/tensorflow/iterator_fusion_pass.cc b/parser/tensorflow/iterator_fusion_pass.cc index 14fcf9a..2a7f2a8 100644 --- a/parser/tensorflow/iterator_fusion_pass.cc +++ b/parser/tensorflow/iterator_fusion_pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include "framework/omg/parser/parser_types.h" #include "common/util.h" -#include "graph_optimizer.h" +#include "parser_graph_optimizer.h" #include "framework/common/ge_inner_error_codes.h" namespace ge { diff --git a/parser/tensorflow/iterator_fusion_pass.h b/parser/tensorflow/iterator_fusion_pass.h index 268613f..0756bb9 100644 --- a/parser/tensorflow/iterator_fusion_pass.h +++ b/parser/tensorflow/iterator_fusion_pass.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/tensorflow/graph_optimizer.cc b/parser/tensorflow/parser_graph_optimizer.cc similarity index 99% rename from parser/tensorflow/graph_optimizer.cc rename to parser/tensorflow/parser_graph_optimizer.cc index e9c8799..9046292 100644 --- a/parser/tensorflow/graph_optimizer.cc +++ b/parser/tensorflow/parser_graph_optimizer.cc @@ -14,14 +14,14 @@ * limitations under the License. */ -#include "graph_optimizer.h" +#include "parser_graph_optimizer.h" #include "graph/op_types.h" #include "common/types_map.h" #include "common/util.h" #include "framework/omg/parser/parser_inner_ctx.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" -#include "graph_functiondef.h" +#include "graph_to_function_def.h" #include "parser/common/acl_graph_parser_util.h" #include "register/op_registry.h" @@ -188,7 +188,10 @@ Status CollectNodeFuncs(vector &nodes, FunctionDefLibrary *library) Status ParserGraphOptimizer::UpdateGraph(vector &nodes) { ComputeGraphPtr sub_graph = nullptr; - GE_MAKE_SHARED(sub_graph = std::make_shared("subGraph"), sub_graph = nullptr; return PARAM_INVALID); + GE_MAKE_SHARED( + sub_graph = std::make_shared("subGraph"), + sub_graph = nullptr; + return PARAM_INVALID); unordered_map node_map; vector input_anchors; diff --git a/parser/tensorflow/graph_optimizer.h b/parser/tensorflow/parser_graph_optimizer.h similarity index 97% rename from parser/tensorflow/graph_optimizer.h rename to parser/tensorflow/parser_graph_optimizer.h index 728230e..6fee699 100644 --- a/parser/tensorflow/graph_optimizer.h +++ b/parser/tensorflow/parser_graph_optimizer.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/tensorflow/scope/scope_pass_manager.cc b/parser/tensorflow/scope/scope_pass_manager.cc index b4a3a65..d715e20 100644 --- a/parser/tensorflow/scope/scope_pass_manager.cc +++ b/parser/tensorflow/scope/scope_pass_manager.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/tensorflow/scope/scope_pass_manager.h b/parser/tensorflow/scope/scope_pass_manager.h index d661003..2f24b28 100644 --- a/parser/tensorflow/scope/scope_pass_manager.h +++ b/parser/tensorflow/scope/scope_pass_manager.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/tensorflow/tensorflow_arg_parser.cc b/parser/tensorflow/tensorflow_arg_parser.cc index afa6097..985eb15 100644 --- a/parser/tensorflow/tensorflow_arg_parser.cc +++ b/parser/tensorflow/tensorflow_arg_parser.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "parser/common/op_def/arg_op.h" +#include "parser/common/op_def/arg_op_operator.h" #include "framework/common/debug/ge_log.h" #include "framework/omg/parser/parser_inner_ctx.h" #include "graph/compute_graph.h" diff --git a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc index 3fc99de..89a2ee8 100644 --- a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc +++ b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h index 9b4cf52..f7bc8f5 100644 --- a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h +++ b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/tensorflow/tensorflow_constant_parser.cc b/parser/tensorflow/tensorflow_constant_parser.cc index 22f1647..9812b37 100644 --- a/parser/tensorflow/tensorflow_constant_parser.cc +++ b/parser/tensorflow/tensorflow_constant_parser.cc @@ -19,7 +19,7 @@ #include #include #include "parser/common/acl_graph_parser_util.h" -#include "parser/common/op_def/constant_op.h" +#include "parser/common/op_def/constant_operator.h" #include "parser/common/op_def/ir_pb_converter.h" #include "parser/common/util.h" #include "framework/common/debug/ge_log.h" diff --git a/parser/tensorflow/tensorflow_constant_parser.h b/parser/tensorflow/tensorflow_constant_parser.h index 557db3d..5d4df36 100644 --- a/parser/tensorflow/tensorflow_constant_parser.h +++ b/parser/tensorflow/tensorflow_constant_parser.h @@ -17,7 +17,7 @@ #ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_CONSTANT_PARSER_H_ #define GE_PARSER_TENSORFLOW_TENSORFLOW_CONSTANT_PARSER_H_ -#include "common/op_def/constant_op.h" +#include "common/op_def/constant_operator.h" #include "parser/common/data_op_parser.h" #include "parser/tensorflow/tensorflow_op_parser.h" diff --git a/parser/tensorflow/tensorflow_fill_parser.cc b/parser/tensorflow/tensorflow_fill_parser.cc index 886dcdb..fab2eb9 100644 --- a/parser/tensorflow/tensorflow_fill_parser.cc +++ b/parser/tensorflow/tensorflow_fill_parser.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "parser/common/op_def/fill_op.h" +#include "parser/common/op_def/fill_operator.h" #include "parser/tensorflow/tensorflow_parser_register.h" #include "framework/omg/parser/parser_types.h" diff --git a/parser/tensorflow/tensorflow_frameworkop_parser.cc b/parser/tensorflow/tensorflow_frameworkop_parser.cc index 2d03c7b..cb9ccff 100644 --- a/parser/tensorflow/tensorflow_frameworkop_parser.cc +++ b/parser/tensorflow/tensorflow_frameworkop_parser.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "parser/common/op_def/frameworkop_op.h" +#include "parser/common/op_def/framework_op_operator.h" #include "framework/common/debug/ge_log.h" #include "parser/common/op_parser_factory.h" #include "framework/omg/parser/parser_types.h" diff --git a/parser/tensorflow/tensorflow_no_op_parser.cc b/parser/tensorflow/tensorflow_no_op_parser.cc index 4d43f9d..efd193a 100644 --- a/parser/tensorflow/tensorflow_no_op_parser.cc +++ b/parser/tensorflow/tensorflow_no_op_parser.cc @@ -18,7 +18,7 @@ #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "parser/common/op_def/ir_pb_converter.h" -#include "parser/common/op_def/no_op_op.h" +#include "parser/common/op_def/no_op_operator.h" #include "parser/common/op_parser_factory.h" using domi::TENSORFLOW; diff --git a/parser/tensorflow/tensorflow_ref_switch_parser.cc b/parser/tensorflow/tensorflow_ref_switch_parser.cc index aadd966..e7672eb 100644 --- a/parser/tensorflow/tensorflow_ref_switch_parser.cc +++ b/parser/tensorflow/tensorflow_ref_switch_parser.cc @@ -17,7 +17,7 @@ #include "parser/tensorflow/tensorflow_ref_switch_parser.h" #include "framework/common/debug/ge_log.h" #include "parser/common/op_def/ir_pb_converter.h" -#include "parser/common/op_def/ref_switch_op.h" +#include "parser/common/op_def/ref_switch_operator.h" #include "parser/common/op_parser_factory.h" #include "parser/common/util.h" diff --git a/parser/tensorflow/tensorflow_ref_switch_parser.h b/parser/tensorflow/tensorflow_ref_switch_parser.h index 52c059d..d048032 100644 --- a/parser/tensorflow/tensorflow_ref_switch_parser.h +++ b/parser/tensorflow/tensorflow_ref_switch_parser.h @@ -17,7 +17,7 @@ #ifndef DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_REF_SWITCH_H_ #define DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_REF_SWITCH_H_ -#include "common/op_def/ref_switch_op.h" +#include "common/op_def/ref_switch_operator.h" #include "parser/tensorflow/tensorflow_op_parser.h" namespace ge { diff --git a/parser/tensorflow/tensorflow_shape_n_parser.cc b/parser/tensorflow/tensorflow_shape_n_parser.cc index 3cb7fd1..16f1d1d 100644 --- a/parser/tensorflow/tensorflow_shape_n_parser.cc +++ b/parser/tensorflow/tensorflow_shape_n_parser.cc @@ -18,7 +18,7 @@ #include "parser/common/op_def/ir_pb_converter.h" #include "framework/common/debug/ge_log.h" #include "parser/common/op_parser_factory.h" -#include "parser/common/op_def/shape_n_op.h" +#include "parser/common/op_def/shape_n_operator.h" #include "parser/common/util.h" using domi::TENSORFLOW; diff --git a/parser/tensorflow/tensorflow_shape_n_parser.h b/parser/tensorflow/tensorflow_shape_n_parser.h index 9f92a93..475d191 100644 --- a/parser/tensorflow/tensorflow_shape_n_parser.h +++ b/parser/tensorflow/tensorflow_shape_n_parser.h @@ -17,7 +17,7 @@ #ifndef DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_SHAPE_N_H_ #define DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_SHAPE_N_H_ -#include "common/op_def/shape_n_op.h" +#include "common/op_def/shape_n_operator.h" #include "parser/tensorflow/tensorflow_op_parser.h" namespace ge { diff --git a/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc b/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc index 577bbca..b0fec3a 100644 --- a/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc +++ b/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc @@ -15,7 +15,7 @@ */ #include "framework/common/debug/ge_log.h" -#include "parser/common/op_def/var_is_initialized_op_op.h" +#include "parser/common/op_def/var_is_initialized_op_operator.h" #include "parser/common/op_parser_factory.h" #include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_parser_register.h" diff --git a/parser/tensorflow/tensorflow_variable_v2_parser.cc b/parser/tensorflow/tensorflow_variable_v2_parser.cc index 2b14280..ac850c7 100644 --- a/parser/tensorflow/tensorflow_variable_v2_parser.cc +++ b/parser/tensorflow/tensorflow_variable_v2_parser.cc @@ -21,7 +21,7 @@ #include "graph/op_desc.h" #include "graph/utils/attr_utils.h" #include "graph/utils/tensor_utils.h" -#include "parser/common/op_def/variable_op.h" +#include "parser/common/op_def/variable_operator.h" #include "parser/common/op_parser_factory.h" #include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_parser_register.h" diff --git a/tests/st/CMakeLists.txt b/tests/st/CMakeLists.txt index 7498987..580669b 100644 --- a/tests/st/CMakeLists.txt +++ b/tests/st/CMakeLists.txt @@ -249,17 +249,17 @@ set(PARSER_SRC_FILES "${PARSER_DIR}/parser/common/convert/message2operator.cc" "${PARSER_DIR}/parser/common/data_op_parser.cc" "${PARSER_DIR}/parser/common/model_saver.cc" - "${PARSER_DIR}/parser/common/op_def/arg_op.cc" - "${PARSER_DIR}/parser/common/op_def/constant_op.cc" - "${PARSER_DIR}/parser/common/op_def/fill_op.cc" - "${PARSER_DIR}/parser/common/op_def/frameworkop_op.cc" + "${PARSER_DIR}/parser/common/op_def/arg_op_operator.cc" + "${PARSER_DIR}/parser/common/op_def/constant_operator.cc" + "${PARSER_DIR}/parser/common/op_def/fill_operator.cc" + "${PARSER_DIR}/parser/common/op_def/framework_op_operator.cc" "${PARSER_DIR}/parser/common/op_def/ir_pb_converter.cc" - "${PARSER_DIR}/parser/common/op_def/no_op_op.cc" + "${PARSER_DIR}/parser/common/op_def/no_op_operator.cc" "${PARSER_DIR}/parser/common/op_def/operator.cc" - "${PARSER_DIR}/parser/common/op_def/ref_switch_op.cc" - "${PARSER_DIR}/parser/common/op_def/shape_n_op.cc" - "${PARSER_DIR}/parser/common/op_def/variable_op.cc" - "${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_op.cc" + "${PARSER_DIR}/parser/common/op_def/ref_switch_operator.cc" + "${PARSER_DIR}/parser/common/op_def/shape_n_operator.cc" + "${PARSER_DIR}/parser/common/op_def/variable_operator.cc" + "${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_operator.cc" "${PARSER_DIR}/parser/common/op_map.cc" "${PARSER_DIR}/parser/common/op_parser_factory.cc" "${PARSER_DIR}/parser/common/parser_api.cc" @@ -284,8 +284,8 @@ set(PARSER_SRC_FILES "${PARSER_DIR}/parser/onnx/onnx_util.cc" "${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc" "${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc" - "${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc" - "${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc" + "${PARSER_DIR}/parser/tensorflow/graph_to_function_def.cc" + "${PARSER_DIR}/parser/tensorflow/parser_graph_optimizer.cc" "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" "${PARSER_DIR}/parser/tensorflow/scope/scope_pass_manager.cc" "${PARSER_DIR}/parser/tensorflow/tensorflow_arg_parser.cc" diff --git a/tests/st/testcase/test_caffe_parser.cc b/tests/st/testcase/test_caffe_parser.cc index 42c9c8d..45fc70c 100644 --- a/tests/st/testcase/test_caffe_parser.cc +++ b/tests/st/testcase/test_caffe_parser.cc @@ -765,7 +765,7 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_CheckLayersSize_test) layer->set_name("Abs"); layer->set_type("AbsVal"); - Status ret = weightParser.CheckLayersSize(layer); + Status ret = weightParser.CheckLayersSize(*layer); EXPECT_EQ(ret, FAILED); } @@ -809,8 +809,54 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test) const google::protobuf::Message *proto = factory.GetPrototype(descriptor); const google::protobuf::Message *message = proto->New(); - Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph); + Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph); delete message; EXPECT_EQ(ret, SUCCESS); } + +TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_DimSize_0) +{ + CaffeModelParser modelParser; + domi::caffe::NetParameter net; + net.add_input("111"); + net.add_input_shape(); + bool input_data_flag = true; + Status ret = modelParser.ParseInput(net, input_data_flag); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_Err1) +{ + CaffeModelParser modelParser; + domi::caffe::NetParameter net; + net.add_input("111"); + net.add_input("222"); + net.add_input_shape(); + bool input_data_flag = true; + Status ret = modelParser.ParseInput(net, input_data_flag); + EXPECT_EQ(ret, FAILED); +} + +TEST_F(STestCaffeParser, CaffeModelParser_ParserLayerParameter_Succ) +{ + CaffeModelParser modelParser; + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/origin_models/"; + const char *model_path = model_file.c_str(); + + std::string custom_proto = model_file; + std::string caffe_proto = model_file; + std::vector operators; + ge::OpDescPtr op_desc_src = std::make_shared("Data", "Input"); + ge::Operator op_src = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); + operators.emplace_back(op_src); + + model_file = case_dir + "/origin_models/caffe_add.pbtxt"; + custom_proto = case_dir + "/../../../metadef/proto/caffe/caffe.proto"; + model_path = model_file.c_str(); + std::string caffe_proto_path = case_dir + "/../../../metadef/proto/caffe/caffe.proto"; + auto ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto_path, operators); + EXPECT_EQ(ret, SUCCESS); +} } // namespace ge diff --git a/tests/st/testcase/test_tensorflow_parser.cc b/tests/st/testcase/test_tensorflow_parser.cc index 5b41752..b5f1908 100644 --- a/tests/st/testcase/test_tensorflow_parser.cc +++ b/tests/st/testcase/test_tensorflow_parser.cc @@ -34,17 +34,17 @@ #include "external/parser/tensorflow_parser.h" #include "parser/tensorflow/tensorflow_constant_parser.h" #include "common/types.h" -#include "parser/common/op_def/variable_op.h" +#include "parser/common/op_def/variable_operator.h" #include "parser/tensorflow/tensorflow_ref_switch_parser.h" #include "parser/tensorflow/tensorflow_fusion_op_parser.h" #include "parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h" -#include "parser/common/op_def/arg_op.h" +#include "parser/common/op_def/arg_op_operator.h" #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" #include "parser/tensorflow/tensorflow_reshape_parser.h" #include "parser/tensorflow/tensorflow_custom_parser_adapter.h" #include "parser/tensorflow/tensorflow_squeeze_parser.h" -#include "parser/tensorflow/graph_functiondef.h" -#include "parser/tensorflow/graph_optimizer.h" +#include "parser/tensorflow/graph_to_function_def.h" +#include "parser/tensorflow/parser_graph_optimizer.h" #include "cce/dnn_base_def.hpp" #include "parser/tensorflow/scope/scope_pass_manager.h" #include "parser/tensorflow/tensorflow_util.h" @@ -52,10 +52,10 @@ #include "parser/tensorflow/tensorflow_enter_parser.h" #include "parser/common/op_def/ir_pb_converter.h" #include "parser/common/tuple.h" -#include "common/op_def/frameworkop_op.h" -#include "common/op_def/shape_n_op.h" -#include "common/op_def/var_is_initialized_op_op.h" -#include "common/op_def/fill_op.h" +#include "common/op_def/framework_op_operator.h" +#include "common/op_def/shape_n_operator.h" +#include "common/op_def/var_is_initialized_op_operator.h" +#include "common/op_def/fill_operator.h" #include "common/convert/pb2json.h" #include "common/convert/message2operator.h" #include "parser/common/proto_file_parser.h" @@ -70,7 +70,7 @@ #include "parser/common/prototype_pass_manager.h" #include "parser/common/register_tbe.h" #include "parser/common/pass_manager.h" -#include "parser/tensorflow/graph_optimizer.h" +#include "parser/tensorflow/parser_graph_optimizer.h" #include "metadef/inc/register/scope/scope_pass_registry_impl.h" #include "register/scope/scope_fusion_pass_register.h" #undef protected @@ -678,6 +678,7 @@ namespace { if ((_name== "S") || (_name == "K")) { int index = 0; + ge::AttrUtils::SetInt(opDef, "T", 1); ge::AttrUtils::SetInt(opDef, "arg_index", index); ge::AttrUtils::SetInt(opDef, "ret_index", index); @@ -1029,7 +1030,9 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) { ParserOperator unused("Add"); case_dir = case_dir.substr(0, case_dir.find_last_of("/")); std::string model_file = case_dir + "/origin_models/tf_add.pb"; - std::map parser_params; + std::map parser_params = { + {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")}, + }; ge::Graph graph; auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); ASSERT_EQ(ret, SUCCESS); @@ -1043,6 +1046,21 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) { EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); } +TEST_F(STestTensorflowParser, tensorflow_parser_failed_for_input_data_names_error) { + RegisterCustomOp(); + + std::string case_dir = __FILE__; + ParserOperator unused("Add"); + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/origin_models/tf_add.pb"; + std::map parser_params = { + {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_3")}, + }; + ge::Graph graph; + auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); + ASSERT_EQ(ret, ge::GRAPH_FAILED); +} + TEST_F(STestTensorflowParser, tensorflow_model_Failed) { ge::Graph graph; std::string caseDir = __FILE__; diff --git a/tests/ut/parser/CMakeLists.txt b/tests/ut/parser/CMakeLists.txt index 0253c0f..28453d9 100644 --- a/tests/ut/parser/CMakeLists.txt +++ b/tests/ut/parser/CMakeLists.txt @@ -250,17 +250,17 @@ set(PARSER_SRC_FILES "${PARSER_DIR}/parser/common/convert/message2operator.cc" "${PARSER_DIR}/parser/common/data_op_parser.cc" "${PARSER_DIR}/parser/common/model_saver.cc" - "${PARSER_DIR}/parser/common/op_def/arg_op.cc" - "${PARSER_DIR}/parser/common/op_def/constant_op.cc" - "${PARSER_DIR}/parser/common/op_def/fill_op.cc" - "${PARSER_DIR}/parser/common/op_def/frameworkop_op.cc" + "${PARSER_DIR}/parser/common/op_def/arg_op_operator.cc" + "${PARSER_DIR}/parser/common/op_def/constant_operator.cc" + "${PARSER_DIR}/parser/common/op_def/fill_operator.cc" + "${PARSER_DIR}/parser/common/op_def/framework_op_operator.cc" "${PARSER_DIR}/parser/common/op_def/ir_pb_converter.cc" - "${PARSER_DIR}/parser/common/op_def/no_op_op.cc" + "${PARSER_DIR}/parser/common/op_def/no_op_operator.cc" "${PARSER_DIR}/parser/common/op_def/operator.cc" - "${PARSER_DIR}/parser/common/op_def/ref_switch_op.cc" - "${PARSER_DIR}/parser/common/op_def/shape_n_op.cc" - "${PARSER_DIR}/parser/common/op_def/variable_op.cc" - "${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_op.cc" + "${PARSER_DIR}/parser/common/op_def/ref_switch_operator.cc" + "${PARSER_DIR}/parser/common/op_def/shape_n_operator.cc" + "${PARSER_DIR}/parser/common/op_def/variable_operator.cc" + "${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_operator.cc" "${PARSER_DIR}/parser/common/op_map.cc" "${PARSER_DIR}/parser/common/op_parser_factory.cc" "${PARSER_DIR}/parser/common/parser_api.cc" @@ -285,8 +285,8 @@ set(PARSER_SRC_FILES "${PARSER_DIR}/parser/onnx/onnx_util.cc" "${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc" "${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc" - "${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc" - "${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc" + "${PARSER_DIR}/parser/tensorflow/graph_to_function_def.cc" + "${PARSER_DIR}/parser/tensorflow/parser_graph_optimizer.cc" "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" "${PARSER_DIR}/parser/tensorflow/scope/scope_pass_manager.cc" "${PARSER_DIR}/parser/tensorflow/tensorflow_arg_parser.cc" diff --git a/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc index 55fc538..5d1a5bd 100755 --- a/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc +++ b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc @@ -739,6 +739,9 @@ TEST_F(UtestCaffeParser, CaffeModelParser_CustomProtoParse_test) Status ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto, operators); EXPECT_EQ(ret, PARAM_INVALID); + ret = modelParser.CustomProtoParse("", custom_proto, caffe_proto, operators); + EXPECT_EQ(ret, FAILED); + model_file = case_dir + "/caffe_model/caffe_add.pbtxt"; custom_proto = case_dir + "/../../../../../metadef/proto/caffe/caffe.proto"; model_path = model_file.c_str(); @@ -890,7 +893,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_CheckLayersSize_test) layer->set_name("Abs"); layer->set_type("AbsVal"); - Status ret = weightParser.CheckLayersSize(layer); + Status ret = weightParser.CheckLayersSize(*layer); EXPECT_EQ(ret, FAILED); } @@ -902,7 +905,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test) layer->set_name("Abs"); layer->set_type("AbsVal"); - Status ret = weightParser.ConvertLayerProto(&net, &net); + Status ret = weightParser.ConvertLayerProto(net, &net); EXPECT_EQ(ret, SUCCESS); BlobProto* blob = layer->add_blobs(); @@ -911,16 +914,16 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test) BlobShape* shap = blob->mutable_shape(); shap->add_dim(1); shap->add_dim(2); - ret = weightParser.ConvertBlobsProto(&net, &net); + ret = weightParser.ConvertBlobsProto(net, &net); EXPECT_EQ(ret, SUCCESS); - ret = weightParser.ConvertBlobShapeProto(&net, &net); + ret = weightParser.ConvertBlobShapeProto(net, &net); EXPECT_EQ(ret, SUCCESS); - ret = weightParser.ConvertConvParamProto(&net, &net); + ret = weightParser.ConvertConvParamProto(net, &net); EXPECT_EQ(ret, SUCCESS); - ret = weightParser.ConvertInnerProdcutProto(&net, &net); + ret = weightParser.ConvertInnerProdcutProto(net, &net); EXPECT_EQ(ret, SUCCESS); } @@ -1133,7 +1136,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test) const google::protobuf::Message *proto = factory.GetPrototype(descriptor); const google::protobuf::Message *message = proto->New(); - Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph); + Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph); delete message; EXPECT_EQ(ret, SUCCESS); } @@ -1163,7 +1166,7 @@ TEST_F(UtestCaffeParser, CaffeModelParser_ParseLayerParameter_test) google::protobuf::DynamicMessageFactory factory; const google::protobuf::Message *proto = factory.GetPrototype(descriptor); google::protobuf::Message *message = proto->New(); - Status ret = modelParser.ParseLayerParameter(descriptor, message, operators); + Status ret = modelParser.ParseLayerParameter(*descriptor, *message, operators); EXPECT_EQ(ret, SUCCESS); const domi::FrameworkType fmk_type = domi::TENSORFLOW; diff --git a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc index 5139002..9667978 100644 --- a/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc +++ b/tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc @@ -7,7 +7,7 @@ #include "tensorflow/iterator_fusion_pass.h" #include "parser/common/acl_graph_parser_util.h" #define private public -#include "tensorflow/graph_optimizer.h" +#include "tensorflow/parser_graph_optimizer.h" #undef private namespace ge { class UtestGraphOptimizer : public testing::Test { diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc index 6fdee9e..c945d69 100644 --- a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -381,7 +381,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto) OnnxFileConstantParser parser; ge::onnx::NodeProto input_node; ge::onnx::TensorProto tensor_proto; - Status ret = parser.GetTensorProto(&input_node, tensor_proto); + Status ret = parser.GetTensorProto(input_node, tensor_proto); EXPECT_EQ(ret, FAILED); ge::onnx::AttributeProto *attribute = input_node.add_attribute(); @@ -391,7 +391,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto) ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); *attribute_tensor = tensor_proto; - ret = parser.GetTensorProto(&input_node, tensor_proto); + ret = parser.GetTensorProto(input_node, tensor_proto); EXPECT_EQ(ret, SUCCESS); } diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index 00678d8..c852aad 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -38,17 +38,17 @@ #include "tests/depends/ops_stub/ops_stub.h" #include "parser/tensorflow/tensorflow_constant_parser.h" #include "common/types.h" -#include "parser/common/op_def/variable_op.h" +#include "parser/common/op_def/variable_operator.h" #include "parser/tensorflow/tensorflow_ref_switch_parser.h" #include "parser/tensorflow/tensorflow_fusion_op_parser.h" #include "parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h" -#include "parser/common/op_def/arg_op.h" +#include "parser/common/op_def/arg_op_operator.h" #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" #include "parser/tensorflow/tensorflow_reshape_parser.h" #include "parser/tensorflow/tensorflow_custom_parser_adapter.h" #include "parser/tensorflow/tensorflow_squeeze_parser.h" -#include "parser/tensorflow/graph_functiondef.h" -#include "parser/tensorflow/graph_optimizer.h" +#include "parser/tensorflow/graph_to_function_def.h" +#include "parser/tensorflow/parser_graph_optimizer.h" #include "cce/dnn_base_def.hpp" #include "parser/tensorflow/scope/scope_pass_manager.h" #include "parser/tensorflow/tensorflow_util.h" @@ -56,10 +56,10 @@ #include "parser/tensorflow/tensorflow_enter_parser.h" #include "parser/common/op_def/ir_pb_converter.h" #include "parser/common/tuple.h" -#include "common/op_def/frameworkop_op.h" -#include "common/op_def/shape_n_op.h" -#include "common/op_def/var_is_initialized_op_op.h" -#include "common/op_def/fill_op.h" +#include "common/op_def/framework_op_operator.h" +#include "common/op_def/shape_n_operator.h" +#include "common/op_def/var_is_initialized_op_operator.h" +#include "common/op_def/fill_operator.h" #include "common/convert/pb2json.h" #include "common/convert/message2operator.h" #include "parser/common/proto_file_parser.h" @@ -73,7 +73,7 @@ #include "parser/common/prototype_pass_manager.h" #include "parser/common/register_tbe.h" #include "parser/common/pass_manager.h" -#include "parser/tensorflow/graph_optimizer.h" +#include "parser/tensorflow/parser_graph_optimizer.h" #include "metadef/inc/register/scope/scope_pass_registry_impl.h" #include "register/scope/scope_fusion_pass_register.h" #include "common/op_map.h" @@ -1032,7 +1032,9 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { ParserOperator unused("Add"); case_dir = case_dir.substr(0, case_dir.find_last_of("/")); std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; - std::map parser_params; + std::map parser_params = { + {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")}, + }; ge::Graph graph; auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); ASSERT_EQ(ret, SUCCESS); @@ -1046,6 +1048,21 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); } +TEST_F(UtestTensorflowParser, tensorflow_parser_input_data_names_failed) { + RegisterCustomOp(); + + std::string case_dir = __FILE__; + ParserOperator unused("Add"); + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; + std::map parser_params = { + {ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_2")}, + }; + ge::Graph graph; + auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); + ASSERT_EQ(ret, ge::GRAPH_FAILED); +} + TEST_F(UtestTensorflowParser, tensorflow_model_Failed) { ge::Graph graph; std::string caseDir = __FILE__;