Browse Source

!653 sync ge_dev to master 20220902

Merge pull request !653 from 王笑天/ge_dev
pull/656/MERGE
王涛 Gitee 2 years ago
parent
commit
9b83ed77f8
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
74 changed files with 576 additions and 450 deletions
  1. +1
    -1
      metadef
  2. +11
    -11
      parser/CMakeLists.txt
  3. +194
    -198
      parser/caffe/caffe_parser.cc
  4. +13
    -13
      parser/caffe/caffe_parser.h
  5. +45
    -0
      parser/common/acl_graph_parser_util.cc
  6. +1
    -0
      parser/common/acl_graph_parser_util.h
  7. +1
    -1
      parser/common/convert/message2operator.cc
  8. +1
    -1
      parser/common/convert/message2operator.h
  9. +4
    -6
      parser/common/convert/pb2json.cc
  10. +1
    -1
      parser/common/convert/pb2json.h
  11. +2
    -2
      parser/common/op_def/arg_op_operator.cc
  12. +1
    -1
      parser/common/op_def/arg_op_operator.h
  13. +2
    -2
      parser/common/op_def/constant_operator.cc
  14. +1
    -1
      parser/common/op_def/constant_operator.h
  15. +2
    -2
      parser/common/op_def/fill_operator.cc
  16. +1
    -1
      parser/common/op_def/fill_operator.h
  17. +2
    -2
      parser/common/op_def/framework_op_operator.cc
  18. +1
    -1
      parser/common/op_def/framework_op_operator.h
  19. +1
    -1
      parser/common/op_def/ir_pb_converter.cc
  20. +1
    -1
      parser/common/op_def/ir_pb_converter.h
  21. +2
    -2
      parser/common/op_def/no_op_operator.cc
  22. +1
    -2
      parser/common/op_def/no_op_operator.h
  23. +1
    -1
      parser/common/op_def/operator.cc
  24. +1
    -1
      parser/common/op_def/operator.h
  25. +2
    -2
      parser/common/op_def/ref_switch_operator.cc
  26. +1
    -1
      parser/common/op_def/ref_switch_operator.h
  27. +2
    -2
      parser/common/op_def/shape_n_operator.cc
  28. +1
    -1
      parser/common/op_def/shape_n_operator.h
  29. +2
    -2
      parser/common/op_def/var_is_initialized_op_operator.cc
  30. +1
    -1
      parser/common/op_def/var_is_initialized_op_operator.h
  31. +2
    -2
      parser/common/op_def/variable_operator.cc
  32. +1
    -1
      parser/common/op_def/variable_operator.h
  33. +11
    -11
      parser/module.mk
  34. +3
    -5
      parser/onnx/onnx_constant_parser.h
  35. +19
    -19
      parser/onnx/onnx_file_constant_parser.cc
  36. +5
    -5
      parser/onnx/onnx_file_constant_parser.h
  37. +10
    -10
      parser/onnx/onnx_parser.cc
  38. +3
    -1
      parser/onnx/onnx_util.h
  39. +13
    -12
      parser/onnx/subgraph_adapter/if_subgraph_adapter.cc
  40. +3
    -2
      parser/onnx/subgraph_adapter/if_subgraph_adapter.h
  41. +1
    -3
      parser/onnx/subgraph_adapter/subgraph_adapter.h
  42. +1
    -1
      parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc
  43. +2
    -3
      parser/onnx/subgraph_adapter/subgraph_adapter_factory.h
  44. +33
    -32
      parser/stub/gen_stubapi.py
  45. +2
    -2
      parser/tensorflow/graph_to_function_def.cc
  46. +1
    -1
      parser/tensorflow/graph_to_function_def.h
  47. +2
    -2
      parser/tensorflow/iterator_fusion_pass.cc
  48. +1
    -1
      parser/tensorflow/iterator_fusion_pass.h
  49. +6
    -3
      parser/tensorflow/parser_graph_optimizer.cc
  50. +1
    -1
      parser/tensorflow/parser_graph_optimizer.h
  51. +1
    -1
      parser/tensorflow/scope/scope_pass_manager.cc
  52. +1
    -1
      parser/tensorflow/scope/scope_pass_manager.h
  53. +2
    -2
      parser/tensorflow/tensorflow_arg_parser.cc
  54. +1
    -1
      parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc
  55. +1
    -1
      parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h
  56. +1
    -1
      parser/tensorflow/tensorflow_constant_parser.cc
  57. +1
    -1
      parser/tensorflow/tensorflow_constant_parser.h
  58. +1
    -1
      parser/tensorflow/tensorflow_fill_parser.cc
  59. +1
    -1
      parser/tensorflow/tensorflow_frameworkop_parser.cc
  60. +1
    -1
      parser/tensorflow/tensorflow_no_op_parser.cc
  61. +1
    -1
      parser/tensorflow/tensorflow_ref_switch_parser.cc
  62. +1
    -1
      parser/tensorflow/tensorflow_ref_switch_parser.h
  63. +1
    -1
      parser/tensorflow/tensorflow_shape_n_parser.cc
  64. +1
    -1
      parser/tensorflow/tensorflow_shape_n_parser.h
  65. +1
    -1
      parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc
  66. +1
    -1
      parser/tensorflow/tensorflow_variable_v2_parser.cc
  67. +11
    -11
      tests/st/CMakeLists.txt
  68. +48
    -2
      tests/st/testcase/test_caffe_parser.cc
  69. +28
    -10
      tests/st/testcase/test_tensorflow_parser.cc
  70. +11
    -11
      tests/ut/parser/CMakeLists.txt
  71. +11
    -8
      tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc
  72. +1
    -1
      tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc
  73. +2
    -2
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc
  74. +27
    -10
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 5d062a35640733026457c91966a558769570b0f8
Subproject commit f5c1b6d1b6b6e97d0cfcf2efd52ec8da12d32c86

+ 11
- 11
parser/CMakeLists.txt View File

@@ -22,18 +22,18 @@ set(SRC_LIST
"caffe/caffe_custom_parser_adapter.cc" "caffe/caffe_custom_parser_adapter.cc"
"caffe/caffe_op_parser.cc" "caffe/caffe_op_parser.cc"
"tensorflow/scope/scope_pass_manager.cc" "tensorflow/scope/scope_pass_manager.cc"
"tensorflow/graph_functiondef.cc"
"tensorflow/graph_optimizer.cc"
"tensorflow/graph_to_function_def.cc"
"tensorflow/parser_graph_optimizer.cc"
"tensorflow/iterator_fusion_pass.cc" "tensorflow/iterator_fusion_pass.cc"
"common/op_def/arg_op.cc"
"common/op_def/constant_op.cc"
"common/op_def/fill_op.cc"
"common/op_def/frameworkop_op.cc"
"common/op_def/no_op_op.cc"
"common/op_def/ref_switch_op.cc"
"common/op_def/shape_n_op.cc"
"common/op_def/var_is_initialized_op_op.cc"
"common/op_def/variable_op.cc"
"common/op_def/arg_op_operator.cc"
"common/op_def/constant_operator.cc"
"common/op_def/fill_operator.cc"
"common/op_def/framework_op_operator.cc"
"common/op_def/no_op_operator.cc"
"common/op_def/ref_switch_operator.cc"
"common/op_def/shape_n_operator.cc"
"common/op_def/var_is_initialized_op_operator.cc"
"common/op_def/variable_operator.cc"
) )


############ libfmk_parser.so ############ ############ libfmk_parser.so ############


+ 194
- 198
parser/caffe/caffe_parser.cc View File

@@ -236,14 +236,8 @@ const char *const kFieldInnerPro = "inner_product_param";
const char *const kFieldDim = "dim"; const char *const kFieldDim = "dim";
const char *const kFieldBiasTerm = "bias_term"; const char *const kFieldBiasTerm = "bias_term";
const char *const kDevNull = "/dev/null"; const char *const kDevNull = "/dev/null";
const std::string kMessage = "message";
const std::string kLayerParameter = "LayerParameter";
const std::string kCloseBrace = "}";
const std::string kOptional = "optional";
const std::string kRepeated = "repeated";
const std::string kRequired = "required";
const std::string kCustom = "custom";
const std::string kBuiltin = "built-in";
const char *const kCustom = "custom";
const char *const kBuiltin = "built-in";
std::vector<std::string> kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT, std::vector<std::string> kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT,
ge::parser::NETOUTPUT}; ge::parser::NETOUTPUT};
const std::set<std::string> kCustomProtoLayerCommonField = {"name", "type"}; const std::set<std::string> kCustomProtoLayerCommonField = {"name", "type"};
@@ -284,104 +278,104 @@ const set<string> CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxW
"Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"}; "Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"};


Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const { Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const {
if (proto_message.input_size() > 0) {
GELOGI("This net exsit input.");
if (proto_message.input_size() <= 0) {
return SUCCESS;
}
GELOGI("This net exsit input.");
if (proto_message.input_dim_size() > 0) {
if (proto_message.input_shape_size() > 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E11001");
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!");
return FAILED;
}


if (proto_message.input_dim_size() > 0) {
if (proto_message.input_shape_size() > 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E11001");
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!");
return FAILED;
}
const int32_t input_dim_size = proto_message.input_dim_size();
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) ||
((input_dim_size % proto_message.input_size()) != 0));
if (is_input_invalid) {
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"},
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())});
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].",
input_dim_size, proto_message.input_size());
return FAILED;
}


const int32_t input_dim_size = proto_message.input_dim_size();
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) ||
((input_dim_size % proto_message.input_size()) != 0));
if (is_input_invalid) {
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"},
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())});
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].",
input_dim_size, proto_message.input_size());
return FAILED;
}
for (int i = 0; i < proto_message.input_size(); i++) {
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(proto_message.input(i));
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));


for (int i = 0; i < proto_message.input_size(); i++) {
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(proto_message.input(i));
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));

domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);

for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) {
// Can guarantee that it will not cross the border
shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE)));
}
input_data_flag = true;
}
} else if (proto_message.input_shape_size() > 0) {
if (proto_message.input_shape_size() != proto_message.input_size()) {
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"},
{std::to_string(proto_message.input_shape_size()),
std::to_string(proto_message.input_size())});
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).",
proto_message.input_shape_size(), proto_message.input_size());
return FAILED;
domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);

for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) {
// Can guarantee that it will not cross the border
shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE)));
} }
input_data_flag = true;
}
} else if (proto_message.input_shape_size() > 0) {
if (proto_message.input_shape_size() != proto_message.input_size()) {
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"},
{std::to_string(proto_message.input_shape_size()),
std::to_string(proto_message.input_size())});
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).",
proto_message.input_shape_size(), proto_message.input_size());
return FAILED;
}


for (int i = 0; i < proto_message.input_size(); i++) {
int dim_size = proto_message.input_shape(i).dim_size();
for (int i = 0; i < proto_message.input_size(); i++) {
int dim_size = proto_message.input_shape(i).dim_size();


domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(proto_message.input(i));
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(proto_message.input(i));
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));


domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);
domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);


