Browse Source

!647 clean code patch2

Merge pull request !647 from xujiuxu/ge_dev
pull/650/head
xujiuxu i-robot 2 years ago
parent
commit
62350d3f15
16 changed files with 361 additions and 316 deletions
  1. +194
    -198
      parser/caffe/caffe_parser.cc
  2. +13
    -13
      parser/caffe/caffe_parser.h
  3. +3
    -5
      parser/onnx/onnx_constant_parser.h
  4. +19
    -19
      parser/onnx/onnx_file_constant_parser.cc
  5. +5
    -5
      parser/onnx/onnx_file_constant_parser.h
  6. +10
    -10
      parser/onnx/onnx_parser.cc
  7. +3
    -1
      parser/onnx/onnx_util.h
  8. +13
    -12
      parser/onnx/subgraph_adapter/if_subgraph_adapter.cc
  9. +3
    -2
      parser/onnx/subgraph_adapter/if_subgraph_adapter.h
  10. +1
    -3
      parser/onnx/subgraph_adapter/subgraph_adapter.h
  11. +1
    -1
      parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc
  12. +2
    -3
      parser/onnx/subgraph_adapter/subgraph_adapter_factory.h
  13. +33
    -32
      parser/stub/gen_stubapi.py
  14. +48
    -2
      tests/st/testcase/test_caffe_parser.cc
  15. +11
    -8
      tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc
  16. +2
    -2
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc

+ 194
- 198
parser/caffe/caffe_parser.cc View File

@@ -236,14 +236,8 @@ const char *const kFieldInnerPro = "inner_product_param";
const char *const kFieldDim = "dim";
const char *const kFieldBiasTerm = "bias_term";
const char *const kDevNull = "/dev/null";
const std::string kMessage = "message";
const std::string kLayerParameter = "LayerParameter";
const std::string kCloseBrace = "}";
const std::string kOptional = "optional";
const std::string kRepeated = "repeated";
const std::string kRequired = "required";
const std::string kCustom = "custom";
const std::string kBuiltin = "built-in";
const char *const kCustom = "custom";
const char *const kBuiltin = "built-in";
std::vector<std::string> kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT,
ge::parser::NETOUTPUT};
const std::set<std::string> kCustomProtoLayerCommonField = {"name", "type"};
@@ -284,104 +278,104 @@ const set<string> CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxW
"Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"};

Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const {
if (proto_message.input_size() > 0) {
GELOGI("This net exsit input.");
if (proto_message.input_size() <= 0) {
return SUCCESS;
}
GELOGI("This net exsit input.");
if (proto_message.input_dim_size() > 0) {
if (proto_message.input_shape_size() > 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E11001");
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!");
return FAILED;
}

if (proto_message.input_dim_size() > 0) {
if (proto_message.input_shape_size() > 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E11001");
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!");
return FAILED;
}
const int32_t input_dim_size = proto_message.input_dim_size();
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) ||
((input_dim_size % proto_message.input_size()) != 0));
if (is_input_invalid) {
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"},
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())});
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].",
input_dim_size, proto_message.input_size());
return FAILED;
}

const int32_t input_dim_size = proto_message.input_dim_size();
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) ||
((input_dim_size % proto_message.input_size()) != 0));
if (is_input_invalid) {
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"},
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())});
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].",
input_dim_size, proto_message.input_size());
return FAILED;
}
for (int i = 0; i < proto_message.input_size(); i++) {
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(proto_message.input(i));
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));

for (int i = 0; i < proto_message.input_size(); i++) {
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(proto_message.input(i));
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));

domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);

for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) {
// Can guarantee that it will not cross the border
shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE)));
}
input_data_flag = true;
}
} else if (proto_message.input_shape_size() > 0) {
if (proto_message.input_shape_size() != proto_message.input_size()) {
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"},
{std::to_string(proto_message.input_shape_size()),
std::to_string(proto_message.input_size())});
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).",
proto_message.input_shape_size(), proto_message.input_size());
return FAILED;
domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);

for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) {
// Can guarantee that it will not cross the border
shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE)));
}
input_data_flag = true;
}
} else if (proto_message.input_shape_size() > 0) {
if (proto_message.input_shape_size() != proto_message.input_size()) {
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"},
{std::to_string(proto_message.input_shape_size()),
std::to_string(proto_message.input_size())});
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).",
proto_message.input_shape_size(), proto_message.input_size());
return FAILED;
}

for (int i = 0; i < proto_message.input_size(); i++) {
int dim_size = proto_message.input_shape(i).dim_size();
for (int i = 0; i < proto_message.input_size(); i++) {
int dim_size = proto_message.input_shape(i).dim_size();

domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(proto_message.input(i));
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(proto_message.input(i));
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));

domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);
domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);

