diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index 6e28a8e..0ea52ce 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -236,14 +236,8 @@ const char *const kFieldInnerPro = "inner_product_param"; const char *const kFieldDim = "dim"; const char *const kFieldBiasTerm = "bias_term"; const char *const kDevNull = "/dev/null"; -const std::string kMessage = "message"; -const std::string kLayerParameter = "LayerParameter"; -const std::string kCloseBrace = "}"; -const std::string kOptional = "optional"; -const std::string kRepeated = "repeated"; -const std::string kRequired = "required"; -const std::string kCustom = "custom"; -const std::string kBuiltin = "built-in"; +const char *const kCustom = "custom"; +const char *const kBuiltin = "built-in"; std::vector kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT, ge::parser::NETOUTPUT}; const std::set kCustomProtoLayerCommonField = {"name", "type"}; @@ -284,104 +278,104 @@ const set CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxW "Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"}; Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const { - if (proto_message.input_size() > 0) { - GELOGI("This net exsit input."); + if (proto_message.input_size() <= 0) { + return SUCCESS; + } + GELOGI("This net exsit input."); + if (proto_message.input_dim_size() > 0) { + if (proto_message.input_shape_size() > 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E11001"); + GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); + return FAILED; + } - if (proto_message.input_dim_size() > 0) { - if (proto_message.input_shape_size() > 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E11001"); - GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); - return FAILED; - } + const int32_t input_dim_size = proto_message.input_dim_size(); + const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || + ((input_dim_size % proto_message.input_size()) != 0)); + if (is_input_invalid) { + ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, + {std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); + GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", + input_dim_size, proto_message.input_size()); + return FAILED; + } - const int32_t input_dim_size = proto_message.input_dim_size(); - const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || - ((input_dim_size % proto_message.input_size()) != 0)); - if (is_input_invalid) { - ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, - {std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); - GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", - input_dim_size, proto_message.input_size()); - return FAILED; - } + for (int i = 0; i < proto_message.input_size(); i++) { + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(proto_message.input(i)); + layer->set_type(ge::parser::INPUT_TYPE); + layer->add_top(proto_message.input(i)); - for (int i = 0; i < proto_message.input_size(); i++) { - domi::caffe::LayerParameter *layer = proto_message.add_layer(); - GE_CHECK_NOTNULL(layer); - layer->set_name(proto_message.input(i)); - layer->set_type(ge::parser::INPUT_TYPE); - layer->add_top(proto_message.input(i)); - - domi::caffe::InputParameter *input_param = layer->mutable_input_param(); - GE_CHECK_NOTNULL(input_param); - domi::caffe::BlobShape *shape = input_param->add_shape(); - GE_CHECK_NOTNULL(shape); - - for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { - // Can guarantee that it will not cross the border - shape->add_dim(static_cast(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); - } - input_data_flag = true; - } - } else if (proto_message.input_shape_size() > 0) { - if (proto_message.input_shape_size() != proto_message.input_size()) { - ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, - {std::to_string(proto_message.input_shape_size()), - std::to_string(proto_message.input_size())}); - GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", - proto_message.input_shape_size(), proto_message.input_size()); - return FAILED; + domi::caffe::InputParameter *input_param = layer->mutable_input_param(); + GE_CHECK_NOTNULL(input_param); + domi::caffe::BlobShape *shape = input_param->add_shape(); + GE_CHECK_NOTNULL(shape); + + for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { + // Can guarantee that it will not cross the border + shape->add_dim(static_cast(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); } + input_data_flag = true; + } + } else if (proto_message.input_shape_size() > 0) { + if (proto_message.input_shape_size() != proto_message.input_size()) { + ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, + {std::to_string(proto_message.input_shape_size()), + std::to_string(proto_message.input_size())}); + GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", + proto_message.input_shape_size(), proto_message.input_size()); + return FAILED; + } - for (int i = 0; i < proto_message.input_size(); i++) { - int dim_size = proto_message.input_shape(i).dim_size(); + for (int i = 0; i < proto_message.input_size(); i++) { + int dim_size = proto_message.input_shape(i).dim_size(); - domi::caffe::LayerParameter *layer = proto_message.add_layer(); - GE_CHECK_NOTNULL(layer); - layer->set_name(proto_message.input(i)); - layer->set_type(ge::parser::INPUT_TYPE); - layer->add_top(proto_message.input(i)); + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(proto_message.input(i)); + layer->set_type(ge::parser::INPUT_TYPE); + layer->add_top(proto_message.input(i)); - domi::caffe::InputParameter *input_param = layer->mutable_input_param(); - GE_CHECK_NOTNULL(input_param); - domi::caffe::BlobShape *shape = input_param->add_shape(); - GE_CHECK_NOTNULL(shape); + domi::caffe::InputParameter *input_param = layer->mutable_input_param(); + GE_CHECK_NOTNULL(input_param); + domi::caffe::BlobShape *shape = input_param->add_shape(); + GE_CHECK_NOTNULL(shape); - for (int j = 0; j < dim_size; j++) { - // Can guarantee that it will not cross the border - shape->add_dim(static_cast(proto_message.input_shape(i).dim(j))); - } - input_data_flag = true; + for (int j = 0; j < dim_size; j++) { + // Can guarantee that it will not cross the border + shape->add_dim(static_cast(proto_message.input_shape(i).dim(j))); } - } else { - const ge::ParserContext &ctx = ge::GetParserContext(); - std::map> input_dims = ctx.input_dims; - for (int i = 0; i < proto_message.input_size(); i++) { - string name = proto_message.input(i); - if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input - REPORT_INPUT_ERROR("E11005", std::vector({"input"}), std::vector({name})); - GELOGE(FAILED, "[Find][Dim]Model has no input shape."); - return FAILED; - } - std::vector dims = input_dims.at(name); - size_t dim_size = dims.size(); - - domi::caffe::LayerParameter *layer = proto_message.add_layer(); - GE_CHECK_NOTNULL(layer); - layer->set_name(name); - layer->set_type(ge::parser::INPUT_TYPE); - layer->add_top(proto_message.input(i)); - - domi::caffe::InputParameter *input_param = layer->mutable_input_param(); - GE_CHECK_NOTNULL(input_param); - domi::caffe::BlobShape *shape = input_param->add_shape(); - GE_CHECK_NOTNULL(shape); - - for (size_t j = 0; j < dim_size; j++) { - shape->add_dim(dims.at(j)); - } - input_data_flag = true; + input_data_flag = true; + } + } else { + const ge::ParserContext &ctx = ge::GetParserContext(); + std::map> input_dims = ctx.input_dims; + for (int i = 0; i < proto_message.input_size(); i++) { + string name = proto_message.input(i); + if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input + REPORT_INPUT_ERROR("E11005", std::vector({"input"}), std::vector({name})); + GELOGE(FAILED, "[Find][Dim]Model has no input shape."); + return FAILED; + } + std::vector dims = input_dims.at(name); + size_t dim_size = dims.size(); + + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(name); + layer->set_type(ge::parser::INPUT_TYPE); + layer->add_top(proto_message.input(i)); + + domi::caffe::InputParameter *input_param = layer->mutable_input_param(); + GE_CHECK_NOTNULL(input_param); + domi::caffe::BlobShape *shape = input_param->add_shape(); + GE_CHECK_NOTNULL(shape); + + for (size_t j = 0; j < dim_size; j++) { + shape->add_dim(dims.at(j)); } + input_data_flag = true; } } return SUCCESS; @@ -423,7 +417,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons return FAILED; } - if (ParseLayerParameter(layer_descriptor, message, operators) != SUCCESS) { + if (ParseLayerParameter(*layer_descriptor, *message, operators) != SUCCESS) { delete message; GELOGE(FAILED, "[Parse][LayerParameter] failed, model path:%s.", model_path); return FAILED; @@ -536,18 +530,18 @@ Status CaffeModelParser::ReadCaffeModelFromText(const char *model_path, google:: return SUCCESS; } -Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, - const google::protobuf::Message *message, +Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, + const google::protobuf::Message &message, vector &operators) const { - auto field_name = layer_descriptor->FindFieldByName(kFieldName); + auto field_name = layer_descriptor.FindFieldByName(kFieldName); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); - auto field_type = layer_descriptor->FindFieldByName(kFieldType); + auto field_type = layer_descriptor.FindFieldByName(kFieldType); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); - const google::protobuf::Reflection *reflection = message->GetReflection(); + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); for (auto &field : field_desc) { CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message"); // Only care about layers @@ -561,10 +555,10 @@ Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor return FAILED; } - int field_size = reflection->FieldSize(*message, field); + int field_size = reflection->FieldSize(message, field); GELOGI("Total Layer num of model file is %d", field_size); for (int i = 0; i < field_size; ++i) { - const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); + const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); const google::protobuf::Reflection *layer_reflection = layer_message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); GE_CHECK_NOTNULL(layer_reflection); @@ -1316,7 +1310,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co layer_name_map[layer.name()]++; // Set the name in proto and layer domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); - duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) + duplicate_name_layer->set_name(new_name); + layer.set_name(new_name);) // Insert the new operator name, the number of times of duplicate name is recorded as 1 layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); @@ -1539,7 +1534,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap layer_name_map[layer.name()]++; // Set the name in proto and layer domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); - duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) + duplicate_name_layer->set_name(new_name); + layer.set_name(new_name);) // Insert the new operator name, the number of times of duplicate name is recorded as 1 layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); @@ -1832,13 +1828,13 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con return FAILED; } - if (CheckLayersSize(message) != SUCCESS) { + if (CheckLayersSize(*message) != SUCCESS) { delete message; message = nullptr; return FAILED; } - if (ParseLayerParameter(layer_descriptor, message, graph) != SUCCESS) { + if (ParseLayerParameter(*layer_descriptor, *message, graph) != SUCCESS) { delete message; message = nullptr; REPORT_CALL_ERROR("E19999", "ParseLayerParameter failed failed from weight file:%s.", weight_path); @@ -1852,18 +1848,18 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con return SUCCESS; } -Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, - const google::protobuf::Message *message, +Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, + const google::protobuf::Message &message, ge::ComputeGraphPtr &graph) { - auto field_name = layer_descriptor->FindFieldByName(kFieldName); + auto field_name = layer_descriptor.FindFieldByName(kFieldName); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); - auto field_type = layer_descriptor->FindFieldByName(kFieldType); + auto field_type = layer_descriptor.FindFieldByName(kFieldType); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); - const google::protobuf::Reflection *reflection = message->GetReflection(); + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); NetParameter tmp_net; for (auto &field : field_desc) { @@ -1880,13 +1876,13 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto return FAILED; } - int field_size = reflection->FieldSize(*message, field); + int field_size = reflection->FieldSize(message, field); GELOGI("Total Layer num of model file is %d", field_size); for (int i = 0; i < field_size; ++i) { - const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); + const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); LayerParameter *layer = tmp_net.add_layer(); - if (ConvertLayerProto(&layer_message, layer) != SUCCESS) { + if (ConvertLayerProto(layer_message, layer) != SUCCESS) { GELOGE(FAILED, "[Invoke][ConvertLayerProto] Convert message to layer proto failed."); return FAILED; } @@ -1907,16 +1903,16 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto return SUCCESS; } -Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *message, +Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message &message, google::protobuf::Message *layer) { - const google::protobuf::Reflection *layer_reflection = message->GetReflection(); + const google::protobuf::Reflection *layer_reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - layer_reflection->ListFields(*message, &field_desc); + layer_reflection->ListFields(message, &field_desc); for (auto &field : field_desc) { GE_CHECK_NOTNULL(field); - if (ParseLayerField(layer_reflection, message, field, layer) != SUCCESS) { + if (ParseLayerField(*layer_reflection, message, *field, layer) != SUCCESS) { GELOGE(FAILED, "[Invoke][ParseLayerField] Parse field %s failed.", field->name().c_str()); return FAILED; } @@ -1924,114 +1920,114 @@ Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *me return SUCCESS; } -Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, +Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection &reflection, + const google::protobuf::Message &message, + const google::protobuf::FieldDescriptor &field, google::protobuf::Message *layer) const { - GELOGD("Start to parse field: %s.", field->name().c_str()); + GELOGD("Start to parse field: %s.", field.name().c_str()); domi::caffe::LayerParameter *layer_proto = PtrToPtr(layer); - string filed_name = field->name(); -#define CASE_FIELD_NAME(kName, method) \ + string filed_name = field.name(); +#define CASE_FIELD_NAME(kName, method, inner_message, field_ptr) \ if (filed_name == kField##kName) { \ - string value = reflection->GetString(*message, field); \ + string value = reflection.GetString(inner_message, field_ptr); \ GELOGD("Parse res: (%s : %s)", filed_name.c_str(), value.c_str()); \ layer_proto->set_##method(value); \ return SUCCESS; \ } - CASE_FIELD_NAME(Name, name); - CASE_FIELD_NAME(Type, type); + CASE_FIELD_NAME(Name, name, message, &field); + CASE_FIELD_NAME(Type, type, message, &field); #undef CASE_FIELD_NAME -#define CASE_FIELD_NAME_REPEATED(kName, method) \ - if (filed_name == kField##kName) { \ - int field_size = reflection->FieldSize(*message, field); \ - for (int i = 0; i < field_size; ++i) { \ - auto value = reflection->GetRepeatedString(*message, field, i); \ - layer_proto->add_##method(value); \ - } \ - return SUCCESS; \ - } - CASE_FIELD_NAME_REPEATED(Bottom, bottom); - CASE_FIELD_NAME_REPEATED(Top, top); +#define CASE_FIELD_NAME_REPEATED(kName, method, inner_message, field_ptr) \ + if (filed_name == kField##kName) { \ + int field_size = reflection.FieldSize(inner_message, field_ptr); \ + for (int i = 0; i < field_size; ++i) { \ + auto value = reflection.GetRepeatedString(inner_message, field_ptr, i); \ + layer_proto->add_##method(value); \ + } \ + return SUCCESS; \ + } + CASE_FIELD_NAME_REPEATED(Bottom, bottom, message, &field); + CASE_FIELD_NAME_REPEATED(Top, top, message, &field); #undef CASE_FIELD_NAME_REPEATED if (filed_name == kFieldBlobs) { - int field_size = reflection->FieldSize(*message, field); + int field_size = reflection.FieldSize(message, &field); for (int i = 0; i < field_size; ++i) { domi::caffe::BlobProto *item_message = layer_proto->add_blobs(); - const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); - if (ConvertBlobsProto(&sub_message, item_message) != SUCCESS) { - GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field->name().c_str()); + const google::protobuf::Message &sub_message = reflection.GetRepeatedMessage(message, &field, i); + if (ConvertBlobsProto(sub_message, item_message) != SUCCESS) { + GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field.name().c_str()); return FAILED; } } return SUCCESS; } if (filed_name == kFieldConvParam) { - const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); + const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); ConvolutionParameter *conv_param = layer_proto->mutable_convolution_param(); - ConvertConvParamProto(&sub_message, conv_param); + ConvertConvParamProto(sub_message, conv_param); } if (filed_name == kFieldInnerPro) { - const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); + const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); InnerProductParameter *inner_product = layer_proto->mutable_inner_product_param(); - ConvertInnerProdcutProto(&sub_message, inner_product); + ConvertInnerProdcutProto(sub_message, inner_product); } return SUCCESS; } -Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message, +Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message &message, google::protobuf::Message *blobs) const { - const google::protobuf::Reflection *blobs_reflection = message->GetReflection(); + const google::protobuf::Reflection *blobs_reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - blobs_reflection->ListFields(*message, &field_desc); + blobs_reflection->ListFields(message, &field_desc); domi::caffe::BlobProto *blobs_proto = PtrToPtr(blobs); for (auto &field : field_desc) { GE_CHECK_NOTNULL(field); string feild_name = field->name(); -#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name) \ - if (feild_name == #kName) { \ - int field_size = blobs_reflection->FieldSize(*message, field); \ - for (int i = 0; i < field_size; ++i) { \ - valuetype value = blobs_reflection->GetRepeated##method(*message, field, i); \ - blobs_proto->add_##name(value); \ - } \ - continue; \ - } - CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data); - CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff); - CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data); - CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff); - CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data); - CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data); +#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name, inner_message, inner_field) \ + if (feild_name == #kName) { \ + int field_size = blobs_reflection->FieldSize(inner_message, inner_field); \ + for (int i = 0; i < field_size; ++i) { \ + valuetype value = blobs_reflection->GetRepeated##method(inner_message, inner_field, i); \ + blobs_proto->add_##name(value); \ + } \ + continue; \ + } + CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data, message, field); + CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff, message, field); + CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data, message, field); + CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff, message, field); + CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data, message, field); + CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data, message, field); #undef CASE_BLOBS_FIELD_NAME_REPEATED -#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name) \ - if (feild_name == #kName) { \ - valuetype value = blobs_reflection->Get##method(*message, field); \ - blobs_proto->set_##name(value); \ - continue; \ - } - CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data); - CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num); - CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels); - CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height); - CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width); +#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name, inner_message, inner_field) \ + if (feild_name == #kName) { \ + valuetype value = blobs_reflection->Get##method(inner_message, inner_field); \ + blobs_proto->set_##name(value); \ + continue; \ + } + CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data, message, field); + CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num, message, field); + CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels, message, field); + CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height, message, field); + CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width, message, field); #undef CASE_BLOBS_FIELD_NAME if (feild_name == kFieldShape) { - const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(*message, field); + const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(message, field); domi::caffe::BlobShape *blob_shape = blobs_proto->mutable_shape(); - ConvertBlobShapeProto(&sub_message, blob_shape); + ConvertBlobShapeProto(sub_message, blob_shape); } } return SUCCESS; } -Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message *message, +Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const { - const google::protobuf::Reflection *reflection = message->GetReflection(); + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); domi::caffe::BlobShape *shape_proto = PtrToPtr(dest_message); @@ -2039,21 +2035,21 @@ Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message if (field->name() != kFieldDim) { continue; } - int field_size = reflection->FieldSize(*message, field); + int field_size = reflection->FieldSize(message, field); for (int i = 0; i < field_size; ++i) { - int64_t value = reflection->GetRepeatedInt64(*message, field, i); + int64_t value = reflection->GetRepeatedInt64(message, field, i); shape_proto->add_dim(value); } } return SUCCESS; } -Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message *message, +Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const { - const google::protobuf::Reflection *reflection = message->GetReflection(); + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); domi::caffe::ConvolutionParameter *conv_param_proto = PtrToPtr(dest_message); @@ -2062,18 +2058,18 @@ Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message if (field->name() != kFieldBiasTerm) { continue; } - bool value = reflection->GetBool(*message, field); + bool value = reflection->GetBool(message, field); conv_param_proto->set_bias_term(value); } return SUCCESS; } -Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message *message, +Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const { - const google::protobuf::Reflection *reflection = message->GetReflection(); + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); domi::caffe::InnerProductParameter *inner_product_proto = PtrToPtr(dest_message); @@ -2082,17 +2078,17 @@ Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Mess if (field->name() != kFieldBiasTerm) { continue; } - bool value = reflection->GetBool(*message, field); + bool value = reflection->GetBool(message, field); inner_product_proto->set_bias_term(value); } return SUCCESS; } -Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *message) const { - const google::protobuf::Reflection *reflection = message->GetReflection(); +Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message &message) const { + const google::protobuf::Reflection *reflection = message.GetReflection(); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); vector field_desc; - reflection->ListFields(*message, &field_desc); + reflection->ListFields(message, &field_desc); int num_layer = 0; int num_layers = 0; @@ -2110,7 +2106,7 @@ Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *mess return FAILED; } - int field_size = reflection->FieldSize(*message, field); + int field_size = reflection->FieldSize(message, field); if (field->name() == kLayerName) { num_layer = field_size; } else { diff --git a/parser/caffe/caffe_parser.h b/parser/caffe/caffe_parser.h index 542aca6..08fed80 100644 --- a/parser/caffe/caffe_parser.h +++ b/parser/caffe/caffe_parser.h @@ -212,8 +212,8 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { * @return SUCCESS parse layer successfully * @return FAILED parse layer failed */ - Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, - const google::protobuf::Message *message, std::vector &operators) const; + Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, + const google::protobuf::Message &message, std::vector &operators) const; /* * @ingroup domi_omg @@ -386,33 +386,33 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser { Status ParseWeightByFusionProto(const char *weight_path, const string &fusion_proto_path, const string &fusion_proto_name, ge::ComputeGraphPtr &graph); - Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, - const google::protobuf::Message *message, + Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, + const google::protobuf::Message &message, ge::ComputeGraphPtr &graph); Status ConvertLayerParameter(const google::protobuf::Message *layer_message, ge::ComputeGraphPtr &graph); - Status CheckLayersSize(const google::protobuf::Message *message) const; + Status CheckLayersSize(const google::protobuf::Message &message) const; - Status ConvertLayerProto(const google::protobuf::Message *message, + Status ConvertLayerProto(const google::protobuf::Message &message, google::protobuf::Message *layer); - Status ParseLayerField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, + Status ParseLayerField(const google::protobuf::Reflection &reflection, + const google::protobuf::Message &message, + const google::protobuf::FieldDescriptor &field, google::protobuf::Message *layer) const; - Status ConvertBlobsProto(const google::protobuf::Message *message, + Status ConvertBlobsProto(const google::protobuf::Message &message, google::protobuf::Message *blobs) const; - Status ConvertBlobShapeProto(const google::protobuf::Message *message, + Status ConvertBlobShapeProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const; - Status ConvertInnerProdcutProto(const google::protobuf::Message *message, + Status ConvertInnerProdcutProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const; - Status ConvertConvParamProto(const google::protobuf::Message *message, + Status ConvertConvParamProto(const google::protobuf::Message &message, google::protobuf::Message *dest_message) const; /** * @ingroup domi_omg diff --git a/parser/onnx/onnx_constant_parser.h b/parser/onnx/onnx_constant_parser.h index 0178787..628e832 100644 --- a/parser/onnx/onnx_constant_parser.h +++ b/parser/onnx/onnx_constant_parser.h @@ -69,16 +69,14 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { break; } #define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ - case dt_type: \ - { \ + case dt_type: { \ unique_ptr addr_trans(new(std::nothrow) value_type[count]()); \ GE_CHECK_NOTNULL(addr_trans); \ for (int32_t i = 0; i < (count); i++) { \ *(addr_trans.get() + i) = static_cast(*((addr).get() + i)); \ } \ (tensor).SetData(reinterpret_cast(addr_trans.get()), (count) * sizeof(value_type)); \ - break; \ - } \ + break; } \ CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor) CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor) @@ -89,7 +87,7 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { #undef CASE_SET_DATA default: { - tensor.SetData(reinterpret_cast(addr.get()), count * sizeof(T)); + tensor.SetData(PtrToPtr(addr.get()), count * sizeof(T)); break; } } diff --git a/parser/onnx/onnx_file_constant_parser.cc b/parser/onnx/onnx_file_constant_parser.cc index f12539f..1b6717d 100644 --- a/parser/onnx/onnx_file_constant_parser.cc +++ b/parser/onnx/onnx_file_constant_parser.cc @@ -31,11 +31,11 @@ using GeTensorDesc = ge::GeTensorDesc; using namespace ge::parser; namespace { -const std::string kAttrShape = "shape"; -const std::string kAttrDataType = "dtype"; -const std::string kFileConstantPath = "file_constant_path"; -const std::string kLocation = "location"; -const std::string kOffset = "offset"; +const char *const kAttrShape = "shape"; +const char *const kAttrDataType = "dtype"; +const char *const kFileConstantPath = "file_constant_path"; +const char *const kLocation = "location"; +const char *const kOffset = "offset"; const int64_t kOffsetCoefficient = 4096; const char *const kFileConstant = "FileConstant"; } @@ -46,7 +46,7 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator & GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); ge::onnx::TensorProto tensor_proto; - if (GetTensorProto(node, tensor_proto) != SUCCESS) { + if (GetTensorProto(*node, tensor_proto) != SUCCESS) { REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str()); GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str()); return FAILED; @@ -65,29 +65,29 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator & return SUCCESS; } -Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto *node_proto, - ge::onnx::TensorProto &tensor_proto) { - for (const auto &it : node_proto->attribute()) { +Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto &node_proto, + ge::onnx::TensorProto &tensor_proto) const { + for (const auto &it : node_proto.attribute()) { if (it.name() != ge::kAttrNameValue) { continue; } tensor_proto = it.t(); return SUCCESS; } - REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto->name().c_str()); - GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto->name().c_str()); + REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto.name().c_str()); + GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto.name().c_str()); return FAILED; } -void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { +void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { std::vector tmp_shape; for (int i = 0; i < tensor_proto.dims_size(); i++) { tmp_shape.push_back(tensor_proto.dims(i)); } - op_def.SetAttr(kAttrShape.c_str(), tmp_shape); + op_def.SetAttr(kAttrShape, tmp_shape); } -Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { +Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { int64_t data_type = tensor_proto.data_type(); ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type); if (type >= ge::DataType::DT_UNDEFINED) { @@ -96,11 +96,11 @@ Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor return FAILED; } - op_def.SetAttr(kAttrDataType.c_str(), type); + op_def.SetAttr(kAttrDataType, type); return SUCCESS; } -Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { +Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { ge::NamedAttrs attrs; for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) { const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i); @@ -116,12 +116,12 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); return FAILED; } - op_def.SetAttr(kFileConstantPath.c_str(), attrs); + op_def.SetAttr(kFileConstantPath, attrs); return SUCCESS; } Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, - ge::NamedAttrs &attrs) { + ge::NamedAttrs &attrs) const { if (string_proto.key() == kLocation) { AttrUtils::SetStr(attrs, kLocation, string_proto.value()); } else { @@ -134,7 +134,7 @@ Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProt return FAILED; } if (string_proto.key() == kOffset) { - if (std::numeric_limits::max() / kOffsetCoefficient < value) { + if (value > (std::numeric_limits::max() / kOffsetCoefficient)) { REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); return FAILED; diff --git a/parser/onnx/onnx_file_constant_parser.h b/parser/onnx/onnx_file_constant_parser.h index e46ce0f..0c523c5 100644 --- a/parser/onnx/onnx_file_constant_parser.h +++ b/parser/onnx/onnx_file_constant_parser.h @@ -26,11 +26,11 @@ class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser { Status ParseParams(const Message *op_src, ge::Operator &op_def) override; private: - Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); - Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); - void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); - Status GetTensorProto(const ge::onnx::NodeProto *node_proto, ge::onnx::TensorProto &tensor_proto); - Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs); + Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; + Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; + void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; + Status GetTensorProto(const ge::onnx::NodeProto &node_proto, ge::onnx::TensorProto &tensor_proto) const; + Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs) const; }; } // namespace ge diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 75472a2..a3ba4ca 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -232,7 +232,8 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type); if (post_func == nullptr) { GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str()); - if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) { + if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || + parse_func_v2 == nullptr) { GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str()); return SUCCESS; } @@ -522,9 +523,9 @@ Status OnnxModelParser::SetOperatorInputs() { auto src_op = output_op_iter->second; int dst_index = input_node_index.second; int src_index = out_node_index.second; - GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", src_index, - ParserUtils::GetOperatorName(src_op).c_str(), dst_index, - ParserUtils::GetOperatorName(dst_op).c_str()); + GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", + src_index, ParserUtils::GetOperatorName(src_op).c_str(), + dst_index, ParserUtils::GetOperatorName(dst_op).c_str()); auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op); GE_CHECK_NOTNULL(dst_op_desc); auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op); @@ -689,7 +690,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve return PARAM_INVALID; } input_ops.emplace_back(in_op->second); - GELOGI("Model assigned input node name: %s", ParserUtils::GetOperatorName(in_op->second).c_str()); + GELOGI("Model assigned input node name: %s", + ParserUtils::GetOperatorName(in_op->second).c_str()); } return SUCCESS; } @@ -717,7 +719,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vectorsecond, vector{static_cast(index)}); out_tensor_to_nodes[output_name] = std::make_pair(node_name, index); - GELOGI("out node index %d, node:%s", index, node_name.c_str()); + GELOGI("Out node index %d, node:%s", index, node_name.c_str()); } } return SUCCESS; @@ -934,16 +936,13 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model cur_compute_graph->GetName().c_str()); return ret; } - } UpdateDataFormat(root_graph); return SUCCESS; } Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { - ClearMembers(); - GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), "Run ProtoType Pass Failed"); // 1. Get all inializer. @@ -1174,7 +1173,8 @@ Status OnnxModelParser::SetOutputsInfo(const ParserUtils::OutputMapping &final_o default_out_nodes.emplace_back(output_node_info); output_tensor_names.emplace_back(tensor_name); GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].", - output_node_info.first.c_str(), output_node_info.second, tensor_name.c_str()); + output_node_info.first.c_str(), + output_node_info.second, tensor_name.c_str()); } return SUCCESS; } diff --git a/parser/onnx/onnx_util.h b/parser/onnx/onnx_util.h index 52fc2b3..cdb61b6 100644 --- a/parser/onnx/onnx_util.h +++ b/parser/onnx/onnx_util.h @@ -17,6 +17,8 @@ #ifndef PARSER_ONNX_ONNX_UTIL_PARSER_H_ #define PARSER_ONNX_ONNX_UTIL_PARSER_H_ +#include +#include #include "external/graph/types.h" namespace OnnxDataType { @@ -59,4 +61,4 @@ class OnnxUtil { }; } // namespace ge -#endif //PARSER_ONNX_ONNX_UTIL_PARSER_H_ +#endif // PARSER_ONNX_ONNX_UTIL_PARSER_H_ diff --git a/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc b/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc index 3b3eaf1..bcf01d8 100644 --- a/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc +++ b/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ using parser::IF; namespace { const std::map kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}}; const int kIfNodeAttrSize = 2; +const char *kIf = "If"; } // namespace domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( ge::onnx::NodeProto *parent_node, std::vector &onnx_graphs, @@ -33,7 +34,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), parent_node->op_type().c_str()); - auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); + auto ret = ParseIfNodeSubgraphs(*parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); if (ret != SUCCESS) { GELOGE(ret, "[Parse][Node] Parse if node failed."); REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); @@ -44,19 +45,19 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( } domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( - ge::onnx::NodeProto *parent_node, std::vector &onnx_graphs, + ge::onnx::NodeProto &parent_node, std::vector &onnx_graphs, std::map &name_to_onnx_graph, const std::string &parent_graph_name) const { - if (parent_node->attribute_size() != kIfNodeAttrSize) { - GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); - REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); + if (parent_node.attribute_size() != kIfNodeAttrSize) { + GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size()); + REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size()); return FAILED; } - GELOGD("node attribute size:%d.", parent_node->attribute_size()); + GELOGD("node attribute size:%d.", parent_node.attribute_size()); std::set all_inputs; // for onnx graph, the first attribute may be else branch and the second attribute may be then branch - for (int i = 0; i < parent_node->attribute_size(); i++) { - ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i); + for (int i = 0; i < parent_node.attribute_size(); i++) { + ge::onnx::AttributeProto *attribute = parent_node.mutable_attribute(i); GE_CHECK_NOTNULL(attribute); std::string attr_name = attribute->name(); auto itr = kAttrNameToIndex.find(attr_name); @@ -68,7 +69,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( return FAILED; } std::string unique_subgraph_name; - std::string node_name = parent_node->name(); + std::string node_name = parent_node.name(); if (!parent_graph_name.empty()) { node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name); } @@ -90,7 +91,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( AddInputNodeForGraph(all_inputs, *onnx_graph); } - AddInputForParentNode(all_inputs, *parent_node); + AddInputForParentNode(all_inputs, parent_node); return SUCCESS; } @@ -135,5 +136,5 @@ void IfSubgraphAdapter::AddInputForParentNode(const std::set &all_i parent_node.add_input(input_name); } } -REGISTER_SUBGRAPH_ADAPTER_CREATOR(IF, IfSubgraphAdapter); +REGISTER_SUBGRAPH_ADAPTER_CREATOR(kIf, IfSubgraphAdapter); } // namespace ge diff --git a/parser/onnx/subgraph_adapter/if_subgraph_adapter.h b/parser/onnx/subgraph_adapter/if_subgraph_adapter.h index eb6f492..936b9fd 100644 --- a/parser/onnx/subgraph_adapter/if_subgraph_adapter.h +++ b/parser/onnx/subgraph_adapter/if_subgraph_adapter.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #include #include #include "subgraph_adapter.h" +#include "parser/onnx/onnx_util.h" namespace ge { class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { @@ -30,7 +31,7 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { const std::string &parent_graph_name = "") override; private: - domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector &onnx_graphs, + domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto &parent_node, std::vector &onnx_graphs, std::map &name_to_onnx_graph, const std::string &parent_graph_name) const; domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set &all_inputs) const; diff --git a/parser/onnx/subgraph_adapter/subgraph_adapter.h b/parser/onnx/subgraph_adapter/subgraph_adapter.h index ad9eb1a..be84ee4 100644 --- a/parser/onnx/subgraph_adapter/subgraph_adapter.h +++ b/parser/onnx/subgraph_adapter/subgraph_adapter.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,8 +36,6 @@ #include "proto/onnx/ge_onnx.pb.h" #include "external/register/register_error_codes.h" #include "framework/omg/parser/parser_types.h" -#include "parser/onnx/onnx_util.h" - namespace ge { class PARSER_FUNC_VISIBILITY SubgraphAdapter { public: diff --git a/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc index cb45db0..b8bf478 100644 --- a/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc +++ b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/parser/onnx/subgraph_adapter/subgraph_adapter_factory.h b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.h index 26f18c3..debe623 100644 --- a/parser/onnx/subgraph_adapter/subgraph_adapter_factory.h +++ b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,7 +62,6 @@ protected: * @brief SubgraphAdapter creation function * @return Created SubgraphAdapter */ - // typedef shared_ptr (*CREATOR_FUN)(void); using CREATOR_FUN = std::function(void)>; /** @@ -105,7 +104,7 @@ public: * @param [in] op_type Op type * @param [in] clazz SubgraphAdapter implementation class */ -#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ +#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ std::shared_ptr Creator_##op_type##_Subgraph_Adapter() { \ std::shared_ptr ptr(new (std::nothrow) clazz()); \ if (ptr == nullptr) { \ diff --git a/parser/stub/gen_stubapi.py b/parser/stub/gen_stubapi.py index 6aa1d18..95a02ee 100644 --- a/parser/stub/gen_stubapi.py +++ b/parser/stub/gen_stubapi.py @@ -167,6 +167,33 @@ class H2CC(object): del self.stack_template del self.func_list_exist + @staticmethod + def implement_function(func): + function_def = '' + function_def += '{\n' + + all_items = func.split() + start = 0 + return_type = all_items[start] + if return_type == "const": + start += 1 + return_type = all_items[start] + if return_type.startswith(('std::map', 'std::set', 'std::vector')): + return_type = "std::map" + if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): + return_type = "Ptr" + if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): + return_type += "&" + if RETURN_STATEMENTS.__contains__(return_type): + function_def += RETURN_STATEMENTS[return_type] + else: + logging.warning("Unhandled return type[%s]", return_type) + + function_def += '\n' + function_def += '}\n' + function_def += '\n' + return function_def + def just_skip(self): # skip blank line or comment if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( @@ -263,6 +290,7 @@ class H2CC(object): logging.info('Added %s functions', len(self.func_list_exist)) logging.info('Successfully converted,please see ' + self.output_file) + def handle_func1(self, line): """ :param line: @@ -461,12 +489,6 @@ class H2CC(object): logging.info("func_name[%s]", func_name) return line, func_name - def write_func_content(self, content, func_name, need_generate): - if not (func_name in self.func_list_exist) and need_generate: - self.output_fd.write(content) - self.func_list_exist.append(func_name) - logging.info('add func:[%s]', func_name) - def gen_comment(self, start_i): comment_line = '' # Function comments are on top of function declarations, copy them over @@ -488,32 +510,11 @@ class H2CC(object): break return comment_line - @staticmethod - def implement_function(func): - function_def = '' - function_def += '{\n' - - all_items = func.split() - start = 0 - return_type = all_items[start] - if return_type == "const": - start += 1 - return_type = all_items[start] - if return_type.startswith(('std::map', 'std::set', 'std::vector')): - return_type = "std::map" - if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): - return_type = "Ptr" - if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): - return_type += "&" - if RETURN_STATEMENTS.__contains__(return_type): - function_def += RETURN_STATEMENTS[return_type] - else: - logging.warning("Unhandled return type[%s]", return_type) - - function_def += '\n' - function_def += '}\n' - function_def += '\n' - return function_def + def write_func_content(self, content, func_name, need_generate): + if not (func_name in self.func_list_exist) and need_generate: + self.output_fd.write(content) + self.func_list_exist.append(func_name) + logging.info('add func:[%s]', func_name) def collect_header_files(path): diff --git a/tests/st/testcase/test_caffe_parser.cc b/tests/st/testcase/test_caffe_parser.cc index 42c9c8d..45fc70c 100644 --- a/tests/st/testcase/test_caffe_parser.cc +++ b/tests/st/testcase/test_caffe_parser.cc @@ -765,7 +765,7 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_CheckLayersSize_test) layer->set_name("Abs"); layer->set_type("AbsVal"); - Status ret = weightParser.CheckLayersSize(layer); + Status ret = weightParser.CheckLayersSize(*layer); EXPECT_EQ(ret, FAILED); } @@ -809,8 +809,54 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test) const google::protobuf::Message *proto = factory.GetPrototype(descriptor); const google::protobuf::Message *message = proto->New(); - Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph); + Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph); delete message; EXPECT_EQ(ret, SUCCESS); } + +TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_DimSize_0) +{ + CaffeModelParser modelParser; + domi::caffe::NetParameter net; + net.add_input("111"); + net.add_input_shape(); + bool input_data_flag = true; + Status ret = modelParser.ParseInput(net, input_data_flag); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_Err1) +{ + CaffeModelParser modelParser; + domi::caffe::NetParameter net; + net.add_input("111"); + net.add_input("222"); + net.add_input_shape(); + bool input_data_flag = true; + Status ret = modelParser.ParseInput(net, input_data_flag); + EXPECT_EQ(ret, FAILED); +} + +TEST_F(STestCaffeParser, CaffeModelParser_ParserLayerParameter_Succ) +{ + CaffeModelParser modelParser; + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/origin_models/"; + const char *model_path = model_file.c_str(); + + std::string custom_proto = model_file; + std::string caffe_proto = model_file; + std::vector operators; + ge::OpDescPtr op_desc_src = std::make_shared("Data", "Input"); + ge::Operator op_src = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); + operators.emplace_back(op_src); + + model_file = case_dir + "/origin_models/caffe_add.pbtxt"; + custom_proto = case_dir + "/../../../metadef/proto/caffe/caffe.proto"; + model_path = model_file.c_str(); + std::string caffe_proto_path = case_dir + "/../../../metadef/proto/caffe/caffe.proto"; + auto ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto_path, operators); + EXPECT_EQ(ret, SUCCESS); +} } // namespace ge diff --git a/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc index 55fc538..5d1a5bd 100755 --- a/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc +++ b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc @@ -739,6 +739,9 @@ TEST_F(UtestCaffeParser, CaffeModelParser_CustomProtoParse_test) Status ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto, operators); EXPECT_EQ(ret, PARAM_INVALID); + ret = modelParser.CustomProtoParse("", custom_proto, caffe_proto, operators); + EXPECT_EQ(ret, FAILED); + model_file = case_dir + "/caffe_model/caffe_add.pbtxt"; custom_proto = case_dir + "/../../../../../metadef/proto/caffe/caffe.proto"; model_path = model_file.c_str(); @@ -890,7 +893,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_CheckLayersSize_test) layer->set_name("Abs"); layer->set_type("AbsVal"); - Status ret = weightParser.CheckLayersSize(layer); + Status ret = weightParser.CheckLayersSize(*layer); EXPECT_EQ(ret, FAILED); } @@ -902,7 +905,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test) layer->set_name("Abs"); layer->set_type("AbsVal"); - Status ret = weightParser.ConvertLayerProto(&net, &net); + Status ret = weightParser.ConvertLayerProto(net, &net); EXPECT_EQ(ret, SUCCESS); BlobProto* blob = layer->add_blobs(); @@ -911,16 +914,16 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test) BlobShape* shap = blob->mutable_shape(); shap->add_dim(1); shap->add_dim(2); - ret = weightParser.ConvertBlobsProto(&net, &net); + ret = weightParser.ConvertBlobsProto(net, &net); EXPECT_EQ(ret, SUCCESS); - ret = weightParser.ConvertBlobShapeProto(&net, &net); + ret = weightParser.ConvertBlobShapeProto(net, &net); EXPECT_EQ(ret, SUCCESS); - ret = weightParser.ConvertConvParamProto(&net, &net); + ret = weightParser.ConvertConvParamProto(net, &net); EXPECT_EQ(ret, SUCCESS); - ret = weightParser.ConvertInnerProdcutProto(&net, &net); + ret = weightParser.ConvertInnerProdcutProto(net, &net); EXPECT_EQ(ret, SUCCESS); } @@ -1133,7 +1136,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test) const google::protobuf::Message *proto = factory.GetPrototype(descriptor); const google::protobuf::Message *message = proto->New(); - Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph); + Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph); delete message; EXPECT_EQ(ret, SUCCESS); } @@ -1163,7 +1166,7 @@ TEST_F(UtestCaffeParser, CaffeModelParser_ParseLayerParameter_test) google::protobuf::DynamicMessageFactory factory; const google::protobuf::Message *proto = factory.GetPrototype(descriptor); google::protobuf::Message *message = proto->New(); - Status ret = modelParser.ParseLayerParameter(descriptor, message, operators); + Status ret = modelParser.ParseLayerParameter(*descriptor, *message, operators); EXPECT_EQ(ret, SUCCESS); const domi::FrameworkType fmk_type = domi::TENSORFLOW; diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc index 6fdee9e..c945d69 100644 --- a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -381,7 +381,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto) OnnxFileConstantParser parser; ge::onnx::NodeProto input_node; ge::onnx::TensorProto tensor_proto; - Status ret = parser.GetTensorProto(&input_node, tensor_proto); + Status ret = parser.GetTensorProto(input_node, tensor_proto); EXPECT_EQ(ret, FAILED); ge::onnx::AttributeProto *attribute = input_node.add_attribute(); @@ -391,7 +391,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto) ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); *attribute_tensor = tensor_proto; - ret = parser.GetTensorProto(&input_node, tensor_proto); + ret = parser.GetTensorProto(input_node, tensor_proto); EXPECT_EQ(ret, SUCCESS); }