for (int j = 0; j < dim_size; j++) {
// Can guarantee that it will not cross the border
shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j)));
}
input_data_flag = true;
for (int j = 0; j < dim_size; j++) {
// Can guarantee that it will not cross the border
shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j)));
} }
} else {
const ge::ParserContext &ctx = ge::GetParserContext();
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
for (int i = 0; i < proto_message.input_size(); i++) {
string name = proto_message.input(i);
if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name}));
GELOGE(FAILED, "[Find][Dim]Model has no input shape.");
return FAILED;
}
std::vector<int64_t> dims = input_dims.at(name);
size_t dim_size = dims.size();
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(name);
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));
domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);
for (size_t j = 0; j < dim_size; j++) {
shape->add_dim(dims.at(j));
}
input_data_flag = true;
input_data_flag = true;
}
} else {
const ge::ParserContext &ctx = ge::GetParserContext();
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
for (int i = 0; i < proto_message.input_size(); i++) {
string name = proto_message.input(i);
if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name}));
GELOGE(FAILED, "[Find][Dim]Model has no input shape.");
return FAILED;
}
std::vector<int64_t> dims = input_dims.at(name);
size_t dim_size = dims.size();
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(name);
layer->set_type(ge::parser::INPUT_TYPE);
layer->add_top(proto_message.input(i));
domi::caffe::InputParameter *input_param = layer->mutable_input_param();
GE_CHECK_NOTNULL(input_param);
domi::caffe::BlobShape *shape = input_param->add_shape();
GE_CHECK_NOTNULL(shape);
for (size_t j = 0; j < dim_size; j++) {
shape->add_dim(dims.at(j));
} }
input_data_flag = true;
} }
} }
return SUCCESS; return SUCCESS;
@@ -423,7 +417,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons
return FAILED; return FAILED;
} }


if (ParseLayerParameter(layer_descriptor, message, operators) != SUCCESS) {
if (ParseLayerParameter(*layer_descriptor, *message, operators) != SUCCESS) {
delete message; delete message;
GELOGE(FAILED, "[Parse][LayerParameter] failed, model path:%s.", model_path); GELOGE(FAILED, "[Parse][LayerParameter] failed, model path:%s.", model_path);
return FAILED; return FAILED;
@@ -536,18 +530,18 @@ Status CaffeModelParser::ReadCaffeModelFromText(const char *model_path, google::
return SUCCESS; return SUCCESS;
} }


Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message,
Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor,
const google::protobuf::Message &message,
vector<ge::Operator> &operators) const { vector<ge::Operator> &operators) const {
auto field_name = layer_descriptor->FindFieldByName(kFieldName);
auto field_name = layer_descriptor.FindFieldByName(kFieldName);
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor");
auto field_type = layer_descriptor->FindFieldByName(kFieldType);
auto field_type = layer_descriptor.FindFieldByName(kFieldType);
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor");


const google::protobuf::Reflection *reflection = message->GetReflection();
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc; vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);
for (auto &field : field_desc) { for (auto &field : field_desc) {
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message");
// Only care about layers // Only care about layers
@@ -561,10 +555,10 @@ Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor
return FAILED; return FAILED;
} }


int field_size = reflection->FieldSize(*message, field);
int field_size = reflection->FieldSize(message, field);
GELOGI("Total Layer num of model file is %d", field_size); GELOGI("Total Layer num of model file is %d", field_size);
for (int i = 0; i < field_size; ++i) { for (int i = 0; i < field_size; ++i) {
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i);
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i);
const google::protobuf::Reflection *layer_reflection = layer_message.GetReflection(); const google::protobuf::Reflection *layer_reflection = layer_message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message");
GE_CHECK_NOTNULL(layer_reflection); GE_CHECK_NOTNULL(layer_reflection);
@@ -1316,7 +1310,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co
layer_name_map[layer.name()]++; layer_name_map[layer.name()]++;
// Set the name in proto and layer // Set the name in proto and layer
domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index);
duplicate_name_layer->set_name(new_name); layer.set_name(new_name);)
duplicate_name_layer->set_name(new_name);
layer.set_name(new_name);)


// Insert the new operator name, the number of times of duplicate name is recorded as 1 // Insert the new operator name, the number of times of duplicate name is recorded as 1
layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); layer_name_map.insert(std::make_pair(layer.name(), kNumOne));
@@ -1539,7 +1534,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap
layer_name_map[layer.name()]++; layer_name_map[layer.name()]++;
// Set the name in proto and layer // Set the name in proto and layer
domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index);
duplicate_name_layer->set_name(new_name); layer.set_name(new_name);)
duplicate_name_layer->set_name(new_name);
layer.set_name(new_name);)


// Insert the new operator name, the number of times of duplicate name is recorded as 1 // Insert the new operator name, the number of times of duplicate name is recorded as 1
layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); layer_name_map.insert(std::make_pair(layer.name(), kNumOne));
@@ -1832,13 +1828,13 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con
return FAILED; return FAILED;
} }


if (CheckLayersSize(message) != SUCCESS) {
if (CheckLayersSize(*message) != SUCCESS) {
delete message; delete message;
message = nullptr; message = nullptr;
return FAILED; return FAILED;
} }


if (ParseLayerParameter(layer_descriptor, message, graph) != SUCCESS) {
if (ParseLayerParameter(*layer_descriptor, *message, graph) != SUCCESS) {
delete message; delete message;
message = nullptr; message = nullptr;
REPORT_CALL_ERROR("E19999", "ParseLayerParameter failed failed from weight file:%s.", weight_path); REPORT_CALL_ERROR("E19999", "ParseLayerParameter failed failed from weight file:%s.", weight_path);
@@ -1852,18 +1848,18 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con
return SUCCESS; return SUCCESS;
} }


Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message,
Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor,
const google::protobuf::Message &message,
ge::ComputeGraphPtr &graph) { ge::ComputeGraphPtr &graph) {
auto field_name = layer_descriptor->FindFieldByName(kFieldName);
auto field_name = layer_descriptor.FindFieldByName(kFieldName);
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor");
auto field_type = layer_descriptor->FindFieldByName(kFieldType);
auto field_type = layer_descriptor.FindFieldByName(kFieldType);
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor");


const google::protobuf::Reflection *reflection = message->GetReflection();
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc; vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);


NetParameter tmp_net; NetParameter tmp_net;
for (auto &field : field_desc) { for (auto &field : field_desc) {
@@ -1880,13 +1876,13 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto
return FAILED; return FAILED;
} }


int field_size = reflection->FieldSize(*message, field);
int field_size = reflection->FieldSize(message, field);
GELOGI("Total Layer num of model file is %d", field_size); GELOGI("Total Layer num of model file is %d", field_size);
for (int i = 0; i < field_size; ++i) { for (int i = 0; i < field_size; ++i) {
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i);
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i);


LayerParameter *layer = tmp_net.add_layer(); LayerParameter *layer = tmp_net.add_layer();
if (ConvertLayerProto(&layer_message, layer) != SUCCESS) {
if (ConvertLayerProto(layer_message, layer) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ConvertLayerProto] Convert message to layer proto failed."); GELOGE(FAILED, "[Invoke][ConvertLayerProto] Convert message to layer proto failed.");
return FAILED; return FAILED;
} }
@@ -1907,16 +1903,16 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto
return SUCCESS; return SUCCESS;
} }


Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *message,
Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message &message,
google::protobuf::Message *layer) { google::protobuf::Message *layer) {
const google::protobuf::Reflection *layer_reflection = message->GetReflection();
const google::protobuf::Reflection *layer_reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc; vector<const google::protobuf::FieldDescriptor *> field_desc;
layer_reflection->ListFields(*message, &field_desc);
layer_reflection->ListFields(message, &field_desc);


for (auto &field : field_desc) { for (auto &field : field_desc) {
GE_CHECK_NOTNULL(field); GE_CHECK_NOTNULL(field);
if (ParseLayerField(layer_reflection, message, field, layer) != SUCCESS) {
if (ParseLayerField(*layer_reflection, message, *field, layer) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ParseLayerField] Parse field %s failed.", field->name().c_str()); GELOGE(FAILED, "[Invoke][ParseLayerField] Parse field %s failed.", field->name().c_str());
return FAILED; return FAILED;
} }
@@ -1924,114 +1920,114 @@ Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *me
return SUCCESS; return SUCCESS;
} }


Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field,
Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection &reflection,
const google::protobuf::Message &message,
const google::protobuf::FieldDescriptor &field,
google::protobuf::Message *layer) const { google::protobuf::Message *layer) const {
GELOGD("Start to parse field: %s.", field->name().c_str());
GELOGD("Start to parse field: %s.", field.name().c_str());
domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer); domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer);
string filed_name = field->name();
#define CASE_FIELD_NAME(kName, method) \
string filed_name = field.name();
#define CASE_FIELD_NAME(kName, method, inner_message, field_ptr) \
if (filed_name == kField##kName) { \ if (filed_name == kField##kName) { \
string value = reflection->GetString(*message, field); \
string value = reflection.GetString(inner_message, field_ptr); \
GELOGD("Parse res: (%s : %s)", filed_name.c_str(), value.c_str()); \ GELOGD("Parse res: (%s : %s)", filed_name.c_str(), value.c_str()); \
layer_proto->set_##method(value); \ layer_proto->set_##method(value); \
return SUCCESS; \ return SUCCESS; \
} }
CASE_FIELD_NAME(Name, name);
CASE_FIELD_NAME(Type, type);
CASE_FIELD_NAME(Name, name, message, &field);
CASE_FIELD_NAME(Type, type, message, &field);
#undef CASE_FIELD_NAME #undef CASE_FIELD_NAME
#define CASE_FIELD_NAME_REPEATED(kName, method) \
if (filed_name == kField##kName) { \
int field_size = reflection->FieldSize(*message, field); \
for (int i = 0; i < field_size; ++i) { \
auto value = reflection->GetRepeatedString(*message, field, i); \
layer_proto->add_##method(value); \
} \
return SUCCESS; \
}
CASE_FIELD_NAME_REPEATED(Bottom, bottom);
CASE_FIELD_NAME_REPEATED(Top, top);
#define CASE_FIELD_NAME_REPEATED(kName, method, inner_message, field_ptr) \
if (filed_name == kField##kName) { \
int field_size = reflection.FieldSize(inner_message, field_ptr); \
for (int i = 0; i < field_size; ++i) { \
auto value = reflection.GetRepeatedString(inner_message, field_ptr, i); \
layer_proto->add_##method(value); \
} \
return SUCCESS; \
}
CASE_FIELD_NAME_REPEATED(Bottom, bottom, message, &field);
CASE_FIELD_NAME_REPEATED(Top, top, message, &field);
#undef CASE_FIELD_NAME_REPEATED #undef CASE_FIELD_NAME_REPEATED
if (filed_name == kFieldBlobs) { if (filed_name == kFieldBlobs) {
int field_size = reflection->FieldSize(*message, field);
int field_size = reflection.FieldSize(message, &field);
for (int i = 0; i < field_size; ++i) { for (int i = 0; i < field_size; ++i) {
domi::caffe::BlobProto *item_message = layer_proto->add_blobs(); domi::caffe::BlobProto *item_message = layer_proto->add_blobs();
const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i);
if (ConvertBlobsProto(&sub_message, item_message) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field->name().c_str());
const google::protobuf::Message &sub_message = reflection.GetRepeatedMessage(message, &field, i);
if (ConvertBlobsProto(sub_message, item_message) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field.name().c_str());
return FAILED; return FAILED;
} }
} }
return SUCCESS; return SUCCESS;
} }
if (filed_name == kFieldConvParam) { if (filed_name == kFieldConvParam) {
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field);
const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field);
ConvolutionParameter *conv_param = layer_proto->mutable_convolution_param(); ConvolutionParameter *conv_param = layer_proto->mutable_convolution_param();
ConvertConvParamProto(&sub_message, conv_param);
ConvertConvParamProto(sub_message, conv_param);
} }
if (filed_name == kFieldInnerPro) { if (filed_name == kFieldInnerPro) {
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field);
const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field);
InnerProductParameter *inner_product = layer_proto->mutable_inner_product_param(); InnerProductParameter *inner_product = layer_proto->mutable_inner_product_param();
ConvertInnerProdcutProto(&sub_message, inner_product);
ConvertInnerProdcutProto(sub_message, inner_product);
} }
return SUCCESS; return SUCCESS;
} }


Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message,
Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message &message,
google::protobuf::Message *blobs) const { google::protobuf::Message *blobs) const {
const google::protobuf::Reflection *blobs_reflection = message->GetReflection();
const google::protobuf::Reflection *blobs_reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc; vector<const google::protobuf::FieldDescriptor *> field_desc;
blobs_reflection->ListFields(*message, &field_desc);
blobs_reflection->ListFields(message, &field_desc);
domi::caffe::BlobProto *blobs_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobProto>(blobs); domi::caffe::BlobProto *blobs_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobProto>(blobs);


for (auto &field : field_desc) { for (auto &field : field_desc) {
GE_CHECK_NOTNULL(field); GE_CHECK_NOTNULL(field);
string feild_name = field->name(); string feild_name = field->name();
#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name) \
if (feild_name == #kName) { \
int field_size = blobs_reflection->FieldSize(*message, field); \
for (int i = 0; i < field_size; ++i) { \
valuetype value = blobs_reflection->GetRepeated##method(*message, field, i); \
blobs_proto->add_##name(value); \
} \
continue; \
}
CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data);
CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff);
CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data);
CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff);
CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data);
CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data);
#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name, inner_message, inner_field) \
if (feild_name == #kName) { \
int field_size = blobs_reflection->FieldSize(inner_message, inner_field); \
for (int i = 0; i < field_size; ++i) { \
valuetype value = blobs_reflection->GetRepeated##method(inner_message, inner_field, i); \
blobs_proto->add_##name(value); \
} \
continue; \
}
CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data, message, field);
CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff, message, field);
CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data, message, field);
CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff, message, field);
CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data, message, field);
CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data, message, field);
#undef CASE_BLOBS_FIELD_NAME_REPEATED #undef CASE_BLOBS_FIELD_NAME_REPEATED
#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name) \
if (feild_name == #kName) { \
valuetype value = blobs_reflection->Get##method(*message, field); \
blobs_proto->set_##name(value); \
continue; \
}
CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data);
CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num);
CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels);
CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height);
CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width);
#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name, inner_message, inner_field) \
if (feild_name == #kName) { \
valuetype value = blobs_reflection->Get##method(inner_message, inner_field); \
blobs_proto->set_##name(value); \
continue; \
}
CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data, message, field);
CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num, message, field);
CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels, message, field);
CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height, message, field);
CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width, message, field);
#undef CASE_BLOBS_FIELD_NAME #undef CASE_BLOBS_FIELD_NAME
if (feild_name == kFieldShape) { if (feild_name == kFieldShape) {
const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(*message, field);
const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(message, field);
domi::caffe::BlobShape *blob_shape = blobs_proto->mutable_shape(); domi::caffe::BlobShape *blob_shape = blobs_proto->mutable_shape();
ConvertBlobShapeProto(&sub_message, blob_shape);
ConvertBlobShapeProto(sub_message, blob_shape);
} }
} }
return SUCCESS; return SUCCESS;
} }


Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message *message,
Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const { google::protobuf::Message *dest_message) const {
const google::protobuf::Reflection *reflection = message->GetReflection();
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc; vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);