for (int j = 0; j < dim_size; j++) {
// Can guarantee that it will not cross the border
shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j)));
}
input_data_flag = true;
for (int j = 0; j < dim_size; j++) {
// Can guarantee that it will not cross the border
shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j)));
}
} else {
const ge::ParserContext &ctx = ge::GetParserContext();
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
for (int i = 0; i < proto_message.input_size(); i++) {
string name = proto_message.input(i);
if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name}));
GELOGE(FAILED, "[Find][Dim]Model has no input shape.");
return FAILED;
}
std::vector<int64_t> dims = input_dims.at(name);
size_t dim_size = dims.size();
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(name);
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));
domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);
for (size_t j = 0; j < dim_size; j++) {
shape->add_dim(dims.at(j));
}
input_data_flag = true;
input_data_flag = true;
}
} else {
const ge::ParserContext &ctx = ge::GetParserContext();
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
for (int i = 0; i < proto_message.input_size(); i++) {
string name = proto_message.input(i);
if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name}));
GELOGE(FAILED, "[Find][Dim]Model has no input shape.");
return FAILED;
}
std::vector<int64_t> dims = input_dims.at(name);
size_t dim_size = dims.size();
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(name);
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));
domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);
for (size_t j = 0; j < dim_size; j++) {
shape->add_dim(dims.at(j));
}
input_data_flag = true;
}
}
return SUCCESS;
@@ -423,7 +417,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons
return FAILED;
}

if (ParseLayerParameter(layer_descriptor, message, operators) != SUCCESS) {
if (ParseLayerParameter(*layer_descriptor, *message, operators) != SUCCESS) {
delete message;
GELOGE(FAILED, "[Parse][LayerParameter] failed, model path:%s.", model_path);
return FAILED;
@@ -536,18 +530,18 @@ Status CaffeModelParser::ReadCaffeModelFromText(const char *model_path, google::
return SUCCESS;
}

Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message,
Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor,
const google::protobuf::Message &message,
vector<ge::Operator> &operators) const {
auto field_name = layer_descriptor->FindFieldByName(kFieldName);
auto field_name = layer_descriptor.FindFieldByName(kFieldName);
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor");
auto field_type = layer_descriptor->FindFieldByName(kFieldType);
auto field_type = layer_descriptor.FindFieldByName(kFieldType);
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor");

const google::protobuf::Reflection *reflection = message->GetReflection();
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);
for (auto &field : field_desc) {
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message");
// Only care about layers
@@ -561,10 +555,10 @@ Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor
return FAILED;
}

int field_size = reflection->FieldSize(*message, field);
int field_size = reflection->FieldSize(message, field);
GELOGI("Total Layer num of model file is %d", field_size);
for (int i = 0; i < field_size; ++i) {
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i);
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i);
const google::protobuf::Reflection *layer_reflection = layer_message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message");
GE_CHECK_NOTNULL(layer_reflection);
@@ -1316,7 +1310,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co
layer_name_map[layer.name()]++;
// Set the name in proto and layer
domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index);
duplicate_name_layer->set_name(new_name); layer.set_name(new_name);)
duplicate_name_layer->set_name(new_name);
layer.set_name(new_name);)

// Insert the new operator name, the number of times of duplicate name is recorded as 1
layer_name_map.insert(std::make_pair(layer.name(), kNumOne));
@@ -1539,7 +1534,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap
layer_name_map[layer.name()]++;
// Set the name in proto and layer
domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index);
duplicate_name_layer->set_name(new_name); layer.set_name(new_name);)
duplicate_name_layer->set_name(new_name);
layer.set_name(new_name);)

// Insert the new operator name, the number of times of duplicate name is recorded as 1
layer_name_map.insert(std::make_pair(layer.name(), kNumOne));
@@ -1832,13 +1828,13 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con
return FAILED;
}

if (CheckLayersSize(message) != SUCCESS) {
if (CheckLayersSize(*message) != SUCCESS) {
delete message;
message = nullptr;
return FAILED;
}

if (ParseLayerParameter(layer_descriptor, message, graph) != SUCCESS) {
if (ParseLayerParameter(*layer_descriptor, *message, graph) != SUCCESS) {
delete message;
message = nullptr;
REPORT_CALL_ERROR("E19999", "ParseLayerParameter failed failed from weight file:%s.", weight_path);
@@ -1852,18 +1848,18 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con
return SUCCESS;
}

Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message,
Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor,
const google::protobuf::Message &message,
ge::ComputeGraphPtr &graph) {
auto field_name = layer_descriptor->FindFieldByName(kFieldName);
auto field_name = layer_descriptor.FindFieldByName(kFieldName);
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor");
auto field_type = layer_descriptor->FindFieldByName(kFieldType);
auto field_type = layer_descriptor.FindFieldByName(kFieldType);
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor");

const google::protobuf::Reflection *reflection = message->GetReflection();
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);

NetParameter tmp_net;
for (auto &field : field_desc) {
@@ -1880,13 +1876,13 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto
return FAILED;
}

int field_size = reflection->FieldSize(*message, field);
int field_size = reflection->FieldSize(message, field);
GELOGI("Total Layer num of model file is %d", field_size);
for (int i = 0; i < field_size; ++i) {
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i);
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i);

LayerParameter *layer = tmp_net.add_layer();
if (ConvertLayerProto(&layer_message, layer) != SUCCESS) {
if (ConvertLayerProto(layer_message, layer) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ConvertLayerProto] Convert message to layer proto failed.");
return FAILED;
}
@@ -1907,16 +1903,16 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto
return SUCCESS;
}

Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *message,
Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message &message,
google::protobuf::Message *layer) {
const google::protobuf::Reflection *layer_reflection = message->GetReflection();
const google::protobuf::Reflection *layer_reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc;
layer_reflection->ListFields(*message, &field_desc);
layer_reflection->ListFields(message, &field_desc);

for (auto &field : field_desc) {
GE_CHECK_NOTNULL(field);
if (ParseLayerField(layer_reflection, message, field, layer) != SUCCESS) {
if (ParseLayerField(*layer_reflection, message, *field, layer) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ParseLayerField] Parse field %s failed.", field->name().c_str());
return FAILED;
}
@@ -1924,114 +1920,114 @@ Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *me
return SUCCESS;
}

Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field,
Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection &reflection,
const google::protobuf::Message &message,
const google::protobuf::FieldDescriptor &field,
google::protobuf::Message *layer) const {
GELOGD("Start to parse field: %s.", field->name().c_str());
GELOGD("Start to parse field: %s.", field.name().c_str());
domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer);
string filed_name = field->name();
#define CASE_FIELD_NAME(kName, method) \
string filed_name = field.name();
#define CASE_FIELD_NAME(kName, method, inner_message, field_ptr) \
if (filed_name == kField##kName) { \
string value = reflection->GetString(*message, field); \
string value = reflection.GetString(inner_message, field_ptr); \
GELOGD("Parse res: (%s : %s)", filed_name.c_str(), value.c_str()); \
layer_proto->set_##method(value); \
return SUCCESS; \
}
CASE_FIELD_NAME(Name, name);
CASE_FIELD_NAME(Type, type);
CASE_FIELD_NAME(Name, name, message, &field);
CASE_FIELD_NAME(Type, type, message, &field);
#undef CASE_FIELD_NAME
#define CASE_FIELD_NAME_REPEATED(kName, method) \
if (filed_name == kField##kName) { \
int field_size = reflection->FieldSize(*message, field); \
for (int i = 0; i < field_size; ++i) { \
auto value = reflection->GetRepeatedString(*message, field, i); \
layer_proto->add_##method(value); \
} \
return SUCCESS; \
}
CASE_FIELD_NAME_REPEATED(Bottom, bottom);
CASE_FIELD_NAME_REPEATED(Top, top);
#define CASE_FIELD_NAME_REPEATED(kName, method, inner_message, field_ptr) \
if (filed_name == kField##kName) { \
int field_size = reflection.FieldSize(inner_message, field_ptr); \
for (int i = 0; i < field_size; ++i) { \
auto value = reflection.GetRepeatedString(inner_message, field_ptr, i); \
layer_proto->add_##method(value); \
} \
return SUCCESS; \
}
CASE_FIELD_NAME_REPEATED(Bottom, bottom, message, &field);
CASE_FIELD_NAME_REPEATED(Top, top, message, &field);
#undef CASE_FIELD_NAME_REPEATED
if (filed_name == kFieldBlobs) {
int field_size = reflection->FieldSize(*message, field);
int field_size = reflection.FieldSize(message, &field);
for (int i = 0; i < field_size; ++i) {
domi::caffe::BlobProto *item_message = layer_proto->add_blobs();
const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i);
if (ConvertBlobsProto(&sub_message, item_message) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field->name().c_str());
const google::protobuf::Message &sub_message = reflection.GetRepeatedMessage(message, &field, i);
if (ConvertBlobsProto(sub_message, item_message) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field.name().c_str());
return FAILED;
}
}
return SUCCESS;
}
if (filed_name == kFieldConvParam) {
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field);
const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field);
ConvolutionParameter *conv_param = layer_proto->mutable_convolution_param();
ConvertConvParamProto(&sub_message, conv_param);
ConvertConvParamProto(sub_message, conv_param);
}
if (filed_name == kFieldInnerPro) {
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field);
const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field);
InnerProductParameter *inner_product = layer_proto->mutable_inner_product_param();
ConvertInnerProdcutProto(&sub_message, inner_product);
ConvertInnerProdcutProto(sub_message, inner_product);
}
return SUCCESS;
}

Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message,
Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message &message,
google::protobuf::Message *blobs) const {
const google::protobuf::Reflection *blobs_reflection = message->GetReflection();
const google::protobuf::Reflection *blobs_reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc;
blobs_reflection->ListFields(*message, &field_desc);
blobs_reflection->ListFields(message, &field_desc);
domi::caffe::BlobProto *blobs_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobProto>(blobs);

for (auto &field : field_desc) {
GE_CHECK_NOTNULL(field);
string feild_name = field->name();
#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name) \
if (feild_name == #kName) { \
int field_size = blobs_reflection->FieldSize(*message, field); \
for (int i = 0; i < field_size; ++i) { \
valuetype value = blobs_reflection->GetRepeated##method(*message, field, i); \
blobs_proto->add_##name(value); \
} \
continue; \
}
CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data);
CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff);
CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data);
CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff);
CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data);
CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data);
#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name, inner_message, inner_field) \
if (feild_name == #kName) { \
int field_size = blobs_reflection->FieldSize(inner_message, inner_field); \
for (int i = 0; i < field_size; ++i) { \
valuetype value = blobs_reflection->GetRepeated##method(inner_message, inner_field, i); \
blobs_proto->add_##name(value); \
} \
continue; \
}
CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data, message, field);
CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff, message, field);
CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data, message, field);
CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff, message, field);
CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data, message, field);
CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data, message, field);
#undef CASE_BLOBS_FIELD_NAME_REPEATED
#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name) \
if (feild_name == #kName) { \
valuetype value = blobs_reflection->Get##method(*message, field); \
blobs_proto->set_##name(value); \
continue; \
}
CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data);
CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num);
CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels);
CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height);
CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width);
#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name, inner_message, inner_field) \
if (feild_name == #kName) { \
valuetype value = blobs_reflection->Get##method(inner_message, inner_field); \
blobs_proto->set_##name(value); \
continue; \
}
CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data, message, field);
CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num, message, field);
CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels, message, field);
CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height, message, field);
CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width, message, field);
#undef CASE_BLOBS_FIELD_NAME
if (feild_name == kFieldShape) {
const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(*message, field);
const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(message, field);
domi::caffe::BlobShape *blob_shape = blobs_proto->mutable_shape();
ConvertBlobShapeProto(&sub_message, blob_shape);
ConvertBlobShapeProto(sub_message, blob_shape);
}
}
return SUCCESS;
}

Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message *message,
Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const {
const google::protobuf::Reflection *reflection = message->GetReflection();
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);