domi::caffe::BlobShape *shape_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobShape>(dest_message); domi::caffe::BlobShape *shape_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobShape>(dest_message);


@@ -2039,21 +2035,21 @@ Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message
if (field->name() != kFieldDim) { if (field->name() != kFieldDim) {
continue; continue;
} }
int field_size = reflection->FieldSize(*message, field);
int field_size = reflection->FieldSize(message, field);
for (int i = 0; i < field_size; ++i) { for (int i = 0; i < field_size; ++i) {
int64_t value = reflection->GetRepeatedInt64(*message, field, i);
int64_t value = reflection->GetRepeatedInt64(message, field, i);
shape_proto->add_dim(value); shape_proto->add_dim(value);
} }
} }
return SUCCESS; return SUCCESS;
} }


Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message *message,
Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const { google::protobuf::Message *dest_message) const {
const google::protobuf::Reflection *reflection = message->GetReflection();
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc; vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);


domi::caffe::ConvolutionParameter *conv_param_proto = domi::caffe::ConvolutionParameter *conv_param_proto =
PtrToPtr<google::protobuf::Message, domi::caffe::ConvolutionParameter>(dest_message); PtrToPtr<google::protobuf::Message, domi::caffe::ConvolutionParameter>(dest_message);
@@ -2062,18 +2058,18 @@ Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message
if (field->name() != kFieldBiasTerm) { if (field->name() != kFieldBiasTerm) {
continue; continue;
} }
bool value = reflection->GetBool(*message, field);
bool value = reflection->GetBool(message, field);
conv_param_proto->set_bias_term(value); conv_param_proto->set_bias_term(value);
} }
return SUCCESS; return SUCCESS;
} }


Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message *message,
Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const { google::protobuf::Message *dest_message) const {
const google::protobuf::Reflection *reflection = message->GetReflection();
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc; vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);


domi::caffe::InnerProductParameter *inner_product_proto = domi::caffe::InnerProductParameter *inner_product_proto =
PtrToPtr<google::protobuf::Message, domi::caffe::InnerProductParameter>(dest_message); PtrToPtr<google::protobuf::Message, domi::caffe::InnerProductParameter>(dest_message);
@@ -2082,17 +2078,17 @@ Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Mess
if (field->name() != kFieldBiasTerm) { if (field->name() != kFieldBiasTerm) {
continue; continue;
} }
bool value = reflection->GetBool(*message, field);
bool value = reflection->GetBool(message, field);
inner_product_proto->set_bias_term(value); inner_product_proto->set_bias_term(value);
} }
return SUCCESS; return SUCCESS;
} }


Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *message) const {
const google::protobuf::Reflection *reflection = message->GetReflection();
Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message &message) const {
const google::protobuf::Reflection *reflection = message.GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc; vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);
reflection->ListFields(message, &field_desc);


int num_layer = 0; int num_layer = 0;
int num_layers = 0; int num_layers = 0;
@@ -2110,7 +2106,7 @@ Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *mess
return FAILED; return FAILED;
} }


int field_size = reflection->FieldSize(*message, field);
int field_size = reflection->FieldSize(message, field);
if (field->name() == kLayerName) { if (field->name() == kLayerName) {
num_layer = field_size; num_layer = field_size;
} else { } else {


+ 13
- 13
parser/caffe/caffe_parser.h View File

@@ -212,8 +212,8 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
* @return SUCCESS parse layer successfully * @return SUCCESS parse layer successfully
* @return FAILED parse layer failed * @return FAILED parse layer failed
*/ */
Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message, std::vector<ge::Operator> &operators) const;
Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor,
const google::protobuf::Message &message, std::vector<ge::Operator> &operators) const;


/* /*
* @ingroup domi_omg * @ingroup domi_omg
@@ -386,33 +386,33 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser {
Status ParseWeightByFusionProto(const char *weight_path, const string &fusion_proto_path, Status ParseWeightByFusionProto(const char *weight_path, const string &fusion_proto_path,
const string &fusion_proto_name, ge::ComputeGraphPtr &graph); const string &fusion_proto_name, ge::ComputeGraphPtr &graph);


Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message,
Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor,
const google::protobuf::Message &message,
ge::ComputeGraphPtr &graph); ge::ComputeGraphPtr &graph);


Status ConvertLayerParameter(const google::protobuf::Message *layer_message, Status ConvertLayerParameter(const google::protobuf::Message *layer_message,
ge::ComputeGraphPtr &graph); ge::ComputeGraphPtr &graph);


Status CheckLayersSize(const google::protobuf::Message *message) const;
Status CheckLayersSize(const google::protobuf::Message &message) const;


Status ConvertLayerProto(const google::protobuf::Message *message,
Status ConvertLayerProto(const google::protobuf::Message &message,
google::protobuf::Message *layer); google::protobuf::Message *layer);


Status ParseLayerField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field,
Status ParseLayerField(const google::protobuf::Reflection &reflection,
const google::protobuf::Message &message,
const google::protobuf::FieldDescriptor &field,
google::protobuf::Message *layer) const; google::protobuf::Message *layer) const;


Status ConvertBlobsProto(const google::protobuf::Message *message,
Status ConvertBlobsProto(const google::protobuf::Message &message,
google::protobuf::Message *blobs) const; google::protobuf::Message *blobs) const;


Status ConvertBlobShapeProto(const google::protobuf::Message *message,
Status ConvertBlobShapeProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const; google::protobuf::Message *dest_message) const;


Status ConvertInnerProdcutProto(const google::protobuf::Message *message,
Status ConvertInnerProdcutProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const; google::protobuf::Message *dest_message) const;


Status ConvertConvParamProto(const google::protobuf::Message *message,
Status ConvertConvParamProto(const google::protobuf::Message &message,
google::protobuf::Message *dest_message) const; google::protobuf::Message *dest_message) const;
/** /**
* @ingroup domi_omg * @ingroup domi_omg


+ 45
- 0
parser/common/acl_graph_parser_util.cc View File

@@ -431,6 +431,41 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra
return SUCCESS; return SUCCESS;
} }


domi::Status AclGrphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph,
const std::string &input_data_names) const {
std::vector<std::string> input_names = StringUtils::Split(input_data_names, ',');
std::unordered_map<std::string, size_t> name_to_index;
for (auto &input_name : input_names) {
if (!name_to_index.emplace(input_name, name_to_index.size()).second) {
GELOGE(PARAM_INVALID, "[Check][Param] Duplicate input name[%s].", input_name.c_str());
return FAILED;
}
}

for (const NodePtr &node : graph->GetDirectNode()) {
if (node->GetType() != ge::parser::DATA) {
continue;
}
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto iter = name_to_index.find(node->GetName());
if (iter== name_to_index.cend()) {
GELOGE(PARAM_INVALID, "[Check][Param] Input name[%s] is not in input_data_names",
node->GetName().c_str());
return FAILED;
}
GELOGI("[SetSpecifyIndexAttr] set node(%s) index attr, index is %ld",
op_desc->GetName().c_str(), iter->second);
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, iter->second)) {
REPORT_CALL_ERROR("E19999", "set attr %s failed for node:%s",
ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str());
GELOGE(FAILED, "set attr %s failed for node:%s", ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str());
return FAILED;
}
}
return SUCCESS;
}

void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name) const { std::vector<std::string> &output_nodes_name) const {
output_nodes_name.clear(); output_nodes_name.clear();
@@ -670,6 +705,16 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,
return PARAM_INVALID; return PARAM_INVALID;
} }


string input_data_names;
GetAclParams(parser_params, ge::ir_option::INPUT_DATA_NAMES, input_data_names);
if (!input_data_names.empty()) {
if (SetSpecifyIndexAttrByInputNames(compute_graph, input_data_names) != SUCCESS) {
GELOGE(FAILED, "[Invoke][SetIndexAttr] set index attr failed, graph:%s",
compute_graph->GetName().c_str());
return PARAM_INVALID;
}
}

return SUCCESS; return SUCCESS;
} }




+ 1
- 0
parser/common/acl_graph_parser_util.h View File

@@ -61,6 +61,7 @@ class AclGrphParseUtil {
size_t index, OpDescPtr &op_desc); size_t index, OpDescPtr &op_desc);
domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes,
const string &is_input_adjust_hw_layout) const; const string &is_input_adjust_hw_layout) const;
domi::Status SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, const std::string &input_data_names) const;
domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const; std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const;
}; };


+ 1
- 1
parser/common/convert/message2operator.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


+ 1
- 1
parser/common/convert/message2operator.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


+ 4
- 6
parser/common/convert/pb2json.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -82,7 +82,7 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr
switch (field->type()) { switch (field->type()) {
case ProtobufFieldDescriptor::TYPE_MESSAGE: { case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); const ProtobufMsg &tmp_message = reflection->GetMessage(message, field);
if (0UL != tmp_message.ByteSizeLong()) {
if (tmp_message.ByteSizeLong() != 0UL) {
Message2Json(tmp_message, black_fields, json[field->name()], enum2str, depth + 1); Message2Json(tmp_message, black_fields, json[field->name()], enum2str, depth + 1);
} }
break; break;
@@ -122,7 +122,7 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr


case ProtobufFieldDescriptor::TYPE_FLOAT: case ProtobufFieldDescriptor::TYPE_FLOAT:
char str[kSignificantDigits]; char str[kSignificantDigits];
if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1){
if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) {
json[field->name()] = str; json[field->name()] = str;
} else { } else {
json[field->name()] = reflection->GetFloat(message, field); json[field->name()] = reflection->GetFloat(message, field);
@@ -155,10 +155,8 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) {
} }
string result = ""; string result = "";
for (char temp_value : type_bytes) { for (char temp_value : type_bytes) {
uint8_t *value = 0;
value = reinterpret_cast<uint8_t *>(&temp_value);
char str[kSignificantDigits]; char str[kSignificantDigits];
if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1){
if (sprintf_s(str, kSignificantDigits, "%c", temp_value) == -1) {
GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str());
continue; continue;
} }


+ 1
- 1
parser/common/convert/pb2json.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


parser/common/op_def/arg_op.cc → parser/common/op_def/arg_op_operator.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include "parser/common/op_def/arg_op.h"
#include "parser/common/op_def/arg_op_operator.h"
#include <string> #include <string>
#include "framework/common/fmk_types.h" #include "framework/common/fmk_types.h"



parser/common/op_def/arg_op.h → parser/common/op_def/arg_op_operator.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

parser/common/op_def/constant_op.cc → parser/common/op_def/constant_operator.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include "common/op_def/constant_op.h"
#include "common/op_def/constant_operator.h"
#include <string> #include <string>
#include <vector> #include <vector>



parser/common/op_def/constant_op.h → parser/common/op_def/constant_operator.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

parser/common/op_def/fill_op.cc → parser/common/op_def/fill_operator.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include "common/op_def/fill_op.h"
#include "common/op_def/fill_operator.h"
#include "framework/common/fmk_types.h" #include "framework/common/fmk_types.h"


namespace ge { namespace ge {

parser/common/op_def/fill_op.h → parser/common/op_def/fill_operator.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

parser/common/op_def/frameworkop_op.cc → parser/common/op_def/framework_op_operator.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include "common/op_def/frameworkop_op.h"
#include "common/op_def/framework_op_operator.h"
#include <string> #include <string>
#include "framework/common/fmk_types.h" #include "framework/common/fmk_types.h"



parser/common/op_def/frameworkop_op.h → parser/common/op_def/framework_op_operator.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

+ 1
- 1
parser/common/op_def/ir_pb_converter.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


+ 1
- 1
parser/common/op_def/ir_pb_converter.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


parser/common/op_def/no_op_op.cc → parser/common/op_def/no_op_operator.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
*/ */