domi::caffe::BlobShape *shape_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobShape>(dest_message);

@@ -2039,21 +2035,21 @@ Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message
if (field->name() != kFieldDim) {
continue;
}
int field_size = reflection->FieldSize(*message, field);
int field_size = reflection->FieldSize(message, field);
for (int i = 0; i < field_size; ++i) {
int64_t value = reflection->GetRepeatedInt64(*message, field, i);
int64_t value = reflection->GetRepeatedInt64(message, field, i);
shape_proto->add_dim(value);
}
}
return SUCCESS;
}

Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message *message,
Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const {
const google::protobuf::Reflection *reflection = message->GetReflection();
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);

domi::caffe::ConvolutionParameter *conv_param_proto =
PtrToPtr<google::protobuf::Message, domi::caffe::ConvolutionParameter>(dest_message);
@@ -2062,18 +2058,18 @@ Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message
if (field->name() != kFieldBiasTerm) {
continue;
}
bool value = reflection->GetBool(*message, field);
bool value = reflection->GetBool(message, field);
conv_param_proto->set_bias_term(value);
}
return SUCCESS;
}

Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message *message,
Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const {
const google::protobuf::Reflection *reflection = message->GetReflection();
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);

domi::caffe::InnerProductParameter *inner_product_proto =
PtrToPtr<google::protobuf::Message, domi::caffe::InnerProductParameter>(dest_message);
@@ -2082,17 +2078,17 @@ Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Mess
if (field->name() != kFieldBiasTerm) {
continue;
}
bool value = reflection->GetBool(*message, field);
bool value = reflection->GetBool(message, field);
inner_product_proto->set_bias_term(value);
}
return SUCCESS;
}

Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *message) const {
const google::protobuf::Reflection *reflection = message->GetReflection();
Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message &message) const {
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);

int num_layer = 0;
int num_layers = 0;
@@ -2110,7 +2106,7 @@ Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *mess
return FAILED;
}

int field_size = reflection->FieldSize(*message, field);
int field_size = reflection->FieldSize(message, field);
if (field->name() == kLayerName) {
num_layer = field_size;
} else {


+ 13
- 13
parser/caffe/caffe_parser.h View File

@@ -212,8 +212,8 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
* @return SUCCESS parse layer successfully
* @return FAILED parse layer failed
*/
Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message, std::vector<ge::Operator> &operators) const;
Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor,
const google::protobuf::Message &message, std::vector<ge::Operator> &operators) const;

/*
* @ingroup domi_omg
@@ -386,33 +386,33 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser {
Status ParseWeightByFusionProto(const char *weight_path, const string &fusion_proto_path,
const string &fusion_proto_name, ge::ComputeGraphPtr &graph);

Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message,
Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor,
const google::protobuf::Message &message,
ge::ComputeGraphPtr &graph);

Status ConvertLayerParameter(const google::protobuf::Message *layer_message,
ge::ComputeGraphPtr &graph);

Status CheckLayersSize(const google::protobuf::Message *message) const;
Status CheckLayersSize(const google::protobuf::Message &message) const;

Status ConvertLayerProto(const google::protobuf::Message *message,
Status ConvertLayerProto(const google::protobuf::Message &message,
google::protobuf::Message *layer);

Status ParseLayerField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field,
Status ParseLayerField(const google::protobuf::Reflection &reflection,
const google::protobuf::Message &message,
const google::protobuf::FieldDescriptor &field,
google::protobuf::Message *layer) const;

Status ConvertBlobsProto(const google::protobuf::Message *message,
Status ConvertBlobsProto(const google::protobuf::Message &message,
google::protobuf::Message *blobs) const;

Status ConvertBlobShapeProto(const google::protobuf::Message *message,
Status ConvertBlobShapeProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const;

Status ConvertInnerProdcutProto(const google::protobuf::Message *message,
Status ConvertInnerProdcutProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const;

Status ConvertConvParamProto(const google::protobuf::Message *message,
Status ConvertConvParamProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const;
/**
* @ingroup domi_omg


+ 3
- 5
parser/onnx/onnx_constant_parser.h View File

@@ -69,16 +69,14 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser {
break;
}
#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \
case dt_type: \
{ \
case dt_type: { \
unique_ptr<value_type> addr_trans(new(std::nothrow) value_type[count]()); \
GE_CHECK_NOTNULL(addr_trans); \
for (int32_t i = 0; i < (count); i++) { \
*(addr_trans.get() + i) = static_cast<value_type>(*((addr).get() + i)); \
} \
(tensor).SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), (count) * sizeof(value_type)); \
break; \
} \
break; } \

CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor)
CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor)
@@ -89,7 +87,7 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser {
#undef CASE_SET_DATA
default:
{
tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(T));
tensor.SetData(PtrToPtr<T, uint8_t>(addr.get()), count * sizeof(T));
break;
}
}


+ 19
- 19
parser/onnx/onnx_file_constant_parser.cc View File

@@ -31,11 +31,11 @@ using GeTensorDesc = ge::GeTensorDesc;
using namespace ge::parser;

namespace {
const std::string kAttrShape = "shape";
const std::string kAttrDataType = "dtype";
const std::string kFileConstantPath = "file_constant_path";
const std::string kLocation = "location";
const std::string kOffset = "offset";
const char *const kAttrShape = "shape";
const char *const kAttrDataType = "dtype";
const char *const kFileConstantPath = "file_constant_path";
const char *const kLocation = "location";
const char *const kOffset = "offset";
const int64_t kOffsetCoefficient = 4096;
const char *const kFileConstant = "FileConstant";
}
@@ -46,7 +46,7 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator &
GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str());

ge::onnx::TensorProto tensor_proto;
if (GetTensorProto(node, tensor_proto) != SUCCESS) {
if (GetTensorProto(*node, tensor_proto) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str());
GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str());
return FAILED;
@@ -65,29 +65,29 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator &
return SUCCESS;
}

Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto *node_proto,
ge::onnx::TensorProto &tensor_proto) {
for (const auto &it : node_proto->attribute()) {
Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto &node_proto,
ge::onnx::TensorProto &tensor_proto) const {
for (const auto &it : node_proto.attribute()) {
if (it.name() != ge::kAttrNameValue) {
continue;
}
tensor_proto = it.t();
return SUCCESS;
}
REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto->name().c_str());
GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto->name().c_str());
REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto.name().c_str());
GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto.name().c_str());
return FAILED;
}

void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const {
std::vector<int64_t> tmp_shape;
for (int i = 0; i < tensor_proto.dims_size(); i++) {
tmp_shape.push_back(tensor_proto.dims(i));
}
op_def.SetAttr(kAttrShape.c_str(), tmp_shape);
op_def.SetAttr(kAttrShape, tmp_shape);
}

Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const {
int64_t data_type = tensor_proto.data_type();
ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type);
if (type >= ge::DataType::DT_UNDEFINED) {
@@ -96,11 +96,11 @@ Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor
return FAILED;
}

op_def.SetAttr(kAttrDataType.c_str(), type);
op_def.SetAttr(kAttrDataType, type);
return SUCCESS;
}

Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const {
ge::NamedAttrs attrs;
for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) {
const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i);
@@ -116,12 +116,12 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro
GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str());
return FAILED;
}
op_def.SetAttr(kFileConstantPath.c_str(), attrs);
op_def.SetAttr(kFileConstantPath, attrs);
return SUCCESS;
}

Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto,
ge::NamedAttrs &attrs) {
ge::NamedAttrs &attrs) const {
if (string_proto.key() == kLocation) {
AttrUtils::SetStr(attrs, kLocation, string_proto.value());
} else {
@@ -134,7 +134,7 @@ Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProt
return FAILED;
}
if (string_proto.key() == kOffset) {
if (std::numeric_limits<int64_t>::max() / kOffsetCoefficient < value) {
if (value > (std::numeric_limits<int64_t>::max() / kOffsetCoefficient)) {
REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value);
GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value);
return FAILED;


+ 5
- 5
parser/onnx/onnx_file_constant_parser.h View File

@@ -26,11 +26,11 @@ class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser {
Status ParseParams(const Message *op_src, ge::Operator &op_def) override;

private:
Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
Status GetTensorProto(const ge::onnx::NodeProto *node_proto, ge::onnx::TensorProto &tensor_proto);
Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs);
Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const;
Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const;
void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const;
Status GetTensorProto(const ge::onnx::NodeProto &node_proto, ge::onnx::TensorProto &tensor_proto) const;
Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs) const;
};
} // namespace ge



+ 10
- 10
parser/onnx/onnx_parser.cc View File

@@ -232,7 +232,8 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra
domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type);
if (post_func == nullptr) {
GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str());
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) {
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS ||
parse_func_v2 == nullptr) {
GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str());
return SUCCESS;
}
@@ -522,9 +523,9 @@ Status OnnxModelParser::SetOperatorInputs() {
auto src_op = output_op_iter->second;
int dst_index = input_node_index.second;
int src_index = out_node_index.second;
GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", src_index,
ParserUtils::GetOperatorName(src_op).c_str(), dst_index,
ParserUtils::GetOperatorName(dst_op).c_str());
GELOGI("Start add output:%d of op:%s as input:%d of op:%s.",
src_index, ParserUtils::GetOperatorName(src_op).c_str(),
dst_index, ParserUtils::GetOperatorName(dst_op).c_str());
auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op);
GE_CHECK_NOTNULL(dst_op_desc);
auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op);
@@ -689,7 +690,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve
return PARAM_INVALID;
}
input_ops.emplace_back(in_op->second);
GELOGI("Model assigned input node name: %s", ParserUtils::GetOperatorName(in_op->second).c_str());
GELOGI("Model assigned input node name: %s",
ParserUtils::GetOperatorName(in_op->second).c_str());
}
return SUCCESS;
}
@@ -717,7 +719,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vec
int index = node_name_index.second;
output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)});
out_tensor_to_nodes[output_name] = std::make_pair(node_name, index);
GELOGI("out node index %d, node:%s", index, node_name.c_str());
GELOGI("Out node index %d, node:%s", index, node_name.c_str());
}
}
return SUCCESS;
@@ -934,16 +936,13 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
cur_compute_graph->GetName().c_str());
return ret;
}

}
UpdateDataFormat(root_graph);
return SUCCESS;
}

Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) {

ClearMembers();

GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX),
"Run ProtoType Pass Failed");
// 1. Get all inializer.
@@ -1174,7 +1173,8 @@ Status OnnxModelParser::SetOutputsInfo(const ParserUtils::OutputMapping &final_o
default_out_nodes.emplace_back(output_node_info);
output_tensor_names.emplace_back(tensor_name);
GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].",
output_node_info.first.c_str(), output_node_info.second, tensor_name.c_str());
output_node_info.first.c_str(),
output_node_info.second, tensor_name.c_str());
}
return SUCCESS;
}


+ 3
- 1
parser/onnx/onnx_util.h View File

@@ -17,6 +17,8 @@
#ifndef PARSER_ONNX_ONNX_UTIL_PARSER_H_
#define PARSER_ONNX_ONNX_UTIL_PARSER_H_

#include <string>
#include <cstdint>
#include "external/graph/types.h"

namespace OnnxDataType {
@@ -59,4 +61,4 @@ class OnnxUtil {
};
} // namespace ge

#endif //PARSER_ONNX_ONNX_UTIL_PARSER_H_
#endif // PARSER_ONNX_ONNX_UTIL_PARSER_H_

+ 13
- 12
parser/onnx/subgraph_adapter/if_subgraph_adapter.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@ using parser::IF;
namespace {
const std::map<std::string, int> kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}};
const int kIfNodeAttrSize = 2;
const char *kIf = "If";
} // namespace
domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
@@ -33,7 +34,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(),
parent_node->op_type().c_str());

auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name);
auto ret = ParseIfNodeSubgraphs(*parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name);
if (ret != SUCCESS) {
GELOGE(ret, "[Parse][Node] Parse if node failed.");
REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str());
@@ -44,19 +45,19 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(
}

domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
ge::onnx::NodeProto &parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) const {
if (parent_node->attribute_size() != kIfNodeAttrSize) {
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
if (parent_node.attribute_size() != kIfNodeAttrSize) {
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size());
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size());
return FAILED;
}

GELOGD("node attribute size:%d.", parent_node->attribute_size());
GELOGD("node attribute size:%d.", parent_node.attribute_size());
std::set<std::string> all_inputs;
// for onnx graph, the first attribute may be else branch and the second attribute may be then branch
for (int i = 0; i < parent_node->attribute_size(); i++) {
ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i);
for (int i = 0; i < parent_node.attribute_size(); i++) {
ge::onnx::AttributeProto *attribute = parent_node.mutable_attribute(i);
GE_CHECK_NOTNULL(attribute);
std::string attr_name = attribute->name();
auto itr = kAttrNameToIndex.find(attr_name);
@@ -68,7 +69,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
return FAILED;
}
std::string unique_subgraph_name;
std::string node_name = parent_node->name();
std::string node_name = parent_node.name();
if (!parent_graph_name.empty()) {
node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name);
}
@@ -90,7 +91,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
AddInputNodeForGraph(all_inputs, *onnx_graph);
}

AddInputForParentNode(all_inputs, *parent_node);
AddInputForParentNode(all_inputs, parent_node);
return SUCCESS;
}

@@ -135,5 +136,5 @@ void IfSubgraphAdapter::AddInputForParentNode(const std::set<std::string> &all_i
parent_node.add_input(input_name);
}
}
REGISTER_SUBGRAPH_ADAPTER_CREATOR(IF, IfSubgraphAdapter);
REGISTER_SUBGRAPH_ADAPTER_CREATOR(kIf, IfSubgraphAdapter);
} // namespace ge

+ 3
- 2
parser/onnx/subgraph_adapter/if_subgraph_adapter.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
#include <set>
#include <string>
#include "subgraph_adapter.h"
#include "parser/onnx/onnx_util.h"

namespace ge {
class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter {
@@ -30,7 +31,7 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter {
const std::string &parent_graph_name = "") override;

private:
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto &parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name) const;
domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const;


+ 1
- 3
parser/onnx/subgraph_adapter/subgraph_adapter.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -36,8 +36,6 @@
#include "proto/onnx/ge_onnx.pb.h"
#include "external/register/register_error_codes.h"
#include "framework/omg/parser/parser_types.h"
#include "parser/onnx/onnx_util.h"

namespace ge {
class PARSER_FUNC_VISIBILITY SubgraphAdapter {
public:


+ 1
- 1
parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.


+ 2
- 3
parser/onnx/subgraph_adapter/subgraph_adapter_factory.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -62,7 +62,6 @@ protected:
* @brief SubgraphAdapter creation function
* @return Created SubgraphAdapter
*/
// typedef shared_ptr<SubgraphAdapter> (*CREATOR_FUN)(void);
using CREATOR_FUN = std::function<std::shared_ptr<SubgraphAdapter>(void)>;