// AUTO GEN PLEASE DO NOT MODIFY IT // AUTO GEN PLEASE DO NOT MODIFY IT
#include "common/op_def/no_op_op.h"
#include "common/op_def/no_op_operator.h"
#include <string> #include <string>


namespace ge { namespace ge {

parser/common/op_def/no_op_op.h → parser/common/op_def/no_op_operator.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@
#ifndef DOMI_OP_NO_OP_OP_H_ #ifndef DOMI_OP_NO_OP_OP_H_
#define DOMI_OP_NO_OP_OP_H_ #define DOMI_OP_NO_OP_OP_H_
#include "parser/common/op_def/operator.h" #include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"


namespace ge { namespace ge {
class NoOpOperator : public ParserOperator { class NoOpOperator : public ParserOperator {

+ 1
- 1
parser/common/op_def/operator.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


+ 1
- 1
parser/common/op_def/operator.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


parser/common/op_def/ref_switch_op.cc → parser/common/op_def/ref_switch_operator.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
*/ */


// AUTO GEN PLEASE DO NOT MODIFY IT // AUTO GEN PLEASE DO NOT MODIFY IT
#include "common/op_def/ref_switch_op.h"
#include "common/op_def/ref_switch_operator.h"


namespace ge { namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::RefSwitchOperator() : ParserOperator("RefSwitch") {} FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::RefSwitchOperator() : ParserOperator("RefSwitch") {}

parser/common/op_def/ref_switch_op.h → parser/common/op_def/ref_switch_operator.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

parser/common/op_def/shape_n_op.cc → parser/common/op_def/shape_n_operator.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
*/ */


// AUTO GEN PLEASE DO NOT MODIFY IT // AUTO GEN PLEASE DO NOT MODIFY IT
#include "common/op_def/shape_n_op.h"
#include "common/op_def/shape_n_operator.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "framework/omg/parser/parser_types.h" #include "framework/omg/parser/parser_types.h"



parser/common/op_def/shape_n_op.h → parser/common/op_def/shape_n_operator.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

parser/common/op_def/var_is_initialized_op_op.cc → parser/common/op_def/var_is_initialized_op_operator.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
*/ */


// AUTO GEN PLEASE DO NOT MODIFY IT // AUTO GEN PLEASE DO NOT MODIFY IT
#include "common/op_def/var_is_initialized_op_op.h"
#include "common/op_def/var_is_initialized_op_operator.h"
#include <string> #include <string>
#include <vector> #include <vector>



parser/common/op_def/var_is_initialized_op_op.h → parser/common/op_def/var_is_initialized_op_operator.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

parser/common/op_def/variable_op.cc → parser/common/op_def/variable_operator.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include "parser/common/op_def/variable_op.h"
#include "parser/common/op_def/variable_operator.h"


#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"



parser/common/op_def/variable_op.h → parser/common/op_def/variable_operator.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

+ 11
- 11
parser/module.mk View File

@@ -92,18 +92,18 @@ PARSER_SCOPE_SRC_FILES := \
tensorflow/scope/scope_pass_manager.cc \ tensorflow/scope/scope_pass_manager.cc \


FMK_COMMON_SRC_FILES := \ FMK_COMMON_SRC_FILES := \
tensorflow/graph_functiondef.cc \
tensorflow/graph_optimizer.cc \
tensorflow/graph_to_function_def.cc \
tensorflow/parser_graph_optimizer.cc \
tensorflow/iterator_fusion_pass.cc \ tensorflow/iterator_fusion_pass.cc \
common/op_def/arg_op.cc \
common/op_def/constant_op.cc \
common/op_def/fill_op.cc \
common/op_def/frameworkop_op.cc \
common/op_def/no_op_op.cc \
common/op_def/ref_switch_op.cc \
common/op_def/shape_n_op.cc \
common/op_def/var_is_initialized_op_op.cc \
common/op_def/variable_op.cc \
common/op_def/arg_op_operator.cc \
common/op_def/constant_operator.cc \
common/op_def/fill_operator.cc \
common/op_def/framework_op_operator.cc \
common/op_def/no_op_operator.cc \
common/op_def/ref_switch_operator.cc \
common/op_def/shape_n_operator.cc \
common/op_def/var_is_initialized_op_operator.cc \
common/op_def/variable_operator.cc \


LOCAL_SRC_FILES := $(PARSER_TENSORFLOW_SRC_FILES) LOCAL_SRC_FILES := $(PARSER_TENSORFLOW_SRC_FILES)
LOCAL_SRC_FILES += $(PARSER_SCOPE_SRC_FILES) LOCAL_SRC_FILES += $(PARSER_SCOPE_SRC_FILES)


+ 3
- 5
parser/onnx/onnx_constant_parser.h View File

@@ -69,16 +69,14 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser {
break; break;
} }
#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ #define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \
case dt_type: \
{ \
case dt_type: { \
unique_ptr<value_type> addr_trans(new(std::nothrow) value_type[count]()); \ unique_ptr<value_type> addr_trans(new(std::nothrow) value_type[count]()); \
GE_CHECK_NOTNULL(addr_trans); \ GE_CHECK_NOTNULL(addr_trans); \
for (int32_t i = 0; i < (count); i++) { \ for (int32_t i = 0; i < (count); i++) { \
*(addr_trans.get() + i) = static_cast<value_type>(*((addr).get() + i)); \ *(addr_trans.get() + i) = static_cast<value_type>(*((addr).get() + i)); \
} \ } \
(tensor).SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), (count) * sizeof(value_type)); \ (tensor).SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), (count) * sizeof(value_type)); \
break; \
} \
break; } \


CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor) CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor)
CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor) CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor)
@@ -89,7 +87,7 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser {
#undef CASE_SET_DATA #undef CASE_SET_DATA
default: default:
{ {
tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(T));
tensor.SetData(PtrToPtr<T, uint8_t>(addr.get()), count * sizeof(T));
break; break;
} }
} }


+ 19
- 19
parser/onnx/onnx_file_constant_parser.cc View File

@@ -31,11 +31,11 @@ using GeTensorDesc = ge::GeTensorDesc;
using namespace ge::parser; using namespace ge::parser;


namespace { namespace {
const std::string kAttrShape = "shape";
const std::string kAttrDataType = "dtype";
const std::string kFileConstantPath = "file_constant_path";
const std::string kLocation = "location";
const std::string kOffset = "offset";
const char *const kAttrShape = "shape";
const char *const kAttrDataType = "dtype";
const char *const kFileConstantPath = "file_constant_path";
const char *const kLocation = "location";
const char *const kOffset = "offset";
const int64_t kOffsetCoefficient = 4096; const int64_t kOffsetCoefficient = 4096;
const char *const kFileConstant = "FileConstant"; const char *const kFileConstant = "FileConstant";
} }
@@ -46,7 +46,7 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator &
GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str());


ge::onnx::TensorProto tensor_proto; ge::onnx::TensorProto tensor_proto;
if (GetTensorProto(node, tensor_proto) != SUCCESS) {
if (GetTensorProto(*node, tensor_proto) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str()); REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str());
GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str()); GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str());
return FAILED; return FAILED;
@@ -65,29 +65,29 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator &
return SUCCESS; return SUCCESS;
} }


Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto *node_proto,
ge::onnx::TensorProto &tensor_proto) {
for (const auto &it : node_proto->attribute()) {
Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto &node_proto,
ge::onnx::TensorProto &tensor_proto) const {
for (const auto &it : node_proto.attribute()) {
if (it.name() != ge::kAttrNameValue) { if (it.name() != ge::kAttrNameValue) {
continue; continue;
} }
tensor_proto = it.t(); tensor_proto = it.t();
return SUCCESS; return SUCCESS;
} }
REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto->name().c_str());
GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto->name().c_str());
REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto.name().c_str());
GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto.name().c_str());
return FAILED; return FAILED;
} }


void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const {
std::vector<int64_t> tmp_shape; std::vector<int64_t> tmp_shape;
for (int i = 0; i < tensor_proto.dims_size(); i++) { for (int i = 0; i < tensor_proto.dims_size(); i++) {
tmp_shape.push_back(tensor_proto.dims(i)); tmp_shape.push_back(tensor_proto.dims(i));
} }
op_def.SetAttr(kAttrShape.c_str(), tmp_shape);
op_def.SetAttr(kAttrShape, tmp_shape);
} }


Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const {
int64_t data_type = tensor_proto.data_type(); int64_t data_type = tensor_proto.data_type();
ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type); ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type);
if (type >= ge::DataType::DT_UNDEFINED) { if (type >= ge::DataType::DT_UNDEFINED) {
@@ -96,11 +96,11 @@ Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor
return FAILED; return FAILED;
} }


op_def.SetAttr(kAttrDataType.c_str(), type);
op_def.SetAttr(kAttrDataType, type);
return SUCCESS; return SUCCESS;
} }


Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const {
ge::NamedAttrs attrs; ge::NamedAttrs attrs;
for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) { for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) {
const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i); const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i);
@@ -116,12 +116,12 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro
GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str());
return FAILED; return FAILED;
} }
op_def.SetAttr(kFileConstantPath.c_str(), attrs);
op_def.SetAttr(kFileConstantPath, attrs);
return SUCCESS; return SUCCESS;
} }


Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto,
ge::NamedAttrs &attrs) {
ge::NamedAttrs &attrs) const {
if (string_proto.key() == kLocation) { if (string_proto.key() == kLocation) {
AttrUtils::SetStr(attrs, kLocation, string_proto.value()); AttrUtils::SetStr(attrs, kLocation, string_proto.value());
} else { } else {
@@ -134,7 +134,7 @@ Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProt
return FAILED; return FAILED;
} }
if (string_proto.key() == kOffset) { if (string_proto.key() == kOffset) {
if (std::numeric_limits<int64_t>::max() / kOffsetCoefficient < value) {
if (value > (std::numeric_limits<int64_t>::max() / kOffsetCoefficient)) {
REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value);
GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value);
return FAILED; return FAILED;


+ 5
- 5
parser/onnx/onnx_file_constant_parser.h View File

@@ -26,11 +26,11 @@ class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser {
Status ParseParams(const Message *op_src, ge::Operator &op_def) override; Status ParseParams(const Message *op_src, ge::Operator &op_def) override;


private: private:
Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
Status GetTensorProto(const ge::onnx::NodeProto *node_proto, ge::onnx::TensorProto &tensor_proto);
Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs);
Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const;
Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const;
void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const;
Status GetTensorProto(const ge::onnx::NodeProto &node_proto, ge::onnx::TensorProto &tensor_proto) const;
Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs) const;
}; };
} // namespace ge } // namespace ge




+ 10
- 10
parser/onnx/onnx_parser.cc View File

@@ -232,7 +232,8 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra
domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type); domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type);
if (post_func == nullptr) { if (post_func == nullptr) {
GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str()); GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str());
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) {
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS ||
parse_func_v2 == nullptr) {
GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str()); GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str());
return SUCCESS; return SUCCESS;
} }
@@ -522,9 +523,9 @@ Status OnnxModelParser::SetOperatorInputs() {
auto src_op = output_op_iter->second; auto src_op = output_op_iter->second;
int dst_index = input_node_index.second; int dst_index = input_node_index.second;
int src_index = out_node_index.second; int src_index = out_node_index.second;
GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", src_index,
ParserUtils::GetOperatorName(src_op).c_str(), dst_index,
ParserUtils::GetOperatorName(dst_op).c_str());
GELOGI("Start add output:%d of op:%s as input:%d of op:%s.",
src_index, ParserUtils::GetOperatorName(src_op).c_str(),
dst_index, ParserUtils::GetOperatorName(dst_op).c_str());
auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op); auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op);
GE_CHECK_NOTNULL(dst_op_desc); GE_CHECK_NOTNULL(dst_op_desc);
auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op); auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op);
@@ -689,7 +690,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve
return PARAM_INVALID; return PARAM_INVALID;
} }
input_ops.emplace_back(in_op->second); input_ops.emplace_back(in_op->second);
GELOGI("Model assigned input node name: %s", ParserUtils::GetOperatorName(in_op->second).c_str());
GELOGI("Model assigned input node name: %s",
ParserUtils::GetOperatorName(in_op->second).c_str());
} }
return SUCCESS; return SUCCESS;
} }
@@ -717,7 +719,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vec
int index = node_name_index.second; int index = node_name_index.second;
output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)}); output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)});
out_tensor_to_nodes[output_name] = std::make_pair(node_name, index); out_tensor_to_nodes[output_name] = std::make_pair(node_name, index);
GELOGI("out node index %d, node:%s", index, node_name.c_str());
GELOGI("Out node index %d, node:%s", index, node_name.c_str());
} }
} }
return SUCCESS; return SUCCESS;
@@ -934,16 +936,13 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
cur_compute_graph->GetName().c_str()); cur_compute_graph->GetName().c_str());
return ret; return ret;
} }

} }
UpdateDataFormat(root_graph); UpdateDataFormat(root_graph);
return SUCCESS; return SUCCESS;
} }


Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) {

ClearMembers(); ClearMembers();

GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX),
"Run ProtoType Pass Failed"); "Run ProtoType Pass Failed");
// 1. Get all inializer. // 1. Get all inializer.
@@ -1174,7 +1173,8 @@ Status OnnxModelParser::SetOutputsInfo(const ParserUtils::OutputMapping &final_o
default_out_nodes.emplace_back(output_node_info); default_out_nodes.emplace_back(output_node_info);
output_tensor_names.emplace_back(tensor_name); output_tensor_names.emplace_back(tensor_name);
GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].", GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].",
output_node_info.first.c_str(), output_node_info.second, tensor_name.c_str());
output_node_info.first.c_str(),
output_node_info.second, tensor_name.c_str());
} }
return SUCCESS; return SUCCESS;
} }


+ 3
- 1
parser/onnx/onnx_util.h View File

@@ -17,6 +17,8 @@
#ifndef PARSER_ONNX_ONNX_UTIL_PARSER_H_ #ifndef PARSER_ONNX_ONNX_UTIL_PARSER_H_
#define PARSER_ONNX_ONNX_UTIL_PARSER_H_ #define PARSER_ONNX_ONNX_UTIL_PARSER_H_


#include <string>
#include <cstdint>
#include "external/graph/types.h" #include "external/graph/types.h"


namespace OnnxDataType { namespace OnnxDataType {
@@ -59,4 +61,4 @@ class OnnxUtil {
}; };
} // namespace ge } // namespace ge


#endif //PARSER_ONNX_ONNX_UTIL_PARSER_H_
#endif // PARSER_ONNX_ONNX_UTIL_PARSER_H_

+ 13
- 12
parser/onnx/subgraph_adapter/if_subgraph_adapter.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@ using parser::IF;
namespace { namespace {
const std::map<std::string, int> kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}}; const std::map<std::string, int> kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}};
const int kIfNodeAttrSize = 2; const int kIfNodeAttrSize = 2;
const char *kIf = "If";
} // namespace } // namespace
domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
@@ -33,7 +34,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(),
parent_node->op_type().c_str()); parent_node->op_type().c_str());


auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name);
auto ret = ParseIfNodeSubgraphs(*parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[Parse][Node] Parse if node failed."); GELOGE(ret, "[Parse][Node] Parse if node failed.");
REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str());
@@ -44,19 +45,19 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(
} }


domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
ge::onnx::NodeProto &parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) const { std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) const {
if (parent_node->attribute_size() != kIfNodeAttrSize) {
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
if (parent_node.attribute_size() != kIfNodeAttrSize) {
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size());
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size());
return FAILED; return FAILED;
} }


GELOGD("node attribute size:%d.", parent_node->attribute_size());
GELOGD("node attribute size:%d.", parent_node.attribute_size());
std::set<std::string> all_inputs; std::set<std::string> all_inputs;
// for onnx graph, the first attribute may be else branch and the second attribute may be then branch // for onnx graph, the first attribute may be else branch and the second attribute may be then branch
for (int i = 0; i < parent_node->attribute_size(); i++) {
ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i);
for (int i = 0; i < parent_node.attribute_size(); i++) {
ge::onnx::AttributeProto *attribute = parent_node.mutable_attribute(i);
GE_CHECK_NOTNULL(attribute); GE_CHECK_NOTNULL(attribute);
std::string attr_name = attribute->name(); std::string attr_name = attribute->name();
auto itr = kAttrNameToIndex.find(attr_name); auto itr = kAttrNameToIndex.find(attr_name);
@@ -68,7 +69,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
return FAILED; return FAILED;
} }
std::string unique_subgraph_name; std::string unique_subgraph_name;
std::string node_name = parent_node->name();
std::string node_name = parent_node.name();
if (!parent_graph_name.empty()) { if (!parent_graph_name.empty()) {
node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name); node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name);
} }
@@ -90,7 +91,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
AddInputNodeForGraph(all_inputs, *onnx_graph); AddInputNodeForGraph(all_inputs, *onnx_graph);
} }


AddInputForParentNode(all_inputs, *parent_node);
AddInputForParentNode(all_inputs, parent_node);
return SUCCESS; return SUCCESS;
} }


@@ -135,5 +136,5 @@ void IfSubgraphAdapter::AddInputForParentNode(const std::set<std::string> &all_i
parent_node.add_input(input_name); parent_node.add_input(input_name);
} }
} }
REGISTER_SUBGRAPH_ADAPTER_CREATOR(IF, IfSubgraphAdapter);
REGISTER_SUBGRAPH_ADAPTER_CREATOR(kIf, IfSubgraphAdapter);
} // namespace ge } // namespace ge

+ 3
- 2
parser/onnx/subgraph_adapter/if_subgraph_adapter.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include "subgraph_adapter.h" #include "subgraph_adapter.h"
#include "parser/onnx/onnx_util.h"


namespace ge { namespace ge {
class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter {
@@ -30,7 +31,7 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter {
const std::string &parent_graph_name = "") override; const std::string &parent_graph_name = "") override;


private: private:
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto &parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name) const; const std::string &parent_graph_name) const;
domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const;


+ 1
- 3
parser/onnx/subgraph_adapter/subgraph_adapter.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -36,8 +36,6 @@
#include "proto/onnx/ge_onnx.pb.h" #include "proto/onnx/ge_onnx.pb.h"
#include "external/register/register_error_codes.h" #include "external/register/register_error_codes.h"
#include "framework/omg/parser/parser_types.h" #include "framework/omg/parser/parser_types.h"
#include "parser/onnx/onnx_util.h"