/**
@@ -105,7 +104,7 @@ public:
* @param [in] op_type Op type
* @param [in] clazz SubgraphAdapter implementation class
*/
#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \
#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \
std::shared_ptr<SubgraphAdapter> Creator_##op_type##_Subgraph_Adapter() { \
std::shared_ptr<clazz> ptr(new (std::nothrow) clazz()); \
if (ptr == nullptr) { \


+ 33
- 32
parser/stub/gen_stubapi.py View File

@@ -167,6 +167,33 @@ class H2CC(object):
del self.stack_template
del self.func_list_exist

@staticmethod
def implement_function(func):
function_def = ''
function_def += '{\n'

all_items = func.split()
start = 0
return_type = all_items[start]
if return_type == "const":
start += 1
return_type = all_items[start]
if return_type.startswith(('std::map', 'std::set', 'std::vector')):
return_type = "std::map"
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')):
return_type = "Ptr"
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
return_type += "&"
if RETURN_STATEMENTS.__contains__(return_type):
function_def += RETURN_STATEMENTS[return_type]
else:
logging.warning("Unhandled return type[%s]", return_type)

function_def += '\n'
function_def += '}\n'
function_def += '\n'
return function_def

def just_skip(self):
# skip blank line or comment
if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search(
@@ -263,6 +290,7 @@ class H2CC(object):
logging.info('Added %s functions', len(self.func_list_exist))
logging.info('Successfully converted,please see ' + self.output_file)


def handle_func1(self, line):
"""
:param line:
@@ -461,12 +489,6 @@ class H2CC(object):
logging.info("func_name[%s]", func_name)
return line, func_name

def write_func_content(self, content, func_name, need_generate):
if not (func_name in self.func_list_exist) and need_generate:
self.output_fd.write(content)
self.func_list_exist.append(func_name)
logging.info('add func:[%s]', func_name)

def gen_comment(self, start_i):
comment_line = ''
# Function comments are on top of function declarations, copy them over
@@ -488,32 +510,11 @@ class H2CC(object):
break
return comment_line

@staticmethod
def implement_function(func):
function_def = ''
function_def += '{\n'

all_items = func.split()
start = 0
return_type = all_items[start]
if return_type == "const":
start += 1
return_type = all_items[start]
if return_type.startswith(('std::map', 'std::set', 'std::vector')):
return_type = "std::map"
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')):
return_type = "Ptr"
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
return_type += "&"
if RETURN_STATEMENTS.__contains__(return_type):
function_def += RETURN_STATEMENTS[return_type]
else:
logging.warning("Unhandled return type[%s]", return_type)

function_def += '\n'
function_def += '}\n'
function_def += '\n'
return function_def
def write_func_content(self, content, func_name, need_generate):
if not (func_name in self.func_list_exist) and need_generate:
self.output_fd.write(content)
self.func_list_exist.append(func_name)
logging.info('add func:[%s]', func_name)


def collect_header_files(path):


+ 48
- 2
tests/st/testcase/test_caffe_parser.cc View File

@@ -765,7 +765,7 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_CheckLayersSize_test)
layer->set_name("Abs");
layer->set_type("AbsVal");

Status ret = weightParser.CheckLayersSize(layer);
Status ret = weightParser.CheckLayersSize(*layer);
EXPECT_EQ(ret, FAILED);
}

@@ -809,8 +809,54 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test)
const google::protobuf::Message *proto = factory.GetPrototype(descriptor);
const google::protobuf::Message *message = proto->New();

Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph);
Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph);
delete message;
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_DimSize_0)
{
CaffeModelParser modelParser;
domi::caffe::NetParameter net;
net.add_input("111");
net.add_input_shape();
bool input_data_flag = true;
Status ret = modelParser.ParseInput(net, input_data_flag);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_Err1)
{
CaffeModelParser modelParser;
domi::caffe::NetParameter net;
net.add_input("111");
net.add_input("222");
net.add_input_shape();
bool input_data_flag = true;
Status ret = modelParser.ParseInput(net, input_data_flag);
EXPECT_EQ(ret, FAILED);
}

TEST_F(STestCaffeParser, CaffeModelParser_ParserLayerParameter_Succ)
{
CaffeModelParser modelParser;
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/";
const char *model_path = model_file.c_str();

std::string custom_proto = model_file;
std::string caffe_proto = model_file;
std::vector<ge::Operator> operators;
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Data", "Input");
ge::Operator op_src = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);
operators.emplace_back(op_src);

model_file = case_dir + "/origin_models/caffe_add.pbtxt";
custom_proto = case_dir + "/../../../metadef/proto/caffe/caffe.proto";
model_path = model_file.c_str();
std::string caffe_proto_path = case_dir + "/../../../metadef/proto/caffe/caffe.proto";
auto ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto_path, operators);
EXPECT_EQ(ret, SUCCESS);
}
} // namespace ge

+ 11
- 8
tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc View File

@@ -739,6 +739,9 @@ TEST_F(UtestCaffeParser, CaffeModelParser_CustomProtoParse_test)
Status ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto, operators);
EXPECT_EQ(ret, PARAM_INVALID);

ret = modelParser.CustomProtoParse("", custom_proto, caffe_proto, operators);
EXPECT_EQ(ret, FAILED);

model_file = case_dir + "/caffe_model/caffe_add.pbtxt";
custom_proto = case_dir + "/../../../../../metadef/proto/caffe/caffe.proto";
model_path = model_file.c_str();
@@ -890,7 +893,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_CheckLayersSize_test)
layer->set_name("Abs");
layer->set_type("AbsVal");

Status ret = weightParser.CheckLayersSize(layer);
Status ret = weightParser.CheckLayersSize(*layer);
EXPECT_EQ(ret, FAILED);
}

@@ -902,7 +905,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test)
layer->set_name("Abs");
layer->set_type("AbsVal");

Status ret = weightParser.ConvertLayerProto(&net, &net);
Status ret = weightParser.ConvertLayerProto(net, &net);
EXPECT_EQ(ret, SUCCESS);

BlobProto* blob = layer->add_blobs();
@@ -911,16 +914,16 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test)
BlobShape* shap = blob->mutable_shape();
shap->add_dim(1);
shap->add_dim(2);
ret = weightParser.ConvertBlobsProto(&net, &net);
ret = weightParser.ConvertBlobsProto(net, &net);
EXPECT_EQ(ret, SUCCESS);

ret = weightParser.ConvertBlobShapeProto(&net, &net);
ret = weightParser.ConvertBlobShapeProto(net, &net);
EXPECT_EQ(ret, SUCCESS);

ret = weightParser.ConvertConvParamProto(&net, &net);
ret = weightParser.ConvertConvParamProto(net, &net);
EXPECT_EQ(ret, SUCCESS);

ret = weightParser.ConvertInnerProdcutProto(&net, &net);
ret = weightParser.ConvertInnerProdcutProto(net, &net);
EXPECT_EQ(ret, SUCCESS);
}

@@ -1133,7 +1136,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test)
const google::protobuf::Message *proto = factory.GetPrototype(descriptor);
const google::protobuf::Message *message = proto->New();

Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph);
Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph);
delete message;
EXPECT_EQ(ret, SUCCESS);
}
@@ -1163,7 +1166,7 @@ TEST_F(UtestCaffeParser, CaffeModelParser_ParseLayerParameter_test)
google::protobuf::DynamicMessageFactory factory;
const google::protobuf::Message *proto = factory.GetPrototype(descriptor);
google::protobuf::Message *message = proto->New();
Status ret = modelParser.ParseLayerParameter(descriptor, message, operators);
Status ret = modelParser.ParseLayerParameter(*descriptor, *message, operators);
EXPECT_EQ(ret, SUCCESS);

const domi::FrameworkType fmk_type = domi::TENSORFLOW;


+ 2
- 2
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -381,7 +381,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto)
OnnxFileConstantParser parser;
ge::onnx::NodeProto input_node;
ge::onnx::TensorProto tensor_proto;
Status ret = parser.GetTensorProto(&input_node, tensor_proto);
Status ret = parser.GetTensorProto(input_node, tensor_proto);
EXPECT_EQ(ret, FAILED);

ge::onnx::AttributeProto *attribute = input_node.add_attribute();
@@ -391,7 +391,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto)

ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
*attribute_tensor = tensor_proto;
ret = parser.GetTensorProto(&input_node, tensor_proto);
ret = parser.GetTensorProto(input_node, tensor_proto);
EXPECT_EQ(ret, SUCCESS);
}



Loading…
Cancel
Save