namespace ge { namespace ge {
class PARSER_FUNC_VISIBILITY SubgraphAdapter { class PARSER_FUNC_VISIBILITY SubgraphAdapter {
public: public:


+ 1
- 1
parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


+ 2
- 3
parser/onnx/subgraph_adapter/subgraph_adapter_factory.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -62,7 +62,6 @@ protected:
* @brief SubgraphAdapter creation function * @brief SubgraphAdapter creation function
* @return Created SubgraphAdapter * @return Created SubgraphAdapter
*/ */
// typedef shared_ptr<SubgraphAdapter> (*CREATOR_FUN)(void);
using CREATOR_FUN = std::function<std::shared_ptr<SubgraphAdapter>(void)>; using CREATOR_FUN = std::function<std::shared_ptr<SubgraphAdapter>(void)>;


/** /**
@@ -105,7 +104,7 @@ public:
* @param [in] op_type Op type * @param [in] op_type Op type
* @param [in] clazz SubgraphAdapter implementation class * @param [in] clazz SubgraphAdapter implementation class
*/ */
#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \
#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \
std::shared_ptr<SubgraphAdapter> Creator_##op_type##_Subgraph_Adapter() { \ std::shared_ptr<SubgraphAdapter> Creator_##op_type##_Subgraph_Adapter() { \
std::shared_ptr<clazz> ptr(new (std::nothrow) clazz()); \ std::shared_ptr<clazz> ptr(new (std::nothrow) clazz()); \
if (ptr == nullptr) { \ if (ptr == nullptr) { \


+ 33
- 32
parser/stub/gen_stubapi.py View File

@@ -167,6 +167,33 @@ class H2CC(object):
del self.stack_template del self.stack_template
del self.func_list_exist del self.func_list_exist


@staticmethod
def implement_function(func):
function_def = ''
function_def += '{\n'

all_items = func.split()
start = 0
return_type = all_items[start]
if return_type == "const":
start += 1
return_type = all_items[start]
if return_type.startswith(('std::map', 'std::set', 'std::vector')):
return_type = "std::map"
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')):
return_type = "Ptr"
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
return_type += "&"
if RETURN_STATEMENTS.__contains__(return_type):
function_def += RETURN_STATEMENTS[return_type]
else:
logging.warning("Unhandled return type[%s]", return_type)

function_def += '\n'
function_def += '}\n'
function_def += '\n'
return function_def

def just_skip(self): def just_skip(self):
# skip blank line or comment # skip blank line or comment
if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search(
@@ -263,6 +290,7 @@ class H2CC(object):
logging.info('Added %s functions', len(self.func_list_exist)) logging.info('Added %s functions', len(self.func_list_exist))
logging.info('Successfully converted,please see ' + self.output_file) logging.info('Successfully converted,please see ' + self.output_file)



def handle_func1(self, line): def handle_func1(self, line):
""" """
:param line: :param line:
@@ -461,12 +489,6 @@ class H2CC(object):
logging.info("func_name[%s]", func_name) logging.info("func_name[%s]", func_name)
return line, func_name return line, func_name


def write_func_content(self, content, func_name, need_generate):
if not (func_name in self.func_list_exist) and need_generate:
self.output_fd.write(content)
self.func_list_exist.append(func_name)
logging.info('add func:[%s]', func_name)

def gen_comment(self, start_i): def gen_comment(self, start_i):
comment_line = '' comment_line = ''
# Function comments are on top of function declarations, copy them over # Function comments are on top of function declarations, copy them over
@@ -488,32 +510,11 @@ class H2CC(object):
break break
return comment_line return comment_line


@staticmethod
def implement_function(func):
function_def = ''
function_def += '{\n'

all_items = func.split()
start = 0
return_type = all_items[start]
if return_type == "const":
start += 1
return_type = all_items[start]
if return_type.startswith(('std::map', 'std::set', 'std::vector')):
return_type = "std::map"
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')):
return_type = "Ptr"
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
return_type += "&"
if RETURN_STATEMENTS.__contains__(return_type):
function_def += RETURN_STATEMENTS[return_type]
else:
logging.warning("Unhandled return type[%s]", return_type)

function_def += '\n'
function_def += '}\n'
function_def += '\n'
return function_def
def write_func_content(self, content, func_name, need_generate):
if not (func_name in self.func_list_exist) and need_generate:
self.output_fd.write(content)
self.func_list_exist.append(func_name)
logging.info('add func:[%s]', func_name)




def collect_header_files(path): def collect_header_files(path):


parser/tensorflow/graph_functiondef.cc → parser/tensorflow/graph_to_function_def.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include "graph_functiondef.h"
#include "graph_to_function_def.h"
#include <iostream> #include <iostream>
#include "common/fmk_error_codes.h" #include "common/fmk_error_codes.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"

parser/tensorflow/graph_functiondef.h → parser/tensorflow/graph_to_function_def.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

+ 2
- 2
parser/tensorflow/iterator_fusion_pass.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@


#include "framework/omg/parser/parser_types.h" #include "framework/omg/parser/parser_types.h"
#include "common/util.h" #include "common/util.h"
#include "graph_optimizer.h"
#include "parser_graph_optimizer.h"
#include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_inner_error_codes.h"


namespace ge { namespace ge {


+ 1
- 1
parser/tensorflow/iterator_fusion_pass.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


parser/tensorflow/graph_optimizer.cc → parser/tensorflow/parser_graph_optimizer.cc View File

@@ -14,14 +14,14 @@
* limitations under the License. * limitations under the License.
*/ */


#include "graph_optimizer.h"
#include "parser_graph_optimizer.h"
#include "graph/op_types.h" #include "graph/op_types.h"
#include "common/types_map.h" #include "common/types_map.h"
#include "common/util.h" #include "common/util.h"
#include "framework/omg/parser/parser_inner_ctx.h" #include "framework/omg/parser/parser_inner_ctx.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
#include "graph_functiondef.h"
#include "graph_to_function_def.h"
#include "parser/common/acl_graph_parser_util.h" #include "parser/common/acl_graph_parser_util.h"
#include "register/op_registry.h" #include "register/op_registry.h"


@@ -188,7 +188,10 @@ Status CollectNodeFuncs(vector<ge::NodePtr> &nodes, FunctionDefLibrary *library)


Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) { Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) {
ComputeGraphPtr sub_graph = nullptr; ComputeGraphPtr sub_graph = nullptr;
GE_MAKE_SHARED(sub_graph = std::make_shared<ComputeGraph>("subGraph"), sub_graph = nullptr; return PARAM_INVALID);
GE_MAKE_SHARED(
sub_graph = std::make_shared<ComputeGraph>("subGraph"),
sub_graph = nullptr;
return PARAM_INVALID);


unordered_map<string, NodePtr> node_map; unordered_map<string, NodePtr> node_map;
vector<InDataAnchorPtr> input_anchors; vector<InDataAnchorPtr> input_anchors;

parser/tensorflow/graph_optimizer.h → parser/tensorflow/parser_graph_optimizer.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

+ 1
- 1
parser/tensorflow/scope/scope_pass_manager.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


+ 1
- 1
parser/tensorflow/scope/scope_pass_manager.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


+ 2
- 2
parser/tensorflow/tensorflow_arg_parser.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include "parser/common/op_def/arg_op.h"
#include "parser/common/op_def/arg_op_operator.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/omg/parser/parser_inner_ctx.h" #include "framework/omg/parser/parser_inner_ctx.h"
#include "graph/compute_graph.h" #include "graph/compute_graph.h"


+ 1
- 1
parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


+ 1
- 1
parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


+ 1
- 1
parser/tensorflow/tensorflow_constant_parser.cc View File

@@ -19,7 +19,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "parser/common/acl_graph_parser_util.h" #include "parser/common/acl_graph_parser_util.h"
#include "parser/common/op_def/constant_op.h"
#include "parser/common/op_def/constant_operator.h"
#include "parser/common/op_def/ir_pb_converter.h" #include "parser/common/op_def/ir_pb_converter.h"
#include "parser/common/util.h" #include "parser/common/util.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"


+ 1
- 1
parser/tensorflow/tensorflow_constant_parser.h View File

@@ -17,7 +17,7 @@
#ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_CONSTANT_PARSER_H_ #ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_CONSTANT_PARSER_H_
#define GE_PARSER_TENSORFLOW_TENSORFLOW_CONSTANT_PARSER_H_ #define GE_PARSER_TENSORFLOW_TENSORFLOW_CONSTANT_PARSER_H_


#include "common/op_def/constant_op.h"
#include "common/op_def/constant_operator.h"
#include "parser/common/data_op_parser.h" #include "parser/common/data_op_parser.h"
#include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_op_parser.h"




+ 1
- 1
parser/tensorflow/tensorflow_fill_parser.cc View File

@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include "parser/common/op_def/fill_op.h"
#include "parser/common/op_def/fill_operator.h"
#include "parser/tensorflow/tensorflow_parser_register.h" #include "parser/tensorflow/tensorflow_parser_register.h"
#include "framework/omg/parser/parser_types.h" #include "framework/omg/parser/parser_types.h"




+ 1
- 1
parser/tensorflow/tensorflow_frameworkop_parser.cc View File

@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include "parser/common/op_def/frameworkop_op.h"
#include "parser/common/op_def/framework_op_operator.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#include "framework/omg/parser/parser_types.h" #include "framework/omg/parser/parser_types.h"


+ 1
- 1
parser/tensorflow/tensorflow_no_op_parser.cc View File

@@ -18,7 +18,7 @@
#include "common/util.h" #include "common/util.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "parser/common/op_def/ir_pb_converter.h" #include "parser/common/op_def/ir_pb_converter.h"
#include "parser/common/op_def/no_op_op.h"
#include "parser/common/op_def/no_op_operator.h"
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"


using domi::TENSORFLOW; using domi::TENSORFLOW;


+ 1
- 1
parser/tensorflow/tensorflow_ref_switch_parser.cc View File

@@ -17,7 +17,7 @@
#include "parser/tensorflow/tensorflow_ref_switch_parser.h" #include "parser/tensorflow/tensorflow_ref_switch_parser.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "parser/common/op_def/ir_pb_converter.h" #include "parser/common/op_def/ir_pb_converter.h"
#include "parser/common/op_def/ref_switch_op.h"
#include "parser/common/op_def/ref_switch_operator.h"
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#include "parser/common/util.h" #include "parser/common/util.h"




+ 1
- 1
parser/tensorflow/tensorflow_ref_switch_parser.h View File

@@ -17,7 +17,7 @@
#ifndef DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_REF_SWITCH_H_ #ifndef DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_REF_SWITCH_H_
#define DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_REF_SWITCH_H_ #define DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_REF_SWITCH_H_


#include "common/op_def/ref_switch_op.h"
#include "common/op_def/ref_switch_operator.h"
#include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_op_parser.h"


namespace ge { namespace ge {


+ 1
- 1
parser/tensorflow/tensorflow_shape_n_parser.cc View File

@@ -18,7 +18,7 @@
#include "parser/common/op_def/ir_pb_converter.h" #include "parser/common/op_def/ir_pb_converter.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#include "parser/common/op_def/shape_n_op.h"
#include "parser/common/op_def/shape_n_operator.h"
#include "parser/common/util.h" #include "parser/common/util.h"


using domi::TENSORFLOW; using domi::TENSORFLOW;


+ 1
- 1
parser/tensorflow/tensorflow_shape_n_parser.h View File

@@ -17,7 +17,7 @@
#ifndef DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_SHAPE_N_H_ #ifndef DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_SHAPE_N_H_
#define DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_SHAPE_N_H_ #define DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_SHAPE_N_H_


#include "common/op_def/shape_n_op.h"
#include "common/op_def/shape_n_operator.h"
#include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_op_parser.h"


namespace ge { namespace ge {


+ 1
- 1
parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc View File

@@ -15,7 +15,7 @@
*/ */


#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "parser/common/op_def/var_is_initialized_op_op.h"
#include "parser/common/op_def/var_is_initialized_op_operator.h"
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_op_parser.h"
#include "parser/tensorflow/tensorflow_parser_register.h" #include "parser/tensorflow/tensorflow_parser_register.h"


+ 1
- 1
parser/tensorflow/tensorflow_variable_v2_parser.cc View File

@@ -21,7 +21,7 @@
#include "graph/op_desc.h" #include "graph/op_desc.h"
#include "graph/utils/attr_utils.h" #include "graph/utils/attr_utils.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "parser/common/op_def/variable_op.h"
#include "parser/common/op_def/variable_operator.h"
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_op_parser.h"
#include "parser/tensorflow/tensorflow_parser_register.h" #include "parser/tensorflow/tensorflow_parser_register.h"


+ 11
- 11
tests/st/CMakeLists.txt View File

@@ -249,17 +249,17 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/common/convert/message2operator.cc" "${PARSER_DIR}/parser/common/convert/message2operator.cc"
"${PARSER_DIR}/parser/common/data_op_parser.cc" "${PARSER_DIR}/parser/common/data_op_parser.cc"
"${PARSER_DIR}/parser/common/model_saver.cc" "${PARSER_DIR}/parser/common/model_saver.cc"
"${PARSER_DIR}/parser/common/op_def/arg_op.cc"
"${PARSER_DIR}/parser/common/op_def/constant_op.cc"
"${PARSER_DIR}/parser/common/op_def/fill_op.cc"
"${PARSER_DIR}/parser/common/op_def/frameworkop_op.cc"
"${PARSER_DIR}/parser/common/op_def/arg_op_operator.cc"
"${PARSER_DIR}/parser/common/op_def/constant_operator.cc"
"${PARSER_DIR}/parser/common/op_def/fill_operator.cc"
"${PARSER_DIR}/parser/common/op_def/framework_op_operator.cc"
"${PARSER_DIR}/parser/common/op_def/ir_pb_converter.cc" "${PARSER_DIR}/parser/common/op_def/ir_pb_converter.cc"
"${PARSER_DIR}/parser/common/op_def/no_op_op.cc"
"${PARSER_DIR}/parser/common/op_def/no_op_operator.cc"
"${PARSER_DIR}/parser/common/op_def/operator.cc" "${PARSER_DIR}/parser/common/op_def/operator.cc"
"${PARSER_DIR}/parser/common/op_def/ref_switch_op.cc"
"${PARSER_DIR}/parser/common/op_def/shape_n_op.cc"
"${PARSER_DIR}/parser/common/op_def/variable_op.cc"
"${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_op.cc"
"${PARSER_DIR}/parser/common/op_def/ref_switch_operator.cc"
"${PARSER_DIR}/parser/common/op_def/shape_n_operator.cc"
"${PARSER_DIR}/parser/common/op_def/variable_operator.cc"
"${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_operator.cc"
"${PARSER_DIR}/parser/common/op_map.cc" "${PARSER_DIR}/parser/common/op_map.cc"
"${PARSER_DIR}/parser/common/op_parser_factory.cc" "${PARSER_DIR}/parser/common/op_parser_factory.cc"
"${PARSER_DIR}/parser/common/parser_api.cc" "${PARSER_DIR}/parser/common/parser_api.cc"
@@ -284,8 +284,8 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/onnx/onnx_util.cc" "${PARSER_DIR}/parser/onnx/onnx_util.cc"
"${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc" "${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc"
"${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc" "${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc"
"${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc"
"${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc"
"${PARSER_DIR}/parser/tensorflow/graph_to_function_def.cc"
"${PARSER_DIR}/parser/tensorflow/parser_graph_optimizer.cc"
"${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc"
"${PARSER_DIR}/parser/tensorflow/scope/scope_pass_manager.cc" "${PARSER_DIR}/parser/tensorflow/scope/scope_pass_manager.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_arg_parser.cc" "${PARSER_DIR}/parser/tensorflow/tensorflow_arg_parser.cc"


+ 48
- 2
tests/st/testcase/test_caffe_parser.cc View File

@@ -765,7 +765,7 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_CheckLayersSize_test)
layer->set_name("Abs"); layer->set_name("Abs");
layer->set_type("AbsVal"); layer->set_type("AbsVal");


Status ret = weightParser.CheckLayersSize(layer);
Status ret = weightParser.CheckLayersSize(*layer);
EXPECT_EQ(ret, FAILED); EXPECT_EQ(ret, FAILED);
} }


@@ -809,8 +809,54 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test)
const google::protobuf::Message *proto = factory.GetPrototype(descriptor); const google::protobuf::Message *proto = factory.GetPrototype(descriptor);
const google::protobuf::Message *message = proto->New(); const google::protobuf::Message *message = proto->New();


Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph);
Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph);
delete message; delete message;
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }

TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_DimSize_0)
{
CaffeModelParser modelParser;
domi::caffe::NetParameter net;
net.add_input("111");
net.add_input_shape();
bool input_data_flag = true;
Status ret = modelParser.ParseInput(net, input_data_flag);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_Err1)
{
CaffeModelParser modelParser;
domi::caffe::NetParameter net;
net.add_input("111");
net.add_input("222");
net.add_input_shape();
bool input_data_flag = true;
Status ret = modelParser.ParseInput(net, input_data_flag);
EXPECT_EQ(ret, FAILED);
}

TEST_F(STestCaffeParser, CaffeModelParser_ParserLayerParameter_Succ)
{
CaffeModelParser modelParser;
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/";
const char *model_path = model_file.c_str();

std::string custom_proto = model_file;
std::string caffe_proto = model_file;
std::vector<ge::Operator> operators;
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Data", "Input");
ge::Operator op_src = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);
operators.emplace_back(op_src);

model_file = case_dir + "/origin_models/caffe_add.pbtxt";
custom_proto = case_dir + "/../../../metadef/proto/caffe/caffe.proto";
model_path = model_file.c_str();
std::string caffe_proto_path = case_dir + "/../../../metadef/proto/caffe/caffe.proto";
auto ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto_path, operators);
EXPECT_EQ(ret, SUCCESS);
}
} // namespace ge } // namespace ge

+ 28
- 10
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -34,17 +34,17 @@
#include "external/parser/tensorflow_parser.h" #include "external/parser/tensorflow_parser.h"
#include "parser/tensorflow/tensorflow_constant_parser.h" #include "parser/tensorflow/tensorflow_constant_parser.h"
#include "common/types.h" #include "common/types.h"
#include "parser/common/op_def/variable_op.h"
#include "parser/common/op_def/variable_operator.h"
#include "parser/tensorflow/tensorflow_ref_switch_parser.h" #include "parser/tensorflow/tensorflow_ref_switch_parser.h"
#include "parser/tensorflow/tensorflow_fusion_op_parser.h" #include "parser/tensorflow/tensorflow_fusion_op_parser.h"
#include "parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h" #include "parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h"
#include "parser/common/op_def/arg_op.h"
#include "parser/common/op_def/arg_op_operator.h"
#include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h"
#include "parser/tensorflow/tensorflow_reshape_parser.h" #include "parser/tensorflow/tensorflow_reshape_parser.h"
#include "parser/tensorflow/tensorflow_custom_parser_adapter.h" #include "parser/tensorflow/tensorflow_custom_parser_adapter.h"
#include "parser/tensorflow/tensorflow_squeeze_parser.h" #include "parser/tensorflow/tensorflow_squeeze_parser.h"
#include "parser/tensorflow/graph_functiondef.h"
#include "parser/tensorflow/graph_optimizer.h"
#include "parser/tensorflow/graph_to_function_def.h"
#include "parser/tensorflow/parser_graph_optimizer.h"
#include "cce/dnn_base_def.hpp" #include "cce/dnn_base_def.hpp"
#include "parser/tensorflow/scope/scope_pass_manager.h" #include "parser/tensorflow/scope/scope_pass_manager.h"
#include "parser/tensorflow/tensorflow_util.h" #include "parser/tensorflow/tensorflow_util.h"
@@ -52,10 +52,10 @@
#include "parser/tensorflow/tensorflow_enter_parser.h" #include "parser/tensorflow/tensorflow_enter_parser.h"
#include "parser/common/op_def/ir_pb_converter.h" #include "parser/common/op_def/ir_pb_converter.h"
#include "parser/common/tuple.h" #include "parser/common/tuple.h"
#include "common/op_def/frameworkop_op.h"
#include "common/op_def/shape_n_op.h"
#include "common/op_def/var_is_initialized_op_op.h"
#include "common/op_def/fill_op.h"
#include "common/op_def/framework_op_operator.h"
#include "common/op_def/shape_n_operator.h"
#include "common/op_def/var_is_initialized_op_operator.h"
#include "common/op_def/fill_operator.h"
#include "common/convert/pb2json.h" #include "common/convert/pb2json.h"
#include "common/convert/message2operator.h" #include "common/convert/message2operator.h"
#include "parser/common/proto_file_parser.h" #include "parser/common/proto_file_parser.h"
@@ -70,7 +70,7 @@
#include "parser/common/prototype_pass_manager.h" #include "parser/common/prototype_pass_manager.h"
#include "parser/common/register_tbe.h" #include "parser/common/register_tbe.h"
#include "parser/common/pass_manager.h" #include "parser/common/pass_manager.h"
#include "parser/tensorflow/graph_optimizer.h"
#include "parser/tensorflow/parser_graph_optimizer.h"
#include "metadef/inc/register/scope/scope_pass_registry_impl.h" #include "metadef/inc/register/scope/scope_pass_registry_impl.h"
#include "register/scope/scope_fusion_pass_register.h" #include "register/scope/scope_fusion_pass_register.h"
#undef protected #undef protected
@@ -678,6 +678,7 @@ namespace {


if ((_name== "S") || (_name == "K")) { if ((_name== "S") || (_name == "K")) {
int index = 0; int index = 0;

ge::AttrUtils::SetInt(opDef, "T", 1); ge::AttrUtils::SetInt(opDef, "T", 1);
ge::AttrUtils::SetInt(opDef, "arg_index", index); ge::AttrUtils::SetInt(opDef, "arg_index", index);
ge::AttrUtils::SetInt(opDef, "ret_index", index); ge::AttrUtils::SetInt(opDef, "ret_index", index);
@@ -1029,7 +1030,9 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) {
ParserOperator unused("Add"); ParserOperator unused("Add");
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/tf_add.pb"; std::string model_file = case_dir + "/origin_models/tf_add.pb";
std::map<ge::AscendString, ge::AscendString> parser_params;
std::map<ge::AscendString, ge::AscendString> parser_params = {
{ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")},
};
ge::Graph graph; ge::Graph graph;
auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
@@ -1043,6 +1046,21 @@ TEST_F(STestTensorflowParser, tensorflow_parser_success) {
EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); EXPECT_EQ(net_out_name.at(0), "add_test_1:0");
} }


TEST_F(STestTensorflowParser, tensorflow_parser_failed_for_input_data_names_error) {
RegisterCustomOp();

std::string case_dir = __FILE__;
ParserOperator unused("Add");
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/tf_add.pb";
std::map<ge::AscendString, ge::AscendString> parser_params = {
{ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_3")},
};
ge::Graph graph;
auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, ge::GRAPH_FAILED);
}

TEST_F(STestTensorflowParser, tensorflow_model_Failed) { TEST_F(STestTensorflowParser, tensorflow_model_Failed) {
ge::Graph graph; ge::Graph graph;
std::string caseDir = __FILE__; std::string caseDir = __FILE__;


+ 11
- 11
tests/ut/parser/CMakeLists.txt View File

@@ -250,17 +250,17 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/common/convert/message2operator.cc" "${PARSER_DIR}/parser/common/convert/message2operator.cc"
"${PARSER_DIR}/parser/common/data_op_parser.cc" "${PARSER_DIR}/parser/common/data_op_parser.cc"
"${PARSER_DIR}/parser/common/model_saver.cc" "${PARSER_DIR}/parser/common/model_saver.cc"
"${PARSER_DIR}/parser/common/op_def/arg_op.cc"
"${PARSER_DIR}/parser/common/op_def/constant_op.cc"
"${PARSER_DIR}/parser/common/op_def/fill_op.cc"
"${PARSER_DIR}/parser/common/op_def/frameworkop_op.cc"
"${PARSER_DIR}/parser/common/op_def/arg_op_operator.cc"
"${PARSER_DIR}/parser/common/op_def/constant_operator.cc"
"${PARSER_DIR}/parser/common/op_def/fill_operator.cc"
"${PARSER_DIR}/parser/common/op_def/framework_op_operator.cc"
"${PARSER_DIR}/parser/common/op_def/ir_pb_converter.cc" "${PARSER_DIR}/parser/common/op_def/ir_pb_converter.cc"
"${PARSER_DIR}/parser/common/op_def/no_op_op.cc"
"${PARSER_DIR}/parser/common/op_def/no_op_operator.cc"
"${PARSER_DIR}/parser/common/op_def/operator.cc" "${PARSER_DIR}/parser/common/op_def/operator.cc"
"${PARSER_DIR}/parser/common/op_def/ref_switch_op.cc"
"${PARSER_DIR}/parser/common/op_def/shape_n_op.cc"
"${PARSER_DIR}/parser/common/op_def/variable_op.cc"
"${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_op.cc"
"${PARSER_DIR}/parser/common/op_def/ref_switch_operator.cc"
"${PARSER_DIR}/parser/common/op_def/shape_n_operator.cc"
"${PARSER_DIR}/parser/common/op_def/variable_operator.cc"
"${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_operator.cc"
"${PARSER_DIR}/parser/common/op_map.cc" "${PARSER_DIR}/parser/common/op_map.cc"
"${PARSER_DIR}/parser/common/op_parser_factory.cc" "${PARSER_DIR}/parser/common/op_parser_factory.cc"
"${PARSER_DIR}/parser/common/parser_api.cc" "${PARSER_DIR}/parser/common/parser_api.cc"
@@ -285,8 +285,8 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/onnx/onnx_util.cc" "${PARSER_DIR}/parser/onnx/onnx_util.cc"
"${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc" "${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc"
"${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc" "${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc"
"${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc"
"${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc"
"${PARSER_DIR}/parser/tensorflow/graph_to_function_def.cc"
"${PARSER_DIR}/parser/tensorflow/parser_graph_optimizer.cc"
"${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc"
"${PARSER_DIR}/parser/tensorflow/scope/scope_pass_manager.cc" "${PARSER_DIR}/parser/tensorflow/scope/scope_pass_manager.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_arg_parser.cc" "${PARSER_DIR}/parser/tensorflow/tensorflow_arg_parser.cc"


+ 11
- 8
tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc View File

@@ -739,6 +739,9 @@ TEST_F(UtestCaffeParser, CaffeModelParser_CustomProtoParse_test)
Status ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto, operators); Status ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto, operators);
EXPECT_EQ(ret, PARAM_INVALID); EXPECT_EQ(ret, PARAM_INVALID);


ret = modelParser.CustomProtoParse("", custom_proto, caffe_proto, operators);
EXPECT_EQ(ret, FAILED);

model_file = case_dir + "/caffe_model/caffe_add.pbtxt"; model_file = case_dir + "/caffe_model/caffe_add.pbtxt";
custom_proto = case_dir + "/../../../../../metadef/proto/caffe/caffe.proto"; custom_proto = case_dir + "/../../../../../metadef/proto/caffe/caffe.proto";
model_path = model_file.c_str(); model_path = model_file.c_str();
@@ -890,7 +893,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_CheckLayersSize_test)
layer->set_name("Abs"); layer->set_name("Abs");
layer->set_type("AbsVal"); layer->set_type("AbsVal");


Status ret = weightParser.CheckLayersSize(layer);
Status ret = weightParser.CheckLayersSize(*layer);
EXPECT_EQ(ret, FAILED); EXPECT_EQ(ret, FAILED);
} }


@@ -902,7 +905,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test)
layer->set_name("Abs"); layer->set_name("Abs");
layer->set_type("AbsVal"); layer->set_type("AbsVal");


Status ret = weightParser.ConvertLayerProto(&net, &net);
Status ret = weightParser.ConvertLayerProto(net, &net);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);


BlobProto* blob = layer->add_blobs(); BlobProto* blob = layer->add_blobs();
@@ -911,16 +914,16 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test)
BlobShape* shap = blob->mutable_shape(); BlobShape* shap = blob->mutable_shape();
shap->add_dim(1); shap->add_dim(1);
shap->add_dim(2); shap->add_dim(2);
ret = weightParser.ConvertBlobsProto(&net, &net);
ret = weightParser.ConvertBlobsProto(net, &net);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);


ret = weightParser.ConvertBlobShapeProto(&net, &net);
ret = weightParser.ConvertBlobShapeProto(net, &net);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);


ret = weightParser.ConvertConvParamProto(&net, &net);
ret = weightParser.ConvertConvParamProto(net, &net);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);


ret = weightParser.ConvertInnerProdcutProto(&net, &net);
ret = weightParser.ConvertInnerProdcutProto(net, &net);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }


@@ -1133,7 +1136,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test)
const google::protobuf::Message *proto = factory.GetPrototype(descriptor); const google::protobuf::Message *proto = factory.GetPrototype(descriptor);
const google::protobuf::Message *message = proto->New(); const google::protobuf::Message *message = proto->New();


Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph);
Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph);
delete message; delete message;
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }
@@ -1163,7 +1166,7 @@ TEST_F(UtestCaffeParser, CaffeModelParser_ParseLayerParameter_test)
google::protobuf::DynamicMessageFactory factory; google::protobuf::DynamicMessageFactory factory;
const google::protobuf::Message *proto = factory.GetPrototype(descriptor); const google::protobuf::Message *proto = factory.GetPrototype(descriptor);
google::protobuf::Message *message = proto->New(); google::protobuf::Message *message = proto->New();
Status ret = modelParser.ParseLayerParameter(descriptor, message, operators);
Status ret = modelParser.ParseLayerParameter(*descriptor, *message, operators);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);


const domi::FrameworkType fmk_type = domi::TENSORFLOW; const domi::FrameworkType fmk_type = domi::TENSORFLOW;


+ 1
- 1
tests/ut/parser/testcase/graph_optimizer_testcase/graph_optimizer_unittest.cc View File

@@ -7,7 +7,7 @@
#include "tensorflow/iterator_fusion_pass.h" #include "tensorflow/iterator_fusion_pass.h"
#include "parser/common/acl_graph_parser_util.h" #include "parser/common/acl_graph_parser_util.h"
#define private public #define private public
#include "tensorflow/graph_optimizer.h"
#include "tensorflow/parser_graph_optimizer.h"
#undef private #undef private
namespace ge { namespace ge {
class UtestGraphOptimizer : public testing::Test { class UtestGraphOptimizer : public testing::Test {


+ 2
- 2
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -381,7 +381,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto)
OnnxFileConstantParser parser; OnnxFileConstantParser parser;
ge::onnx::NodeProto input_node; ge::onnx::NodeProto input_node;
ge::onnx::TensorProto tensor_proto; ge::onnx::TensorProto tensor_proto;
Status ret = parser.GetTensorProto(&input_node, tensor_proto);
Status ret = parser.GetTensorProto(input_node, tensor_proto);
EXPECT_EQ(ret, FAILED); EXPECT_EQ(ret, FAILED);


ge::onnx::AttributeProto *attribute = input_node.add_attribute(); ge::onnx::AttributeProto *attribute = input_node.add_attribute();
@@ -391,7 +391,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto)


ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
*attribute_tensor = tensor_proto; *attribute_tensor = tensor_proto;
ret = parser.GetTensorProto(&input_node, tensor_proto);
ret = parser.GetTensorProto(input_node, tensor_proto);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }




+ 27
- 10
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -38,17 +38,17 @@
#include "tests/depends/ops_stub/ops_stub.h" #include "tests/depends/ops_stub/ops_stub.h"
#include "parser/tensorflow/tensorflow_constant_parser.h" #include "parser/tensorflow/tensorflow_constant_parser.h"
#include "common/types.h" #include "common/types.h"
#include "parser/common/op_def/variable_op.h"
#include "parser/common/op_def/variable_operator.h"
#include "parser/tensorflow/tensorflow_ref_switch_parser.h" #include "parser/tensorflow/tensorflow_ref_switch_parser.h"
#include "parser/tensorflow/tensorflow_fusion_op_parser.h" #include "parser/tensorflow/tensorflow_fusion_op_parser.h"
#include "parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h" #include "parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h"
#include "parser/common/op_def/arg_op.h"
#include "parser/common/op_def/arg_op_operator.h"
#include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h"
#include "parser/tensorflow/tensorflow_reshape_parser.h" #include "parser/tensorflow/tensorflow_reshape_parser.h"
#include "parser/tensorflow/tensorflow_custom_parser_adapter.h" #include "parser/tensorflow/tensorflow_custom_parser_adapter.h"
#include "parser/tensorflow/tensorflow_squeeze_parser.h" #include "parser/tensorflow/tensorflow_squeeze_parser.h"
#include "parser/tensorflow/graph_functiondef.h"
#include "parser/tensorflow/graph_optimizer.h"
#include "parser/tensorflow/graph_to_function_def.h"
#include "parser/tensorflow/parser_graph_optimizer.h"
#include "cce/dnn_base_def.hpp" #include "cce/dnn_base_def.hpp"
#include "parser/tensorflow/scope/scope_pass_manager.h" #include "parser/tensorflow/scope/scope_pass_manager.h"
#include "parser/tensorflow/tensorflow_util.h" #include "parser/tensorflow/tensorflow_util.h"
@@ -56,10 +56,10 @@
#include "parser/tensorflow/tensorflow_enter_parser.h" #include "parser/tensorflow/tensorflow_enter_parser.h"
#include "parser/common/op_def/ir_pb_converter.h" #include "parser/common/op_def/ir_pb_converter.h"
#include "parser/common/tuple.h" #include "parser/common/tuple.h"
#include "common/op_def/frameworkop_op.h"
#include "common/op_def/shape_n_op.h"
#include "common/op_def/var_is_initialized_op_op.h"
#include "common/op_def/fill_op.h"
#include "common/op_def/framework_op_operator.h"
#include "common/op_def/shape_n_operator.h"
#include "common/op_def/var_is_initialized_op_operator.h"
#include "common/op_def/fill_operator.h"
#include "common/convert/pb2json.h" #include "common/convert/pb2json.h"
#include "common/convert/message2operator.h" #include "common/convert/message2operator.h"
#include "parser/common/proto_file_parser.h" #include "parser/common/proto_file_parser.h"
@@ -73,7 +73,7 @@
#include "parser/common/prototype_pass_manager.h" #include "parser/common/prototype_pass_manager.h"
#include "parser/common/register_tbe.h" #include "parser/common/register_tbe.h"
#include "parser/common/pass_manager.h" #include "parser/common/pass_manager.h"
#include "parser/tensorflow/graph_optimizer.h"
#include "parser/tensorflow/parser_graph_optimizer.h"
#include "metadef/inc/register/scope/scope_pass_registry_impl.h" #include "metadef/inc/register/scope/scope_pass_registry_impl.h"
#include "register/scope/scope_fusion_pass_register.h" #include "register/scope/scope_fusion_pass_register.h"
#include "common/op_map.h" #include "common/op_map.h"
@@ -1032,7 +1032,9 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) {
ParserOperator unused("Add"); ParserOperator unused("Add");
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/tensorflow_model/tf_add.pb"; std::string model_file = case_dir + "/tensorflow_model/tf_add.pb";
std::map<ge::AscendString, ge::AscendString> parser_params;
std::map<ge::AscendString, ge::AscendString> parser_params = {
{ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder,Placeholder_1")},
};
ge::Graph graph; ge::Graph graph;
auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
@@ -1046,6 +1048,21 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) {
EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); EXPECT_EQ(net_out_name.at(0), "add_test_1:0");
} }


TEST_F(UtestTensorflowParser, tensorflow_parser_input_data_names_failed) {
RegisterCustomOp();

std::string case_dir = __FILE__;
ParserOperator unused("Add");
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/tensorflow_model/tf_add.pb";
std::map<ge::AscendString, ge::AscendString> parser_params = {
{ge::AscendString(ge::ir_option::INPUT_DATA_NAMES), ge::AscendString("Placeholder_1,Placeholder_2")},
};
ge::Graph graph;
auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, ge::GRAPH_FAILED);
}

TEST_F(UtestTensorflowParser, tensorflow_model_Failed) { TEST_F(UtestTensorflowParser, tensorflow_model_Failed) {
ge::Graph graph; ge::Graph graph;
std::string caseDir = __FILE__; std::string caseDir = __FILE__;


Loading…
Cancel
Save