Browse Source

!555 sync master to ge_dev

Merge pull request !555 from 王涛/master
pull/546/MERGE
王涛 Gitee 3 years ago
parent
commit
a8893867b0
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
82 changed files with 937 additions and 899 deletions
  1. +5
    -0
      OWNERS
  2. +1
    -1
      metadef
  3. +2
    -2
      parser/caffe/caffe_data_parser.cc
  4. +6
    -5
      parser/caffe/caffe_op_parser.cc
  5. +128
    -112
      parser/caffe/caffe_parser.cc
  6. +26
    -0
      parser/caffe/caffe_parser.h
  7. +1
    -1
      parser/caffe/caffe_reshape_parser.cc
  8. +121
    -84
      parser/common/acl_graph_parser_util.cc
  9. +1
    -1
      parser/common/auto_mapping_subgraph_io_index_func.cc
  10. +14
    -8
      parser/common/convert/pb2json.cc
  11. +3
    -3
      parser/common/convert/pb2json.h
  12. +2
    -2
      parser/common/data_op_parser.h
  13. +5
    -3
      parser/common/model_saver.cc
  14. +2
    -2
      parser/common/op_def/constant_op.cc
  15. +3
    -3
      parser/common/op_def/ir_pb_converter.cc
  16. +1
    -1
      parser/common/op_def/ref_switch_op.cc
  17. +4
    -4
      parser/common/op_def/shape_n_op.cc
  18. +5
    -5
      parser/common/op_parser_factory.cc
  19. +0
    -61
      parser/common/op_types.h
  20. +10
    -0
      parser/common/parser_factory.cc
  21. +56
    -55
      parser/common/parser_fp16_t.cc
  22. +16
    -16
      parser/common/parser_fp16_t.h
  23. +1
    -1
      parser/common/pass_manager.h
  24. +7
    -5
      parser/common/pre_checker.cc
  25. +1
    -1
      parser/common/pre_checker.h
  26. +4
    -2
      parser/common/proto_file_parser.cc
  27. +3
    -2
      parser/common/register_tbe.cc
  28. +1
    -1
      parser/common/tbe_plugin_loader.cc
  29. +13
    -13
      parser/func_to_graph/func2graph.py
  30. +2
    -2
      parser/onnx/onnx_constant_parser.cc
  31. +1
    -1
      parser/onnx/onnx_constant_parser.h
  32. +32
    -17
      parser/onnx/onnx_parser.cc
  33. +28
    -3
      parser/onnx/onnx_parser.h
  34. +4
    -0
      parser/onnx/onnx_util.cc
  35. +1
    -0
      parser/onnx/onnx_util.h
  36. +16
    -9
      parser/onnx/subgraph_adapter/if_subgraph_adapter.cc
  37. +5
    -3
      parser/onnx/subgraph_adapter/if_subgraph_adapter.h
  38. +3
    -1
      parser/onnx/subgraph_adapter/subgraph_adapter.h
  39. +1
    -1
      parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc
  40. +13
    -17
      parser/tensorflow/graph_functiondef.cc
  41. +1
    -1
      parser/tensorflow/graph_functiondef.h
  42. +28
    -18
      parser/tensorflow/graph_optimizer.cc
  43. +1
    -1
      parser/tensorflow/tensorflow_arg_parser.cc
  44. +6
    -3
      parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc
  45. +1
    -1
      parser/tensorflow/tensorflow_custom_parser_adapter.cc
  46. +3
    -3
      parser/tensorflow/tensorflow_data_parser.cc
  47. +1
    -1
      parser/tensorflow/tensorflow_enter_parser.cc
  48. +1
    -1
      parser/tensorflow/tensorflow_fill_parser.cc
  49. +2
    -2
      parser/tensorflow/tensorflow_frameworkop_parser.cc
  50. +1
    -1
      parser/tensorflow/tensorflow_fusion_op_parser.h
  51. +1
    -1
      parser/tensorflow/tensorflow_merge_parser.cc
  52. +1
    -1
      parser/tensorflow/tensorflow_no_op_parser.cc
  53. +0
    -19
      parser/tensorflow/tensorflow_op_parser.h
  54. +141
    -221
      parser/tensorflow/tensorflow_parser.cc
  55. +32
    -46
      parser/tensorflow/tensorflow_parser.h
  56. +0
    -2
      parser/tensorflow/tensorflow_parser_register.h
  57. +0
    -2
      parser/tensorflow/tensorflow_ref_switch_parser.h
  58. +1
    -1
      parser/tensorflow/tensorflow_reshape_parser.cc
  59. +1
    -1
      parser/tensorflow/tensorflow_reshape_parser.h
  60. +3
    -3
      parser/tensorflow/tensorflow_shape_n_parser.cc
  61. +2
    -4
      parser/tensorflow/tensorflow_shape_n_parser.h
  62. +1
    -1
      parser/tensorflow/tensorflow_squeeze_parser.cc
  63. +1
    -1
      parser/tensorflow/tensorflow_squeeze_parser.h
  64. +5
    -2
      parser/tensorflow/tensorflow_util.cc
  65. +1
    -9
      parser/tensorflow/tensorflow_util.h
  66. +1
    -1
      parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc
  67. +1
    -2
      parser/tensorflow/tensorflow_variable_v2_parser.cc
  68. +2
    -2
      tests/depends/graph/src/attr_util_stub.cc
  69. +11
    -0
      tests/depends/mmpa/src/mmpa_stub.cc
  70. +6
    -1
      tests/st/CMakeLists.txt
  71. +13
    -1
      tests/st/parser_st_utils.cc
  72. +1
    -0
      tests/st/parser_st_utils.h
  73. +25
    -1
      tests/st/testcase/test_caffe_parser.cc
  74. +2
    -1
      tests/st/testcase/test_onnx_parser.cc
  75. +17
    -47
      tests/st/testcase/test_tensorflow_parser.cc
  76. +6
    -1
      tests/ut/parser/CMakeLists.txt
  77. +13
    -0
      tests/ut/parser/parser_ut_utils.cc
  78. +1
    -0
      tests/ut/parser/parser_ut_utils.h
  79. +25
    -1
      tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc
  80. +12
    -0
      tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc
  81. +2
    -1
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc
  82. +16
    -46
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 5
- 0
OWNERS View File

@@ -3,6 +3,11 @@ approvers:
- wqtshg - wqtshg
- ljl0711 - ljl0711
- liu-jisheng - liu-jisheng
- zhangfan_hq
- lipeiyang3699
reviewers: reviewers:
- xchu42 - xchu42
- sheng-nan - sheng-nan
- tangqunzhang
- wangxiaotian22
- stevenaw

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 0a2335712484f85cd44a0f2402eac6932b22b40a
Subproject commit 8fb59a00c6291207f3491fee0c4064efff94d79f

+ 2
- 2
parser/caffe/caffe_data_parser.cc View File

@@ -94,7 +94,7 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l
const ge::ParserContext &ctx = GetParserContext(); const ge::ParserContext &ctx = GetParserContext();
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
string name = layer->name(); string name = layer->name();
auto search = input_dims.find(name);
std::map<std::string, std::vector<int64_t>>::const_iterator search = input_dims.find(name);
if (search == input_dims.end()) { if (search == input_dims.end()) {
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()})); REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()}));
GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user "
@@ -139,7 +139,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete
const ge::ParserContext &ctx = GetParserContext(); const ge::ParserContext &ctx = GetParserContext();
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
string name = layer->name(); string name = layer->name();
auto search = input_dims.find(name);
std::map<std::string, std::vector<int64_t>>::const_iterator search = input_dims.find(name);
if (search == input_dims.end()) { if (search == input_dims.end()) {
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()})); REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()}));
GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user "


+ 6
- 5
parser/caffe/caffe_op_parser.cc View File

@@ -19,6 +19,7 @@
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#include "common/util/error_manager/error_manager.h" #include "common/util/error_manager/error_manager.h"
#include "framework/omg/parser/parser_types.h" #include "framework/omg/parser/parser_types.h"
#include "graph/def_types.h"


using namespace ge::parser; using namespace ge::parser;
using domi::caffe::BlobProto; using domi::caffe::BlobProto;
@@ -107,7 +108,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
buf[i] = proto.double_data(i); buf[i] = proto.double_data(i);
} }
GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast<uint8_t *>(buf.get()), size * sizeof(float)) != ge::GRAPH_SUCCESS,
GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<float, uint8_t>(buf.get()), size * sizeof(float)) != ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return GELOGW("SetData failed for GeTensor.");); // no need to return
} else if (proto.int8_data().length() > 0) { } else if (proto.int8_data().length() > 0) {
if (size != static_cast<int>(proto.int8_data().length())) { if (size != static_cast<int>(proto.int8_data().length())) {
@@ -121,7 +122,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape
const char *data_ptr = proto.int8_data().data(); const char *data_ptr = proto.int8_data().data();
GE_CHECK_NOTNULL(data_ptr); GE_CHECK_NOTNULL(data_ptr);
GE_IF_BOOL_EXEC( GE_IF_BOOL_EXEC(
weight->SetData(reinterpret_cast<const uint8_t *>(data_ptr), size * sizeof(int8_t)) != ge::GRAPH_SUCCESS,
weight->SetData(PtrToPtr<const char, const uint8_t>(data_ptr), size * sizeof(int8_t)) != ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return GELOGW("SetData failed for GeTensor.");); // no need to return
dtype = ge::DT_INT8; dtype = ge::DT_INT8;
} else if (proto.int32_data_size() > 0) { } else if (proto.int32_data_size() > 0) {
@@ -139,7 +140,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape
int32_weight_buf[i] = proto.int32_data(i); int32_weight_buf[i] = proto.int32_data(i);
} }
GE_IF_BOOL_EXEC( GE_IF_BOOL_EXEC(
weight->SetData(reinterpret_cast<uint8_t *>(int32_weight_buf.get()), size * sizeof(int32_t)) != ge::GRAPH_SUCCESS,
weight->SetData(PtrToPtr<int32_t, uint8_t>(int32_weight_buf.get()), size * sizeof(int32_t)) != ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return GELOGW("SetData failed for GeTensor.");); // no need to return
dtype = ge::DT_INT32; dtype = ge::DT_INT32;
} else if (proto.uint64_data_size() > 0) { } else if (proto.uint64_data_size() > 0) {
@@ -156,7 +157,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
uint64_weight_buf[i] = proto.uint64_data(i); uint64_weight_buf[i] = proto.uint64_data(i);
} }
GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast<uint8_t *>(uint64_weight_buf.get()), size * sizeof(uint64_t)) !=
GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<uint64_t, uint8_t>(uint64_weight_buf.get()), size * sizeof(uint64_t)) !=
ge::GRAPH_SUCCESS, ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return GELOGW("SetData failed for GeTensor.");); // no need to return
dtype = ge::DT_UINT64; dtype = ge::DT_UINT64;
@@ -173,7 +174,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape
const float *data_ptr = proto.data().data(); const float *data_ptr = proto.data().data();
GE_CHECK_NOTNULL(data_ptr); GE_CHECK_NOTNULL(data_ptr);
GE_IF_BOOL_EXEC( GE_IF_BOOL_EXEC(
weight->SetData(reinterpret_cast<const uint8_t *>(data_ptr), size * sizeof(float)) != ge::GRAPH_SUCCESS,
weight->SetData(PtrToPtr<const float, const uint8_t>(data_ptr), size * sizeof(float)) != ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return GELOGW("SetData failed for GeTensor.");); // no need to return
} }
ge::GeTensorDesc weight_desc = ge::GeTensorDesc(); ge::GeTensorDesc weight_desc = ge::GeTensorDesc();


+ 128
- 112
parser/caffe/caffe_parser.cc View File

@@ -45,7 +45,6 @@
#include "parser/caffe/caffe_custom_parser_adapter.h" #include "parser/caffe/caffe_custom_parser_adapter.h"
#include "parser/caffe/caffe_op_parser.h" #include "parser/caffe/caffe_op_parser.h"
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#include "parser/common/pre_checker.h"
#include "parser/common/prototype_pass_manager.h" #include "parser/common/prototype_pass_manager.h"
#include "framework/omg/parser/parser_types.h" #include "framework/omg/parser/parser_types.h"
#include "parser/common/model_saver.h" #include "parser/common/model_saver.h"
@@ -61,13 +60,7 @@ using domi::caffe::InnerProductParameter;
using domi::caffe::LayerParameter; using domi::caffe::LayerParameter;
using domi::caffe::NetParameter; using domi::caffe::NetParameter;
using domi::ParseParamByOpFunc; using domi::ParseParamByOpFunc;
using ge::caffe_op_map;
using ge::CaffeOpParser;
using ge::parser::ModelSaver; using ge::parser::ModelSaver;
using ge::OpParser;
using ge::OpParserFactory;
using ge::Pb2Json;
using ge::PreChecker;
using std::ifstream; using std::ifstream;


#define CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(val, errormsg) \ #define CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(val, errormsg) \
@@ -299,16 +292,17 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!");
return FAILED; return FAILED;
} }
int input_dim_size = proto_message.input_dim_size();


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((input_dim_size / proto_message.input_size() != parser::DIM_DEFAULT_SIZE ||
input_dim_size % proto_message.input_size() != 0),
ErrorManager::GetInstance().ATCReportErrMessage(
"E11003", {"input_dim_size", "input_size"},
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())});
return FAILED,
"[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].",
input_dim_size, proto_message.input_size())
const int32_t input_dim_size = proto_message.input_dim_size();
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) ||
((input_dim_size % proto_message.input_size()) != 0));
if (is_input_invalid) {
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"},
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())});
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].",
input_dim_size, proto_message.input_size());
return FAILED;
}


for (int i = 0; i < proto_message.input_size(); i++) { for (int i = 0; i < proto_message.input_size(); i++) {
domi::caffe::LayerParameter *layer = proto_message.add_layer(); domi::caffe::LayerParameter *layer = proto_message.add_layer();
@@ -329,12 +323,14 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo
input_data_flag = true; input_data_flag = true;
} }
} else if (proto_message.input_shape_size() > 0) { } else if (proto_message.input_shape_size() > 0) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto_message.input_shape_size() != proto_message.input_size(),
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"},
{std::to_string(proto_message.input_shape_size()),
std::to_string(proto_message.input_size())});
return FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).",
proto_message.input_shape_size(), proto_message.input_size());
if (proto_message.input_shape_size() != proto_message.input_size()) {
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"},
{std::to_string(proto_message.input_shape_size()),
std::to_string(proto_message.input_size())});
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).",
proto_message.input_shape_size(), proto_message.input_size());
return FAILED;
}


for (int i = 0; i < proto_message.input_size(); i++) { for (int i = 0; i < proto_message.input_size(); i++) {
int dim_size = proto_message.input_shape(i).dim_size(); int dim_size = proto_message.input_shape(i).dim_size();
@@ -755,7 +751,8 @@ Status CaffeModelParser::GetCustomOp(const domi::caffe::LayerParameter &layer, v
} }


if (is_search_built_in_layer) { if (is_search_built_in_layer) {
const google::protobuf::Message *layer_message = reinterpret_cast<const google::protobuf::Message *>(&layer);
const google::protobuf::Message *layer_message = PtrToPtr<const domi::caffe::LayerParameter,
const google::protobuf::Message>(&layer);
Status status = CreateCustomOperator(op_name, op_type, layer_message, 0, operators); Status status = CreateCustomOperator(op_name, op_type, layer_message, 0, operators);
if (status != SUCCESS || operators.empty()) { if (status != SUCCESS || operators.empty()) {
GELOGE(status, "[Create][CustomOperator] failed, name: %s, type: %s.", op_name.c_str(), op_type.c_str()); GELOGE(status, "[Create][CustomOperator] failed, name: %s, type: %s.", op_name.c_str(), op_type.c_str());
@@ -838,11 +835,11 @@ Status CaffeModelParser::AddNode(const domi::caffe::LayerParameter &layer, ge::C
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::CAFFE); std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::CAFFE);
GE_CHECK_NOTNULL(factory); GE_CHECK_NOTNULL(factory);
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_parser == nullptr,
ErrorManager::GetInstance().ATCReportErrMessage("E11009", {"opname", "optype"},
{layer.name(), op_type});
return FAILED, "op_parser is null, op_type: %s.",
op_type.c_str());
if (op_parser == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E11009", {"opname", "optype"}, {layer.name(), op_type});
GELOGE(FAILED, "op_parser is null, op_type: %s.", op_type.c_str());
return FAILED;
}


ge::OpDescPtr op; ge::OpDescPtr op;
// Process change of tensordesc initialization of opdesc, // Process change of tensordesc initialization of opdesc,
@@ -994,7 +991,7 @@ Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const
GELOGI("op [%s], type[%s], update output(%d) with name %s %s", GELOGI("op [%s], type[%s], update output(%d) with name %s %s",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(),
i, op_desc->GetOutputNameByIndex(i).c_str(), i, op_desc->GetOutputNameByIndex(i).c_str(),
ret == ge::GRAPH_SUCCESS ? "success" : "failed");
ret == ge::GRAPH_SUCCESS ? "success" : "not success");
} }
} }
return SUCCESS; return SUCCESS;
@@ -1025,7 +1022,8 @@ Status CaffeModelParser::AddEdges(ge::ComputeGraphPtr &graph) {
// Find the layer for this output // Find the layer for this output
auto top_node_iter = node_map.find(top_blob_layer_pair.first); auto top_node_iter = node_map.find(top_blob_layer_pair.first);
// Find the layer for this input // Find the layer for this input
auto bottom_node_iter = node_map.find(bottom_blob_layer_pair.first);
std::map<std::string, ge::NodePtr>::const_iterator bottom_node_iter =
node_map.find(bottom_blob_layer_pair.first);
if (top_node_iter != node_map.end() && bottom_node_iter != node_map.end()) { if (top_node_iter != node_map.end() && bottom_node_iter != node_map.end()) {
// Output node top_node_iter->second, // Output node top_node_iter->second,
// Output index top_blob_layer_pair.second // Output index top_blob_layer_pair.second
@@ -1057,7 +1055,7 @@ Status CaffeModelParser::AddEdges(ge::ComputeGraphPtr &graph) {
{top_blob_layer_pair.first}); {top_blob_layer_pair.first});
GELOGE(INTERNAL_ERROR, "[Find][TopLayer] %s failed.", top_blob_layer_pair.first.c_str()); GELOGE(INTERNAL_ERROR, "[Find][TopLayer] %s failed.", top_blob_layer_pair.first.c_str());
return ge::FAILED;) return ge::FAILED;)
GE_IF_BOOL_EXEC(top_node_iter == node_map.end(),
GE_IF_BOOL_EXEC(bottom_node_iter == node_map.end(),
ErrorManager::GetInstance().ATCReportErrMessage("E11015", {"opname"}, ErrorManager::GetInstance().ATCReportErrMessage("E11015", {"opname"},
{bottom_blob_layer_pair.first}); {bottom_blob_layer_pair.first});
GELOGE(INTERNAL_ERROR, "[Find][BottomLayer] %s failed.", bottom_blob_layer_pair.first.c_str()); GELOGE(INTERNAL_ERROR, "[Find][BottomLayer] %s failed.", bottom_blob_layer_pair.first.c_str());
@@ -1095,7 +1093,7 @@ Status CaffeModelParser::AddUserOutNodesTop() {
const std::vector<std::pair<std::string, int32_t>> &user_out_nodes = ge::GetParserContext().user_out_nodes; const std::vector<std::pair<std::string, int32_t>> &user_out_nodes = ge::GetParserContext().user_out_nodes;
int net_output_num = user_out_nodes.size(); int net_output_num = user_out_nodes.size();
for (const auto &out_pair : user_out_nodes) { for (const auto &out_pair : user_out_nodes) {
auto layer_iter = layer_tops_map_.find(out_pair.first);
std::map<std::string, std::vector<std::string>>::const_iterator layer_iter = layer_tops_map_.find(out_pair.first);
GELOGI("Add to output, node name: %s", out_pair.first.c_str()); GELOGI("Add to output, node name: %s", out_pair.first.c_str());
if (layer_iter != layer_tops_map_.end()) { if (layer_iter != layer_tops_map_.end()) {
if (static_cast<uint32_t>(out_pair.second) >= (layer_iter->second).size()) { if (static_cast<uint32_t>(out_pair.second) >= (layer_iter->second).size()) {
@@ -1110,7 +1108,7 @@ Status CaffeModelParser::AddUserOutNodesTop() {
} }


string top_name = layer_iter->second[out_pair.second]; string top_name = layer_iter->second[out_pair.second];
auto top_node_iter = node_map.find(out_pair.first);
std::map<std::string, ge::NodePtr>::const_iterator top_node_iter = node_map.find(out_pair.first);
if (top_node_iter != node_map.end()) { if (top_node_iter != node_map.end()) {
ge::GetParserContext().out_tensor_names.push_back(top_name); ge::GetParserContext().out_tensor_names.push_back(top_name);
GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str()); GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str());
@@ -1142,7 +1140,8 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes
top = RemapTopNameByLayer(layer, top, i); top = RemapTopNameByLayer(layer, top, i);
} }


auto t_iter = top_blobs_map_.find(top);
std::map<std::string, std::vector<std::pair<std::string, int32_t>>>::const_iterator t_iter =
top_blobs_map_.find(top);


GE_RETURN_WITH_LOG_IF_FALSE(t_iter != top_blobs_map_.end(), GE_RETURN_WITH_LOG_IF_FALSE(t_iter != top_blobs_map_.end(),
"[Check][Param]Failed to find top: %s, layer name:%s", top.c_str(), "[Check][Param]Failed to find top: %s, layer name:%s", top.c_str(),
@@ -1156,7 +1155,7 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes


// If not found, add to the output side of the output // If not found, add to the output side of the output
// Find the layer for this output // Find the layer for this output
auto top_node_iter = node_map.find(layer.name());
std::map<std::string, ge::NodePtr>::const_iterator top_node_iter = node_map.find(layer.name());
GELOGI("output in top_blob: %s", layer.name().c_str()); GELOGI("output in top_blob: %s", layer.name().c_str());
if (top_node_iter != node_map.end()) { if (top_node_iter != node_map.end()) {
ge::GetParserContext().out_tensor_names.push_back(top_origin); ge::GetParserContext().out_tensor_names.push_back(top_origin);
@@ -1226,12 +1225,14 @@ Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) {
GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&layer, layer.name(), layer.type()), GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&layer, layer.name(), layer.type()),
"[Invoke][AddOp]Add layer to PreChecker failed, layer name: %s.", "[Invoke][AddOp]Add layer to PreChecker failed, layer name: %s.",
layer.name().c_str()); layer.name().c_str());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckName(&layer) != SUCCESS, return FAILED,
"[Invoke][CheckName]Check op[%s] failed, name repeat in caffe prototxt.",
layer.name().c_str());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckType(&layer) != SUCCESS, return FAILED,
"[Invoke][CheckType]Check op[%s]'s optype failed, type is not supported.",
layer.name().c_str());
if (PreChecker::Instance().CheckName(&layer) != SUCCESS) {
GELOGE(FAILED, "[Invoke][CheckName]Check op[%s] failed, name repeat in caffe prototxt.", layer.name().c_str());
return FAILED;
}
if (PreChecker::Instance().CheckType(&layer) != SUCCESS) {
GELOGE(FAILED, "[Invoke][CheckType]Check op[%s]'s optype failed, type is not supported.", layer.name().c_str());
return FAILED;
}
} }


return SUCCESS; return SUCCESS;
@@ -1290,9 +1291,11 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co
for (int32_t layer_index = 0; layer_index < layer_count; ++layer_index) { for (int32_t layer_index = 0; layer_index < layer_count; ++layer_index) {
domi::caffe::LayerParameter &layer = const_cast<domi::caffe::LayerParameter &>(proto_message.layer(layer_index)); domi::caffe::LayerParameter &layer = const_cast<domi::caffe::LayerParameter &>(proto_message.layer(layer_index));


GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue,
"[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.",
layer.name().c_str(), layer.type().c_str());
if (!CheckValidLayer(layer)) {
GELOGI("[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.",
layer.name().c_str(), layer.type().c_str());
continue;
}


CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && input_data_flag), has_error = true; CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && input_data_flag), has_error = true;
REPORT_INNER_ERROR("E19999", "net %s has input and data layer simultaneously, check invalid." REPORT_INNER_ERROR("E19999", "net %s has input and data layer simultaneously, check invalid."
@@ -1392,7 +1395,7 @@ void CaffeModelParser::SaveOrigionLayerTops(domi::caffe::LayerParameter &layer)
for (auto top : layer.top()) { for (auto top : layer.top()) {
tops.push_back(top); tops.push_back(top);
} }
auto it = layer_tops_map_.find(name);
std::map<std::string, std::vector<std::string>>::const_iterator it = layer_tops_map_.find(name);
if (it == layer_tops_map_.end()) { if (it == layer_tops_map_.end()) {
layer_tops_map_[name] = tops; layer_tops_map_[name] = tops;
} }
@@ -1431,11 +1434,23 @@ Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &la
return SUCCESS; return SUCCESS;
} }


Status CaffeModelParser::ReportLayerInvalid(const domi::caffe::NetParameter &proto, const std::string &path) const {
if (proto.layers_size() > 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E11021", {"realpath"}, {path});
GELOGE(FAILED, "[Check][Size]The model file[%s] is consisted of layers-structure which is deprecated in Caffe "
"and unsupported in ATC. The \"layers\" should be changed to \"layer\".", path.c_str());
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E11022");
GELOGE(FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid.");
}
return FAILED;
}

Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &graph) { Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &graph) {
bool has_error = false; bool has_error = false;
GE_CHECK_NOTNULL(model_path); GE_CHECK_NOTNULL(model_path);
GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(graph);
GELOGI("Caffe Parse model file %s", model_path);
GELOGI("Caffe Parse model file [%s]", model_path);


PreChecker::Instance().Clear(); PreChecker::Instance().Clear();


@@ -1450,22 +1465,12 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap
// parse network model by custom proto and get custom operators // parse network model by custom proto and get custom operators
string custom_proto_path = ge::GetParserContext().custom_proto_path + "custom.proto"; string custom_proto_path = ge::GetParserContext().custom_proto_path + "custom.proto";
string caffe_proto_path = ge::GetParserContext().caffe_proto_path + "caffe.proto"; string caffe_proto_path = ge::GetParserContext().caffe_proto_path + "caffe.proto";
Status result = CustomProtoParse(model_path, custom_proto_path, caffe_proto_path, custom_operator_);
if (result != SUCCESS) {
GELOGE(FAILED, "[Parse][Model] by custom proto failed, model path: %s.", model_path);
return FAILED;
}
GE_CHK_STATUS(CustomProtoParse(model_path, custom_proto_path, caffe_proto_path, custom_operator_),
"[Parse][Model] by custom proto failed, model path: %s.", model_path);


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
proto_message.layer_size() == 0 && proto_message.layers_size() > 0,
ErrorManager::GetInstance().ATCReportErrMessage("E11021", {"realpath"}, {model_path});
return FAILED,
"[Check][Size]The model file[%s] is consisted of layers-structure which is deprecated in Caffe "
"and unsupported in ATC. The \"layers\" should be changed to \"layer\".",
model_path);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto_message.layer_size() == 0),
ErrorManager::GetInstance().ATCReportErrMessage("E11022");
return FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid.");
if (proto_message.layer_size() == 0) {
return ReportLayerInvalid(proto_message, model_path);
}


GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE),
"Run ProtoType Pass Failed"); "Run ProtoType Pass Failed");
@@ -1476,8 +1481,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap
GE_RETURN_IF_ERROR(PreCheck(proto_message)); GE_RETURN_IF_ERROR(PreCheck(proto_message));


if (PreChecker::Instance().HasError()) { if (PreChecker::Instance().HasError()) {
REPORT_INNER_ERROR("E19999", "Precheck failed. Please read check report.");
GELOGE(INTERNAL_ERROR, "[Has][Error]Precheck failed. Please read check report.");
REPORT_INNER_ERROR("E19999", "Precheck failed. a report of json format will be create, Please read it.");
GELOGE(INTERNAL_ERROR, "[Has][Error]Precheck failed. a report of json format will be create, Please read it.");
return FAILED; return FAILED;
} }


@@ -1512,9 +1517,11 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap
for (int32_t layer_index = 0; layer_index < layer_count; ++layer_index) { for (int32_t layer_index = 0; layer_index < layer_count; ++layer_index) {
domi::caffe::LayerParameter &layer = const_cast<domi::caffe::LayerParameter &>(proto_message.layer(layer_index)); domi::caffe::LayerParameter &layer = const_cast<domi::caffe::LayerParameter &>(proto_message.layer(layer_index));
SaveOrigionLayerTops(layer); SaveOrigionLayerTops(layer);
GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue,
"[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.",
layer.name().c_str(), layer.type().c_str());
if (!CheckValidLayer(layer)) {
GELOGI("[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.",
layer.name().c_str(), layer.type().c_str());
continue;
}


CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && input_data_flag), has_error = true; CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && input_data_flag), has_error = true;
GELOGE(FAILED, "[Check][Layer]net %s has input and data layer simultaneously, check invalid." GELOGE(FAILED, "[Check][Layer]net %s has input and data layer simultaneously, check invalid."
@@ -1679,7 +1686,7 @@ Status CaffeWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge::


// Resolve proto file to netparameter // Resolve proto file to netparameter
NetParameter proto; NetParameter proto;
bool success = ge::parser::ReadProtoFromArray(reinterpret_cast<const char *>(data), static_cast<int>(size), &proto);
bool success = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &proto);
if (!success) { if (!success) {
REPORT_CALL_ERROR("E19999", "ReadProtoFromArray failed."); REPORT_CALL_ERROR("E19999", "ReadProtoFromArray failed.");
GELOGE(domi::PARSE_WEIGHTS_FAILED, "[Read][Proto] from Memory fail"); GELOGE(domi::PARSE_WEIGHTS_FAILED, "[Read][Proto] from Memory fail");
@@ -1920,7 +1927,7 @@ Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *r
const google::protobuf::FieldDescriptor *field, const google::protobuf::FieldDescriptor *field,
google::protobuf::Message *layer) { google::protobuf::Message *layer) {
GELOGD("Start to parse field: %s.", field->name().c_str()); GELOGD("Start to parse field: %s.", field->name().c_str());
domi::caffe::LayerParameter *layer_proto = reinterpret_cast<domi::caffe::LayerParameter *>(layer);
domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer);
string filed_name = field->name(); string filed_name = field->name();
#define CASE_FIELD_NAME(kName, method) \ #define CASE_FIELD_NAME(kName, method) \
if (filed_name == kField##kName) { \ if (filed_name == kField##kName) { \
@@ -1975,8 +1982,7 @@ Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *me
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc; vector<const google::protobuf::FieldDescriptor *> field_desc;
blobs_reflection->ListFields(*message, &field_desc); blobs_reflection->ListFields(*message, &field_desc);

domi::caffe::BlobProto *blobs_proto = reinterpret_cast<domi::caffe::BlobProto *>(blobs);
domi::caffe::BlobProto *blobs_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobProto>(blobs);


for (auto &field : field_desc) { for (auto &field : field_desc) {
GE_CHECK_NOTNULL(field); GE_CHECK_NOTNULL(field);
@@ -2025,7 +2031,7 @@ Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message
vector<const google::protobuf::FieldDescriptor *> field_desc; vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc); reflection->ListFields(*message, &field_desc);


domi::caffe::BlobShape *shape_proto = reinterpret_cast<domi::caffe::BlobShape *>(dest_message);
domi::caffe::BlobShape *shape_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobShape>(dest_message);


for (auto &field : field_desc) { for (auto &field : field_desc) {
if (field->name() != kFieldDim) { if (field->name() != kFieldDim) {
@@ -2048,7 +2054,7 @@ Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message
reflection->ListFields(*message, &field_desc); reflection->ListFields(*message, &field_desc);


domi::caffe::ConvolutionParameter *conv_param_proto = domi::caffe::ConvolutionParameter *conv_param_proto =
reinterpret_cast<domi::caffe::ConvolutionParameter *>(dest_message);
PtrToPtr<google::protobuf::Message, domi::caffe::ConvolutionParameter>(dest_message);


for (auto &field : field_desc) { for (auto &field : field_desc) {
if (field->name() != kFieldBiasTerm) { if (field->name() != kFieldBiasTerm) {
@@ -2068,7 +2074,7 @@ Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Mess
reflection->ListFields(*message, &field_desc); reflection->ListFields(*message, &field_desc);


domi::caffe::InnerProductParameter *inner_product_proto = domi::caffe::InnerProductParameter *inner_product_proto =
reinterpret_cast<domi::caffe::InnerProductParameter *>(dest_message);
PtrToPtr<google::protobuf::Message, domi::caffe::InnerProductParameter>(dest_message);


for (auto &field : field_desc) { for (auto &field : field_desc) {
if (field->name() != kFieldBiasTerm) { if (field->name() != kFieldBiasTerm) {
@@ -2110,22 +2116,25 @@ Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *mess
} }
} }


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(num_layer == 0 && num_layers > 0,
ErrorManager::GetInstance().ATCReportErrMessage("E11023");
return FAILED,
"[Check][Param]The weight file is consisted of layers-structure which is deprecated "
"in Caffe and unsupported in ATC. The \"layers\" should be changed to \"layer\".");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((num_layer == 0), ErrorManager::GetInstance().ATCReportErrMessage("E11024");
return FAILED,
"[Check][Param] Weight layer num is zero, weight file may be invalid.");

if (num_layer == 0 && num_layers > 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E11023");
GELOGE(FAILED, "[Check][Param]The weight file is consisted of layers-structure which is deprecated "
"in Caffe and unsupported in ATC. The \"layers\" should be changed to \"layer\".");
return FAILED;
}
if (num_layer == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E11024");
GELOGE(FAILED, "[Check][Param] Weight layer num is zero, weight file may be invalid.");
return FAILED;
}
return SUCCESS; return SUCCESS;
} }


Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message *layer_message, Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message *layer_message,
ge::ComputeGraphPtr &graph) { ge::ComputeGraphPtr &graph) {
vector<string> need_share_layers; vector<string> need_share_layers;
const domi::caffe::LayerParameter *layer = reinterpret_cast<const domi::caffe::LayerParameter *>(layer_message);
const domi::caffe::LayerParameter *layer =
PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer_message);
const string &shared_layer_name = layer->name(); const string &shared_layer_name = layer->name();
const string &layer_type = layer->type(); const string &layer_type = layer->type();
for (auto p_iter = params_share_map.begin(); p_iter != params_share_map.end(); ++p_iter) { for (auto p_iter = params_share_map.begin(); p_iter != params_share_map.end(); ++p_iter) {
@@ -2159,7 +2168,7 @@ Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message
} }


// The weight processing also needs to judge the duplicate operator, which is reserved here and processed later. // The weight processing also needs to judge the duplicate operator, which is reserved here and processed later.
auto iter = caffe_op_map.find(layer_type);
std::map<std::string, std::string>::const_iterator iter = caffe_op_map.find(layer_type);
if (iter == caffe_op_map.end()) { if (iter == caffe_op_map.end()) {
GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer_type.c_str(), layer_name.c_str()); GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer_type.c_str(), layer_name.c_str());
continue; continue;
@@ -2172,20 +2181,20 @@ Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message
GE_CHECK_NOTNULL(factory); GE_CHECK_NOTNULL(factory);
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type);


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
(op_parser.get() == nullptr),
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}),
std::vector<std::string>({layer_name, op_type}));
return FAILED,
"[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str());
if (op_parser.get() == nullptr) {
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}),
std::vector<std::string>({layer_name, op_type}));
GELOGE(FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str());
return FAILED;
}


// Parsing weight information through op parser // Parsing weight information through op parser
Status status = op_parser->ParseWeights(layer_message, node); Status status = op_parser->ParseWeights(layer_message, node);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
(status != SUCCESS),
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str());
return status,
"[Parse][Weights] for op[%s] failed", layer_name.c_str());
if (status != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str());
GELOGE(FAILED, "[Parse][Weights] for op[%s] failed", layer_name.c_str());
return status;
}
} }
return SUCCESS; return SUCCESS;
} }
@@ -2233,13 +2242,18 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter &param, ge::Co
// Operator name and occurrence map, handle duplicate operators // Operator name and occurrence map, handle duplicate operators
std::map<std::string, int32_t> layer_name_map; std::map<std::string, int32_t> layer_name_map;


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(num_layer == 0 && num_layers > 0,
ErrorManager::GetInstance().ATCReportErrMessage("E11023");
return FAILED, "[Check][Param] The weight file is consisted of layers-structure "
"which is deprecated in Caffe and unsupported in ATC. "
"The \"layers\" should be changed to \"layer\".");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((num_layer == 0), ErrorManager::GetInstance().ATCReportErrMessage("E11024");
return FAILED, "weight layer num is zero, weight file may be invalid.");
if (num_layer == 0 && num_layers > 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E11023");
GELOGE(FAILED, "[Check][Param] The weight file is consisted of layers-structure "
"which is deprecated in Caffe and unsupported in ATC. "
"The \"layers\" should be changed to \"layer\".");
return FAILED;
}
if (num_layer == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E11024");
GELOGE(FAILED, "weight layer num is zero, weight file may be invalid.");
return FAILED;
}


for (int i = 0; i < num_layer; ++i) { for (int i = 0; i < num_layer; ++i) {
const LayerParameter &layer = param.layer(i); const LayerParameter &layer = param.layer(i);
@@ -2285,7 +2299,7 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter &param, ge::Co
} }


// The weight processing also needs to judge the duplicate operator, which is reserved here and processed later. // The weight processing also needs to judge the duplicate operator, which is reserved here and processed later.
auto iter = caffe_op_map.find(layer.type());
std::map<std::string, std::string>::const_iterator iter = caffe_op_map.find(layer.type());
if (iter == caffe_op_map.end()) { if (iter == caffe_op_map.end()) {
GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer.type().c_str(), layer_name.c_str()); GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer.type().c_str(), layer_name.c_str());
continue; continue;
@@ -2298,18 +2312,20 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter &param, ge::Co
GE_CHECK_NOTNULL(factory); GE_CHECK_NOTNULL(factory);
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type);


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
(op_parser.get() == nullptr),
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}),
std::vector<std::string>({layer_name, op_type}));
return FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str());
if (op_parser.get() == nullptr) {
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}),
std::vector<std::string>({layer_name, op_type}));
GELOGE(FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str());
return FAILED;
}


// Parsing weight information through op parser // Parsing weight information through op parser
Status status = op_parser->ParseWeights(&layer, node); Status status = op_parser->ParseWeights(&layer, node);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
(status != SUCCESS),
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str());
return status, "[Parse][Weights] for op[%s] failed", layer_name.c_str());
if (status != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str());
GELOGE(FAILED, "[Parse][Weights] for op[%s] failed", layer_name.c_str());
return status;
}
} }
} }




+ 26
- 0
parser/caffe/caffe_parser.h View File

@@ -40,6 +40,7 @@
#include "omg/parser/op_parser.h" #include "omg/parser/op_parser.h"
#include "omg/parser/model_parser.h" #include "omg/parser/model_parser.h"
#include "omg/parser/weights_parser.h" #include "omg/parser/weights_parser.h"
#include "common/pre_checker.h"
#include "proto/caffe/caffe.pb.h" #include "proto/caffe/caffe.pb.h"
#include "proto/om.pb.h" #include "proto/om.pb.h"


@@ -123,6 +124,17 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
return domi::SUCCESS; return domi::SUCCESS;
} }


bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}
private: private:
Status Parse(const char *model_path, ge::ComputeGraphPtr &graph); Status Parse(const char *model_path, ge::ComputeGraphPtr &graph);


@@ -313,6 +325,8 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {


Status SaveDataLayerTops(const domi::caffe::LayerParameter &layer); Status SaveDataLayerTops(const domi::caffe::LayerParameter &layer);


Status ReportLayerInvalid(const domi::caffe::NetParameter &proto, const std::string &path) const;

std::map<std::string, ge::NodePtr> node_map; std::map<std::string, ge::NodePtr> node_map;


// key: blob name, value: layer name and index // key: blob name, value: layer name and index
@@ -344,6 +358,18 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser {


Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;


bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}

private: private:
Status CheckNodes(ge::ComputeGraphPtr &graph); Status CheckNodes(ge::ComputeGraphPtr &graph);
/** /**


+ 1
- 1
parser/caffe/caffe_reshape_parser.cc View File

@@ -128,7 +128,7 @@ Status CaffeReshapeParser::AddConstInput(ge::NodePtr &node) {
data[i] = attr_shape[i]; data[i] = attr_shape[i];
} }
GE_IF_BOOL_EXEC( GE_IF_BOOL_EXEC(
constTensor->SetData(reinterpret_cast<uint8_t *>(data.get()), dims_size * sizeof(int64_t)) != ge::GRAPH_SUCCESS,
constTensor->SetData(PtrToPtr<int64_t, uint8_t>(data.get()), dims_size * sizeof(int64_t)) != ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return GELOGW("SetData failed for GeTensor.");); // no need to return


// construct const node and add edge // construct const node and add edge


+ 121
- 84
parser/common/acl_graph_parser_util.cc View File

@@ -492,7 +492,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node,
} }


domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes; std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes;
if (!default_out_nodes.empty()) { if (!default_out_nodes.empty()) {
for (size_t i = 0; i < default_out_nodes.size(); ++i) { for (size_t i = 0; i < default_out_nodes.size(); ++i) {
@@ -613,24 +613,27 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin
ge::GetParserContext().out_tensor_names.clear(); ge::GetParserContext().out_tensor_names.clear();
ge::GetParserContext().data_tensor_names.clear(); ge::GetParserContext().data_tensor_names.clear();


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS,
return PARAM_INVALID, "[Check][Options] Parse paragrams invalid, graph:%s.",
graph_name.c_str());
if (CheckOptions(parser_params) != SUCCESS) {
GELOGE(FAILED, "[Check][Options] Parse paragrams invalid, graph:%s.", graph_name.c_str());
return PARAM_INVALID;
}
// support paragrams: out_nodes, is_output_adjust_hw_layout, output, enable_scope_fusion_passes // support paragrams: out_nodes, is_output_adjust_hw_layout, output, enable_scope_fusion_passes
SetDefaultFormat(); SetDefaultFormat();


string out_nodes; string out_nodes;
GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes); GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputNodes(out_nodes) != SUCCESS,
return PARAM_INVALID,
"[Invoke][ParseAclOutputNodes] Parse out_nodes failed, graph:%s.", graph_name.c_str());
if (ParseAclOutputNodes(out_nodes) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ParseAclOutputNodes] Parse out_nodes failed, graph:%s.", graph_name.c_str());
return PARAM_INVALID;
}


string is_output_adjust_hw_layout; string is_output_adjust_hw_layout;
GetAclParams(parser_params, ge::ir_option::IS_OUTPUT_ADJUST_HW_LAYOUT, is_output_adjust_hw_layout); GetAclParams(parser_params, ge::ir_option::IS_OUTPUT_ADJUST_HW_LAYOUT, is_output_adjust_hw_layout);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS,
return PARAM_INVALID,
"[Invoke][ParseAclOutputFp16NodesFormat] Parse is_output_adjust_hw_layout failed, "
"graph:%s.", graph_name.c_str());
if (ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ParseAclOutputFp16NodesFormat] Parse is_output_adjust_hw_layout failed, graph:%s.",
graph_name.c_str());
return PARAM_INVALID;
}


string tmp_name; string tmp_name;
GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name); GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name);
@@ -638,10 +641,11 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin


string enable_scope_fusion_passes; string enable_scope_fusion_passes;
GetAclParams(parser_params, ge::ir_option::ENABLE_SCOPE_FUSION_PASSES, enable_scope_fusion_passes); GetAclParams(parser_params, ge::ir_option::ENABLE_SCOPE_FUSION_PASSES, enable_scope_fusion_passes);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclEnableScope(enable_scope_fusion_passes) != SUCCESS,
return PARAM_INVALID,
"[Invoke][ParseAclEnableScope] Parse enable_scope_fusion_passes failed, graph:%s.",
graph_name.c_str());
if (ParseAclEnableScope(enable_scope_fusion_passes) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ParseAclEnableScope] Parse enable_scope_fusion_passes failed, graph:%s.",
graph_name.c_str());
return PARAM_INVALID;
}


return SUCCESS; return SUCCESS;
} }
@@ -657,10 +661,11 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,


string is_input_adjust_hw_layout; string is_input_adjust_hw_layout;
GetAclParams(parser_params, ge::ir_option::IS_INPUT_ADJUST_HW_LAYOUT, is_input_adjust_hw_layout); GetAclParams(parser_params, ge::ir_option::IS_INPUT_ADJUST_HW_LAYOUT, is_input_adjust_hw_layout);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS,
return PARAM_INVALID, "[Invoke][ParseAclInputFp16Nodes] Parse input_fp16_nodes failed, graph:%s",
compute_graph->GetName().c_str());
if (ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS) {
GELOGE(FAILED, "[Invoke][ParseAclInputFp16Nodes] Parse input_fp16_nodes failed, graph:%s",
compute_graph->GetName().c_str());
return PARAM_INVALID;
}


return SUCCESS; return SUCCESS;
} }
@@ -689,30 +694,35 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char


// Get file length // Get file length
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(),
REPORT_INNER_ERROR("E19999", "input_file path is null, check invalid.");
return -1, "[Check][Param] input_file path is null.");
if (input_file.empty()) {
REPORT_INNER_ERROR("E19999", "input_file path is null, check invalid.");
GELOGE(FAILED, "[Check][Param] input_file path is null.");
return -1;
}


std::string real_path = RealPath(input_file.c_str()); std::string real_path = RealPath(input_file.c_str());
char_t err_buf[kMaxErrStrLen + 1U] = {}; char_t err_buf[kMaxErrStrLen + 1U] = {};
const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen); const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
REPORT_INPUT_ERROR("E19000", std::vector<std::string>({"path", "errmsg"}),
std::vector<std::string>({real_path, err_msg}));
return -1, "[Get][Path] input_file path '%s' not valid", input_file.c_str());
if (real_path.empty()) {
REPORT_INPUT_ERROR("E19000", std::vector<std::string>({"path", "errmsg"}),
std::vector<std::string>({real_path, err_msg}));
GELOGE(FAILED, "[Get][Path] input_file path '%s' not valid", input_file.c_str());
return -1;
}
unsigned long long file_length = 0; unsigned long long file_length = 0;
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK,
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"},
{input_file, err_msg});
return -1, "[Open][File] [%s] failed. %s", input_file.c_str(), err_msg);

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0 || file_length > kMaxFileSizeLimit),
REPORT_INPUT_ERROR(
"E19015", std::vector<std::string>({"file", "size", "maxsize"}),
std::vector<std::string>({input_file, std::to_string(file_length),
std::to_string(kMaxFileSizeLimit)}));
return -1, "[Check][Param] File[%s] size %lld is out of range(0,%d).",
input_file.c_str(), file_length, kMaxFileSizeLimit);
if (mmGetFileSize(input_file.c_str(), &file_length) != EN_OK) {
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, err_msg});
GELOGE(FAILED, "[Open][File] [%s] failed. %s", input_file.c_str(), err_msg);
return -1;
}

if ((file_length == 0) || (file_length > kMaxFileSizeLimit)) {
REPORT_INPUT_ERROR("E19015", std::vector<std::string>({ "file", "size", "maxsize" }),
std::vector<std::string>({ input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit) }));
GELOGE(FAILED, "[Check][Param] File[%s] size %lld is out of range(0,%d).",
input_file.c_str(), file_length, kMaxFileSizeLimit);
return -1;
}
return static_cast<long>(file_length); return static_cast<long>(file_length);
} }


@@ -725,9 +735,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp()
} }


static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr,
REPORT_INNER_ERROR("E19999", "param proto is nullptr, check invalid");
return false, "[Check][Param] incorrect parameter. nullptr == proto");
if (proto == nullptr) {
REPORT_INNER_ERROR("E19999", "param proto is nullptr, check invalid");
GELOGE(FAILED, "[Check][Param] incorrect parameter. nullptr == proto");
return false;
}


coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit); coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit);
return proto->ParseFromCodedStream(&coded_stream); return proto->ParseFromCodedStream(&coded_stream);
@@ -743,17 +755,23 @@ static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Messag
*/ */
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer,
int &length) { int &length) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr),
REPORT_INNER_ERROR("E19999", "param file_name is nullptr, check invalid");
return false, "[Check][Param] incorrect parameter. file is nullptr");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr),
REPORT_INNER_ERROR("E19999", "param buffer is nullptr, check invalid");
return false, "[Check][Param] incorrect parameter. buffer is nullptr");
if (file_name == nullptr) {
REPORT_INNER_ERROR("E19999", "param file_name is nullptr, check invalid");
GELOGE(FAILED, "[Check][Param] incorrect parameter. file is nullptr");
return false;
}
if (buffer == nullptr) {
REPORT_INNER_ERROR("E19999", "param buffer is nullptr, check invalid");
GELOGE(FAILED, "[Check][Param] incorrect parameter. buffer is nullptr");
return false;
}


std::string real_path = RealPath(file_name); std::string real_path = RealPath(file_name);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file_name);
return false, "[Check][Param]file path '%s' not valid, realpath failed", file_name);
if (real_path.empty()) {
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file_name);
GELOGE(FAILED, "[Check][Param]file path '%s' not valid, realpath failed", file_name);
return false;
}


std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate); std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate);
if (!file.is_open()) { if (!file.is_open()) {
@@ -763,16 +781,22 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co
} }


length = static_cast<int>(file.tellg()); length = static_cast<int>(file.tellg());

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((length <= 0), file.close(); REPORT_INNER_ERROR("E19999", "file length <= 0");
return false, "[Check][Param] file length <= 0");
if ((length <= 0)) {
file.close();
REPORT_INNER_ERROR("E19999", "file length <= 0");
GELOGE(FAILED, "[Check][Param] file length <= 0");
return false;
}


file.seekg(0, std::ios::beg); file.seekg(0, std::ios::beg);


*buffer = new(std::nothrow) char[length](); *buffer = new(std::nothrow) char[length]();
GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, file.close();
REPORT_CALL_ERROR("E19999", "new an object failed."),
"[Create][Buffer] new an object failed.");
if (*buffer == nullptr) {
REPORT_INNER_ERROR("E19999", "[Create][Buffer] new an object failed, length=%d.", length);
GELOGE(FAILED, "[Create][Buffer] new an object failed, length=%d.", length);
file.close();
return false;
}


file.read(*buffer, length); file.read(*buffer, length);
file.close(); file.close();
@@ -780,16 +804,23 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr),
REPORT_INNER_ERROR("E19999", "param file or proto is nullptr, check invalid");
return false, "[Check][Param] Input parameter file or proto is nullptr!");
if ((file == nullptr) || (proto == nullptr)) {
REPORT_INNER_ERROR("E19999", "param file or proto is nullptr, check invalid");
GELOGE(FAILED, "[Check][Param] Input parameter file or proto is nullptr!");
return false;
}


std::string real_path = RealPath(file); std::string real_path = RealPath(file);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file);
return false, "[Check][Param]pb file path '%s' not valid, realpath failed", file);
if (real_path.empty()) {
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file);
GELOGE(FAILED, "[Check][Param]pb file path '%s' not valid, realpath failed", file);
return false;
}


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "[Get][FileLength]file size not valid.");
if (GetFileLength(real_path) == -1) {
GELOGE(FAILED, "[Get][FileLength]file size not valid.");
return false;
}


std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary);
if (!fs.is_open()) { if (!fs.is_open()) {
@@ -815,32 +846,37 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0),
REPORT_INNER_ERROR("E19999", "param proto or data is nullptr "
"or size is 0, check invalid"); return false,
"[Check][Param]incorrect parameter. proto is nullptr || data is nullptr || size is 0");
if ((proto == nullptr) || (data == nullptr) || (size == 0)) {
REPORT_INNER_ERROR("E19999", "param proto or data is nullptr or size is 0, check invalid");
GELOGE(FAILED, "[Check][Param]incorrect parameter. proto is nullptr || data is nullptr || size is 0");
return false;
}


google::protobuf::io::CodedInputStream coded_stream(reinterpret_cast<uint8_t *>(const_cast<void *>(data)), size);
google::protobuf::io::CodedInputStream coded_stream(PtrToPtr<void, uint8_t>(const_cast<void *>(data)), size);
return ReadProtoFromCodedInputStream(coded_stream, proto); return ReadProtoFromCodedInputStream(coded_stream, proto);
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file, FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file,
google::protobuf::Message *message) { google::protobuf::Message *message) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr),
REPORT_INNER_ERROR("E19999", "param file or message is nullptr, check invalid");
return false,
"[Check][Param]incorrect parameter. nullptr == file || nullptr == message");
if ((file == nullptr) || (message == nullptr)) {
REPORT_INNER_ERROR("E19999", "param file or message is nullptr, check invalid");
GELOGE(FAILED, "[Check][Param]incorrect parameter. nullptr == file || nullptr == message");
return false;
}


std::string real_path = RealPath(file); std::string real_path = RealPath(file);
char_t err_buf[kMaxErrStrLen + 1U] = {}; char_t err_buf[kMaxErrStrLen + 1U] = {};
const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen); const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"},
{file, err_msg});
return false, "[Check][Param]Path[%s]'s realpath is empty, errmsg[%s]", file,
err_msg);
if (real_path.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {file, err_msg});
GELOGE(FAILED, "[Check][Param]Path[%s]'s realpath is empty, errmsg[%s]", file, err_msg);
return false;
}


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "[Check][Param] file size not valid.");
if (GetFileLength(real_path) == -1) {
GELOGE(FAILED, "[Check][Param] file size not valid.");
return false;
}


std::ifstream fs(real_path.c_str(), std::ifstream::in); std::ifstream fs(real_path.c_str(), std::ifstream::in);


@@ -863,10 +899,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size, FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size,
google::protobuf::Message *message) { google::protobuf::Message *message) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr),
REPORT_INNER_ERROR("E19999", "param data or message is nullptr,check invalid");
return false,
"[Check][Param] incorrect parameter. data is nullptr || message is nullptr");
if ((data == nullptr) || (message == nullptr)) {
REPORT_INNER_ERROR("E19999", "param data or message is nullptr,check invalid");
GELOGE(FAILED, "[Check][Param] incorrect parameter. data is nullptr || message is nullptr");
return false;
}
std::string str(data, static_cast<size_t>(size)); std::string str(data, static_cast<size_t>(size));
std::istringstream fs(str); std::istringstream fs(str);


@@ -901,7 +938,7 @@ Status GetOriginalType(const ge::NodePtr &node, string &type) {
return SUCCESS; return SUCCESS;
} }


FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) {
FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &filePath, const std::string &mode) {
char ebuff[kMaxBuffSize]; char ebuff[kMaxBuffSize];
regex_t reg; regex_t reg;
int cflags = REG_EXTENDED | REG_NOSUB; int cflags = REG_EXTENDED | REG_NOSUB;
@@ -913,7 +950,7 @@ FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::str
return true; return true;
} }


ret = regexec(&reg, str.c_str(), 0, nullptr, 0);
ret = regexec(&reg, filePath.c_str(), 0, nullptr, 0);
if (ret) { if (ret) {
regerror(ret, &reg, ebuff, kMaxBuffSize); regerror(ret, &reg, ebuff, kMaxBuffSize);
GELOGE(ge::PARAM_INVALID, "[Invoke][RegExec] failed, reason: %s", ebuff); GELOGE(ge::PARAM_INVALID, "[Invoke][RegExec] failed, reason: %s", ebuff);


+ 1
- 1
parser/common/auto_mapping_subgraph_io_index_func.cc View File

@@ -21,11 +21,11 @@
#include "graph/op_desc.h" #include "graph/op_desc.h"
#include "graph/utils/attr_utils.h" #include "graph/utils/attr_utils.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/debug/ge_util.h"
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h" #include "graph/utils/node_utils.h"
#include "register/register_fmk_types.h" #include "register/register_fmk_types.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/util.h"


namespace ge { namespace ge {
namespace { namespace {


+ 14
- 8
parser/common/convert/pb2json.cc View File

@@ -31,11 +31,17 @@ using std::string;
namespace ge { namespace ge {
namespace { namespace {
const int kSignificantDigits = 10; const int kSignificantDigits = 10;
const int kMaxParseDepth = 20;
} }
// JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields // JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message, FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message,
const set<string> &black_fields, Json &json, const set<string> &black_fields, Json &json,
bool enum2str) {
bool enum2str, int depth) {
if (depth > kMaxParseDepth) {
REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth);
GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth);
return;
}
auto descriptor = message.GetDescriptor(); auto descriptor = message.GetDescriptor();
auto reflection = message.GetReflection(); auto reflection = message.GetReflection();
if (descriptor == nullptr || reflection == nullptr) { if (descriptor == nullptr || reflection == nullptr) {
@@ -57,7 +63,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons


if (field->is_repeated()) { if (field->is_repeated()) {
if (reflection->FieldSize(message, field) > 0) { if (reflection->FieldSize(message, field) > 0) {
RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str);
RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str, depth);
} }
continue; continue;
} }
@@ -66,18 +72,18 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons
continue; continue;
} }


OneField2Json(message, field, reflection, black_fields, json, enum2str);
OneField2Json(message, field, reflection, black_fields, json, enum2str, depth);
} }
} }


void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, const ProtobufReflection *reflection, const set<string> &black_fields, Json &json,
bool enum2str) {
bool enum2str, int depth) {
switch (field->type()) { switch (field->type()) {
case ProtobufFieldDescriptor::TYPE_MESSAGE: { case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); const ProtobufMsg &tmp_message = reflection->GetMessage(message, field);
if (0UL != tmp_message.ByteSizeLong()) { if (0UL != tmp_message.ByteSizeLong()) {
Message2Json(tmp_message, black_fields, json[field->name()], enum2str);
Message2Json(tmp_message, black_fields, json[field->name()], enum2str, depth + 1);
} }
break; break;
} }
@@ -163,9 +169,9 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) {


void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, const ProtobufReflection *reflection, const set<string> &black_fields, Json &json,
bool enum2str) {
bool enum2str, int depth) {
if ((field == nullptr) || (reflection == nullptr)) { if ((field == nullptr) || (reflection == nullptr)) {
Message2Json(message, black_fields, json, enum2str);
Message2Json(message, black_fields, json, enum2str, depth + 1);
return; return;
} }


@@ -175,7 +181,7 @@ void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFie
case ProtobufFieldDescriptor::TYPE_MESSAGE: { case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i);
if (0UL != tmp_message.ByteSizeLong()) { if (0UL != tmp_message.ByteSizeLong()) {
Message2Json(tmp_message, black_fields, tmp_json, enum2str);
Message2Json(tmp_message, black_fields, tmp_json, enum2str, depth + 1);
} }
} break; } break;




+ 3
- 3
parser/common/convert/pb2json.h View File

@@ -45,11 +45,11 @@ class Pb2Json {
* @author * @author
*/ */
static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json,
bool enum2str = false);
bool enum2str = false, int depth = 0);


static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const std::set<std::string> &black_fields, const ProtobufReflection *reflection, const std::set<std::string> &black_fields,
Json &json, bool enum2str);
Json &json, bool enum2str, int depth = 0);


protected: protected:
static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field,
@@ -59,7 +59,7 @@ class Pb2Json {


static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const std::set<std::string> &black_fields, Json &json, const ProtobufReflection *reflection, const std::set<std::string> &black_fields, Json &json,
bool enum2str);
bool enum2str, int depth);


static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes); static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes);
}; };


+ 2
- 2
parser/common/data_op_parser.h View File

@@ -60,7 +60,7 @@ class DataOpParser {
* @param [in] 4D shape information (dimensions) * @param [in] 4D shape information (dimensions)
* @param [out] Save converted shap information * @param [out] Save converted shap information
*/ */
static Status Init5DInputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &tensorDesc);
static Status Init5DInputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &tensor_desc);


/** /**
* @ingroup domi_omg * @ingroup domi_omg
@@ -98,7 +98,7 @@ class DataOpParser {
* @return SUCCESS Convert success * @return SUCCESS Convert success
* @return FAILED Convert failed * @return FAILED Convert failed
*/ */
static Status InitNDTensor(const std::vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &desc);
static Status InitNDTensor(const std::vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &tensor_desc);
}; };
} // namespace ge } // namespace ge



+ 5
- 3
parser/common/model_saver.cc View File

@@ -55,9 +55,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi
} }


char real_path[PATH_MAX] = {0}; char real_path[PATH_MAX] = {0};
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path) >= PATH_MAX,
REPORT_INNER_ERROR("E19999", "file path %s is too long!", file_path);
return FAILED, "[Check][Param] file path %s is too long!", file_path);
if (strlen(file_path) >= PATH_MAX) {
REPORT_INNER_ERROR("E19999", "file path %s is too long!", file_path);
GELOGE(FAILED, "[Check][Param] file path %s is too long!", file_path);
return FAILED;
}
if (realpath(file_path, real_path) == nullptr) { if (realpath(file_path, real_path) == nullptr) {
GELOGI("File %s does not exit, it will be created.", file_path); GELOGI("File %s does not exit, it will be created.", file_path);
} }


+ 2
- 2
parser/common/op_def/constant_op.cc View File

@@ -32,9 +32,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOpera
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::DType(ge::DataType t) { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::DType(ge::DataType t) {
Attr(VAR_ATTR_DTYPE, (int64_t)t);
Attr(VAR_ATTR_DTYPE, static_cast<int64_t>(t));
return *this; return *this;
} }


ge::DataType ConstantOperator::GetDType() const { return (ge::DataType)GetIntAttr(VAR_ATTR_DTYPE); }
ge::DataType ConstantOperator::GetDType() const { return static_cast<ge::DataType>(GetIntAttr(VAR_ATTR_DTYPE)); }
} // namespace ge } // namespace ge

+ 3
- 3
parser/common/op_def/ir_pb_converter.cc View File

@@ -32,7 +32,7 @@ static void ConvertList(const std::pair<std::string, OpAttribute> &op_attr_pair,


vector<int64_t> v_i; vector<int64_t> v_i;
for (int32_t i = 0; i < a_list.i_size(); i++) { for (int32_t i = 0; i < a_list.i_size(); i++) {
v_i.push_back((int64_t)a_list.i(i));
v_i.push_back(static_cast<int64_t>(a_list.i(i)));
} }
if (v_i.size() > 0) { if (v_i.size() > 0) {
(void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_i); (void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_i);
@@ -56,7 +56,7 @@ static void ConvertList(const std::pair<std::string, OpAttribute> &op_attr_pair,
} }
vector<int32_t> v_u; vector<int32_t> v_u;
for (int32_t i = 0; i < a_list.u_size(); i++) { for (int32_t i = 0; i < a_list.u_size(); i++) {
v_u.push_back((int32_t)a_list.u(i));
v_u.push_back(static_cast<int32_t>(a_list.u(i)));
} }
if (v_u.size() > 0) { if (v_u.size() > 0) {
(void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_u); (void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_u);
@@ -114,7 +114,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertToOpDesc(co
if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kBt) { if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kBt) {
auto &buffer = op_attr_pair.second.value_.bt(); auto &buffer = op_attr_pair.second.value_.bt();
(void)ge::AttrUtils::SetZeroCopyBytes(op_def, op_attr_pair.first, (void)ge::AttrUtils::SetZeroCopyBytes(op_def, op_attr_pair.first,
ge::Buffer::CopyFrom(reinterpret_cast<uint8_t *>(const_cast<char *>(buffer.data())), buffer.size()));
ge::Buffer::CopyFrom(PtrToPtr<void, uint8_t>(const_cast<char *>(buffer.data())), buffer.size()));
} }


if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kS) { if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kS) {


+ 1
- 1
parser/common/op_def/ref_switch_op.cc View File

@@ -23,7 +23,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::RefSwitchOpe
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::~RefSwitchOperator() {} FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::~RefSwitchOperator() {}


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator &RefSwitchOperator::T(ge::DataType t) { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator &RefSwitchOperator::T(ge::DataType t) {
Attr("T", (int64_t)t);
Attr("T", static_cast<int64_t>(t));
return *this; return *this;
} }
} // namespace ge AUTO GEN PLEASE DO NOT MODIFY IT } // namespace ge AUTO GEN PLEASE DO NOT MODIFY IT

+ 4
- 4
parser/common/op_def/shape_n_op.cc View File

@@ -32,20 +32,20 @@ FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::N(int64_t n) {
FMK_FUNC_HOST_VISIBILITY int64_t ShapeNOperator::GetN() const { return GetIntAttr(SHAPEN_ATTR_N); } FMK_FUNC_HOST_VISIBILITY int64_t ShapeNOperator::GetN() const { return GetIntAttr(SHAPEN_ATTR_N); }


FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::InType(ge::DataType t) { FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::InType(ge::DataType t) {
Attr(SHAPEN_ATTR_IN_TYPE, (int64_t)t);
Attr(SHAPEN_ATTR_IN_TYPE, static_cast<int64_t>(t));
return *this; return *this;
} }


FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetInType() const { FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetInType() const {
return (ge::DataType)GetIntAttr(SHAPEN_ATTR_IN_TYPE);
return static_cast<ge::DataType>(GetIntAttr(SHAPEN_ATTR_IN_TYPE));
} }


FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::OutType(ge::DataType t) { FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::OutType(ge::DataType t) {
Attr(SHAPEN_ATTR_OUT_TYPE, (int64_t)t);
Attr(SHAPEN_ATTR_OUT_TYPE, static_cast<int64_t>(t));
return *this; return *this;
} }


FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetOutType() const { FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetOutType() const {
return (ge::DataType)GetIntAttr(SHAPEN_ATTR_OUT_TYPE);
return static_cast<ge::DataType>(GetIntAttr(SHAPEN_ATTR_OUT_TYPE));
} }
} // namespace ge } // namespace ge

+ 5
- 5
parser/common/op_parser_factory.cc View File

@@ -50,7 +50,7 @@ FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParserFactory> OpParserFactory::Insta
// Instances cannot be a member of a class because they may be used before initialization, resulting in a run error. // Instances cannot be a member of a class because they may be used before initialization, resulting in a run error.
static std::map<domi::FrameworkType, std::shared_ptr<OpParserFactory>> instances; static std::map<domi::FrameworkType, std::shared_ptr<OpParserFactory>> instances;


auto iter = instances.find(framework);
std::map<domi::FrameworkType, std::shared_ptr<OpParserFactory>>::const_iterator iter = instances.find(framework);
if (iter == instances.end()) { if (iter == instances.end()) {
std::shared_ptr<OpParserFactory> instance(new (std::nothrow) OpParserFactory()); std::shared_ptr<OpParserFactory> instance(new (std::nothrow) OpParserFactory());
if (instance == nullptr) { if (instance == nullptr) {
@@ -67,7 +67,7 @@ FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParserFactory> OpParserFactory::Insta


FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateOpParser(const std::string &op_type) { FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateOpParser(const std::string &op_type) {
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. // First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser.
auto iter = op_parser_creator_map_.find(op_type);
std::map<std::string, CREATOR_FUN>::const_iterator iter = op_parser_creator_map_.find(op_type);
if (iter != op_parser_creator_map_.end()) { if (iter != op_parser_creator_map_.end()) {
return iter->second(); return iter->second();
} }
@@ -78,7 +78,7 @@ FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateOpPars


FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateFusionOpParser(const std::string &op_type) { FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateFusionOpParser(const std::string &op_type) {
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. // First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser.
auto iter = fusion_op_parser_creator_map_.find(op_type);
std::map<std::string, CREATOR_FUN>::const_iterator iter = fusion_op_parser_creator_map_.find(op_type);
if (iter != fusion_op_parser_creator_map_.end()) { if (iter != fusion_op_parser_creator_map_.end()) {
return iter->second(); return iter->second();
} }
@@ -102,12 +102,12 @@ FMK_FUNC_HOST_VISIBILITY void OpParserFactory::RegisterCreator(const std::string


FMK_FUNC_HOST_VISIBILITY bool OpParserFactory::OpParserIsRegistered(const std::string &op_type, bool is_fusion_op) { FMK_FUNC_HOST_VISIBILITY bool OpParserFactory::OpParserIsRegistered(const std::string &op_type, bool is_fusion_op) {
if (is_fusion_op) { if (is_fusion_op) {
auto iter = fusion_op_parser_creator_map_.find(op_type);
std::map<std::string, CREATOR_FUN>::const_iterator iter = fusion_op_parser_creator_map_.find(op_type);
if (iter != fusion_op_parser_creator_map_.end()) { if (iter != fusion_op_parser_creator_map_.end()) {
return true; return true;
} }
} else { } else {
auto iter = op_parser_creator_map_.find(op_type);
std::map<std::string, CREATOR_FUN>::const_iterator iter = op_parser_creator_map_.find(op_type);
if (iter != op_parser_creator_map_.end()) { if (iter != op_parser_creator_map_.end()) {
return true; return true;
} }


+ 0
- 61
parser/common/op_types.h View File

@@ -1,61 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_OP_TYPES_H_
#define PARSER_COMMON_OP_TYPES_H_

#include <set>
#include <string>

namespace ge {
class GE_FUNC_VISIBILITY OpTypeContainer {
public:
static OpTypeContainer *Instance() {
static OpTypeContainer instance;
return &instance;
}
~OpTypeContainer() = default;

void Register(const std::string &op_type) { op_type_list_.insert(op_type); }

bool IsExisting(const std::string &op_type) {
return op_type_list_.count(op_type) > 0UL;
}

protected:
OpTypeContainer() {}

private:
std::set<std::string> op_type_list_;
};

class GE_FUNC_VISIBILITY OpTypeRegistrar {
public:
explicit OpTypeRegistrar(const std::string &op_type) { OpTypeContainer::Instance()->Register(op_type); }
~OpTypeRegistrar() {}
};

#define REGISTER_OPTYPE_DECLARE(var_name, str_name) \
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *var_name;

#define REGISTER_OPTYPE_DEFINE(var_name, str_name) \
const char *var_name = str_name; \
const OpTypeRegistrar g_##var_name##_reg(str_name);

#define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name))
} // namespace ge

#endif // PARSER_COMMON_OP_TYPES_H_

+ 10
- 0
parser/common/parser_factory.cc View File

@@ -16,6 +16,7 @@


#include "omg/parser/parser_factory.h" #include "omg/parser/parser_factory.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "common/register_tbe.h"


namespace domi { namespace domi {
FMK_FUNC_HOST_VISIBILITY WeightsParserFactory *WeightsParserFactory::Instance() { FMK_FUNC_HOST_VISIBILITY WeightsParserFactory *WeightsParserFactory::Instance() {
@@ -77,4 +78,13 @@ FMK_FUNC_HOST_VISIBILITY void ModelParserFactory::RegisterCreator(const domi::Fr
ModelParserFactory::~ModelParserFactory() { ModelParserFactory::~ModelParserFactory() {
creator_map_.clear(); creator_map_.clear();
} }

FMK_FUNC_HOST_VISIBILITY OpRegTbeParserFactory *OpRegTbeParserFactory::Instance() {
static OpRegTbeParserFactory instance;
return &instance;
}

void OpRegTbeParserFactory::Finalize(const domi::OpRegistrationData &reg_data) {
(void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data);
}
} // namespace domi } // namespace domi

+ 56
- 55
parser/common/parser_fp16_t.cc View File

@@ -17,15 +17,16 @@
#include "parser/common/parser_fp16_t.h" #include "parser/common/parser_fp16_t.h"


#include "external/register/register_types.h" #include "external/register/register_types.h"
#include "graph/def_types.h"


namespace { namespace {
constexpr uint16_t kManBitLength = 11;
constexpr uint16_t kManBitLength = 11U;
} }
namespace ge { namespace ge {
namespace parser { namespace parser {
/// @ingroup fp16_t global filed /// @ingroup fp16_t global filed
/// @brief round mode of last valid digital /// @brief round mode of last valid digital
enum TagFp16RoundMode g_round_mode = TagFp16RoundMode::kRoundToNearest;
const TagFp16RoundMode g_round_mode = TagFp16RoundMode::kRoundToNearest;


void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m) { void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m) {
// 1.Extract // 1.Extract
@@ -99,12 +100,12 @@ static float Fp16ToFloat(const uint16_t &fp_val) {
e_ret = 0; e_ret = 0;
m_ret = 0; m_ret = 0;
} else { } else {
e_ret = hf_exp - kFp16ExpBias + kFp32ExpBias;
e_ret = (static_cast<uint32_t>(hf_exp) - static_cast<uint32_t>(kFp16ExpBias)) + static_cast<uint32_t>(kFp32ExpBias);
m_ret = hf_man & kFp16ManMask; m_ret = hf_man & kFp16ManMask;
m_ret = m_ret << (kFp32ManLen - kFp16ManLen); m_ret = m_ret << (kFp32ManLen - kFp16ManLen);
} }
uint32_t f_val = FP32_CONSTRUCTOR(s_ret, e_ret, m_ret); uint32_t f_val = FP32_CONSTRUCTOR(s_ret, e_ret, m_ret);
auto p_ret_v = reinterpret_cast<float *>(&f_val);
auto p_ret_v = ge::PtrToPtr<uint32_t, float>(&f_val);


return *p_ret_v; return *p_ret_v;
} }
@@ -131,12 +132,12 @@ static double Fp16ToDouble(const uint16_t &fp_val) {
e_ret = 0; e_ret = 0;
m_ret = 0; m_ret = 0;
} else { } else {
e_ret = hf_exp - kFp16ExpBias + kFp64ExpBias;
e_ret = (static_cast<uint64_t>(hf_exp) - static_cast<uint64_t>(kFp16ExpBias)) + static_cast<uint64_t>(kFp64ExpBias);
m_ret = hf_man & kFp16ManMask; m_ret = hf_man & kFp16ManMask;
m_ret = m_ret << (kFp64ManLen - kFp16ManLen); m_ret = m_ret << (kFp64ManLen - kFp16ManLen);
} }
uint64_t f_val = (s_ret << kFp64SignIndex) | (e_ret << kFp64ManLen) | (m_ret); uint64_t f_val = (s_ret << kFp64SignIndex) | (e_ret << kFp64ManLen) | (m_ret);
auto p_ret_v = reinterpret_cast<double *>(&f_val);
auto p_ret_v = ge::PtrToPtr<uint64_t, double>(&f_val);


return *p_ret_v; return *p_ret_v;
} }
@@ -154,13 +155,13 @@ static uint8_t GetUint8ValByMan(uint8_t s_ret, const uint64_t &long_int_m, const
if (need_round) { if (need_round) {
m_ret++; m_ret++;
} }
if (s_ret) {
m_ret = (~m_ret) + 1;
if (static_cast<bool>(s_ret)) {
m_ret = (~m_ret) + 1U;
} }
if (m_ret == 0) { if (m_ret == 0) {
s_ret = 0; s_ret = 0;
} }
return static_cast<uint8_t>((s_ret << kBitShift7) | (m_ret));
return static_cast<uint8_t>((s_ret << static_cast<uint8_t>(kBitShift7)) | (m_ret));
} }


/// @ingroup fp16_t math conversion static method /// @ingroup fp16_t math conversion static method
@@ -178,7 +179,7 @@ static int8_t Fp16ToInt8(const uint16_t &fp_val) {


if (FP16_IS_DENORM(fp_val)) { // Denormalized number if (FP16_IS_DENORM(fp_val)) { // Denormalized number
ret_v = 0; ret_v = 0;
ret = *(reinterpret_cast<uint8_t *>(&ret_v));
ret = *(ge::PtrToPtr<uint8_t, uint8_t>(&ret_v));
return ret; return ret;
} }


@@ -207,14 +208,14 @@ static int8_t Fp16ToInt8(const uint16_t &fp_val) {
} }
} }
} }
if (overflow_flag) {
if (static_cast<bool>(overflow_flag)) {
ret_v = kInt8Max + s_ret; ret_v = kInt8Max + s_ret;
} else { } else {
// Generate final result // Generate final result
ret_v = GetUint8ValByMan(s_ret, long_int_m, shift_out); ret_v = GetUint8ValByMan(s_ret, long_int_m, shift_out);
} }


ret = *(reinterpret_cast<uint8_t *>(&ret_v));
ret = *(ge::PtrToPtr<uint8_t, uint8_t>(&ret_v));
return ret; return ret;
} }


@@ -283,8 +284,8 @@ static uint16_t GetUint16ValByMan(uint16_t s_ret, const uint64_t &long_int_m, co
if (need_round && m_ret < kInt16Max) { if (need_round && m_ret < kInt16Max) {
m_ret++; m_ret++;
} }
if (s_ret) {
m_ret = (~m_ret) + 1;
if (static_cast<bool>(s_ret)) {
m_ret = (~m_ret) + 1U;
} }
if (m_ret == 0) { if (m_ret == 0) {
s_ret = 0; s_ret = 0;
@@ -307,7 +308,7 @@ static int16_t Fp16ToInt16(const uint16_t &fp_val) {


if (FP16_IS_DENORM(fp_val)) { // Denormalized number if (FP16_IS_DENORM(fp_val)) { // Denormalized number
ret_v = 0; ret_v = 0;
ret = *(reinterpret_cast<uint8_t *>(&ret_v));
ret = *(ge::PtrToPtr<uint16_t, uint8_t>(&ret_v));
return ret; return ret;
} }


@@ -336,13 +337,13 @@ static int16_t Fp16ToInt16(const uint16_t &fp_val) {
} }
} }
} }
if (overflow_flag) {
if (static_cast<bool>(overflow_flag)) {
ret_v = kInt16Max + s_ret; ret_v = kInt16Max + s_ret;
} else { } else {
// Generate final result // Generate final result
ret_v = GetUint16ValByMan(s_ret, long_int_m, shift_out); ret_v = GetUint16ValByMan(s_ret, long_int_m, shift_out);
} }
ret = *(reinterpret_cast<int16_t *>(&ret_v));
ret = *(ge::PtrToPtr<uint16_t, uint16_t>(&ret_v));
return ret; return ret;
} }


@@ -433,7 +434,7 @@ static int32_t Fp16ToInt32(const uint16_t &fp_val) {
ret_v = (s_ret << kBitShift31) | (m_ret); ret_v = (s_ret << kBitShift31) | (m_ret);
} }


return *(reinterpret_cast<int32_t *>(&ret_v));
return *(ge::PtrToPtr<uint32_t, uint32_t>(&ret_v));
} }


/// @ingroup fp16_t math conversion static method /// @ingroup fp16_t math conversion static method
@@ -498,8 +499,8 @@ static uint16_t Fp16AddCalVal(uint16_t s_ret, int16_t e_ret, uint16_t m_ret, uin
} }


bool b_last_bit = ((m_ret & 1) > 0); bool b_last_bit = ((m_ret & 1) > 0);
bool b_trunc_high = 0;
bool b_trunc_left = 0;
bool b_trunc_high = false;
bool b_trunc_left = false;
b_trunc_high = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32SignMask) > 0); b_trunc_high = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32SignMask) > 0);
b_trunc_left = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32AbsMax) > 0); b_trunc_left = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32AbsMax) > 0);
m_ret = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_ret, shift_out); m_ret = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_ret, shift_out);
@@ -561,7 +562,7 @@ static uint16_t Fp16Add(uint16_t v_1, uint16_t v_2) {
int16_t e_ret = std::max(e_a, e_b); int16_t e_ret = std::max(e_a, e_b);
int16_t e_tmp = std::abs(e_a - e_b); int16_t e_tmp = std::abs(e_a - e_b);
if (e_a > e_b) { if (e_a > e_b) {
m_trunc = (m_b << (kBitShift32 - static_cast<uint16_t>(e_tmp)));
m_trunc = (m_b << (static_cast<uint16_t>(kBitShift32) - static_cast<uint16_t>(e_tmp)));
m_b = RightShift(m_b, e_tmp); m_b = RightShift(m_b, e_tmp);
} else if (e_a < e_b) { } else if (e_a < e_b) {
m_trunc = (m_a << (kBitShift32 - static_cast<uint16_t>(e_tmp))); m_trunc = (m_a << (kBitShift32 - static_cast<uint16_t>(e_tmp)));
@@ -602,7 +603,7 @@ static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) {
m_a = m_a_tmp; m_a = m_a_tmp;
m_b = m_b_tmp; m_b = m_b_tmp;


e_ret = e_a + e_b - kFp16ExpBias - kDim10;
e_ret = ((e_a + e_b) - kFp16ExpBias) - kDim10;
mul_m = m_a * m_b; mul_m = m_a * m_b;
s_ret = s_a ^ s_b; s_ret = s_a ^ s_b;


@@ -621,8 +622,8 @@ static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) {
e_ret = e_ret + 1; e_ret = e_ret + 1;
} }
bool b_last_bit = ((mul_m & 1) > 0); bool b_last_bit = ((mul_m & 1) > 0);
bool b_trunc_high = 0;
bool b_trunc_left = 0;
bool b_trunc_high = false;
bool b_trunc_left = false;
b_trunc_high = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32SignMask) > 0); b_trunc_high = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32SignMask) > 0);
b_trunc_left = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32AbsMax) > 0); b_trunc_left = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32AbsMax) > 0);
mul_m = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, mul_m); mul_m = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, mul_m);
@@ -675,14 +676,14 @@ static uint16_t Fp16Div(uint16_t v_1, uint16_t v_2) {
uint64_t m_tmp; uint64_t m_tmp;
if (e_a > e_b) { if (e_a > e_b) {
m_tmp = m_a; m_tmp = m_a;
uint16_t tmp = e_a - e_b;
uint16_t tmp = static_cast<uint16_t>(e_a - e_b);
for (int i = 0; i < tmp; i++) { for (int i = 0; i < tmp; i++) {
m_tmp = m_tmp << 1; m_tmp = m_tmp << 1;
} }
m_a = m_tmp; m_a = m_tmp;
} else if (e_a < e_b) { } else if (e_a < e_b) {
m_tmp = m_b; m_tmp = m_b;
uint16_t tmp = e_b - e_a;
uint16_t tmp = static_cast<uint16_t>(e_b - e_a);
for (int i = 0; i < tmp; i++) { for (int i = 0; i < tmp; i++) {
m_tmp = m_tmp << 1; m_tmp = m_tmp << 1;
} }
@@ -853,7 +854,7 @@ fp16_t &fp16_t::operator=(const float &f_val) {
uint16_t s_ret, m_ret; uint16_t s_ret, m_ret;
int16_t e_ret; int16_t e_ret;
uint32_t e_f, m_f; uint32_t e_f, m_f;
const uint32_t ui32_v = *(reinterpret_cast<const uint32_t *>(&f_val)); // 1:8:23bit sign:exp:man
const uint32_t ui32_v = *(ge::PtrToPtr<const float, const uint32_t>(&f_val)); // 1:8:23bit sign:exp:man
uint32_t m_len_delta; uint32_t m_len_delta;


s_ret = static_cast<uint16_t>((ui32_v & kFp32SignMask) >> kFp32SignIndex); // 4Byte->2Byte s_ret = static_cast<uint16_t>((ui32_v & kFp32SignMask) >> kFp32SignIndex); // 4Byte->2Byte
@@ -891,7 +892,7 @@ fp16_t &fp16_t::operator=(const float &f_val) {
if (need_round) { if (need_round) {
m_ret++; m_ret++;
} }
if (m_ret & kFp16ManHideBit) {
if (static_cast<bool>(m_ret & kFp16ManHideBit)) {
e_ret++; e_ret++;
} }
} }
@@ -910,14 +911,14 @@ fp16_t &fp16_t::operator=(const int8_t &i_val) {
if (m_ret == 0) { if (m_ret == 0) {
e_ret = 0; e_ret = 0;
} else { } else {
if (s_ret) { // negative number(<0)
if (static_cast<bool>(s_ret)) { // negative number(<0)
m_ret = static_cast<uint16_t>(std::abs(i_val)); // complement m_ret = static_cast<uint16_t>(std::abs(i_val)); // complement
} }


e_ret = kFp16ManLen; e_ret = kFp16ManLen;
while ((m_ret & kFp16ManHideBit) == 0) { while ((m_ret & kFp16ManHideBit) == 0) {
m_ret = m_ret << 1; m_ret = m_ret << 1;
e_ret = e_ret - 1;
e_ret = e_ret - 1U;
} }
e_ret = e_ret + kFp16ExpBias; e_ret = e_ret + kFp16ExpBias;
} }
@@ -931,11 +932,11 @@ fp16_t &fp16_t::operator=(const uint8_t &ui_val) {
s_ret = 0; s_ret = 0;
e_ret = 0; e_ret = 0;
m_ret = ui_val; m_ret = ui_val;
if (m_ret) {
if (static_cast<bool>(m_ret)) {
e_ret = kFp16ManLen; e_ret = kFp16ManLen;
while ((m_ret & kFp16ManHideBit) == 0) { while ((m_ret & kFp16ManHideBit) == 0) {
m_ret = m_ret << 1; m_ret = m_ret << 1;
e_ret = e_ret - 1;
e_ret = e_ret - 1U;
} }
e_ret = e_ret + kFp16ExpBias; e_ret = e_ret + kFp16ExpBias;
} }
@@ -949,11 +950,11 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u
uint16_t m_min = kFp16ManHideBit; uint16_t m_min = kFp16ManHideBit;
uint16_t m_max = m_min << 1; uint16_t m_max = m_min << 1;
uint16_t len = static_cast<uint16_t>(GetManBitLength(m_tmp)); uint16_t len = static_cast<uint16_t>(GetManBitLength(m_tmp));
if (m_tmp) {
if (static_cast<bool>(m_tmp)) {
int16_t e_ret; int16_t e_ret;
if (len > kDim11) { if (len > kDim11) {
e_ret = kFp16ExpBias + kFp16ManLen; e_ret = kFp16ExpBias + kFp16ManLen;
uint16_t e_tmp = len - kDim11;
uint16_t e_tmp = len - static_cast<uint16_t>(kDim11);
uint32_t trunc_mask = 1; uint32_t trunc_mask = 1;
for (int i = 1; i < e_tmp; i++) { for (int i = 1; i < e_tmp; i++) {
trunc_mask = (trunc_mask << 1) + 1; trunc_mask = (trunc_mask << 1) + 1;
@@ -964,8 +965,8 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u
e_ret = e_ret + 1; e_ret = e_ret + 1;
} }
bool b_last_bit = ((m_tmp & 1) > 0); bool b_last_bit = ((m_tmp & 1) > 0);
bool b_trunc_high = 0;
bool b_trunc_left = 0;
bool b_trunc_high = false;
bool b_trunc_left = false;
if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc
b_trunc_high = ((m_trunc & kFp32SignMask) > 0); b_trunc_high = ((m_trunc & kFp32SignMask) > 0);
b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); b_trunc_left = ((m_trunc & kFp32AbsMax) > 0);
@@ -976,7 +977,7 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u
e_ret = e_ret + 1; e_ret = e_ret + 1;
} }
} else { } else {
e_ret = kFp16ExpBias;
e_ret = static_cast<int16_t>(kFp16ExpBias);
m_tmp = m_tmp << (kManBitLength - len); m_tmp = m_tmp << (kManBitLength - len);
e_ret = e_ret + (len - 1); e_ret = e_ret + (len - 1);
} }
@@ -989,11 +990,11 @@ fp16_t &fp16_t::operator=(const int16_t &i_val) {
if (i_val == 0) { if (i_val == 0) {
val = 0; val = 0;
} else { } else {
uint16_t ui_val = *(reinterpret_cast<const uint16_t *>(&i_val));
uint16_t ui_val = *(ge::PtrToPtr<const int16_t, const int16_t>(&i_val));
auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift15); auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift15);
if (s_ret) {
if (static_cast<bool>(s_ret)) {
int16_t iValM = -i_val; int16_t iValM = -i_val;
ui_val = *(reinterpret_cast<uint16_t *>(&iValM));
ui_val = *(ge::PtrToPtr<int16_t, uint16_t>(&iValM));
} }
SetValByUint16Val(ui_val, s_ret, val); SetValByUint16Val(ui_val, s_ret, val);
} }
@@ -1023,8 +1024,8 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) {
e_ret = e_ret + 1; e_ret = e_ret + 1;
} }
bool b_last_bit = ((m_ret & 1) > 0); bool b_last_bit = ((m_ret & 1) > 0);
bool b_trunc_high = 0;
bool b_trunc_left = 0;
bool b_trunc_high = false;
bool b_trunc_left = false;
if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc
b_trunc_high = ((m_trunc & kFp32SignMask) > 0); b_trunc_high = ((m_trunc & kFp32SignMask) > 0);
b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); b_trunc_left = ((m_trunc & kFp32AbsMax) > 0);
@@ -1038,7 +1039,7 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) {
val = kFp16Max; val = kFp16Max;
} }
} else { } else {
e_ret = kFp16ExpBias;
e_ret = static_cast<int16_t>(kFp16ExpBias);
m_ret = m_ret << (kDim11 - len); m_ret = m_ret << (kDim11 - len);
e_ret = e_ret + (len - 1); e_ret = e_ret + (len - 1);
} }
@@ -1057,7 +1058,7 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u
e_ret = kFp16ExpBias + kFp16ManLen; e_ret = kFp16ExpBias + kFp16ManLen;
uint32_t m_trunc = 0; uint32_t m_trunc = 0;
uint32_t trunc_mask = 1; uint32_t trunc_mask = 1;
uint16_t e_tmp = len - kDim11;
uint16_t e_tmp = len - static_cast<uint16_t>(kDim11);
for (int i = 1; i < e_tmp; i++) { for (int i = 1; i < e_tmp; i++) {
trunc_mask = (trunc_mask << 1) + 1; trunc_mask = (trunc_mask << 1) + 1;
} }
@@ -1067,8 +1068,8 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u
e_ret = e_ret + 1; e_ret = e_ret + 1;
} }
bool b_last_bit = ((m_tmp & 1) > 0); bool b_last_bit = ((m_tmp & 1) > 0);
bool b_trunc_high = 0;
bool b_trunc_left = 0;
bool b_trunc_high = false;
bool b_trunc_left = false;
if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc
b_trunc_high = ((m_trunc & kFp32SignMask) > 0); b_trunc_high = ((m_trunc & kFp32SignMask) > 0);
b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); b_trunc_left = ((m_trunc & kFp32AbsMax) > 0);
@@ -1083,7 +1084,7 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u
m_tmp = kFp16MaxMan; m_tmp = kFp16MaxMan;
} }
} else { } else {
e_ret = kFp16ExpBias;
e_ret = static_cast<int16_t>(kFp16ExpBias);
m_tmp = m_tmp << (kDim11 - len); m_tmp = m_tmp << (kDim11 - len);
e_ret = e_ret + (len - 1); e_ret = e_ret + (len - 1);
} }
@@ -1095,11 +1096,11 @@ fp16_t &fp16_t::operator=(const int32_t &i_val) {
if (i_val == 0) { if (i_val == 0) {
val = 0; val = 0;
} else { } else {
uint32_t ui_val = *(reinterpret_cast<const uint32_t *>(&i_val));
uint32_t ui_val = *(ge::PtrToPtr<const int32_t, const uint32_t>(&i_val));
auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift31); auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift31);
if (s_ret) {
if (static_cast<bool>(s_ret)) {
int32_t iValM = -i_val; int32_t iValM = -i_val;
ui_val = *(reinterpret_cast<uint32_t *>(&iValM));
ui_val = *(ge::PtrToPtr<int32_t, uint32_t>(&iValM));
} }
SetValByUint32Val(ui_val, s_ret, val); SetValByUint32Val(ui_val, s_ret, val);
} }
@@ -1119,7 +1120,7 @@ fp16_t &fp16_t::operator=(const uint32_t &ui_val) {
e_ret = kFp16ExpBias + kFp16ManLen; e_ret = kFp16ExpBias + kFp16ManLen;
uint32_t m_trunc = 0; uint32_t m_trunc = 0;
uint32_t trunc_mask = 1; uint32_t trunc_mask = 1;
uint16_t e_tmp = len - kDim11;
uint16_t e_tmp = len - static_cast<uint16_t>(kDim11);
for (int i = 1; i < e_tmp; i++) { for (int i = 1; i < e_tmp; i++) {
trunc_mask = (trunc_mask << 1) + 1; trunc_mask = (trunc_mask << 1) + 1;
} }
@@ -1145,7 +1146,7 @@ fp16_t &fp16_t::operator=(const uint32_t &ui_val) {
m_tmp = kFp16MaxMan; m_tmp = kFp16MaxMan;
} }
} else { } else {
e_ret = kFp16ExpBias;
e_ret = static_cast<int16_t>(kFp16ExpBias);
m_tmp = m_tmp << (kDim11 - len); m_tmp = m_tmp << (kDim11 - len);
e_ret = e_ret + (len - 1); e_ret = e_ret + (len - 1);
} }
@@ -1161,7 +1162,7 @@ fp16_t &fp16_t::operator=(const double &d_val) {
int16_t e_ret; int16_t e_ret;
uint64_t e_d; uint64_t e_d;
uint64_t m_d; uint64_t m_d;
uint64_t ui64_v = *(reinterpret_cast<const uint64_t *>(&d_val)); // 1:11:52bit sign:exp:man
uint64_t ui64_v = *(ge::PtrToPtr<const double, const uint64_t>(&d_val)); // 1:11:52bit sign:exp:man
uint32_t m_len_delta; uint32_t m_len_delta;


s_ret = static_cast<uint16_t>((ui64_v & kFp64SignMask) >> kFp64SignIndex); // 4Byte s_ret = static_cast<uint16_t>((ui64_v & kFp64SignMask) >> kFp64SignIndex); // 4Byte
@@ -1204,7 +1205,7 @@ fp16_t &fp16_t::operator=(const double &d_val) {
if (need_round) { if (need_round) {
m_ret++; m_ret++;
} }
if (m_ret & kFp16ManHideBit) {
if (static_cast<bool>(m_ret & kFp16ManHideBit)) {
e_ret++; e_ret++;
} }
} }
@@ -1239,7 +1240,7 @@ fp16_t::operator uint64_t() const { return 0; }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() const { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() const {
if ((val & kFp16AbsMax) == kFp16ExpMask) { if ((val & kFp16AbsMax) == kFp16ExpMask) {
if (val & kFp16SignMask) {
if (static_cast<bool>(val & kFp16SignMask)) {
return -1; return -1;
} else { } else {
return 1; return 1;


+ 16
- 16
parser/common/parser_fp16_t.h View File

@@ -91,16 +91,16 @@ using BitShift = enum {
}; };
/// @ingroup fp16 basic parameter /// @ingroup fp16 basic parameter
/// @brief fp16 exponent bias /// @brief fp16 exponent bias
constexpr uint16_t kFp16ExpBias = 15;
constexpr uint16_t kFp16ExpBias = 15U;
/// @ingroup fp16 basic parameter /// @ingroup fp16 basic parameter
/// @brief the exponent bit length of fp16 is 5 /// @brief the exponent bit length of fp16 is 5
constexpr uint16_t kFp16ExpLen = 5;
constexpr uint16_t kFp16ExpLen = 5U;
/// @ingroup fp16 basic parameter /// @ingroup fp16 basic parameter
/// @brief the mantissa bit length of fp16 is 10 /// @brief the mantissa bit length of fp16 is 10
constexpr uint16_t kFp16ManLen = 10;
constexpr uint16_t kFp16ManLen = 10U;
/// @ingroup fp16 basic parameter /// @ingroup fp16 basic parameter
/// @brief bit index of sign in fp16 /// @brief bit index of sign in fp16
constexpr uint16_t kFp16SignIndex = 15;
constexpr uint16_t kFp16SignIndex = 15U;
/// @ingroup fp16 basic parameter /// @ingroup fp16 basic parameter
/// @brief sign mask of fp16 (1 00000 00000 00000) /// @brief sign mask of fp16 (1 00000 00000 00000)
constexpr uint16_t kFp16SignMask = 0x8000; constexpr uint16_t kFp16SignMask = 0x8000;
@@ -164,16 +164,16 @@ constexpr uint16_t kFp16MinNormal = 1.0f / (2 << 14);
#define FP16_IS_INVALID(x) (((x) & kFp16ExpMask) == kFp16ExpMask) #define FP16_IS_INVALID(x) (((x) & kFp16ExpMask) == kFp16ExpMask)
/// @ingroup fp32 basic parameter /// @ingroup fp32 basic parameter
/// @brief fp32 exponent bias /// @brief fp32 exponent bias
constexpr uint16_t kFp32ExpBias = 127;
constexpr uint16_t kFp32ExpBias = 127U;
/// @ingroup fp32 basic parameter /// @ingroup fp32 basic parameter
/// @brief the exponent bit length of float/fp32 is 8 /// @brief the exponent bit length of float/fp32 is 8
constexpr uint16_t kFp32ExpLen = 8;
constexpr uint16_t kFp32ExpLen = 8U;
/// @ingroup fp32 basic parameter /// @ingroup fp32 basic parameter
/// @brief the mantissa bit length of float/fp32 is 23 /// @brief the mantissa bit length of float/fp32 is 23
constexpr uint16_t kFp32ManLen = 23;
constexpr uint16_t kFp32ManLen = 23U;
/// @ingroup fp32 basic parameter /// @ingroup fp32 basic parameter
/// @brief bit index of sign in float/fp32 /// @brief bit index of sign in float/fp32
constexpr uint16_t kFp32SignIndex = 31;
constexpr uint16_t kFp32SignIndex = 31U;
/// @ingroup fp32 basic parameter /// @ingroup fp32 basic parameter
/// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000) /// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000)
constexpr uint32_t kFp32SignMask = 0x80000000u; constexpr uint32_t kFp32SignMask = 0x80000000u;
@@ -191,10 +191,10 @@ constexpr uint32_t kFp32ManHideBit = 0x00800000u;
constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFu; constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFu;
/// @ingroup fp32 basic parameter /// @ingroup fp32 basic parameter
/// @brief maximum exponent value of fp32 is 255(1111 1111) /// @brief maximum exponent value of fp32 is 255(1111 1111)
constexpr uint32_t kFp32MaxExp = 0xFF;
constexpr uint32_t kFp32MaxExp = 0xFFU;
/// @ingroup fp32 basic parameter /// @ingroup fp32 basic parameter
/// @brief maximum mantissa value of fp32 (1111 1111 1111 1111 1111 111) /// @brief maximum mantissa value of fp32 (1111 1111 1111 1111 1111 111)
constexpr uint32_t kFp32MaxMan = 0x7FFFFF;
constexpr uint32_t kFp32MaxMan = 0x7FFFFFU;
/// @ingroup fp32 special value judgment /// @ingroup fp32 special value judgment
/// @brief whether a fp32 is NaN /// @brief whether a fp32 is NaN
#define FP32_IS_NAN(x) ((((x) & kFp32ExpMask) == kFp32ExpMask) && ((x) & kFp32ManMask)) #define FP32_IS_NAN(x) ((((x) & kFp32ExpMask) == kFp32ExpMask) && ((x) & kFp32ManMask))
@@ -218,16 +218,16 @@ constexpr uint32_t kFp32MaxMan = 0x7FFFFF;
#define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m) & kFp32MaxMan)) #define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m) & kFp32MaxMan))
/// @ingroup fp64 basic parameter /// @ingroup fp64 basic parameter
/// @brief fp64 exponent bias /// @brief fp64 exponent bias
constexpr uint16_t kFp64ExpBias = 1023;
constexpr uint16_t kFp64ExpBias = 1023U;
/// @ingroup fp64 basic parameter /// @ingroup fp64 basic parameter
/// @brief the exponent bit length of double/fp64 is 11 /// @brief the exponent bit length of double/fp64 is 11
constexpr uint16_t kFp64ExpLen = 11;
constexpr uint16_t kFp64ExpLen = 11U;
/// @ingroup fp64 basic parameter /// @ingroup fp64 basic parameter
/// @brief the mantissa bit length of double/fp64 is 52 /// @brief the mantissa bit length of double/fp64 is 52
constexpr uint16_t kFp64ManLen = 52;
constexpr uint16_t kFp64ManLen = 52U;
/// @ingroup fp64 basic parameter /// @ingroup fp64 basic parameter
/// @brief bit index of sign in double/fp64 is 63 /// @brief bit index of sign in double/fp64 is 63
constexpr uint16_t kFp64SignIndex = 63;
constexpr uint16_t kFp64SignIndex = 63U;
/// @ingroup fp64 basic parameter /// @ingroup fp64 basic parameter
/// @brief sign mask of fp64 (1 000 (total 63bits 0)) /// @brief sign mask of fp64 (1 000 (total 63bits 0))
constexpr uint64_t kFp64SignMask = 0x8000000000000000LLu; constexpr uint64_t kFp64SignMask = 0x8000000000000000LLu;
@@ -269,14 +269,14 @@ constexpr int16_t kInt16Max = 0x7FFF;
constexpr uint16_t kBitLen16Max = 0xFFFF; constexpr uint16_t kBitLen16Max = 0xFFFF;
/// @ingroup integer special value judgment /// @ingroup integer special value judgment
/// @brief maximum positive value of int32_t (0111 1111 1111 1111 1111 1111 1111 1111) /// @brief maximum positive value of int32_t (0111 1111 1111 1111 1111 1111 1111 1111)
constexpr int32_t kInt32Max = 0x7FFFFFFFu;
constexpr int32_t kInt32Max = 0x7FFFFFFF;
/// @ingroup integer special value judgment /// @ingroup integer special value judgment
/// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111) /// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111)
constexpr uint32_t kBitLen32Max = 0xFFFFFFFFu; constexpr uint32_t kBitLen32Max = 0xFFFFFFFFu;
/// @ingroup integer special value judgment /// @ingroup integer special value judgment
/// @brief maximum positive value of int64_t /// @brief maximum positive value of int64_t
/// (0111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) /// (0111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111)
constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFFu;
constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFF;
/// @ingroup integer special value judgment /// @ingroup integer special value judgment
/// @brief maximum value of a data with 64 bits length /// @brief maximum value of a data with 64 bits length
/// (1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) /// (1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111)


+ 1
- 1
parser/common/pass_manager.h View File

@@ -62,7 +62,7 @@ public:
/// @return others optimized failed /// @return others optimized failed
/// @author /// @author
/// ///
static Status Run(const ge::ComputeGraphPtr &graph, std::vector<std::pair<std::string, GraphPass *>> &passes);
static Status Run(const ge::ComputeGraphPtr &graph, std::vector<std::pair<std::string, GraphPass *>> &names_to_passes);


~PassManager(); ~PassManager();




+ 7
- 5
parser/common/pre_checker.cc View File

@@ -99,9 +99,9 @@ Status PreChecker::CheckName(OpId id) {
if (id != v.first && info.name == v.second.name) { if (id != v.first && info.name == v.second.name) {
Cause cause; Cause cause;
cause.code = ErrorCode::NAME_REPEATED; cause.code = ErrorCode::NAME_REPEATED;
cause.message = "The name is repeated.";
cause.message = "The name is repeated in the graph.";


GELOGI("Name %s repeated.", info.name.c_str());
GELOGE(FAILED, "opname %s repeated, same name op in the graph", info.name.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E19009", {"opname"}, {info.name}); ErrorManager::GetInstance().ATCReportErrMessage("E19009", {"opname"}, {info.name});
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed."); GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed.");
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(v.first, cause), "[Add][Cause] failed."); GE_RETURN_WITH_LOG_IF_ERROR(AddCause(v.first, cause), "[Add][Cause] failed.");
@@ -200,7 +200,7 @@ FMK_FUNC_HOST_VISIBILITY bool PreChecker::HasError() {
return false; return false;
} }


Status PreChecker::Save(string file) {
Status PreChecker::Save(const string &file) {
uint32_t fail_num = 0; uint32_t fail_num = 0;
for (auto id : ops_) { for (auto id : ops_) {
if (HasError(id)) { if (HasError(id)) {
@@ -250,7 +250,7 @@ Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string
Cause cause; Cause cause;
cause.code = ErrorCode::TYPE_UNSUPPORTED; cause.code = ErrorCode::TYPE_UNSUPPORTED;
cause.message = "The type is not supported."; cause.message = "The type is not supported.";
GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str());
GELOGI("Check op[%s]'s type[%s], The type is not supported.", name.c_str(), type.c_str());
if (!is_tensorflow) { if (!is_tensorflow) {
ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type});
} }
@@ -265,9 +265,11 @@ Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string
cause.code = ErrorCode::TYPE_UNSUPPORTED; cause.code = ErrorCode::TYPE_UNSUPPORTED;
cause.message = "The type is not supported."; cause.message = "The type is not supported.";


GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str());
if (!is_tensorflow) { if (!is_tensorflow) {
ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type});
GELOGE(FAILED, "Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str());
} else {
GELOGI("Check op[%s]'s type[%s] is not supported.", name.c_str(), type.c_str());
} }
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed."); GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed.");
} }


+ 1
- 1
parser/common/pre_checker.h View File

@@ -142,7 +142,7 @@ class PreChecker {
* @ingroup domi_omg * @ingroup domi_omg
* @brief Save inspection results(JSON) * @brief Save inspection results(JSON)
*/ */
Status Save(string file);
Status Save(const string &file);


private: private:
/** /**


+ 4
- 2
parser/common/proto_file_parser.cc View File

@@ -425,7 +425,8 @@ Status ProtoFileParser::FindConflictLine(const char *proto_file, int identifier,
void ProtoFileParser::CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file, void ProtoFileParser::CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file,
std::map<std::string, std::pair<int, string>> &caffe_op_identifier_map, std::map<std::string, std::pair<int, string>> &caffe_op_identifier_map,
std::map<std::string, std::pair<int, string>> &custom_op_identifier_map) { std::map<std::string, std::pair<int, string>> &custom_op_identifier_map) {
for (auto iter = custom_op_identifier_map.begin(); iter != custom_op_identifier_map.end(); ++iter) {
std::map<std::string, std::pair<int, string>>::const_iterator iter = custom_op_identifier_map.begin();
for (; iter != custom_op_identifier_map.end(); ++iter) {
if (caffe_op_identifier_map.count(iter->first) > 0) { if (caffe_op_identifier_map.count(iter->first) > 0) {
string message_name = iter->first; string message_name = iter->first;
auto caffe_pair = caffe_op_identifier_map[iter->first]; auto caffe_pair = caffe_op_identifier_map[iter->first];
@@ -452,7 +453,8 @@ void ProtoFileParser::CheckConflictOp(const char *caffe_proto_file, const char *
void ProtoFileParser::CheckConflictIdentifier(const char *caffe_proto_file, const char *custom_proto_file, void ProtoFileParser::CheckConflictIdentifier(const char *caffe_proto_file, const char *custom_proto_file,
std::map<int, std::pair<string, string>> caffe_identifier_op_map, std::map<int, std::pair<string, string>> caffe_identifier_op_map,
std::map<int, std::pair<string, string>> custom_identifier_op_map) { std::map<int, std::pair<string, string>> custom_identifier_op_map) {
for (auto iter = custom_identifier_op_map.begin(); iter != custom_identifier_op_map.end(); ++iter) {
std::map<int, std::pair<string, string>>::const_iterator iter = custom_identifier_op_map.begin();
for (; iter != custom_identifier_op_map.end(); ++iter) {
if (caffe_identifier_op_map.count(iter->first) > 0) { if (caffe_identifier_op_map.count(iter->first) > 0) {
int identifier = iter->first; int identifier = iter->first;
auto caffe_pair = caffe_identifier_op_map[iter->first]; auto caffe_pair = caffe_identifier_op_map[iter->first];


+ 3
- 2
parser/common/register_tbe.cc View File

@@ -97,7 +97,8 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData &reg_data) {
return false; return false;
} }
OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar(
domi::TENSORFLOW, GetOmOptype(reg_data), [=]() -> std::shared_ptr<OpParser> { return tf_parser_adapter; });
domi::TENSORFLOW, GetOmOptype(reg_data), [tf_parser_adapter]() -> std::shared_ptr<OpParser>
{ return tf_parser_adapter; });
} }
if (reg_data.GetFusionParseParamFn() != nullptr || reg_data.GetFusionParseParamByOpFn() != nullptr) { if (reg_data.GetFusionParseParamFn() != nullptr || reg_data.GetFusionParseParamByOpFn() != nullptr) {
bool is_registed = factory->OpParserIsRegistered(GetOmOptype(reg_data), true); bool is_registed = factory->OpParserIsRegistered(GetOmOptype(reg_data), true);
@@ -115,7 +116,7 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData &reg_data) {
} }
OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar(
domi::TENSORFLOW, GetOmOptype(reg_data), domi::TENSORFLOW, GetOmOptype(reg_data),
[=]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; }, true);
[tf_fusion_parser_adapter]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; }, true);
} }
} else { } else {
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(reg_data.GetFrameworkType()); std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(reg_data.GetFrameworkType());


+ 1
- 1
parser/common/tbe_plugin_loader.cc View File

@@ -105,7 +105,7 @@ void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) {
GELOGI("Enter get custom op path schedule"); GELOGI("Enter get custom op path schedule");
std::string fmk_type; std::string fmk_type;
domi::FrameworkType type = domi::TENSORFLOW; domi::FrameworkType type = domi::TENSORFLOW;
auto it = options_.find(FRAMEWORK_TYPE);
std::map<string, string>::const_iterator it = options_.find(FRAMEWORK_TYPE);
if (it != options_.end()) { if (it != options_.end()) {
type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10)); type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10));
} }


+ 13
- 13
parser/func_to_graph/func2graph.py View File

@@ -227,9 +227,9 @@ def convert_subgraphs(graph_def, filename):
print(graph_def_library.graph_def[i]) print(graph_def_library.graph_def[i])


# Write to prototxt # Write to prototxt
graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename)))
print("graph_def_file: ", graph_def_file)
try: try:
graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename)))
print("graph_def_file: ", graph_def_file)
with open(graph_def_file, "w") as f: with open(graph_def_file, "w") as f:
print(graph_def_library, file=f) print(graph_def_library, file=f)
except IOError: except IOError:
@@ -261,18 +261,18 @@ if __name__ == '__main__':
model = '' model = ''
try: try:
opts, args = getopt.getopt(sys.argv[1:], '-v-h-m:', ['version', 'help', 'model=']) opts, args = getopt.getopt(sys.argv[1:], '-v-h-m:', ['version', 'help', 'model='])
for opt_name, opt_value in opts:
if opt_name in ('-m', '--model'):
model = opt_value
print("INFO: Input model file is", model)
convert_graphs(model)
elif opt_name in ('-h', '--help'):
usage()
break
elif opt_name in ('-v', '--version'):
print("version 1.0.0")
break
except getopt.GetoptError: except getopt.GetoptError:
print("ERROR: Input parameters is invalid, use '--help' to view the help.") print("ERROR: Input parameters is invalid, use '--help' to view the help.")
for opt_name, opt_value in opts:
if opt_name in ('-m', '--model'):
model = opt_value
print("INFO: Input model file is", model)
convert_graphs(model)
elif opt_name in ('-h', '--help'):
usage()
break
elif opt_name in ('-v', '--version'):
print("version 1.0.0")
break
if len(sys.argv) == 1: if len(sys.argv) == 1:
print("INFO: Please specify the input parameters, and use '--help' to view the help.") print("INFO: Please specify the input parameters, and use '--help' to view the help.")

+ 2
- 2
parser/onnx/onnx_constant_parser.cc View File

@@ -71,7 +71,7 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_
}; };


int32_t datatype_val_size = 0; int32_t datatype_val_size = 0;
auto iter = datatype_val_size_map.find(data_type);
std::map<uint32_t, int32_t>::const_iterator iter = datatype_val_size_map.find(data_type);
if (iter != datatype_val_size_map.end()) { if (iter != datatype_val_size_map.end()) {
datatype_val_size = iter->second; datatype_val_size = iter->second;
} else { } else {
@@ -91,7 +91,7 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_
if (data_type == OnnxDataType::STRING) { if (data_type == OnnxDataType::STRING) {
tensor.SetData(tensor_proto.raw_data().c_str()); tensor.SetData(tensor_proto.raw_data().c_str());
} else { } else {
tensor.SetData(reinterpret_cast<const uint8_t *>(tensor_proto.raw_data().c_str()),
tensor.SetData(PtrToPtr<const char_t, const uint8_t>(tensor_proto.raw_data().c_str()),
tensor_proto.raw_data().size()); tensor_proto.raw_data().size());
} }
GELOGD("Raw data size is : %zu", tensor_proto.raw_data().size()); GELOGD("Raw data size is : %zu", tensor_proto.raw_data().size());


+ 1
- 1
parser/onnx/onnx_constant_parser.h View File

@@ -65,7 +65,7 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser {
*(addr_trans.get() + i) = static_cast<bool>( *(addr_trans.get() + i) = static_cast<bool>(
std::fabs(*((addr).get() + i)) > std::numeric_limits<T>::epsilon()); std::fabs(*((addr).get() + i)) > std::numeric_limits<T>::epsilon());
} }
(tensor).SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), (count) * sizeof(bool));
(tensor).SetData(PtrToPtr<bool, uint8_t>(addr_trans.get()), (count) * sizeof(bool));
break; break;
} }
#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ #define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \


+ 32
- 17
parser/onnx/onnx_parser.cc View File

@@ -32,7 +32,6 @@
#include "onnx_op_parser.h" #include "onnx_op_parser.h"
#include "onnx_util.h" #include "onnx_util.h"
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#include "parser/common/pre_checker.h"
#include "parser/common/acl_graph_parser_util.h" #include "parser/common/acl_graph_parser_util.h"
#include "parser/common/model_saver.h" #include "parser/common/model_saver.h"
#include "parser/common/parser_utils.h" #include "parser/common/parser_utils.h"
@@ -240,7 +239,7 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra
if (node->GetOpDesc() == nullptr) { if (node->GetOpDesc() == nullptr) {
continue; continue;
} }
node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName());
node->GetOpDesc()->SetName(OnnxUtil::GenUniqueNodeName(sub_graph->GetName(), node->GetName()));
} }


auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph); auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph);
@@ -384,7 +383,7 @@ Status OnnxModelParser::ConstructOriType(const ge::onnx::NodeProto *node_proto,
std::string domain = node_proto->domain(); std::string domain = node_proto->domain();
int64_t version = 0; int64_t version = 0;
if (!domain.empty()) { if (!domain.empty()) {
auto it = domain_verseion_.find(domain);
std::map<std::string, int64_t>::const_iterator it = domain_verseion_.find(domain);
if (it != domain_verseion_.end()) { if (it != domain_verseion_.end()) {
version = it->second; version = it->second;
} else { } else {
@@ -493,14 +492,14 @@ Status OnnxModelParser::SetOperatorInputs() {
std::vector<std::pair<std::string, int>> &output_node_indexs = out_iter->second; std::vector<std::pair<std::string, int>> &output_node_indexs = out_iter->second;
for (auto input_node_index : input_node_indexs) { for (auto input_node_index : input_node_indexs) {
for (auto out_node_index : output_node_indexs) { for (auto out_node_index : output_node_indexs) {
auto input_op_iter = name_operator_.find(input_node_index.first);
std::map<std::string, ge::Operator>::const_iterator input_op_iter = name_operator_.find(input_node_index.first);
if (input_op_iter == name_operator_.end()) { if (input_op_iter == name_operator_.end()) {
REPORT_INNER_ERROR("E19999", "Node: %s can not find in name_operator map.", input_node_index.first.c_str()); REPORT_INNER_ERROR("E19999", "Node: %s can not find in name_operator map.", input_node_index.first.c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] Node: %s can not find in name_operator map.", GELOGE(INTERNAL_ERROR, "[Check][Param] Node: %s can not find in name_operator map.",
input_node_index.first.c_str()); input_node_index.first.c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
auto output_op_iter = name_operator_.find(out_node_index.first);
std::map<std::string, ge::Operator>::const_iterator output_op_iter = name_operator_.find(out_node_index.first);
if (output_op_iter == name_operator_.end()) { if (output_op_iter == name_operator_.end()) {
REPORT_INNER_ERROR("E19999", "Node: %s can not find in name_operator map.", out_node_index.first.c_str()); REPORT_INNER_ERROR("E19999", "Node: %s can not find in name_operator map.", out_node_index.first.c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] Node: %s can not find in name_operator map.", GELOGE(INTERNAL_ERROR, "[Check][Param] Node: %s can not find in name_operator map.",
@@ -594,6 +593,7 @@ Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::
} }


Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) {
bool has_error = false;
for (int i = 0; i < onnx_graph.node_size(); i++) { for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i);
std::string node_name = node_proto->name(); std::string node_name = node_proto->name();
@@ -605,7 +605,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
if (status != SUCCESS) { if (status != SUCCESS) {
GELOGE(status, "[Adapt][OpType] Adapter op type for ori type %s failed.", ori_type.c_str()); GELOGE(status, "[Adapt][OpType] Adapter op type for ori type %s failed.", ori_type.c_str());
REPORT_CALL_ERROR("E19999", "Adapter op type for ori type %s failed.", ori_type.c_str()); REPORT_CALL_ERROR("E19999", "Adapter op type for ori type %s failed.", ori_type.c_str());
return status;
has_error = true;
continue;
} }
node_proto->set_op_type(ori_type); node_proto->set_op_type(ori_type);


@@ -616,7 +617,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
if (status != SUCCESS) { if (status != SUCCESS) {
GELOGE(status, "[Trans][Node] Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); GELOGE(status, "[Trans][Node] Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str());
REPORT_CALL_ERROR("E19999", "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); REPORT_CALL_ERROR("E19999", "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str());
return status;
has_error = true;
continue;
} }


// 7. op parser // 7. op parser
@@ -627,7 +629,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
status = ParseOpParam(node_proto, op, op_parser); status = ParseOpParam(node_proto, op, op_parser);
if (status != SUCCESS) { if (status != SUCCESS) {
GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status);
return status;
has_error = true;
continue;
} }


GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu",
@@ -638,7 +641,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
if (graph_status != ge::GRAPH_SUCCESS) { if (graph_status != ge::GRAPH_SUCCESS) {
GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str()); GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str());
REPORT_CALL_ERROR("E19999", "Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str()); REPORT_CALL_ERROR("E19999", "Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str());
return FAILED;
has_error = true;
continue;
} }
name_operator_[ParserUtils::GetOperatorName(op)] = op; name_operator_[ParserUtils::GetOperatorName(op)] = op;


@@ -647,11 +651,12 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
if (status != SUCCESS) { if (status != SUCCESS) {
REPORT_INNER_ERROR("E19999", "ConstructInputOutputContext failed."); REPORT_INNER_ERROR("E19999", "ConstructInputOutputContext failed.");
GELOGE(status, "[Construct][RelationMap] to input and output failed."); GELOGE(status, "[Construct][RelationMap] to input and output failed.");
return status;
has_error = true;
continue;
} }
} }
GELOGI("Parse all node proto success.");
return SUCCESS;
GELOGI("Parse all node proto end.");
return has_error ? FAILED : SUCCESS;
} }


Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) {
@@ -665,7 +670,7 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve
} }
} }
for (auto in_name : input_node_names_) { for (auto in_name : input_node_names_) {
auto in_op = name_operator_.find(in_name);
std::map<std::string, ge::Operator>::const_iterator in_op = name_operator_.find(in_name);
if (in_op == name_operator_.end()) { if (in_op == name_operator_.end()) {
GELOGE(PARAM_INVALID, "[Get][Inputs] Model assigned input node name: %s can not find in graph.", GELOGE(PARAM_INVALID, "[Get][Inputs] Model assigned input node name: %s can not find in graph.",
in_name.c_str()); in_name.c_str());
@@ -682,7 +687,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve
Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops, Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops,
ParserUtils::OutputMapping &out_tensor_to_nodes) { ParserUtils::OutputMapping &out_tensor_to_nodes) {
for (auto output_name : output_node_names_) { for (auto output_name : output_node_names_) {
auto itr = outputs_map_.find(output_name);
std::map<std::string, std::vector<std::pair<std::string, int>>>::const_iterator itr =
outputs_map_.find(output_name);
if (itr == outputs_map_.end()) { if (itr == outputs_map_.end()) {
GELOGE(PARAM_INVALID, "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); GELOGE(PARAM_INVALID, "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str());
REPORT_INNER_ERROR("E19999", "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); REPORT_INNER_ERROR("E19999", "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str());
@@ -692,7 +698,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vec
std::vector<std::pair<std::string, int>> node_names_indexes = itr->second; std::vector<std::pair<std::string, int>> node_names_indexes = itr->second;
for (const auto &node_name_index : node_names_indexes) { for (const auto &node_name_index : node_names_indexes) {
auto node_name = node_name_index.first; auto node_name = node_name_index.first;
auto out_op_itr = name_operator_.find(node_name);
std::map<std::string, ge::Operator>::const_iterator out_op_itr = name_operator_.find(node_name);
if (out_op_itr == name_operator_.end()) { if (out_op_itr == name_operator_.end()) {
GELOGE(PARAM_INVALID, "[Get][Operator] Can not find operator: %s in graph.", node_name.c_str()); GELOGE(PARAM_INVALID, "[Get][Operator] Can not find operator: %s in graph.", node_name.c_str());
REPORT_INNER_ERROR("E19999", "Can not find operator: %s in graph.", node_name.c_str()); REPORT_INNER_ERROR("E19999", "Can not find operator: %s in graph.", node_name.c_str());
@@ -749,6 +755,14 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph(
while (!onnx_graph_tasks.empty()) { while (!onnx_graph_tasks.empty()) {
ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front(); ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front();
onnx_graph_tasks.pop(); onnx_graph_tasks.pop();
std::string graph_name;
for (const auto &graph_iter : name_to_onnx_graph) {
if (graph_iter.second == onnx_graph) {
graph_name = graph_iter.first;
break;
}
}

for (int i = 0; i < onnx_graph->node_size(); i++) { for (int i = 0; i < onnx_graph->node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i); ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i);
if (node_proto->name().empty()) { if (node_proto->name().empty()) {
@@ -766,7 +780,8 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph(
} }
std::vector<ge::onnx::GraphProto *> onnx_graphs; std::vector<ge::onnx::GraphProto *> onnx_graphs;
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_subgraph; std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_subgraph;
if (subgraph_adapter->AdaptAndFindAllSubgraphs(node_proto, onnx_graphs, name_to_onnx_subgraph) != SUCCESS) {
if (subgraph_adapter->AdaptAndFindAllSubgraphs(
node_proto, onnx_graphs, name_to_onnx_subgraph, graph_name) != SUCCESS) {
GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str()); GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str());
REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str()); REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str());
return FAILED; return FAILED;
@@ -815,7 +830,7 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
bool is_subgraph = (arg.parent_node != nullptr) ? true : false; bool is_subgraph = (arg.parent_node != nullptr) ? true : false;


if (arg.onnx_graph == nullptr) { if (arg.onnx_graph == nullptr) {
auto itr = name_to_onnx_graph.find(arg.graph_name);
std::map<std::string, ge::onnx::GraphProto *>::const_iterator itr = name_to_onnx_graph.find(arg.graph_name);
if (itr == name_to_onnx_graph.end()) { if (itr == name_to_onnx_graph.end()) {
GELOGE(FAILED, "[Find][OnnxGraph] Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); GELOGE(FAILED, "[Find][OnnxGraph] Can not find onnx graph, graph:%s.", arg.graph_name.c_str());
REPORT_INNER_ERROR("E19999", "Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); REPORT_INNER_ERROR("E19999", "Can not find onnx graph, graph:%s.", arg.graph_name.c_str());


+ 28
- 3
parser/onnx/onnx_parser.h View File

@@ -38,6 +38,7 @@
#include "omg/parser/op_parser.h" #include "omg/parser/op_parser.h"
#include "omg/parser/weights_parser.h" #include "omg/parser/weights_parser.h"
#include "common/parser_utils.h" #include "common/parser_utils.h"
#include "common/pre_checker.h"
#include "proto/onnx/ge_onnx.pb.h" #include "proto/onnx/ge_onnx.pb.h"


namespace ge { namespace ge {
@@ -81,6 +82,18 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {
return domi::SUCCESS; return domi::SUCCESS;
} }


bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}

private: private:
Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph);


@@ -96,7 +109,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {


Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type);


Status AdapterOpType(const ge::onnx::NodeProto *node_proto, std::string &ori_type, std::string &om_type);
Status AdapterOpType(const ge::onnx::NodeProto *node_proto, std::string &ori_type, std::string &op_type);


Status TransNodeToOperator(const ge::onnx::NodeProto *node_proto, ge::Operator &op, const string &op_type) const; Status TransNodeToOperator(const ge::onnx::NodeProto *node_proto, ge::Operator &op, const string &op_type) const;


@@ -106,7 +119,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {


Status GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops); Status GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops);


Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &outputs,
Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops,
ParserUtils::OutputMapping &out_tensor_to_nodes); ParserUtils::OutputMapping &out_tensor_to_nodes);


Status Prechecker(ge::onnx::GraphProto &onnx_graph); Status Prechecker(ge::onnx::GraphProto &onnx_graph);
@@ -115,7 +128,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {


Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const; Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const;


Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph);
Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph);


Status ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); Status ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph);


@@ -161,6 +174,18 @@ class PARSER_FUNC_VISIBILITY OnnxWeightsParser : public domi::WeightsParser {
(void)graph; (void)graph;
return domi::SUCCESS; return domi::SUCCESS;
} }

bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}
}; };
} // namespace domi } // namespace domi
#endif // PARSER_ONNX_ONNX_PARSER_H_ #endif // PARSER_ONNX_ONNX_PARSER_H_

+ 4
- 0
parser/onnx/onnx_util.cc View File

@@ -45,4 +45,8 @@ void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &orig
const std::string &parent_node_name, std::string &unique_subgraph_name) { const std::string &parent_node_name, std::string &unique_subgraph_name) {
unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name;
} }

std::string OnnxUtil::GenUniqueNodeName(const std::string &graph_name, const std::string &node_name) {
return graph_name + "/" + node_name;
}
} // namespace ge } // namespace ge

+ 1
- 0
parser/onnx/onnx_util.h View File

@@ -54,6 +54,7 @@ class OnnxUtil {
static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type);
static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name,
const std::string &parent_node_name, std::string &unique_subgraph_name); const std::string &parent_node_name, std::string &unique_subgraph_name);
static std::string GenUniqueNodeName(const std::string &graph_name, const std::string &node_name);
}; };
} // namespace ge } // namespace ge




+ 16
- 9
parser/onnx/subgraph_adapter/if_subgraph_adapter.cc View File

@@ -15,6 +15,7 @@
*/ */


#include "if_subgraph_adapter.h" #include "if_subgraph_adapter.h"
#include <unordered_set>
#include "subgraph_adapter_factory.h" #include "subgraph_adapter_factory.h"
#include "common/util.h" #include "common/util.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
@@ -27,12 +28,12 @@ const int kIfNodeAttrSize = 2;
} // namespace } // namespace
domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) {
GE_CHECK_NOTNULL(parent_node); GE_CHECK_NOTNULL(parent_node);
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(),
parent_node->op_type().c_str()); parent_node->op_type().c_str());


auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph);
auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[Parse][Node] Parse if node failed."); GELOGE(ret, "[Parse][Node] Parse if node failed.");
REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str());
@@ -44,7 +45,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(


domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) {
if (parent_node->attribute_size() != kIfNodeAttrSize) { if (parent_node->attribute_size() != kIfNodeAttrSize) {
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
@@ -67,7 +68,11 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
return FAILED; return FAILED;
} }
std::string unique_subgraph_name; std::string unique_subgraph_name;
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, parent_node->name(), unique_subgraph_name);
std::string node_name = parent_node->name();
if (!parent_graph_name.empty()) {
node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name);
}
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, node_name, unique_subgraph_name);
GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str()); GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str());
ge::onnx::GraphProto *onnx_graph = attribute->mutable_g(); ge::onnx::GraphProto *onnx_graph = attribute->mutable_g();
name_to_onnx_graph[unique_subgraph_name] = onnx_graph; name_to_onnx_graph[unique_subgraph_name] = onnx_graph;
@@ -91,8 +96,8 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(


domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph,
std::set<std::string> &all_inputs) const { std::set<std::string> &all_inputs) const {
std::set<std::string> graph_inputs;
std::set<std::string> graph_outputs;
std::unordered_set<std::string> graph_inputs;
std::unordered_set<std::string> graph_outputs;
for (int i = 0; i < onnx_graph.node_size(); i++) { for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i);
for (int j = 0; j < node_proto->input_size(); j++) { for (int j = 0; j < node_proto->input_size(); j++) {
@@ -102,10 +107,12 @@ domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx
graph_outputs.emplace(node_proto->output(j)); graph_outputs.emplace(node_proto->output(j));
} }
} }

std::unordered_set<std::string> graph_initializer_tensors;
for (int32_t i = 0; i < onnx_graph.initializer_size(); i++) {
graph_initializer_tensors.emplace(onnx_graph.initializer(i).name());
}
for (const auto &input : graph_inputs) { for (const auto &input : graph_inputs) {
auto out_iter = graph_outputs.find(input);
if (out_iter == graph_outputs.end()) {
if (graph_outputs.count(input) == 0 && graph_initializer_tensors.count(input) == 0) {
// Record input node need to be constructed // Record input node need to be constructed
all_inputs.emplace(input); all_inputs.emplace(input);
} }


+ 5
- 3
parser/onnx/subgraph_adapter/if_subgraph_adapter.h View File

@@ -24,13 +24,15 @@
namespace ge { namespace ge {
class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter {
public: public:
domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op,
domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_node,
std::vector<ge::onnx::GraphProto *> &onnx_graphs, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) override;
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name = "") override;


private: private:
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph);
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name);
domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const;
void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph) const; void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph) const;
void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node) const; void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node) const;


+ 3
- 1
parser/onnx/subgraph_adapter/subgraph_adapter.h View File

@@ -49,10 +49,12 @@ class PARSER_FUNC_VISIBILITY SubgraphAdapter {
/// @return FAILED Parse failed /// @return FAILED Parse failed
virtual domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, virtual domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op,
std::vector<ge::onnx::GraphProto *> &onnx_graphs, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name = "") {
(void)parent_op; (void)parent_op;
(void)onnx_graphs; (void)onnx_graphs;
(void)name_to_onnx_graph; (void)name_to_onnx_graph;
(void)parent_graph_name;
return domi::SUCCESS; return domi::SUCCESS;
} }
}; };


+ 1
- 1
parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc View File

@@ -26,7 +26,7 @@ SubgraphAdapterFactory* SubgraphAdapterFactory::Instance() {
std::shared_ptr<SubgraphAdapter> SubgraphAdapterFactory::CreateSubgraphAdapter( std::shared_ptr<SubgraphAdapter> SubgraphAdapterFactory::CreateSubgraphAdapter(
const std::string &op_type) { const std::string &op_type) {
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create SubgraphAdapter. // First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create SubgraphAdapter.
auto iter = subgraph_adapter_creator_map_.find(op_type);
std::map<std::string, CREATOR_FUN>::const_iterator iter = subgraph_adapter_creator_map_.find(op_type);
if (iter != subgraph_adapter_creator_map_.end()) { if (iter != subgraph_adapter_creator_map_.end()) {
return iter->second(); return iter->second();
} }


+ 13
- 17
parser/tensorflow/graph_functiondef.cc View File

@@ -161,7 +161,6 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp
if (node_def->input(i).find("^") != string::npos) { if (node_def->input(i).find("^") != string::npos) {
// Control input // Control input
const string normalized = node_names.Renormalize(node_def->input(i).substr(1)); const string normalized = node_names.Renormalize(node_def->input(i).substr(1));

GE_IF_BOOL_EXEC(normalized.empty(), GE_IF_BOOL_EXEC(normalized.empty(),
REPORT_INNER_ERROR("E19999", "Could not remap control input %s of node %s in function %s", REPORT_INNER_ERROR("E19999", "Could not remap control input %s of node %s in function %s",
node_def->input(i).c_str(), node_def->name().c_str(), name.c_str()); node_def->input(i).c_str(), node_def->name().c_str(), name.c_str());
@@ -172,7 +171,6 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp
*node_def->mutable_input(i) = "^" + normalized; *node_def->mutable_input(i) = "^" + normalized;
} else { } else {
const auto iter = tensor_renaming.find(node_def->input(i)); const auto iter = tensor_renaming.find(node_def->input(i));

GE_IF_BOOL_EXEC(iter == tensor_renaming.end(), GE_IF_BOOL_EXEC(iter == tensor_renaming.end(),
REPORT_INNER_ERROR("E19999", "Could not remap input %s of node %s in function %s", REPORT_INNER_ERROR("E19999", "Could not remap input %s of node %s in function %s",
node_def->input(i).c_str(), node_def->name().c_str(), name.c_str()); node_def->input(i).c_str(), node_def->name().c_str(), name.c_str());
@@ -188,14 +186,12 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp
// Remap return values. // Remap return values.
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
const string &ret_name = fdef->signature().output_arg(r).name(); const string &ret_name = fdef->signature().output_arg(r).name();

GE_IF_BOOL_EXEC(ret_name.empty(), GE_IF_BOOL_EXEC(ret_name.empty(),
REPORT_INNER_ERROR("E19999", "Missing output %d to function %s", r, name.c_str()); REPORT_INNER_ERROR("E19999", "Missing output %d to function %s", r, name.c_str());
GELOGE(domi::INTERNAL_ERROR, "Missing output %d to function %s .", r, name.c_str()); GELOGE(domi::INTERNAL_ERROR, "Missing output %d to function %s .", r, name.c_str());
return domi::INTERNAL_ERROR); return domi::INTERNAL_ERROR);


const string &return_value = return_values[ret_name]; const string &return_value = return_values[ret_name];

GE_IF_BOOL_EXEC(return_value.empty(), GE_IF_BOOL_EXEC(return_value.empty(),
REPORT_INNER_ERROR("E19999", "Could not remap return value %d ,%s of %s in function %s", r, REPORT_INNER_ERROR("E19999", "Could not remap return value %d ,%s of %s in function %s", r,
ret_name.c_str(), return_value.c_str(), name.c_str()); ret_name.c_str(), return_value.c_str(), name.c_str());
@@ -204,12 +200,11 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp
return domi::INTERNAL_ERROR); return domi::INTERNAL_ERROR);


const auto iter = tensor_renaming.find(return_value); const auto iter = tensor_renaming.find(return_value);

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iter == tensor_renaming.end(),
REPORT_INNER_ERROR("E19999", "can not find value[%s] in tensor_renaming map",
return_value.c_str());
return domi::INTERNAL_ERROR,
"can not find value[%s] in tensor_renaming map.", return_value.c_str());
if (iter == tensor_renaming.end()) {
REPORT_INNER_ERROR("E19999", "can not find value[%s] in tensor_renaming map", return_value.c_str());
GELOGE(FAILED, "can not find value[%s] in tensor_renaming map.", return_value.c_str());
return domi::INTERNAL_ERROR;
}


(*fdef->mutable_ret())[ret_name] = iter->second; (*fdef->mutable_ret())[ret_name] = iter->second;
} }
@@ -227,7 +222,7 @@ domi::Status GraphToFunctionDef::RecordResult(ge::ComputeGraphPtr graph,
GE_CHECK_NOTNULL(anchor); GE_CHECK_NOTNULL(anchor);
GE_CHECK_NOTNULL(anchor->GetOwnerNode()->GetOpDesc()); GE_CHECK_NOTNULL(anchor->GetOwnerNode()->GetOpDesc());
int32_t type = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(anchor->GetIdx()).GetDataType(); int32_t type = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(anchor->GetIdx()).GetDataType();
auto iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type);
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type);
GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(),
REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported", REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported",
type, anchor->GetIdx(), anchor->GetOwnerNode()->GetName().c_str(), type, anchor->GetIdx(), anchor->GetOwnerNode()->GetName().c_str(),
@@ -304,7 +299,7 @@ domi::Status GraphToFunctionDef::RecordArg(ge::ComputeGraphPtr graph, const vect
GE_CHECK_NOTNULL_EXEC(tensor_desc_ptr, return domi::FAILED); GE_CHECK_NOTNULL_EXEC(tensor_desc_ptr, return domi::FAILED);


int32_t type = tensor_desc_ptr->GetDataType(); int32_t type = tensor_desc_ptr->GetDataType();
auto iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type);
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type);
GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(),
REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported", REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported",
type, anchor->GetIdx(), anchor->GetOwnerNode()->GetName().c_str(), type, anchor->GetIdx(), anchor->GetOwnerNode()->GetName().c_str(),
@@ -325,8 +320,8 @@ domi::Status GraphToFunctionDef::RecordArg(ge::ComputeGraphPtr graph, const vect
return FAILED; return FAILED;
} }


(void)ge::AttrUtils::SetInt(op, "T", (int32_t)dtype);
(void)ge::AttrUtils::SetInt(op, "arg_index", (int32_t)index);
(void)ge::AttrUtils::SetInt(op, "T", static_cast<int32_t>(dtype));
(void)ge::AttrUtils::SetInt(op, "arg_index", static_cast<int32_t>(index));
ge::NodePtr arg_node = graph->AddNode(op); ge::NodePtr arg_node = graph->AddNode(op);
GE_CHECK_NOTNULL(arg_node); GE_CHECK_NOTNULL(arg_node);
bool node_exists = false; bool node_exists = false;
@@ -378,7 +373,6 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
if (node->GetOpDesc()->GetType() == ge::parser::DATA) { if (node->GetOpDesc()->GetType() == ge::parser::DATA) {
int64_t index = 0; int64_t index = 0;

int64_t type = 1; int64_t type = 1;
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID, GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID,
"Get type attr failed"); "Get type attr failed");
@@ -400,7 +394,6 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph
if (node->GetOpDesc()->GetType() == ge::parser::NETOUTPUT) { if (node->GetOpDesc()->GetType() == ge::parser::NETOUTPUT) {
int64_t index = 0; int64_t index = 0;
int64_t type = 1; int64_t type = 1;

GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID, GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID,
"Get type attr failed"); "Get type attr failed");


@@ -589,7 +582,10 @@ bool GraphToFunctionDef::FindAttrValue(const domi::tensorflow::NodeDef *node_def


void GraphToFunctionDef::AddNodeAttr(const string &attr_name, const domi::tensorflow::AttrValue &value, void GraphToFunctionDef::AddNodeAttr(const string &attr_name, const domi::tensorflow::AttrValue &value,
domi::tensorflow::NodeDef *node_def) { domi::tensorflow::NodeDef *node_def) {
GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null.");
if (node_def == nullptr) {
GELOGI("input parameter is null.");
return;
}
node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value));
} }
} // namespace ge } // namespace ge

+ 1
- 1
parser/tensorflow/graph_functiondef.h View File

@@ -52,7 +52,7 @@ class GraphToFunctionDef {
const string &name, FunctionDef *fdef); const string &name, FunctionDef *fdef);


static domi::Status BuildFunctionDef(ge::ComputeGraphPtr &graph, static domi::Status BuildFunctionDef(ge::ComputeGraphPtr &graph,
const string &nme_in,
const string &name_in,
FunctionDefLibrary *library, FunctionDefLibrary *library,
NodeDef *call_node_def, NodeDef *call_node_def,
vector<ge::InDataAnchorPtr> &in_anchor, vector<ge::InDataAnchorPtr> &in_anchor,


+ 28
- 18
parser/tensorflow/graph_optimizer.cc View File

@@ -15,7 +15,7 @@
*/ */


#include "graph_optimizer.h" #include "graph_optimizer.h"
#include "common/op_types.h"
#include "graph/op_types.h"
#include "common/types_map.h" #include "common/types_map.h"
#include "common/util.h" #include "common/util.h"
#include "framework/omg/parser/parser_inner_ctx.h" #include "framework/omg/parser/parser_inner_ctx.h"
@@ -93,6 +93,7 @@ Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, const boo
GE_CHECK_NOTNULL(graph_); GE_CHECK_NOTNULL(graph_);
for (auto node : graph_->GetDirectNode()) { for (auto node : graph_->GetDirectNode()) {
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue)
string type; string type;
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type));
@@ -178,7 +179,7 @@ Status CollectNodeFuncs(vector<ge::NodePtr> &nodes, FunctionDefLibrary *library)
GE_IF_BOOL_EXEC( GE_IF_BOOL_EXEC(
AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib; AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib;
GE_CHECK_NOTNULL(funcDefBytes.GetData()); GE_CHECK_NOTNULL(funcDefBytes.GetData());
string str(reinterpret_cast<char *>(funcDefBytes.GetData()), funcDefBytes.GetSize());
string str(PtrToPtr<uint8_t, char_t>(funcDefBytes.GetData()), funcDefBytes.GetSize());
GELOGI("FUNCDEF: Get function -> %s.", str.c_str()); GE_IF_BOOL_EXEC( GELOGI("FUNCDEF: Get function -> %s.", str.c_str()); GE_IF_BOOL_EXEC(
funcLib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()), library->MergeFrom(funcLib))); funcLib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()), library->MergeFrom(funcLib)));
} }
@@ -206,9 +207,11 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) {
std::unique_ptr<FunctionDefLibrary> func_def_lib(new (std::nothrow) FunctionDefLibrary()); std::unique_ptr<FunctionDefLibrary> func_def_lib(new (std::nothrow) FunctionDefLibrary());
GE_CHECK_NOTNULL(func_def_lib); GE_CHECK_NOTNULL(func_def_lib);
// convert graph to FunctionDef // convert graph to FunctionDef
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(nodes.size() == 0,
REPORT_INNER_ERROR("E19999", "Param nodes size must greater than 0");
return PARAM_INVALID, "node size must greater than 0 .");
if (nodes.size() == 0) {
REPORT_INNER_ERROR("E19999", "Param nodes size must greater than 0");
GELOGE(FAILED, "node size must greater than 0 .");
return PARAM_INVALID;
}
GE_CHK_STATUS_RET(CollectNodeFuncs(nodes, func_def_lib.get()), "Collect functionDef in nodes failed."); GE_CHK_STATUS_RET(CollectNodeFuncs(nodes, func_def_lib.get()), "Collect functionDef in nodes failed.");
GE_CHK_STATUS_RET(GraphToFunctionDef::BuildFunctionDef(sub_graph, nodes[0]->GetName(), func_def_lib.get(), GE_CHK_STATUS_RET(GraphToFunctionDef::BuildFunctionDef(sub_graph, nodes[0]->GetName(), func_def_lib.get(),
node_def.get(), input_anchors, output_anchors), node_def.get(), input_anchors, output_anchors),
@@ -226,7 +229,10 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) {
GELOGE(PARAM_INVALID, "Serialize func_def to string failed."); GELOGE(PARAM_INVALID, "Serialize func_def to string failed.");
return PARAM_INVALID); return PARAM_INVALID);


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(nodes.size() == 0, return PARAM_INVALID, "nodes is empty.");
if (nodes.size() == 0) {
GELOGE(FAILED, "nodes is empty.");
return PARAM_INVALID;
}


std::string fusion_op_name; std::string fusion_op_name;
for (auto node : nodes) { for (auto node : nodes) {
@@ -250,10 +256,10 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) {


(void)AttrUtils::SetZeroCopyBytes( (void)AttrUtils::SetZeroCopyBytes(
fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF,
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(funcdefStr.data()), funcdefStr.length()));
Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(funcdefStr.data()), funcdefStr.length()));
(void)AttrUtils::SetZeroCopyBytes( (void)AttrUtils::SetZeroCopyBytes(
fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF,
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length()));
Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(nodefStr.data()), nodefStr.length()));


(void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, ge::GetParserContext().type); (void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, ge::GetParserContext().type);


@@ -284,6 +290,7 @@ Status ParserGraphOptimizer::InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
OpDescPtr op_def = node->GetOpDesc(); OpDescPtr op_def = node->GetOpDesc();
NodePtr new_node = sub_graph->AddNode(op_def); NodePtr new_node = sub_graph->AddNode(op_def);
GE_CHECK_NOTNULL(new_node);
node_map[node->GetName()] = new_node; node_map[node->GetName()] = new_node;


// Input // Input
@@ -381,7 +388,8 @@ Status ParserGraphOptimizer::RebuildOutputAnchors(vector<ge::OutDataAnchorPtr> &
GE_CHK_BOOL_EXEC(fusion_op_desc->AddOutputDesc(src_out_desc) == ge::GRAPH_SUCCESS, return FAILED); GE_CHK_BOOL_EXEC(fusion_op_desc->AddOutputDesc(src_out_desc) == ge::GRAPH_SUCCESS, return FAILED);


ge::DataType data_type = src_out_desc.GetDataType(); ge::DataType data_type = src_out_desc.GetDataType();
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find((int32_t)data_type);
const std::map<int32_t, int32_t>::const_iterator iter =
GE_TENSORFLOW_DATA_TYPE_MAP.find(static_cast<int32_t>(data_type));
GE_IF_BOOL_EXEC( GE_IF_BOOL_EXEC(
iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(),
REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported", REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported",
@@ -390,7 +398,7 @@ Status ParserGraphOptimizer::RebuildOutputAnchors(vector<ge::OutDataAnchorPtr> &
return PARAM_INVALID); return PARAM_INVALID);


int32_t dtype = iter->second; int32_t dtype = iter->second;
output_list.push_back((int64_t)dtype);
output_list.push_back(static_cast<int64_t>(dtype));
GELOGI("FUNCDEF: output_list push_back %d.", dtype); GELOGI("FUNCDEF: output_list push_back %d.", dtype);
} }
GE_IF_BOOL_EXEC(!output_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_OUT_DATATYPE, output_list)); GE_IF_BOOL_EXEC(!output_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_OUT_DATATYPE, output_list));
@@ -410,14 +418,15 @@ Status ParserGraphOptimizer::RebuildInputAnchors(vector<ge::InDataAnchorPtr> &in
auto tensorDescPtr = dst_node->GetOpDesc()->GetInputDescPtr(in_anchor->GetIdx()); auto tensorDescPtr = dst_node->GetOpDesc()->GetInputDescPtr(in_anchor->GetIdx());
GE_CHECK_NOTNULL_EXEC(tensorDescPtr, return domi::FAILED); GE_CHECK_NOTNULL_EXEC(tensorDescPtr, return domi::FAILED);


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((fusion_op_desc->AddInputDesc(*tensorDescPtr)) != GRAPH_SUCCESS,
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
fusion_op_desc->GetName().c_str(),
fusion_op_desc->GetType().c_str());
return FAILED,
"Add fusion_op_desc AddInputDesc failed");
if (fusion_op_desc->AddInputDesc(*tensorDescPtr) != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
fusion_op_desc->GetName().c_str(), fusion_op_desc->GetType().c_str());
GELOGE(FAILED, "Add fusion_op_desc AddInputDesc failed");
return FAILED;
}
ge::DataType data_type = tensorDescPtr->GetDataType(); ge::DataType data_type = tensorDescPtr->GetDataType();
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find((int32_t)data_type);
const std::map<int32_t, int32_t>::const_iterator iter =
GE_TENSORFLOW_DATA_TYPE_MAP.find(static_cast<int32_t>(data_type));
GE_IF_BOOL_EXEC( GE_IF_BOOL_EXEC(
iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(),
REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported", REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported",
@@ -426,7 +435,7 @@ Status ParserGraphOptimizer::RebuildInputAnchors(vector<ge::InDataAnchorPtr> &in
return PARAM_INVALID); return PARAM_INVALID);


int32_t dtype = iter->second; int32_t dtype = iter->second;
input_list.push_back((int64_t)dtype);
input_list.push_back(static_cast<int64_t>(dtype));
GELOGI("FUNCDEF: input_list push_back %d.", dtype); GELOGI("FUNCDEF: input_list push_back %d.", dtype);
} }
GE_IF_BOOL_EXEC(!input_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_IN_DATATYPE, input_list)); GE_IF_BOOL_EXEC(!input_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_IN_DATATYPE, input_list));
@@ -440,6 +449,7 @@ Status ParserGraphOptimizer::RebuildFusionNode(vector<ge::InDataAnchorPtr> &inpu
vector<ge::InControlAnchorPtr> &input_control_anchors, vector<ge::InControlAnchorPtr> &input_control_anchors,
vector<ge::OutControlAnchorPtr> &output_control_anchors, vector<ge::OutControlAnchorPtr> &output_control_anchors,
ge::NodePtr fusion_node) { ge::NodePtr fusion_node) {
GE_CHECK_NOTNULL(fusion_node);
int32_t src_index = 0; int32_t src_index = 0;


for (auto out_anchor : output_anchors) { for (auto out_anchor : output_anchors) {


+ 1
- 1
parser/tensorflow/tensorflow_arg_parser.cc View File

@@ -32,7 +32,7 @@ const char *const kSerializeFormat = "serialize_format";
Status ParseParams(const Message *op_src, ArgOpOperator *const op) { Status ParseParams(const Message *op_src, ArgOpOperator *const op) {
GE_CHECK_NOTNULL(op_src); GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op); GE_CHECK_NOTNULL(op);
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src);
const domi::tensorflow::NodeDef *node = reinterpret_cast<const domi::tensorflow::NodeDef *>(op_src);
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
domi::tensorflow::AttrValue output_attr_value; domi::tensorflow::AttrValue output_attr_value;
if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) {


+ 6
- 3
parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc View File

@@ -44,7 +44,7 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge
GELOGE(PARAM_INVALID, "Op src is null"); GELOGE(PARAM_INVALID, "Op src is null");
return PARAM_INVALID; return PARAM_INVALID;
} }
const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src);
const domi::tensorflow::NodeDef *node = PtrToPtr<const Message, const domi::tensorflow::NodeDef>(op_src);
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
if (op_dest == nullptr) { if (op_dest == nullptr) {
REPORT_INNER_ERROR("E19999", "Param op_dest is nullptr, check invalid"); REPORT_INNER_ERROR("E19999", "Param op_dest is nullptr, check invalid");
@@ -109,8 +109,11 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge
return FAILED; return FAILED;
} }
} }
const auto out_desc = op_dest->MutableOutputDesc(0);
GE_CHECK_NOTNULL(out_desc);
out_desc->SetDataType(out_type);


std::shared_ptr<NodeDef> pkg_node = ge::parser::MakeShared<NodeDef>();
std::shared_ptr<domi::tensorflow::NodeDef> pkg_node = ge::parser::MakeShared<domi::tensorflow::NodeDef>();
GE_CHECK_NOTNULL(pkg_node); GE_CHECK_NOTNULL(pkg_node);
pkg_node->CopyFrom(*node); pkg_node->CopyFrom(*node);


@@ -130,7 +133,7 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge


(void)AttrUtils::SetZeroCopyBytes( (void)AttrUtils::SetZeroCopyBytes(
op_dest, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, op_dest, ge::ATTR_NAME_FRAMEWORK_NODE_DEF,
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(serialized_node.data()), serialized_node.length()));
Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(serialized_node.data()), serialized_node.length()));
GELOGI("node_def of %s is %s.", op_dest->GetName().c_str(), serialized_node.c_str()); GELOGI("node_def of %s is %s.", op_dest->GetName().c_str(), serialized_node.c_str());
} }




+ 1
- 1
parser/tensorflow/tensorflow_custom_parser_adapter.cc View File

@@ -27,7 +27,7 @@ using domi::ParseParamByOpFunc;
namespace ge { namespace ge {
Status TensorFlowCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { Status TensorFlowCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {
GE_CHECK_NOTNULL(op_src); GE_CHECK_NOTNULL(op_src);
const NodeDef *node_src = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src);
const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src);
GE_CHECK_NOTNULL(node_src); GE_CHECK_NOTNULL(node_src);
GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str());
GE_CHECK_NOTNULL(op_dest); GE_CHECK_NOTNULL(op_dest);


+ 3
- 3
parser/tensorflow/tensorflow_data_parser.cc View File

@@ -93,7 +93,7 @@ Status TensorFlowDataParser::ParseInputFromModel(const Message *op_src, const ge
GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_SHAPE), GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_SHAPE),
"check Attr %s failed", TENSORFLOW_ATTR_SHAPE.c_str()); "check Attr %s failed", TENSORFLOW_ATTR_SHAPE.c_str());


const TensorShapeProto &data_shape = attr_value.shape();
const domi::tensorflow::TensorShapeProto &data_shape = attr_value.shape();
for (auto i = 0; i < data_shape.dim_size(); i++) { for (auto i = 0; i < data_shape.dim_size(); i++) {
model_input_dims_v.push_back(data_shape.dim(i).size()); model_input_dims_v.push_back(data_shape.dim(i).size());
} }
@@ -110,7 +110,7 @@ Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge:
std::string name = op_def->GetName(); std::string name = op_def->GetName();
if (input_dims.count(name) == 0) { if (input_dims.count(name) == 0) {
GELOGI("input shape of node %s is not designated ,need parse from model", name.c_str()); GELOGI("input shape of node %s is not designated ,need parse from model", name.c_str());
for (uint32_t i = 0; i < model_input_dims_v.size(); i++) {
for (size_t i = 0; i < model_input_dims_v.size(); ++i) {
user_input_dims_v.push_back(model_input_dims_v[i]); user_input_dims_v.push_back(model_input_dims_v[i]);
} }


@@ -138,7 +138,7 @@ Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge:
} }


Status TensorFlowDataParser::CheckInputShape(const std::string &name) { Status TensorFlowDataParser::CheckInputShape(const std::string &name) {
for (uint32_t i = 0; i < user_input_dims_v.size(); i++) {
for (size_t i = 0; i < user_input_dims_v.size(); ++i) {
// if input_shape has some placeholders, user should designate them. // if input_shape has some placeholders, user should designate them.
// dim i = 0, means empty tensor. // dim i = 0, means empty tensor.
// dim i = -1 or -2, means unknown shape. // dim i = -1 or -2, means unknown shape.


+ 1
- 1
parser/tensorflow/tensorflow_enter_parser.cc View File

@@ -31,7 +31,7 @@ Status TensorFlowEnterParser::ParseParams(const Message *op_src, ge::OpDescPtr &
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
const std::string name = op_desc->GetName(); const std::string name = op_desc->GetName();


const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src);
const domi::tensorflow::NodeDef *node = PtrToPtr<const Message, const domi::tensorflow::NodeDef>(op_src);
domi::tensorflow::AttrValue str_attr; domi::tensorflow::AttrValue str_attr;
if (!TensorFlowUtil::FindAttrValue(node, ENTER_ATTR_FRAME_NAME, str_attr)) { if (!TensorFlowUtil::FindAttrValue(node, ENTER_ATTR_FRAME_NAME, str_attr)) {
REPORT_CALL_ERROR("E19999", "In NodeDef:%s attr:%s not exist, check invalid", REPORT_CALL_ERROR("E19999", "In NodeDef:%s attr:%s not exist, check invalid",


+ 1
- 1
parser/tensorflow/tensorflow_fill_parser.cc View File

@@ -38,7 +38,7 @@ node {
} }
} }
*/ */
domi::Status ParseParams(const NodeDef *node, FillOperator *op) {
domi::Status ParseParams(const domi::tensorflow::NodeDef *node, FillOperator *op) {
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(op); GE_CHECK_NOTNULL(op);
op->Name(node->name()); op->Name(node->name());


+ 2
- 2
parser/tensorflow/tensorflow_frameworkop_parser.cc View File

@@ -31,7 +31,7 @@ namespace ge {
Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { Status ParseParams(const Message *op_src, FrameworkOpOperator *op) {
GE_CHECK_NOTNULL(op_src); GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op); GE_CHECK_NOTNULL(op);
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src);
const domi::tensorflow::NodeDef *node = reinterpret_cast<const domi::tensorflow::NodeDef *>(op_src);
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
string type = node->op(); string type = node->op();


@@ -64,7 +64,7 @@ Status ParseParams(const Message *op_src, FrameworkOpOperator *op) {
GE_IF_BOOL_EXEC(((type == "_Retval") && (TensorFlowUtil::FindAttrValue(node, ATTR_NAME_INDEX, index_attr_value))), GE_IF_BOOL_EXEC(((type == "_Retval") && (TensorFlowUtil::FindAttrValue(node, ATTR_NAME_INDEX, index_attr_value))),
op->Index(index_attr_value.i())); op->Index(index_attr_value.i()));


NodeDef *pkg_node = new (std::nothrow) NodeDef();
domi::tensorflow::NodeDef *pkg_node = new (std::nothrow) domi::tensorflow::NodeDef();
GE_CHECK_NOTNULL(pkg_node); GE_CHECK_NOTNULL(pkg_node);


pkg_node->CopyFrom(*node); pkg_node->CopyFrom(*node);


+ 1
- 1
parser/tensorflow/tensorflow_fusion_op_parser.h View File

@@ -44,7 +44,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowFusionOpParser : public TensorFlowOpParse
* @return SUCCESS Parsing success * @return SUCCESS Parsing success
* @return FAILED Parsing failed * @return FAILED Parsing failed
*/ */
virtual Status ParseParams(const std::vector<const NodeDef *> &v_input_const, ge::NodePtr &node) const;
virtual Status ParseParams(const std::vector<const NodeDef *> &v_input_const, ge::NodePtr &op_dest) const;


/** /**
* @ingroup domi_omg * @ingroup domi_omg


+ 1
- 1
parser/tensorflow/tensorflow_merge_parser.cc View File

@@ -31,7 +31,7 @@ Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr &
GE_CHECK_NOTNULL(op_src); GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);


const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src);
const domi::tensorflow::NodeDef *node = PtrToPtr<const Message, const domi::tensorflow::NodeDef>(op_src);
domi::tensorflow::AttrValue attr_num; domi::tensorflow::AttrValue attr_num;
if (!(TensorFlowUtil::FindAttrValue(node, ATTR_NAME_N, attr_num))) { if (!(TensorFlowUtil::FindAttrValue(node, ATTR_NAME_N, attr_num))) {
GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_desc->GetName().c_str(), ATTR_NAME_N.c_str()); GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_desc->GetName().c_str(), ATTR_NAME_N.c_str());


+ 1
- 1
parser/tensorflow/tensorflow_no_op_parser.cc View File

@@ -26,7 +26,7 @@ using namespace ge::parser;


namespace ge { namespace ge {
Status TensorFlowNoOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { Status TensorFlowNoOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {
const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src);
const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src);
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
NoOpOperator op; NoOpOperator op;


+ 0
- 19
parser/tensorflow/tensorflow_op_parser.h View File

@@ -42,25 +42,6 @@
#include "proto/tensorflow/graph.pb.h" #include "proto/tensorflow/graph.pb.h"
#include "proto/tensorflow/node_def.pb.h" #include "proto/tensorflow/node_def.pb.h"



using domi::tensorflow::NodeDef;
using domi::tensorflow::TensorProto;
using google::protobuf::int32;
using google::protobuf::int64;
using google::protobuf::Message;
using std::string;
using std::vector;
using Status = domi::Status;
using domi::tensorflow::AttrValue;
using domi::tensorflow::DataType;
using domi::tensorflow::DT_BOOL;
using domi::tensorflow::DT_FLOAT;
using domi::tensorflow::DT_INT32;
using domi::tensorflow::DT_INT64;
using domi::tensorflow::DT_INVALID;
using domi::tensorflow::TensorShapeProto;
using domi::tensorflow::TensorShapeProto_Dim;

namespace ge { namespace ge {
/** /**
* @ingroup domi_omg * @ingroup domi_omg


+ 141
- 221
parser/tensorflow/tensorflow_parser.cc View File

@@ -40,7 +40,6 @@
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#include "parser/common/parser_fp16_t.h" #include "parser/common/parser_fp16_t.h"
#include "parser/common/pass_manager.h" #include "parser/common/pass_manager.h"
#include "parser/common/pre_checker.h"
#include "parser/common/prototype_pass_manager.h" #include "parser/common/prototype_pass_manager.h"
#include "parser/common/thread_pool.h" #include "parser/common/thread_pool.h"
#include "parser/common/parser_utils.h" #include "parser/common/parser_utils.h"
@@ -54,6 +53,7 @@
#include "register/register_utils.h" #include "register/register_utils.h"
#include "register/scope/scope_pass_registry_impl.h" #include "register/scope/scope_pass_registry_impl.h"
#include "parser/common/auto_mapping_subgraph_io_index_func.h" #include "parser/common/auto_mapping_subgraph_io_index_func.h"
#include "graph/def_types.h"


using ge::OpParserFactory; using ge::OpParserFactory;
using ge::Pb2Json; using ge::Pb2Json;
@@ -507,8 +507,11 @@ Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def,


ge::NodePtr node = nullptr; ge::NodePtr node = nullptr;
node = graph->AddNode(op_desc); node = graph->AddNode(op_desc);

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node == nullptr), DeleteFuisonNodeDef(); return INTERNAL_ERROR, "add node failed.");
if (node == nullptr) {
DeleteFuisonNodeDef();
GELOGE(FAILED, "add node failed.");
return INTERNAL_ERROR;
}


node_map_[node_name] = node; node_map_[node_name] = node;


@@ -545,7 +548,11 @@ Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def,
// checkout op input number with IR // checkout op input number with IR
GE_RETURN_IF_ERROR(CheckoutInputNum(op, node_def)); GE_RETURN_IF_ERROR(CheckoutInputNum(op, node_def));
ge::NodePtr node = graph->AddNode(op); ge::NodePtr node = graph->AddNode(op);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node == nullptr), DeleteFuisonNodeDef(); return INTERNAL_ERROR, "add node failed.");
if (node == nullptr) {
DeleteFuisonNodeDef();
GELOGE(FAILED, "add node failed.");
return INTERNAL_ERROR;
}


node_map_[node_name] = node; node_map_[node_name] = node;


@@ -794,22 +801,24 @@ Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(out_archor_ptr); GE_CHECK_NOTNULL(out_archor_ptr);
ge::InDataAnchorPtr in_archor_ptr = dest->GetInDataAnchor(outputpair.second); ge::InDataAnchorPtr in_archor_ptr = dest->GetInDataAnchor(outputpair.second);
GE_CHECK_NOTNULL(in_archor_ptr); GE_CHECK_NOTNULL(in_archor_ptr);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS,
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed",
src->GetName().c_str(), dest->GetName().c_str());
return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].",
src->GetName().c_str(), dest->GetName().c_str());
if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) {
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed",
src->GetName().c_str(), dest->GetName().c_str());
GELOGE(FAILED, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), dest->GetName().c_str());
return INTERNAL_ERROR;
}
} else { } else {
GELOGD("Start add contorl edge: from %s to %s.", src->GetName().c_str(), dest->GetName().c_str()); GELOGD("Start add contorl edge: from %s to %s.", src->GetName().c_str(), dest->GetName().c_str());
ge::InControlAnchorPtr in_archor_ptr = dest->GetInControlAnchor(); ge::InControlAnchorPtr in_archor_ptr = dest->GetInControlAnchor();
GE_CHECK_NOTNULL(in_archor_ptr); GE_CHECK_NOTNULL(in_archor_ptr);
ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor(); ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_archor_ptr); GE_CHECK_NOTNULL(out_archor_ptr);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS,
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed",
src->GetName().c_str(), dest->GetName().c_str());
return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].",
src->GetName().c_str(), dest->GetName().c_str());
if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) {
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed",
src->GetName().c_str(), dest->GetName().c_str());
GELOGE(FAILED, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), dest->GetName().c_str());
return INTERNAL_ERROR;
}
} }
} }
dest_input_map.erase(input_iter); dest_input_map.erase(input_iter);
@@ -921,10 +930,11 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co
} }


std::map<std::string, std::string>::const_iterator iterator = parser->adaptedOpTypeMap_.find(node_name); std::map<std::string, std::string>::const_iterator iterator = parser->adaptedOpTypeMap_.find(node_name);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
iterator == parser->adaptedOpTypeMap_.end(),
REPORT_INNER_ERROR("E19999", "get adapted op type failed, node name = %s", node_name.c_str());
return FAILED, "get adapted op type failed, node name = %s", node_name.c_str());
if (iterator == parser->adaptedOpTypeMap_.cend()) {
REPORT_INNER_ERROR("E19999", "get adapted op type failed, node name = %s", node_name.c_str());
GELOGE(FAILED, "get adapted op type failed, node name = %s", node_name.c_str());
return FAILED;
}


string op_type = iterator->second; string op_type = iterator->second;
// Log printing for determining operator type // Log printing for determining operator type
@@ -1017,10 +1027,12 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co
node = graph->AddNode(op); node = graph->AddNode(op);
} }


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
(node == nullptr), REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op->GetName().c_str(),
op->GetType().c_str(), graph->GetName().c_str());
return INTERNAL_ERROR, "add node failed.");
if (node == nullptr) {
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op->GetName().c_str(),
op->GetType().c_str(), graph->GetName().c_str());
GELOGE(FAILED, "add node failed.");
return INTERNAL_ERROR;
}


if (needFusion) { if (needFusion) {
shared_ptr<OpParser> fusion_op_parser = factory->CreateFusionOpParser(op_type); shared_ptr<OpParser> fusion_op_parser = factory->CreateFusionOpParser(op_type);
@@ -1116,10 +1128,11 @@ Status TensorFlowModelParser::AddNodeToGraphAndMarkFormat(ge::ComputeGraphPtr &g
for (size_t j = 0; j < op_node_list_size; j++) { for (size_t j = 0; j < op_node_list_size; j++) {
const string op_node_name = op_node_name_list[j]; const string op_node_name = op_node_name_list[j];
auto iterator = node_map_.find(op_node_name); auto iterator = node_map_.find(op_node_name);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
(iterator == node_map_.end()),
REPORT_INNER_ERROR("E19999", "node:%s can't find in node_map_, check invalid", op_node_name.c_str());
return INTERNAL_ERROR, "add node failed.");
if (iterator == node_map_.end()) {
REPORT_INNER_ERROR("E19999", "node:%s can't find in node_map_, check invalid", op_node_name.c_str());
GELOGE(FAILED, "add node failed.");
return INTERNAL_ERROR;
}
GE_CHECK_NOTNULL(iterator->second); GE_CHECK_NOTNULL(iterator->second);
GE_CHK_STATUS_RET(iterator->second->SetOwnerComputeGraph(graph), "set owner compute graph failed"); GE_CHK_STATUS_RET(iterator->second->SetOwnerComputeGraph(graph), "set owner compute graph failed");
graph->AddNode(iterator->second); graph->AddNode(iterator->second);
@@ -1178,15 +1191,22 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g
domi::tensorflow::GraphDef OriDef; domi::tensorflow::GraphDef OriDef;


bool read = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &OriDef); bool read = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &OriDef);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, REPORT_INNER_ERROR("E19999", "read graph proto from binary failed");
return INTERNAL_ERROR, "read_proto_from_binary failed.");
if (!read) {
REPORT_INNER_ERROR("E19999", "read graph proto from binary failed");
GELOGE(FAILED, "read_proto_from_binary failed.");
return INTERNAL_ERROR;
}


domi::tensorflow::GraphDef graph_def; domi::tensorflow::GraphDef graph_def;
if (ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) {
const bool is_empty_input = GetParserContext().input_dims.empty() && GetParserContext().out_nodes_map.empty();
if (is_empty_input) {
graph_def = OriDef; graph_def = OriDef;
} else { } else {
GELOGI("Before Trim, the Graph Node size is:%d", OriDef.node_size()); GELOGI("Before Trim, the Graph Node size is:%d", OriDef.node_size());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(TrimGraph(OriDef, &graph_def), return INTERNAL_ERROR, "Trim Graph fail.");
if (static_cast<bool>(TrimGraph(OriDef, &graph_def))) {
GELOGE(FAILED, "Trim Graph fail.");
return INTERNAL_ERROR;
}
GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size()); GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size());
} }


@@ -1236,9 +1256,6 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g
// This function call affects the return value of prechecker::instance().Haserror() // This function call affects the return value of prechecker::instance().Haserror()
GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list)); GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list));


// Check the input validity of the node, the input attribute must have a corresponding node
GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def));

// Building input and input relationships for all OP nodes // Building input and input relationships for all OP nodes
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def)); GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def));
GELOGD("[TF ParseFromMemory] get op nodes context from graph success"); GELOGD("[TF ParseFromMemory] get op nodes context from graph success");
@@ -1269,9 +1286,12 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
const string &node_op = node_def->op(); const string &node_op = node_def->op();
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((tensorflow_op_map.find(node_op) == tensorflow_op_map.end()), DeleteFuisonNodeDef();
REPORT_INNER_ERROR("E19999", "Op type %s unsupport", node_op.c_str());
return INTERNAL_ERROR, "Unsupport op type %s", node_op.c_str());
if (tensorflow_op_map.find(node_op) == tensorflow_op_map.cend()) {
DeleteFuisonNodeDef();
REPORT_INNER_ERROR("E19999", "Op type %s unsupport", node_op.c_str());
GELOGE(FAILED, "Unsupport op type %s", node_op.c_str());
return INTERNAL_ERROR;
}


ret = AddNode(node_def, graph, scope_graph); ret = AddNode(node_def, graph, scope_graph);
if (ret != SUCCESS) { if (ret != SUCCESS) {
@@ -1339,7 +1359,10 @@ Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr
// Store objects parsed from pb files // Store objects parsed from pb files
domi::tensorflow::GraphDef ori_def; domi::tensorflow::GraphDef ori_def;
bool read = ge::parser::ReadProtoFromBinaryFile(model_path, &ori_def); bool read = ge::parser::ReadProtoFromBinaryFile(model_path, &ori_def);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, return INTERNAL_ERROR, "read_proto_from_binary failed.");
if (!read) {
GELOGE(FAILED, "read tensorflow file failed when the inupt param value of --framework is 3.");
return INTERNAL_ERROR;
}


// Trim graph by user input and output. // Trim graph by user input and output.
domi::tensorflow::GraphDef graph_def; domi::tensorflow::GraphDef graph_def;
@@ -1347,7 +1370,10 @@ Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr
graph_def = ori_def; graph_def = ori_def;
} else { } else {
GELOGI("Before Trim, the Graph Node size is:%d", ori_def.node_size()); GELOGI("Before Trim, the Graph Node size is:%d", ori_def.node_size());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(TrimGraph(ori_def, &graph_def), return INTERNAL_ERROR, "Trim Graph fail.");
if (static_cast<bool>(TrimGraph(ori_def, &graph_def))) {
GELOGE(FAILED, "Trim Graph fail.");
return INTERNAL_ERROR;
}
GELOGI("After Trim, The graph_def.node size is:%d", graph_def.node_size()); GELOGI("After Trim, The graph_def.node size is:%d", graph_def.node_size());
} }


@@ -1375,7 +1401,7 @@ Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr
} }
} }


std::map<std::string, domi::tensorflow::GraphDef>::const_iterator
const std::map<std::string, domi::tensorflow::GraphDef>::const_iterator
iter = function_name_to_graphdef.find(arg.function_name); iter = function_name_to_graphdef.find(arg.function_name);
if (iter == function_name_to_graphdef.end()) { if (iter == function_name_to_graphdef.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E12013", {"functionname"}, {arg.function_name}); ErrorManager::GetInstance().ATCReportErrMessage("E12013", {"functionname"}, {arg.function_name});
@@ -1415,7 +1441,8 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
GE_CHECK_NOTNULL(proto); GE_CHECK_NOTNULL(proto);
GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(graph);


const domi::tensorflow::GraphDef *ori_graph = reinterpret_cast<const domi::tensorflow::GraphDef *>(proto);
const domi::tensorflow::GraphDef *ori_graph =
ge::PtrToPtr<google::protobuf::Message, domi::tensorflow::GraphDef>(proto);
// Make a copy for operation without modifying the original graph def. // Make a copy for operation without modifying the original graph def.
domi::tensorflow::GraphDef graph_def = *ori_graph; domi::tensorflow::GraphDef graph_def = *ori_graph;


@@ -1441,12 +1468,16 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()), GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()),
"Add node_def to PreChecker failed, node name: %s.", node.name().c_str()); "Add node_def to PreChecker failed, node name: %s.", node.name().c_str());


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckName(&node) != SUCCESS, return FAILED,
"Check op[%s] failed, name repeat in tensorflow pb file.", node.name().c_str());
GE_CHK_BOOL_EXEC_NOLOG(
node.op() == TENSORFLOWF_NODE_OP_IDENTITY,
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckType(&node, true) != SUCCESS, return FAILED,
"Check op[%s]'s optype failed, type is not supported.", node.name().c_str());)
if (PreChecker::Instance().CheckName(&node) != SUCCESS) {
GELOGE(FAILED, "Check op[%s] failed, name repeat in tensorflow pb file.", node.name().c_str());
return FAILED;
}
if (node.op() != TENSORFLOWF_NODE_OP_IDENTITY) {
if (PreChecker::Instance().CheckType(&node, true) != SUCCESS) {
GELOGE(FAILED, "Check op[%s]'s optype failed, type is not supported.", node.name().c_str());
return FAILED;
}
}
} }


bool has_error = false; bool has_error = false;
@@ -1471,10 +1502,6 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
// This function call affects the return value of prechecker::instance().Haserror() // This function call affects the return value of prechecker::instance().Haserror()
GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list)); GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list));


// Check the input validity of the node, the input attribute must have a corresponding node
GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def));
GELOGD("[TF Parse] check graph success");

// Building input and input relationships for all OP nodes // Building input and input relationships for all OP nodes
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def)); GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def));
GELOGD("[TF Parse] get op nodes context from graph success"); GELOGD("[TF Parse] get op nodes context from graph success");
@@ -1547,37 +1574,6 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
return SUCCESS; return SUCCESS;
} }


Status TensorFlowModelParser::CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const {
// Number of data nodes
uint32_t data_node_count = 0;
for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) {
// Check that all input is valid
for (const string &node_name : node_def.input()) {
string tmp_node_name;
GE_RETURN_IF_ERROR(CheckInputNodeName(node_name, &tmp_node_name, nullptr, nullptr));

if (nodedef_map_.find(tmp_node_name) == nodedef_map_.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E12009", {"opname", "inputopname"},
{node_def.name(), node_name});
GELOGE(INTERNAL_ERROR, "Op[%s]'s input op[%s] is not exist in the graph_def.", node_def.name().c_str(),
node_name.c_str());
return INTERNAL_ERROR;
}
}

if (node_def.op() == TENSORFLOWF_NODE_OP_PLACEHOLDER || node_def.op() == ge::parser::ARG) {
data_node_count++;
}
}
if (data_node_count == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E12010");
GELOGE(INTERNAL_ERROR, "Model has no Placeholder node.");
return INTERNAL_ERROR;
}

return SUCCESS;
}

Status TensorFlowModelParser::GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def) { Status TensorFlowModelParser::GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def) {
// Build the input relationship first // Build the input relationship first
for (auto &iter : op_node_context_map_) { for (auto &iter : op_node_context_map_) {
@@ -1868,7 +1864,7 @@ Status TensorFlowModelParser::UpdateAllNodeOpContext(shared_ptr<ge::ScopeGraph>
ge::ScopeFusionOpInfo info; ge::ScopeFusionOpInfo info;
if (IsFusionOpChild(op_node_name, &info) && nodedef_map_[op_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) { if (IsFusionOpChild(op_node_name, &info) && nodedef_map_[op_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) {
// This node is a fusion operator // This node is a fusion operator
std::map<std::string, OpNodeContext>::const_iterator
const std::map<std::string, OpNodeContext>::const_iterator
fusion_iter = tmp_fusion_op_node_context_map.find(info.fusion_node_name); fusion_iter = tmp_fusion_op_node_context_map.find(info.fusion_node_name);
if (fusion_iter == tmp_fusion_op_node_context_map.end()) { if (fusion_iter == tmp_fusion_op_node_context_map.end()) {
OpNodeContext op_node_context; OpNodeContext op_node_context;
@@ -2108,10 +2104,10 @@ Status TensorFlowModelParser::NormalizeInputOrOutputMap(
std::set<std::string> compare_set; std::set<std::string> compare_set;


for (auto &pair : pairs) { for (auto &pair : pairs) {
bool is_fusion_child = (fusion_op_children_.find(node_name) != fusion_op_children_.end()) ||
(fusion_op_children_.find(iter->first) != fusion_op_children_.end());
bool is_fusion_op = (fusion_op_type_map_.find(node_name) != fusion_op_type_map_.end()) ||
(fusion_op_type_map_.find(iter->first) != fusion_op_type_map_.end());
bool is_fusion_child = (fusion_op_children_.find(node_name) != fusion_op_children_.cend()) ||
(fusion_op_children_.find(iter->first) != fusion_op_children_.cend());
bool is_fusion_op = (fusion_op_type_map_.find(node_name) != fusion_op_type_map_.cend()) ||
(fusion_op_type_map_.find(iter->first) != fusion_op_type_map_.cend());
if (((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) && if (((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) &&
(is_fusion_child || is_fusion_op)) { (is_fusion_child || is_fusion_op)) {
// The edge will be cut off at the back, ignoring // The edge will be cut off at the back, ignoring
@@ -2119,7 +2115,7 @@ Status TensorFlowModelParser::NormalizeInputOrOutputMap(
} }


string name = to_string(pair.first) + ":" + to_string(pair.second); string name = to_string(pair.first) + ":" + to_string(pair.second);
std::set<std::string>::const_iterator compare_iter = compare_set.find(name);
const std::set<std::string>::const_iterator compare_iter = compare_set.find(name);
if (compare_iter != compare_set.end()) { if (compare_iter != compare_set.end()) {
// pair<from,to> repeat, ignore // pair<from,to> repeat, ignore
continue; continue;
@@ -2158,7 +2154,7 @@ void TensorFlowModelParser::SaveEdgesControlInfo(const string &node_name, const
} }


void TensorFlowModelParser::UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo &info) { void TensorFlowModelParser::UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo &info) {
std::map<std::string, std::vector<int32_t>>::const_iterator iter = edges_control_map.find(info.node_name);
const std::map<std::string, std::vector<int32_t>>::const_iterator iter = edges_control_map.find(info.node_name);
if (iter != edges_control_map.end()) { if (iter != edges_control_map.end()) {
// Delete the original fusion operator node information and add the fusion operator control edge information // Delete the original fusion operator node information and add the fusion operator control edge information
edges_control_map.erase(iter); edges_control_map.erase(iter);
@@ -2228,7 +2224,8 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,
GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(graph);
ge::GetParserContext().train_flag = true; ge::GetParserContext().train_flag = true;


const domi::tensorflow::GraphDef *graph_def_in = reinterpret_cast<const domi::tensorflow::GraphDef *>(proto);
const domi::tensorflow::GraphDef *graph_def_in =
ge::PtrToPtr<google::protobuf::Message, domi::tensorflow::GraphDef>(proto);
// Make a copy for operation without modifying the original graph def. // Make a copy for operation without modifying the original graph def.
domi::tensorflow::GraphDef graph_def_operation = *graph_def_in; domi::tensorflow::GraphDef graph_def_operation = *graph_def_in;
domi::tensorflow::GraphDef *graph_def = &graph_def_operation; domi::tensorflow::GraphDef *graph_def = &graph_def_operation;
@@ -2415,7 +2412,7 @@ Status TensorFlowModelParser::ParseProto(const std::string &serialized_proto, ge
GELOGE(FAILED, "Proto object GraphDef parse serialized proto failed"); GELOGE(FAILED, "Proto object GraphDef parse serialized proto failed");
return FAILED; return FAILED;
} }
return ParseProto(reinterpret_cast<const google::protobuf::Message *>(&graph_def), graph);
return ParseProto(ge::PtrToPtr<domi::tensorflow::GraphDef, const google::protobuf::Message>(&graph_def), graph);
} }


Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback, Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback,
@@ -2472,95 +2469,18 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_pro
return SUCCESS; return SUCCESS;
} }


// For the identity operator whose output is "_retval", optimize it.
Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map,
const string &curr_node_name, bool &clear_input_flag) {
auto context_iter = op_node_context_map_.find(curr_node_name);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
(context_iter == op_node_context_map_.end()),
REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str());
return INTERNAL_ERROR, "Can't find op node context.");
OpNodeContext op_node_context = context_iter->second;

std::map<std::string, NodeDef *>::const_iterator node_def_iter = nodedef_map.find(curr_node_name);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
(node_def_iter == nodedef_map.end()),
REPORT_INNER_ERROR("E19999", "Node:%s can't find in nodedef_map, check invalid", curr_node_name.c_str());
return INTERNAL_ERROR, "Can't find nodedef");
domi::tensorflow::NodeDef *curr_node_def = node_def_iter->second;
GE_CHECK_NOTNULL(curr_node_def);
bool has_out_retval = false;
// For the identity operator whose output is "_retval", optimize it
std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map;
for (auto output_iter = output_map.begin(); output_iter != output_map.end(); ++output_iter) {
const string &output_node_name = output_iter->first;
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name];
GE_CHECK_NOTNULL(output_node_def);
if (output_node_def->op() == "_Retval") {
GELOGD("_Retval Identity need optimize.");
output_node_def->set_input(0, curr_node_def->input(0).c_str());
has_out_retval = true;
GELOGD("op %s set input(0):%s.", output_node_def->name().c_str(), curr_node_def->input(0).c_str());
}
}

// Deal with non _Retval output operator of Identity.
if (has_out_retval) {
for (auto output_iter = output_map.begin(); output_iter != output_map.end(); ++output_iter) {
const string &output_node_name = output_iter->first;
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name];
GE_CHECK_NOTNULL(output_node_def);
GE_IF_BOOL_EXEC(output_node_def->op() == "_Retval", continue);
for (int k = 0; k < output_node_def->input_size(); ++k) {
GE_IF_BOOL_EXEC(
output_node_def->input(k) == curr_node_name, output_node_def->set_input(k, curr_node_def->input(0).c_str());
GELOGD("%s op set input(%d):%s.", output_node_def->name().c_str(), k, curr_node_def->input(0).c_str());)
}
}
clear_input_flag = true;
}
return SUCCESS;
}

Status TensorFlowModelParser::GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def,
map<string, NodeDef *> &nodedef_map,
const vector<NodeDef *> &nodedef_to_optimize) {
GE_CHECK_NOTNULL(graph_def);
if (!nodedef_to_optimize.empty()) {
// Building input and input relationships for all OP nodes
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def));
} else {
return SUCCESS;
}
for (auto &curr_node_def : nodedef_to_optimize) {
GE_CHECK_NOTNULL(curr_node_def);
bool clear_input_flag = false;
const string &curr_node_name = curr_node_def->name();
GE_RETURN_IF_ERROR(OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag));
if (clear_input_flag) {
curr_node_def->clear_input();
}
}
GELOGI("GraphDefOptimizeIdentity success.");
return SUCCESS;
}

Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def,
map<string, NodeDef *> &nodedef_map, map<string, NodeDef *> &nodedef_map,
const std::pair<string, int> &input_data, const std::pair<string, int> &input_data,
const std::vector<string> &control_list) { const std::vector<string> &control_list) {
GE_CHECK_NOTNULL(curr_mode_def); GE_CHECK_NOTNULL(curr_mode_def);
if (curr_mode_def == nullptr) {
REPORT_INNER_ERROR("E19999", "Param curr_mode_def is nullptr, check invalid");
GELOGE(FAILED, "input param is nullptr.");
return PARAM_INVALID;
}
string curr_node_name = curr_mode_def->name(); string curr_node_name = curr_mode_def->name();
auto context_iter = op_node_context_map_.find(curr_node_name); auto context_iter = op_node_context_map_.find(curr_node_name);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
(context_iter == op_node_context_map_.end()),
REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str());
return INTERNAL_ERROR, "Can't find op node context.");
if (context_iter == op_node_context_map_.end()) {
REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str());
GELOGE(FAILED, "Can't find op node context.");
return INTERNAL_ERROR;
}
OpNodeContext op_node_context = context_iter->second; OpNodeContext op_node_context = context_iter->second;


std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map; std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map;
@@ -2773,7 +2693,7 @@ struct DelTransposeInfo {
int inputIdx; int inputIdx;
}; };


Status GetTransposeInfo(GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo,
Status GetTransposeInfo(domi::tensorflow::GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo,
std::map<std::string, DelTransposeInfo> &transposeInfo) { std::map<std::string, DelTransposeInfo> &transposeInfo) {
GE_CHECK_NOTNULL(graph_def); GE_CHECK_NOTNULL(graph_def);
for (int i = 0; i < graph_def->node_size(); ++i) { for (int i = 0; i < graph_def->node_size(); ++i) {
@@ -2826,7 +2746,7 @@ Status EraseTransposeNode(std::map<std::string, std::string> &softmaxInfo,
itTranspose->second.node_def->input(0).c_str()); itTranspose->second.node_def->input(0).c_str());
itTranspose = transposeInfo.erase(itTranspose); itTranspose = transposeInfo.erase(itTranspose);
} else { } else {
itTranspose++;
++itTranspose;
} }
} }


@@ -2847,7 +2767,7 @@ void TensorFlowModelParser::OptimizeTranspose(std::map<std::string, DelTranspose
} }
} }


void TensorFlowModelParser::SoftmaxAddAttr(GraphDef *const graph_def) {
void TensorFlowModelParser::SoftmaxAddAttr(domi::tensorflow::GraphDef *const graph_def) {
// The caller guarantees that the pointer is not null // The caller guarantees that the pointer is not null
for (int i = 0; i < graph_def->node_size(); ++i) { for (int i = 0; i < graph_def->node_size(); ++i) {
auto node_def = graph_def->mutable_node(i); auto node_def = graph_def->mutable_node(i);
@@ -2864,8 +2784,6 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph
GE_CHECK_NOTNULL(graph_def); GE_CHECK_NOTNULL(graph_def);
map<string, NodeDef *> nodedef_map; map<string, NodeDef *> nodedef_map;
vector<string> op_node_name_list; vector<string> op_node_name_list;
// Save Identity and ReadVariableOp
vector<NodeDef *> identity_to_optimize;
// Save Snapshot // Save Snapshot
vector<NodeDef *> snapshot_to_optimize; vector<NodeDef *> snapshot_to_optimize;


@@ -2875,16 +2793,12 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph
const string &node_name = node_def->name(); const string &node_name = node_def->name();
Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list); Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list);
GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed");
if (node_def->op() == ge::parser::IDENTITY || node_def->op() == ge::parser::READVARIABLEOP) {
identity_to_optimize.push_back(node_def);
} else if (node_def->op() == ge::parser::SNAPSHOT) {
if (node_def->op() == ge::parser::SNAPSHOT) {
snapshot_to_optimize.push_back(node_def); snapshot_to_optimize.push_back(node_def);
} }
nodedef_map[node_name] = node_def; nodedef_map[node_name] = node_def;
} }


// Optimize for Identity/ReadVariableOp
GE_RETURN_IF_ERROR(GraphDefOptimizeIdentity(graph_def, nodedef_map, identity_to_optimize));
// Optimize for Snapshot // Optimize for Snapshot
GE_RETURN_IF_ERROR(GraphDefOptimizeSnapShot(graph_def, nodedef_map, snapshot_to_optimize)); GE_RETURN_IF_ERROR(GraphDefOptimizeSnapShot(graph_def, nodedef_map, snapshot_to_optimize));


@@ -3055,7 +2969,7 @@ Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node,
GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TENSOR) != SUCCESS, GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TENSOR) != SUCCESS,
return FAILED); return FAILED);
const TensorProto &tensor = attr_value.tensor(); const TensorProto &tensor = attr_value.tensor();
const TensorShapeProto &tensor_shape = tensor.tensor_shape();
const domi::tensorflow::TensorShapeProto &tensor_shape = tensor.tensor_shape();
GE_IF_BOOL_EXEC(tensor_shape.dim_size() != 1 || tensor_shape.dim(0).size() != parser::DIM_DEFAULT_SIZE, GE_IF_BOOL_EXEC(tensor_shape.dim_size() != 1 || tensor_shape.dim(0).size() != parser::DIM_DEFAULT_SIZE,
return SUCCESS); return SUCCESS);
GE_IF_BOOL_EXEC(tensor.tensor_content().empty(), return SUCCESS); GE_IF_BOOL_EXEC(tensor.tensor_content().empty(), return SUCCESS);
@@ -3110,10 +3024,10 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef
std::set<string> next_inputs; std::set<string> next_inputs;
for (const string &current_input : current_inputs) { for (const string &current_input : current_inputs) {
delete_nodes.insert(current_input); delete_nodes.insert(current_input);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input),
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"input_shape", current_input});
return FAILED, "Input op[%s] not found in graph.", current_input.c_str());
GE_CHK_BOOL_EXEC(node_lookup.count(current_input) > 0U,
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"input_shape", current_input});
return FAILED, "Input op[%s] not found in graph.", current_input.c_str());
const NodeDef *current_node = node_lookup[current_input]; const NodeDef *current_node = node_lookup[current_input];
GE_CHECK_NOTNULL(current_node); GE_CHECK_NOTNULL(current_node);
for (const string &input_name : current_node->input()) { for (const string &input_name : current_node->input()) {
@@ -3128,7 +3042,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef
domi::tensorflow::GraphDef filtered_graph_def; domi::tensorflow::GraphDef filtered_graph_def;
filtered_graph_def.mutable_node()->Clear(); filtered_graph_def.mutable_node()->Clear();
for (const NodeDef &node : input_graph_def.node()) { for (const NodeDef &node : input_graph_def.node()) {
if (input_nodes.count(node.name())) {
if (static_cast<bool>(input_nodes.count(node.name()))) {
*(filtered_graph_def.mutable_node()->Add()) = node; *(filtered_graph_def.mutable_node()->Add()) = node;
} }
if (!delete_nodes.count(node.name())) { if (!delete_nodes.count(node.name())) {
@@ -3137,12 +3051,12 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef
} }
output_graph_def->Clear(); output_graph_def->Clear();
for (const NodeDef &node : filtered_graph_def.node()) { for (const NodeDef &node : filtered_graph_def.node()) {
if (input_nodes.count(node.name())) {
if (static_cast<bool>(input_nodes.count(node.name()))) {
NodeDef placeholder_node = node; NodeDef placeholder_node = node;
placeholder_node.clear_input(); placeholder_node.clear_input();
GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder"));
domi::tensorflow::AttrValue attr_value; domi::tensorflow::AttrValue attr_value;
TensorShapeProto *data_shape = attr_value.mutable_shape();
domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape();
GE_CHECK_NOTNULL(data_shape); GE_CHECK_NOTNULL(data_shape);
const ge::ParserContext &ctx = ge::GetParserContext(); const ge::ParserContext &ctx = ge::GetParserContext();
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
@@ -3185,11 +3099,11 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef
std::set<string> next_inputs; std::set<string> next_inputs;
for (const string &current_input : current_inputs) { for (const string &current_input : current_inputs) {
required_nodes.insert(current_input); required_nodes.insert(current_input);
GE_IF_BOOL_EXEC(input_nodes.count(current_input), continue);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input),
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"out_nodes", current_input});
return FAILED, "Input op[%s] not found in graph.", current_input.c_str());
GE_IF_BOOL_EXEC(static_cast<bool>(input_nodes.count(current_input)), continue);
GE_CHK_BOOL_EXEC(node_lookup.count(current_input) > 0U,
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"out_nodes", current_input});
return FAILED, "op[%s] not found in graph.", current_input.c_str());
const NodeDef *current_node = node_lookup[current_input]; const NodeDef *current_node = node_lookup[current_input];
GE_CHECK_NOTNULL(current_node); GE_CHECK_NOTNULL(current_node);
for (const string &input_name : current_node->input()) { for (const string &input_name : current_node->input()) {
@@ -3204,18 +3118,18 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef
domi::tensorflow::GraphDef filtered_graph_def; domi::tensorflow::GraphDef filtered_graph_def;
filtered_graph_def.mutable_node()->Clear(); filtered_graph_def.mutable_node()->Clear();
for (const NodeDef &node : input_graph_def.node()) { for (const NodeDef &node : input_graph_def.node()) {
if (required_nodes.count(node.name())) {
if (static_cast<bool>(required_nodes.count(node.name()))) {
*(filtered_graph_def.mutable_node()->Add()) = node; *(filtered_graph_def.mutable_node()->Add()) = node;
} }
} }
output_graph_def->Clear(); output_graph_def->Clear();
for (const NodeDef &node : filtered_graph_def.node()) { for (const NodeDef &node : filtered_graph_def.node()) {
if (input_nodes.count(node.name())) {
if (static_cast<bool>(input_nodes.count(node.name()))) {
NodeDef placeholder_node = node; NodeDef placeholder_node = node;
placeholder_node.clear_input(); placeholder_node.clear_input();
GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder"));
domi::tensorflow::AttrValue attr_value; domi::tensorflow::AttrValue attr_value;
TensorShapeProto *data_shape = attr_value.mutable_shape();
domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape();
GE_CHECK_NOTNULL(data_shape); GE_CHECK_NOTNULL(data_shape);
const ge::ParserContext &ctx = ge::GetParserContext(); const ge::ParserContext &ctx = ge::GetParserContext();
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
@@ -3265,11 +3179,12 @@ Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_par


// Find all children of the fusion operator // Find all children of the fusion operator
auto iter = fusion_op_nodedef_map_.find(node_def->name()); auto iter = fusion_op_nodedef_map_.find(node_def->name());
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
iter == fusion_op_nodedef_map_.end(),
REPORT_INNER_ERROR("E19999", "Node:%s can't find in fusion_op_nodedef_map_, check invalid",
node_def->name().c_str());
return INTERNAL_ERROR, "FusionOp node %s has no children node!", node_def->name().c_str());
if (iter == fusion_op_nodedef_map_.end()) {
REPORT_INNER_ERROR("E19999", "Node:%s can't find in fusion_op_nodedef_map_, check invalid",
node_def->name().c_str());
GELOGE(FAILED, "FusionOp node %s has no children node!", node_def->name().c_str());
return INTERNAL_ERROR;
}


(void)ge::AttrUtils::SetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, node_def->op()); (void)ge::AttrUtils::SetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, node_def->op());
vector<const domi::tensorflow::NodeDef *> node_def_v = iter->second; vector<const domi::tensorflow::NodeDef *> node_def_v = iter->second;
@@ -3325,7 +3240,7 @@ Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_par
* @return false optimize failed * @return false optimize failed
* *
*/ */
Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) {
Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) const {
GE_CHECK_NOTNULL(graph_def); GE_CHECK_NOTNULL(graph_def);
// 1. find all the nodes in the graph and save them to all_nodedef_map // 1. find all the nodes in the graph and save them to all_nodedef_map
map<string, NodeDef *> all_nodedef_map; map<string, NodeDef *> all_nodedef_map;
@@ -3336,7 +3251,7 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap
string node_name = current_node->name(); string node_name = current_node->name();
all_nodedef_map[node_name] = current_node; all_nodedef_map[node_name] = current_node;
} }
GE_CHK_BOOL_EXEC_INFO(!all_nodedef_map.empty(), return SUCCESS, "all_nodedef_map is empty");
GELOGD("node size is: %zu", all_nodedef_map.size());


// 2. move input to attr. // 2. move input to attr.
for (auto &it_node_map : all_nodedef_map) { for (auto &it_node_map : all_nodedef_map) {
@@ -3347,14 +3262,14 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap
// 2.1. check whether the current op is register for move to attr. // 2.1. check whether the current op is register for move to attr.
const std::vector<domi::RemoveInputConfigure> &move_input_vec = const std::vector<domi::RemoveInputConfigure> &move_input_vec =
domi::OpRegistry::Instance()->GetRemoveInputConfigure(current_op_name); domi::OpRegistry::Instance()->GetRemoveInputConfigure(current_op_name);
GE_CHK_BOOL_EXEC_NOLOG(!move_input_vec.empty(), continue);
GELOGD("Current op %s is registered for remove input.", current_op_name.c_str());

// 2.2 check whether the current op is a TVM op. // 2.2 check whether the current op is a TVM op.
GE_CHK_BOOL_EXEC_INFO(
domi::OpRegistry::Instance()->GetImplyTypeByOriOpType(current_op_name) == domi::ImplyType::TVM, continue,
"op %s is not TVM op", current_op_name.c_str());
GELOGD("handle tvm op %s", current_op_name.c_str());
const bool is_unknown_custom_op = move_input_vec.empty() ||
(domi::OpRegistry::Instance()->GetImplyTypeByOriOpType(current_op_name) != domi::ImplyType::TVM);
if (is_unknown_custom_op) {
GELOGI("op %s is not TVM op, move input size: %zu", current_op_name.c_str(), move_input_vec.size());
continue;
}
GELOGD("Current op %s is registered for remove input and tvm op", current_op_name.c_str());


// 2.3 copy input to attr // 2.3 copy input to attr
set<uint32_t> unused_inputs; set<uint32_t> unused_inputs;
@@ -3401,7 +3316,8 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap
} }
for (size_t i = 0; i < it.input_order.size(); ++i) { for (size_t i = 0; i < it.input_order.size(); ++i) {
int new_index = it.input_order[i]; int new_index = it.input_order[i];
if (new_index < 0 || new_index >= inputs.size()) {
const bool is_input_invalid = (new_index < 0) || (new_index >= inputs.size());
if (is_input_invalid) {
REPORT_INNER_ERROR("E19999", "New order of %s has invalid index %d, out of range(0, %d)", REPORT_INNER_ERROR("E19999", "New order of %s has invalid index %d, out of range(0, %d)",
it_node_map.first.c_str(), new_index, inputs.size()); it_node_map.first.c_str(), new_index, inputs.size());
GELOGE(INTERNAL_ERROR, "New order of %s has invalid index %d.", it_node_map.first.c_str(), new_index); GELOGE(INTERNAL_ERROR, "New order of %s has invalid index %d.", it_node_map.first.c_str(), new_index);
@@ -3443,7 +3359,10 @@ Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow::
if (input_node_def->op() == parser::SWITCH || input_node_def->op() == parser::REFSWITCH) { if (input_node_def->op() == parser::SWITCH || input_node_def->op() == parser::REFSWITCH) {
NodeDef *identity_node_def = graph_def->add_node(); NodeDef *identity_node_def = graph_def->add_node();
GE_CHECK_NOTNULL(identity_node_def); GE_CHECK_NOTNULL(identity_node_def);
input_node_name = input_node_name + "identity";
std::string remove_input_name = remove_input;
remove_input_name = remove_input_name.find(":") == std::string::npos ?
input_node_name : (remove_input_name.replace(remove_input_name.find(":"), 1, "_"));
input_node_name = remove_input_name + "_identity";
identity_node_def->set_name(input_node_name); identity_node_def->set_name(input_node_name);
identity_node_def->set_op(parser::IDENTITY); identity_node_def->set_op(parser::IDENTITY);
identity_node_def->add_input(remove_input); identity_node_def->add_input(remove_input);
@@ -3465,7 +3384,7 @@ Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow::
*/ */
Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *node_def, Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *node_def,
const set<uint32_t> &remove_index_set, const set<uint32_t> &remove_index_set,
const map<string, NodeDef *> &all_node_map) {
const map<string, NodeDef *> &all_node_map) const {
GE_CHECK_NOTNULL(node_def); GE_CHECK_NOTNULL(node_def);
if (remove_index_set.empty()) { if (remove_index_set.empty()) {
GELOGI("The size of remove_index_set is zero."); GELOGI("The size of remove_index_set is zero.");
@@ -3662,7 +3581,7 @@ Status TensorFlowModelParser::RecordFusionResult(const std::shared_ptr<ge::Scope
return SUCCESS; return SUCCESS;
} }


Status TensorFlowModelParser::SetOriginNodeContext(NodeDef *node_def, OpNodeContext &op_node_context,
Status TensorFlowModelParser::SetOriginNodeContext(const NodeDef *node_def, OpNodeContext &op_node_context,
const std::vector<std::pair<std::string, int32_t>> &inputs, const std::vector<std::pair<std::string, int32_t>> &inputs,
const std::vector<std::pair<std::string, int32_t>> &outputs) { const std::vector<std::pair<std::string, int32_t>> &outputs) {
int32_t in_index = 0; int32_t in_index = 0;
@@ -3752,7 +3671,7 @@ void TensorFlowModelParser::UpdateInnerInputMap(const string &fusion_op_name, Op
++iter; ++iter;
} }
} }
op_node_context.input_map.insert(tmp_input_map.begin(), tmp_input_map.end());
op_node_context.input_map.insert(tmp_input_map.cbegin(), tmp_input_map.cend());
// update output map of pre node // update output map of pre node
for (const auto &in_iter : op_node_context.input_map) { for (const auto &in_iter : op_node_context.input_map) {
auto src_iter = op_node_context_map_.find(in_iter.first); auto src_iter = op_node_context_map_.find(in_iter.first);
@@ -3801,7 +3720,7 @@ void TensorFlowModelParser::UpdateInnerOutputMap(const string &fusion_op_name, O
++iter; ++iter;
} }
} }
op_node_context.output_map.insert(tmp_output_map.begin(), tmp_output_map.end());
op_node_context.output_map.insert(tmp_output_map.cbegin(), tmp_output_map.cend());
// update input map of pre node // update input map of pre node
for (const auto &out_iter : op_node_context.output_map) { for (const auto &out_iter : op_node_context.output_map) {
auto dst_iter = op_node_context_map_.find(out_iter.first); auto dst_iter = op_node_context_map_.find(out_iter.first);
@@ -3902,7 +3821,7 @@ Status TensorFlowModelParser::AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope
DumpAllNodeContext("BeforeAddFusionNodeDef"); DumpAllNodeContext("BeforeAddFusionNodeDef");
for (size_t i = 0; i < op_node_list_size; ++i) { for (size_t i = 0; i < op_node_list_size; ++i) {
const string op_node_name = node_name_list[i]; const string op_node_name = node_name_list[i];
auto iter = fusion_op_nodedef_map_.find(op_node_name);
std::map<string, vector<const NodeDef *>>::const_iterator iter = fusion_op_nodedef_map_.find(op_node_name);
if (iter != fusion_op_nodedef_map_.end()) { if (iter != fusion_op_nodedef_map_.end()) {
vector<string> fusion_op_info = fusion_op_type_map_[op_node_name]; vector<string> fusion_op_info = fusion_op_type_map_[op_node_name];
if (fusion_op_info[0] != ge::kScopeToMultiNodes) { if (fusion_op_info[0] != ge::kScopeToMultiNodes) {
@@ -3943,7 +3862,8 @@ Status TensorFlowModelParser::AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope
} }


Status TensorFlowModelParser::AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, Status TensorFlowModelParser::AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph,
std::mutex *graph_mutex, const domi::tensorflow::NodeDef *node_def) {
std::mutex *const graph_mutex,
const domi::tensorflow::NodeDef *node_def) {
// This is an internal function. The pointer input parameter is not empty when this function is invoked. // This is an internal function. The pointer input parameter is not empty when this function is invoked.
string node_name = node_def->name(); string node_name = node_def->name();
string node_op = node_def->op(); string node_op = node_def->op();
@@ -4059,7 +3979,7 @@ Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph
std::string model_data; std::string model_data;
if (AttrUtils::GetStr(node->GetOpDesc(), kExternalModel, model_data) && !model_data.empty()) { if (AttrUtils::GetStr(node->GetOpDesc(), kExternalModel, model_data) && !model_data.empty()) {
ge::Model model; ge::Model model;
auto load_ret = ge::Model::Load(reinterpret_cast<const uint8_t *>(model_data.data()), model_data.size(), model);
auto load_ret = ge::Model::Load(ge::PtrToPtr<char_t, const uint8_t>(model_data.data()), model_data.size(), model);
if (load_ret != GRAPH_SUCCESS) { if (load_ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Parse][ExternalModel]Node:%s.", node->GetName().c_str()); GELOGE(INTERNAL_ERROR, "[Parse][ExternalModel]Node:%s.", node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to parse external model, node:%s.", node->GetName().c_str()); REPORT_CALL_ERROR("E19999", "Failed to parse external model, node:%s.", node->GetName().c_str());


+ 32
- 46
parser/tensorflow/tensorflow_parser.h View File

@@ -35,6 +35,7 @@
#include "omg/parser/model_parser.h" #include "omg/parser/model_parser.h"
#include "omg/parser/op_parser.h" #include "omg/parser/op_parser.h"
#include "omg/parser/weights_parser.h" #include "omg/parser/weights_parser.h"
#include "common/pre_checker.h"
#include "parser/tensorflow/tensorflow_fusion_op_parser.h" #include "parser/tensorflow/tensorflow_fusion_op_parser.h"
#include "parser/tensorflow/tensorflow_util.h" #include "parser/tensorflow/tensorflow_util.h"
#include "proto/om.pb.h" #include "proto/om.pb.h"
@@ -46,15 +47,6 @@
#include "scope/scope_pass_manager.h" #include "scope/scope_pass_manager.h"
#include "common/parser_utils.h" #include "common/parser_utils.h"


using ge::ScopePassManager;
using domi::tensorflow::GraphDef;
using domi::tensorflow::DT_HALF;
using domi::tensorflow::NodeDef;
using domi::tensorflow::GraphDef;
using domi::tensorflow::AttrValue;
using domi::tensorflow::DataType;
using ge::OpParser;

namespace ge { namespace ge {
using std::string; using std::string;
using std::vector; using std::vector;
@@ -130,7 +122,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {


Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto,
domi::GetGraphCallback callback, domi::GetGraphCallback callback,
ge::ComputeGraphPtr &graph) override;
ge::ComputeGraphPtr &root_graph) override;


/* /*
* @ingroup domi_omg * @ingroup domi_omg
@@ -163,6 +155,18 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
*/ */
Status ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback, Status ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback,
ge::ComputeGraphPtr &root_graph) override; ge::ComputeGraphPtr &root_graph) override;

bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}
private: private:
Status Parse(const char *model_path, ge::ComputeGraphPtr &root_graph); Status Parse(const char *model_path, ge::ComputeGraphPtr &root_graph);


@@ -241,15 +245,6 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {


/** /**
* @ingroup domi_omg * @ingroup domi_omg
* @brief Verifying the validity of graphdef object parsed by pb
* @param [in] graph_def Parsed tensorflow:: graphdef object
* @return SUCCESS check successfully
* @return FAILED check failed
*/
Status CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const;

/**
* @ingroup domi_omg
* @brief whether const OP need to update context * @brief whether const OP need to update context
* @param const op name * @param const op name
* @return true or false * @return true or false
@@ -433,28 +428,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
* @brief Delete the connection relationship of the identity operator connecting the Arg node in graphdef * @brief Delete the connection relationship of the identity operator connecting the Arg node in graphdef
*/ */
Status GraphDefOptimize(domi::tensorflow::GraphDef *graph_def); Status GraphDefOptimize(domi::tensorflow::GraphDef *graph_def);
/**
* @ingroup domi_omg
* @brief Optimize for Identity/ReadVariableOp operator
* @param [in] graph_def GraphDef to be optimized
* @param [in] nodedef_map Map of all nodes in graph
* @param [in] nodedef_to_optimize vector of NodeDef to be optimized
* @return SUCCESS optimize successfully
* @return others failed
*/
Status GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map,
const vector<NodeDef *> &nodedef_to_optimize);
/**
* @ingroup domi_omg
* @brief For the identity operator whose output is "_retval", optimize it.
* @param [in] nodedef_map Map of all nodes in graph
* @param [in] curr_node_name Name of node to be optimized
* @param [in] clear_input_flag Flag of whether to clear the input of the current node
* @return SUCCESS optimize successfully
* @return others failed
*/
Status OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, const string &curr_node_name,
bool &clear_input_flag);

Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map, Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map,
const vector<NodeDef *> &nodedef_to_optimize); const vector<NodeDef *> &nodedef_to_optimize);
Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def,
@@ -469,7 +443,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
void OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *const graph_def, void OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *const graph_def,
domi::tensorflow::NodeDef *const nodeCurrent, bool &clearInputFlag) const; domi::tensorflow::NodeDef *const nodeCurrent, bool &clearInputFlag) const;
static void OptimizeTranspose(std::map<std::string, DelTransposeInfo> &transposeInfo); static void OptimizeTranspose(std::map<std::string, DelTransposeInfo> &transposeInfo);
static void SoftmaxAddAttr(GraphDef *const graph_def);
static void SoftmaxAddAttr(domi::tensorflow::GraphDef *const graph_def);


/** /**
* @ingroup domi_omg * @ingroup domi_omg
@@ -551,7 +525,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
* @return false optimize failed * @return false optimize failed
* *
*/ */
Status OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def);
Status OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) const;


/** /**
* @ingroup domi_omg * @ingroup domi_omg
@@ -565,7 +539,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
Status RemoveInputs(domi::tensorflow::GraphDef *graph_def, Status RemoveInputs(domi::tensorflow::GraphDef *graph_def,
domi::tensorflow::NodeDef *node_def, domi::tensorflow::NodeDef *node_def,
const set<uint32_t> &remove_index_set, const set<uint32_t> &remove_index_set,
const map<string, NodeDef *> &all_node_map);
const map<string, NodeDef *> &all_node_map) const;


Status AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def, Status AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def,
domi::tensorflow::NodeDef *node_def, domi::tensorflow::NodeDef *node_def,
@@ -611,7 +585,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {


static Status GetFunctionProto(const string &file, domi::tensorflow::GraphDefLibrary &graph_def_library); static Status GetFunctionProto(const string &file, domi::tensorflow::GraphDefLibrary &graph_def_library);


Status SetOriginNodeContext(NodeDef *node_def, OpNodeContext &op_node_context,
Status SetOriginNodeContext(const NodeDef *node_def, OpNodeContext &op_node_context,
const std::vector<std::pair<std::string, int32_t>> &inputs, const std::vector<std::pair<std::string, int32_t>> &inputs,
const std::vector<std::pair<std::string, int32_t>> &outputs); const std::vector<std::pair<std::string, int32_t>> &outputs);


@@ -642,7 +616,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
Status AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph, vector<string> &node_name_list); Status AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph, vector<string> &node_name_list);


static Status AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, static Status AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph,
std::mutex *graph_mutex, const domi::tensorflow::NodeDef *node_def);
std::mutex *const graph_mutex, const domi::tensorflow::NodeDef *node_def);


static void DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase); static void DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase);
void DumpAllNodeContext(const string &phase) const; void DumpAllNodeContext(const string &phase) const;
@@ -725,6 +699,18 @@ class PARSER_FUNC_VISIBILITY TensorFlowWeightsParser : public domi::WeightsParse
Status Parse(const char *file, ge::Graph &graph) override; Status Parse(const char *file, ge::Graph &graph) override;


Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;

bool HasError() override {
return PreChecker::Instance().HasError();
}

Status Save(const string &file) override {
return PreChecker::Instance().Save(file);
}

void Clear() override {
PreChecker::Instance().Clear();
}
}; };
} // namespace domi } // namespace domi
#endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_ #endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_

+ 0
- 2
parser/tensorflow/tensorflow_parser_register.h View File

@@ -30,8 +30,6 @@
#include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_op_parser.h"
#include "proto/tensorflow/node_def.pb.h" #include "proto/tensorflow/node_def.pb.h"


using domi::tensorflow::NodeDef;

namespace ge { namespace ge {
class PARSER_FUNC_VISIBILITY TensorflowFinalizeable { class PARSER_FUNC_VISIBILITY TensorflowFinalizeable {
public: public:


+ 0
- 2
parser/tensorflow/tensorflow_ref_switch_parser.h View File

@@ -20,8 +20,6 @@
#include "common/op_def/ref_switch_op.h" #include "common/op_def/ref_switch_op.h"
#include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_op_parser.h"


using domi::tensorflow::NodeDef;

namespace ge { namespace ge {
class PARSER_FUNC_VISIBILITY TensorFlowRefSwitchParser : public TensorFlowOpParser { class PARSER_FUNC_VISIBILITY TensorFlowRefSwitchParser : public TensorFlowOpParser {
// AUTO GEN PLEASE DO NOT MODIFY IT // AUTO GEN PLEASE DO NOT MODIFY IT


+ 1
- 1
parser/tensorflow/tensorflow_reshape_parser.cc View File

@@ -61,7 +61,7 @@ Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr
GE_CHECK_NOTNULL(op_src); GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op); GE_CHECK_NOTNULL(op);


const NodeDef *node_src = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src);
const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src);
GE_CHECK_NOTNULL(node_src); GE_CHECK_NOTNULL(node_src);
GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str());
domi::tensorflow::AttrValue input_attr_value; domi::tensorflow::AttrValue input_attr_value;


+ 1
- 1
parser/tensorflow/tensorflow_reshape_parser.h View File

@@ -34,7 +34,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowReshapeParser : public TensorFlowOpParser
* @return FAILED parse failed * @return FAILED parse failed
* @author * @author
*/ */
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override;
}; };
} // namespace ge } // namespace ge




+ 3
- 3
parser/tensorflow/tensorflow_shape_n_parser.cc View File

@@ -94,7 +94,7 @@ Status TensorFlowShapeNParser::ParseN(const domi::tensorflow::NodeDef *node, Sha


Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {
GE_CHECK_NOTNULL(op_dest); GE_CHECK_NOTNULL(op_dest);
const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src);
const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src);
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
ShapeNOperator op; ShapeNOperator op;
op.Name(node->name()); op.Name(node->name());
@@ -154,13 +154,13 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr
} }


// AUTO GEN PLEASE DO NOT MODIFY IT // AUTO GEN PLEASE DO NOT MODIFY IT
Status TensorFlowShapeNParser::PreParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) {
Status TensorFlowShapeNParser::PreParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op) {
(void)node; (void)node;
(void)op; (void)op;
return SUCCESS; return SUCCESS;
} }


Status TensorFlowShapeNParser::PostParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) {
Status TensorFlowShapeNParser::PostParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op) {
(void)node; (void)node;
(void)op; (void)op;
return SUCCESS; return SUCCESS;


+ 2
- 4
parser/tensorflow/tensorflow_shape_n_parser.h View File

@@ -20,8 +20,6 @@
#include "common/op_def/shape_n_op.h" #include "common/op_def/shape_n_op.h"
#include "parser/tensorflow/tensorflow_op_parser.h" #include "parser/tensorflow/tensorflow_op_parser.h"


using domi::tensorflow::NodeDef;

namespace ge { namespace ge {
class PARSER_FUNC_VISIBILITY TensorFlowShapeNParser : public TensorFlowOpParser { class PARSER_FUNC_VISIBILITY TensorFlowShapeNParser : public TensorFlowOpParser {
// AUTO GEN PLEASE DO NOT MODIFY IT // AUTO GEN PLEASE DO NOT MODIFY IT
@@ -29,8 +27,8 @@ class PARSER_FUNC_VISIBILITY TensorFlowShapeNParser : public TensorFlowOpParser
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;


protected: protected:
Status PreParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op);
Status PostParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op);
Status PreParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op);
Status PostParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op);


static Status ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); static Status ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op);
static Status ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); static Status ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op);


+ 1
- 1
parser/tensorflow/tensorflow_squeeze_parser.cc View File

@@ -66,7 +66,7 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr
GE_CHECK_NOTNULL(op_src); GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op); GE_CHECK_NOTNULL(op);


const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src);
const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src);
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
bool has_axis = true; bool has_axis = true;


+ 1
- 1
parser/tensorflow/tensorflow_squeeze_parser.h View File

@@ -22,7 +22,7 @@
namespace ge { namespace ge {
class PARSER_FUNC_VISIBILITY TensorFlowSqueezeParser : public TensorFlowOpParser { class PARSER_FUNC_VISIBILITY TensorFlowSqueezeParser : public TensorFlowOpParser {
public: public:
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override;


private: private:
static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc);


+ 5
- 2
parser/tensorflow/tensorflow_util.cc View File

@@ -207,7 +207,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Ch
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType( FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType(
const NodeDef *node_src, const std::string &attr_src, domi::tensorflow::DataType &data_type) {
const domi::tensorflow::NodeDef *node_src, const std::string &attr_src, domi::tensorflow::DataType &data_type) {
GE_CHECK_NOTNULL(node_src); GE_CHECK_NOTNULL(node_src);


std::string node_name = node_src->name(); std::string node_name = node_src->name();
@@ -301,7 +301,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr
} }
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr( FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr(
const std::string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *const node_def) { const std::string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *const node_def) {
GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null.");
if (node_def == nullptr) {
GELOGI("input parameter is null.");
return;
}
node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value));
} }
} // namespace ge } // namespace ge

+ 1
- 9
parser/tensorflow/tensorflow_util.h View File

@@ -18,14 +18,11 @@
#define OMG_PARSER_TENSORFLOW_TENSORFLOW_UTIL_H_ #define OMG_PARSER_TENSORFLOW_TENSORFLOW_UTIL_H_


#include <map> #include <map>
#include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "parser/common/op_def/operator.h" #include "parser/common/op_def/operator.h"
#include "external/graph/attr_value.h" #include "external/graph/attr_value.h"
#include "external/graph/graph.h" #include "external/graph/graph.h"
#include "external/graph/operator.h"
#include "framework/omg/parser/parser_types.h" #include "framework/omg/parser/parser_types.h"
#include "framework/omg/omg_inner_types.h" #include "framework/omg/omg_inner_types.h"
#include "graph/compute_graph.h" #include "graph/compute_graph.h"
@@ -37,11 +34,6 @@
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "proto/tensorflow/graph.pb.h" #include "proto/tensorflow/graph.pb.h"


using domi::tensorflow::NodeDef;
using domi::tensorflow::FunctionDef;
using domi::tensorflow::AttrValue_ListValue;
using domi::tensorflow::FunctionDefLibrary;

namespace ge { namespace ge {
/***************************TensorFlow attribute type, constant definition*******************************************/ /***************************TensorFlow attribute type, constant definition*******************************************/
extern const std::string TENSORFLOW_ATTR_TYPE_STRING; extern const std::string TENSORFLOW_ATTR_TYPE_STRING;
@@ -167,7 +159,7 @@ class TensorFlowUtil {
* @return FAILED parsing failed * @return FAILED parsing failed
* *
*/ */
static domi::Status ParseDataType(const NodeDef *node_src,
static domi::Status ParseDataType(const domi::tensorflow::NodeDef *node_src,
const std::string &attr_src, const std::string &attr_src,
domi::tensorflow::DataType &data_type); domi::tensorflow::DataType &data_type);




+ 1
- 1
parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc View File

@@ -25,7 +25,7 @@ using namespace ge::parser;
namespace ge { namespace ge {
Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *const op) { Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *const op) {
GE_CHECK_NOTNULL(op_src); GE_CHECK_NOTNULL(op_src);
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src);
const domi::tensorflow::NodeDef *node = ge::PtrToPtr<Message, domi::tensorflow::NodeDef>(op_src);
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
op->Name(node->name()); op->Name(node->name());


+ 1
- 2
parser/tensorflow/tensorflow_variable_v2_parser.cc View File

@@ -19,7 +19,6 @@
#include "graph/ge_attr_value.h" #include "graph/ge_attr_value.h"
#include "graph/ge_tensor.h" #include "graph/ge_tensor.h"
#include "graph/op_desc.h" #include "graph/op_desc.h"
#include "graph/operator.h"
#include "graph/utils/attr_utils.h" #include "graph/utils/attr_utils.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "parser/common/op_def/variable_op.h" #include "parser/common/op_def/variable_op.h"
@@ -253,7 +252,7 @@ static void ParseMemType(const domi::tensorflow::NodeDef *node, VariableOperator


Status ParseParams(const Message *op_src, VariableOperator *op) { Status ParseParams(const Message *op_src, VariableOperator *op) {
GE_CHECK_NOTNULL(op_src); GE_CHECK_NOTNULL(op_src);
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src);
const NodeDef *node = ge::PtrToPtr<Message, NodeDef>(op_src);
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
string node_op = node->op(); string node_op = node->op();


+ 2
- 2
tests/depends/graph/src/attr_util_stub.cc View File

@@ -121,7 +121,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(
GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def),
REPORT_CALL_ERROR("E19999", "UnserializeOpDesc failed"); REPORT_CALL_ERROR("E19999", "UnserializeOpDesc failed");
return op_desc, "[Call][UnserializeOpDesc] op_desc unserialize failed"); return op_desc, "[Call][UnserializeOpDesc] op_desc unserialize failed");
op_desc->extAttrs_ = org_op_desc->extAttrs_;
op_desc->ext_attrs_ = org_op_desc->ext_attrs_;


// This function may be called by some passes of fusion engine, in this condition, do not need these attribute // This function may be called by some passes of fusion engine, in this condition, do not need these attribute
if (op_desc->impl_ == nullptr) { if (op_desc->impl_ == nullptr) {
@@ -164,7 +164,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c
return nullptr; return nullptr;
} }


op_desc->extAttrs_ = org_op_desc->extAttrs_;
op_desc->ext_attrs_ = org_op_desc->ext_attrs_;


if (op_desc->impl_ == nullptr) { if (op_desc->impl_ == nullptr) {
REPORT_INNER_ERROR("E19999", "op desc impl is nullptr, check invalid"); REPORT_INNER_ERROR("E19999", "op desc impl is nullptr, check invalid");


+ 11
- 0
tests/depends/mmpa/src/mmpa_stub.cc View File

@@ -158,6 +158,17 @@ mmTimespec mmGetTickCount() {
return rts; return rts;
} }


INT32 mmGetSystemTime(mmSystemTime_t *sysTime) {
// Beijing olympics
sysTime->wYear = 2008;
sysTime->wMonth = 8;
sysTime->wDay = 8;
sysTime->wHour = 20;
sysTime->wMinute = 8;
sysTime->wSecond = 0;
return 0;
}

INT32 mmGetTid() { INT32 mmGetTid() {
INT32 ret = (INT32)syscall(SYS_gettid); INT32 ret = (INT32)syscall(SYS_gettid);




+ 6
- 1
tests/st/CMakeLists.txt View File

@@ -61,7 +61,7 @@ target_link_libraries(st_parser_proto PRIVATE


################################################################################ ################################################################################
set(DUPLICATE_PROTO_LIST set(DUPLICATE_PROTO_LIST
"${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto"
"${PARSER_DIR}/metadef/proto/onnx/ge_onnx.proto"
) )


protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST}) protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST})
@@ -118,7 +118,10 @@ set(MATEDEF_SRC_FILES
"${PARSER_DIR}/metadef/graph/resource_context_mgr.cc" "${PARSER_DIR}/metadef/graph/resource_context_mgr.cc"
"${PARSER_DIR}/metadef/graph/utils/constant_utils.cc" "${PARSER_DIR}/metadef/graph/utils/constant_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc" "${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/file_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc" "${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/connection_matrix.cc"
"${PARSER_DIR}/metadef/graph/utils/cycle_detector.cc"
"${PARSER_DIR}/metadef/graph/utils/graph_utils.cc" "${PARSER_DIR}/metadef/graph/utils/graph_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/node_utils.cc" "${PARSER_DIR}/metadef/graph/utils/node_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc" "${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc"
@@ -126,6 +129,8 @@ set(MATEDEF_SRC_FILES
"${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc" "${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc" "${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/type_utils.cc" "${PARSER_DIR}/metadef/graph/utils/type_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/trace/trace_manager.cc"
"${PARSER_DIR}/metadef/graph/common/large_bm.cc"
"${PARSER_DIR}/metadef/ops/op_imp.cpp" "${PARSER_DIR}/metadef/ops/op_imp.cpp"
"${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc" "${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc"
"${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" "${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc"


+ 13
- 1
tests/st/parser_st_utils.cc View File

@@ -21,7 +21,9 @@
#include <google/protobuf/io/zero_copy_stream_impl.h> #include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include <fstream> #include <fstream>

#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>


namespace ge { namespace ge {
void ParerSTestsUtils::ClearParserInnerCtx() { void ParerSTestsUtils::ClearParserInnerCtx() {
@@ -131,4 +133,14 @@ void ParerSTestsUtils::WriteProtoToBinaryFile(const google::protobuf::Message &p
out.close(); out.close();
delete[] buf; delete[] buf;
} }

void ParerSTestsUtils::WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename) {
const int32_t fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 384U);
if (fd >= 0) {
google::protobuf::io::FileOutputStream output(fd);
google::protobuf::TextFormat::Print(proto, &output);
output.Close();
close(fd);
}
}
} // namespace ge } // namespace ge

+ 1
- 0
tests/st/parser_st_utils.h View File

@@ -31,6 +31,7 @@ class ParerSTestsUtils {
static MemBuffer* MemBufferFromFile(const char *path); static MemBuffer* MemBufferFromFile(const char *path);
static bool ReadProtoFromText(const char *file, google::protobuf::Message *message); static bool ReadProtoFromText(const char *file, google::protobuf::Message *message);
static void WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename); static void WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename);
static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename);
}; };
} // namespace ge } // namespace ge




+ 25
- 1
tests/st/testcase/test_caffe_parser.cc View File

@@ -36,6 +36,7 @@
#include "parser/caffe/caffe_op_parser.h" #include "parser/caffe/caffe_op_parser.h"
#include "graph/operator_reg.h" #include "graph/operator_reg.h"
#include "parser/common/acl_graph_parser_util.h" #include "parser/common/acl_graph_parser_util.h"
#include "common/op_map.h"
#undef protected #undef protected
#undef private #undef private


@@ -173,7 +174,7 @@ void STestCaffeParser::RegisterCustomOp() {


std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) { for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data); domi::OpRegistry::Instance()->Register(reg_data);
} }
domi::OpRegistry::Instance()->registrationDatas.clear(); domi::OpRegistry::Instance()->registrationDatas.clear();
@@ -223,6 +224,29 @@ TEST_F(STestCaffeParser, acl_caffe_parser) {
EXPECT_EQ(ret, GRAPH_FAILED); EXPECT_EQ(ret, GRAPH_FAILED);
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), graph); ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), graph);
EXPECT_EQ(ret, GRAPH_FAILED); EXPECT_EQ(ret, GRAPH_FAILED);

caffe_op_map.clear();
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_FAILED);

{
proto.set_name("empty_layer");
auto &layers = *proto.add_layers();
layers.set_name("layers");

proto.clear_layer();
const std::string empty_layer = case_dir + "/origin_models/empty_layer.pbtxt";
ParerSTestsUtils::WriteProtoToTextFile(proto, empty_layer.c_str());
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layer.c_str(), weight_file.c_str(), parser_params, graph), FAILED);

proto.clear_layers();
const std::string empty_layers = case_dir + "/origin_models/empty_layers.pbtxt";
ParerSTestsUtils::WriteProtoToTextFile(proto, empty_layers.c_str());
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layers.c_str(), weight_file.c_str(), parser_params, graph), FAILED);

unlink(empty_layer.c_str());
unlink(empty_layers.c_str());
}
} }


TEST_F(STestCaffeParser, modelparser_parsefrommemory_success) TEST_F(STestCaffeParser, modelparser_parsefrommemory_success)


+ 2
- 1
tests/st/testcase/test_onnx_parser.cc View File

@@ -24,6 +24,7 @@
#include "st/parser_st_utils.h" #include "st/parser_st_utils.h"
#include "external/ge/ge_api_types.h" #include "external/ge/ge_api_types.h"
#include "tests/depends/ops_stub/ops_stub.h" #include "tests/depends/ops_stub/ops_stub.h"
#include "framework/omg/parser/parser_factory.h"
#include "parser/onnx/onnx_parser.h" #include "parser/onnx/onnx_parser.h"


namespace ge { namespace ge {
@@ -96,7 +97,7 @@ void STestOnnxParser::RegisterCustomOp() {


std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) { for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data); domi::OpRegistry::Instance()->Register(reg_data);
} }
domi::OpRegistry::Instance()->registrationDatas.clear(); domi::OpRegistry::Instance()->registrationDatas.clear();


+ 17
- 47
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -64,6 +64,7 @@
#include "parser/common/data_op_parser.h" #include "parser/common/data_op_parser.h"
#include "parser/common/model_saver.h" #include "parser/common/model_saver.h"
#include "framework/omg/parser/parser_api.h" #include "framework/omg/parser/parser_api.h"
#include "framework/omg/parser/parser_factory.h"
#include "parser/common/parser_fp16_t.h" #include "parser/common/parser_fp16_t.h"
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#include "parser/common/prototype_pass_manager.h" #include "parser/common/prototype_pass_manager.h"
@@ -151,7 +152,7 @@ void STestTensorflowParser::RegisterCustomOp() {
.ParseParamsFn(ParseParams); .ParseParamsFn(ParseParams);
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) { for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data); domi::OpRegistry::Instance()->Register(reg_data);
} }
domi::OpRegistry::Instance()->registrationDatas.clear(); domi::OpRegistry::Instance()->registrationDatas.clear();
@@ -584,7 +585,7 @@ namespace {
void register_tbe_op() { void register_tbe_op() {
std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas; std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas;
for (OpRegistrationData reg_data : registrationDatas) { for (OpRegistrationData reg_data : registrationDatas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
OpRegistry::Instance()->Register(reg_data); OpRegistry::Instance()->Register(reg_data);
} }
OpRegistry::Instance()->registrationDatas.clear(); OpRegistry::Instance()->registrationDatas.clear();
@@ -1124,7 +1125,7 @@ TEST_F(STestTensorflowParser, tensorflow_parserfrommemory_failed)
ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph);
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
ret = modelParser.ParseFromMemory(data, size, compute_graph); ret = modelParser.ParseFromMemory(data, size, compute_graph);
EXPECT_EQ(ret, INTERNAL_ERROR);
EXPECT_NE(ret, SUCCESS);
} }


TEST_F(STestTensorflowParser, modelparser_parsefrommemory_success) TEST_F(STestTensorflowParser, modelparser_parsefrommemory_success)
@@ -1259,7 +1260,7 @@ TEST_F(STestTensorflowParser, tensorflow_parserAllGraph_failed)
ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
TensorFlowModelParser tensorflow_parser; TensorFlowModelParser tensorflow_parser;
ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
EXPECT_EQ(INTERNAL_ERROR, ret);
ASSERT_NE(ret, SUCCESS);
} }


TEST_F(STestTensorflowParser, test_parse_acl_output_nodes) TEST_F(STestTensorflowParser, test_parse_acl_output_nodes)
@@ -1913,6 +1914,7 @@ TEST_F(STestTensorflowParser, tensorflow_auto_mapping_parser_adapter_test)
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);


op_dest->SetType(ge::parser::SHAPE); op_dest->SetType(ge::parser::SHAPE);
op_dest->AddOutputDesc(GeTensorDesc());
ret = autoMappingParser.ParseParams(node_def, op_dest); ret = autoMappingParser.ParseParams(node_def, op_dest);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }
@@ -2648,29 +2650,6 @@ TEST_F(STestTensorflowParser, tensorflow_UpdateEdgesControlInfo_test)
model_parser.UpdateEdgesControlInfo(info); model_parser.UpdateEdgesControlInfo(info);
} }


TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test)
{
TensorFlowModelParser model_parser;
NodeDef *node_def = new NodeDef();
node_def->set_name("Placeholder");
node_def->set_op("Placeholder_0");
std::map<string, NodeDef *> nodedef_map;
nodedef_map.emplace("Placeholder", node_def);
std::string curr_node_name = "Placeholder";
bool clear_input_flag = true;
Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, INTERNAL_ERROR);

GraphDef graph;
curr_node_name = "pre_node_a";
nodedef_map.emplace("pre_node_a", node_def);
node_def->set_op("pre_node_a");
GenOriginContext(&model_parser, curr_node_name);
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, SUCCESS);
delete node_def;
}

TEST_F(STestTensorflowParser, tensorflow_OptimizeSnapShot_test) TEST_F(STestTensorflowParser, tensorflow_OptimizeSnapShot_test)
{ {
TensorFlowModelParser model_parser; TensorFlowModelParser model_parser;
@@ -2831,27 +2810,17 @@ TEST_F(STestTensorflowParser, tensorflow_AddControlEdgeAfterRemoveInputs_test)
removed_inputs_vec.emplace_back("Add0"); removed_inputs_vec.emplace_back("Add0");
Status ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_def, all_node_map, removed_inputs_vec); Status ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_def, all_node_map, removed_inputs_vec);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
tensorflow::NodeDef *node_swith = initNodeDef();
node_swith->set_name("switch_op");
node_swith->set_op(parser::SWITCH);
all_node_map.emplace("switch_op", node_swith);
removed_inputs_vec.clear();
removed_inputs_vec.emplace_back("switch_op");
ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_swith, all_node_map, removed_inputs_vec);
EXPECT_EQ(ret, SUCCESS);
} }


TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test)
{
tensorflow::GraphDef graph_def;
TensorFlowModelParser tensorflow_parser;
tensorflow::NodeDef *node_def = initNodeDef();
node_def->set_name("post_node_d");

std::map<string, NodeDef *> nodedef_map;
nodedef_map.emplace("post_node_d", node_def);
nodedef_map.emplace("post_node_a", node_def);
nodedef_map.emplace("post_node_b", node_def);
std::vector<NodeDef *> nodedef_to_optimize;
nodedef_to_optimize.emplace_back(node_def);

std::string curr_node_name = "post_node_b";
GenOriginContext(&tensorflow_parser, curr_node_name);
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize);
EXPECT_EQ(ret, ge::PARAM_INVALID);
}
TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) {
std::string caseDir = __FILE__; std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/"); std::size_t idx = caseDir.find_last_of("/");
@@ -3534,7 +3503,8 @@ TEST_F(STestTensorflowParser, tensorflow_Pb2Json_OneField2Json_test)
ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc);
field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM); field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM);
mess2Op.ParseField(reflection, node_def, field, depth, ops); mess2Op.ParseField(reflection, node_def, field, depth, ops);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 1);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 5);
delete field; delete field;
} }




+ 6
- 1
tests/ut/parser/CMakeLists.txt View File

@@ -62,7 +62,7 @@ target_link_libraries(ut_parser_proto PRIVATE


################################################################################ ################################################################################
set(DUPLICATE_PROTO_LIST set(DUPLICATE_PROTO_LIST
"${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto"
"${PARSER_DIR}/metadef/proto/onnx/ge_onnx.proto"
) )


protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST}) protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST})
@@ -119,14 +119,19 @@ set(MATEDEF_SRC_FILES
"${PARSER_DIR}/metadef/graph/tensor.cc" "${PARSER_DIR}/metadef/graph/tensor.cc"
"${PARSER_DIR}/metadef/graph/types.cc" "${PARSER_DIR}/metadef/graph/types.cc"
"${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc" "${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/file_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc" "${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/graph_utils.cc" "${PARSER_DIR}/metadef/graph/utils/graph_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/connection_matrix.cc"
"${PARSER_DIR}/metadef/graph/utils/cycle_detector.cc"
"${PARSER_DIR}/metadef/graph/utils/node_utils.cc" "${PARSER_DIR}/metadef/graph/utils/node_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc" "${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/tensor_utils.cc" "${PARSER_DIR}/metadef/graph/utils/tensor_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc" "${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc" "${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/type_utils.cc" "${PARSER_DIR}/metadef/graph/utils/type_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/trace/trace_manager.cc"
"${PARSER_DIR}/metadef/graph/common/large_bm.cc"
"${PARSER_DIR}/metadef/ops/op_imp.cpp" "${PARSER_DIR}/metadef/ops/op_imp.cpp"
"${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc" "${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc"
"${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" "${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc"


+ 13
- 0
tests/ut/parser/parser_ut_utils.cc View File

@@ -21,6 +21,9 @@
#include <google/protobuf/io/zero_copy_stream_impl.h> #include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include <limits.h> #include <limits.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>


namespace ge { namespace ge {
void ParerUTestsUtils::ClearParserInnerCtx() { void ParerUTestsUtils::ClearParserInnerCtx() {
@@ -131,6 +134,16 @@ void ParerUTestsUtils::WriteProtoToBinaryFile(const google::protobuf::Message &p
delete[] buf; delete[] buf;
} }


void ParerUTestsUtils::WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename) {
const int32_t fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 384U);
if (fd >= 0) {
google::protobuf::io::FileOutputStream output(fd);
google::protobuf::TextFormat::Print(proto, &output);
output.Close();
close(fd);
}
}

namespace ut { namespace ut {
NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format,
DataType data_type, std::vector<int64_t> shape) { DataType data_type, std::vector<int64_t> shape) {


+ 1
- 0
tests/ut/parser/parser_ut_utils.h View File

@@ -32,6 +32,7 @@ class ParerUTestsUtils {
static MemBuffer* MemBufferFromFile(const char *path); static MemBuffer* MemBufferFromFile(const char *path);
static bool ReadProtoFromText(const char *file, google::protobuf::Message *message); static bool ReadProtoFromText(const char *file, google::protobuf::Message *message);
static void WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename); static void WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename);
static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename);
}; };


namespace ut { namespace ut {


+ 25
- 1
tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc View File

@@ -39,6 +39,7 @@
#include "graph/operator_reg.h" #include "graph/operator_reg.h"
#include "parser/common/acl_graph_parser_util.h" #include "parser/common/acl_graph_parser_util.h"
#include "parser/caffe/caffe_reshape_parser.h" #include "parser/caffe/caffe_reshape_parser.h"
#include "common/op_map.h"
#undef protected #undef protected
#undef private #undef private


@@ -162,7 +163,7 @@ static ge::NodePtr GenNodeFromOpDesc(ge::OpDescPtr opDesc){
void UtestCaffeParser::RegisterCustomOp() { void UtestCaffeParser::RegisterCustomOp() {
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) { for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data); domi::OpRegistry::Instance()->Register(reg_data);
} }
domi::OpRegistry::Instance()->registrationDatas.clear(); domi::OpRegistry::Instance()->registrationDatas.clear();
@@ -266,6 +267,29 @@ TEST_F(UtestCaffeParser, acl_caffe_parser) {
EXPECT_EQ(ret, GRAPH_FAILED); EXPECT_EQ(ret, GRAPH_FAILED);
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), graph); ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), graph);
EXPECT_EQ(ret, GRAPH_FAILED); EXPECT_EQ(ret, GRAPH_FAILED);

caffe_op_map.clear();
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_FAILED);

{
proto.set_name("empty_layer");
auto &layers = *proto.add_layers();
layers.set_name("layers");

proto.clear_layer();
const std::string empty_layer = case_dir + "/caffe_model/empty_layer.pbtxt";
ParerUTestsUtils::WriteProtoToTextFile(proto, empty_layer.c_str());
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layer.c_str(), weight_file.c_str(), parser_params, graph), FAILED);

proto.clear_layers();
const std::string empty_layers = case_dir + "/caffe_model/empty_layers.pbtxt";
ParerUTestsUtils::WriteProtoToTextFile(proto, empty_layers.c_str());
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layers.c_str(), weight_file.c_str(), parser_params, graph), FAILED);

unlink(empty_layer.c_str());
unlink(empty_layers.c_str());
}
} }


TEST_F(UtestCaffeParser, ParseFromMemory_success) TEST_F(UtestCaffeParser, ParseFromMemory_success)


+ 12
- 0
tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc View File

@@ -34,6 +34,7 @@
#include "parser/common/pass_manager.h" #include "parser/common/pass_manager.h"
#include "parser/common/tbe_plugin_loader.h" #include "parser/common/tbe_plugin_loader.h"
#include "parser/common/parser_fp16_t.h" #include "parser/common/parser_fp16_t.h"
#include "parser/common/pre_checker.h"
#undef protected #undef protected
#undef private #undef private


@@ -342,4 +343,15 @@ TEST_F(UtestAclGraphParser, test_operatoreq)
int8 = fp16; int8 = fp16;
} }


TEST_F(UtestAclGraphParser, test_pre_checker) {
PreChecker::Instance().fmk_op_types_ = nullptr;
const char* str = "iiii";
PreChecker::OpId id = str;
std::string type("ddd");
std::string name("lll");
Status ret = PreChecker::Instance().CheckTypeSupported(id, type, name, false);
EXPECT_EQ(ret, FAILED);
ret = PreChecker::Instance().CheckTypeSupported(id, type, name, true);
EXPECT_EQ(ret, FAILED);
}
} // namespace ge } // namespace ge

+ 2
- 1
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -24,6 +24,7 @@
#include "external/parser/onnx_parser.h" #include "external/parser/onnx_parser.h"
#include "ut/parser/parser_ut_utils.h" #include "ut/parser/parser_ut_utils.h"
#include "external/ge/ge_api_types.h" #include "external/ge/ge_api_types.h"
#include "framework/omg/parser/parser_factory.h"
#include "tests/depends/ops_stub/ops_stub.h" #include "tests/depends/ops_stub/ops_stub.h"


#define protected public #define protected public
@@ -103,7 +104,7 @@ void UtestOnnxParser::RegisterCustomOp() {
.ParseParamsFn(ParseParams); .ParseParamsFn(ParseParams);
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) { for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data); domi::OpRegistry::Instance()->Register(reg_data);
} }
domi::OpRegistry::Instance()->registrationDatas.clear(); domi::OpRegistry::Instance()->registrationDatas.clear();


+ 16
- 46
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -176,7 +176,7 @@ void UtestTensorflowParser::RegisterCustomOp() {
.ParseParamsFn(ParseParams); .ParseParamsFn(ParseParams);
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) { for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data); domi::OpRegistry::Instance()->Register(reg_data);
} }
domi::OpRegistry::Instance()->registrationDatas.clear(); domi::OpRegistry::Instance()->registrationDatas.clear();
@@ -599,7 +599,7 @@ namespace {
void register_tbe_op() { void register_tbe_op() {
std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas; std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas;
for (OpRegistrationData reg_data : registrationDatas) { for (OpRegistrationData reg_data : registrationDatas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
OpRegistry::Instance()->Register(reg_data); OpRegistry::Instance()->Register(reg_data);
} }
OpRegistry::Instance()->registrationDatas.clear(); OpRegistry::Instance()->registrationDatas.clear();
@@ -1288,7 +1288,7 @@ TEST_F(UtestTensorflowParser, tensorflow_parserfrommemory_failed)
ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph);
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
ret = modelParser.ParseFromMemory(data, size, compute_graph); ret = modelParser.ParseFromMemory(data, size, compute_graph);
EXPECT_EQ(ret, INTERNAL_ERROR);
EXPECT_NE(ret, SUCCESS);
} }


TEST_F(UtestTensorflowParser, modelparser_parsefrommemory_success) TEST_F(UtestTensorflowParser, modelparser_parsefrommemory_success)
@@ -1419,7 +1419,7 @@ TEST_F(UtestTensorflowParser, tensorflow_parserAllGraph_failed)
ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
TensorFlowModelParser tensorflow_parser; TensorFlowModelParser tensorflow_parser;
ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
EXPECT_EQ(INTERNAL_ERROR, ret);
ASSERT_NE(ret, SUCCESS);
} }


TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes) TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes)
@@ -2082,6 +2082,7 @@ TEST_F(UtestTensorflowParser, tensorflow_auto_mapping_parser_adapter_test)
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);


op_dest->SetType(ge::parser::SHAPE); op_dest->SetType(ge::parser::SHAPE);
op_dest->AddOutputDesc(GeTensorDesc());
ret = autoMappingParser.ParseParams(node_def, op_dest); ret = autoMappingParser.ParseParams(node_def, op_dest);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }
@@ -2824,29 +2825,6 @@ TEST_F(UtestTensorflowParser, tensorflow_UpdateEdgesControlInfo_test)
model_parser.UpdateEdgesControlInfo(info); model_parser.UpdateEdgesControlInfo(info);
} }


TEST_F(UtestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test)
{
TensorFlowModelParser model_parser;
NodeDef *node_def = new NodeDef();
node_def->set_name("Placeholder");
node_def->set_op("Placeholder_0");
std::map<string, NodeDef *> nodedef_map;
nodedef_map.emplace("Placeholder", node_def);
std::string curr_node_name = "Placeholder";
bool clear_input_flag = true;
Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, INTERNAL_ERROR);

GraphDef graph;
curr_node_name = "pre_node_a";
nodedef_map.emplace("pre_node_a", node_def);
node_def->set_op("pre_node_a");
GenOriginContext(&model_parser, curr_node_name);
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, SUCCESS);
delete node_def;
}

TEST_F(UtestTensorflowParser, tensorflow_OptimizeSnapShot_test) TEST_F(UtestTensorflowParser, tensorflow_OptimizeSnapShot_test)
{ {
TensorFlowModelParser model_parser; TensorFlowModelParser model_parser;
@@ -3007,27 +2985,18 @@ TEST_F(UtestTensorflowParser, tensorflow_AddControlEdgeAfterRemoveInputs_test)
removed_inputs_vec.emplace_back("Add0"); removed_inputs_vec.emplace_back("Add0");
Status ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_def, all_node_map, removed_inputs_vec); Status ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_def, all_node_map, removed_inputs_vec);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
}


TEST_F(UtestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test)
{
tensorflow::GraphDef graph_def;
TensorFlowModelParser tensorflow_parser;
tensorflow::NodeDef *node_def = initNodeDef();
node_def->set_name("post_node_d");
tensorflow::NodeDef *node_swith = initNodeDef();
node_swith->set_name("switch_op");
node_swith->set_op(parser::SWITCH);
all_node_map.emplace("switch_op", node_swith);
removed_inputs_vec.clear();
removed_inputs_vec.emplace_back("switch_op");
ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_swith, all_node_map, removed_inputs_vec);
EXPECT_EQ(ret, SUCCESS);
}


std::map<string, NodeDef *> nodedef_map;
nodedef_map.emplace("post_node_d", node_def);
nodedef_map.emplace("post_node_a", node_def);
nodedef_map.emplace("post_node_b", node_def);
std::vector<NodeDef *> nodedef_to_optimize;
nodedef_to_optimize.emplace_back(node_def);


std::string curr_node_name = "post_node_b";
GenOriginContext(&tensorflow_parser, curr_node_name);
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize);
EXPECT_EQ(ret, ge::PARAM_INVALID);
}
TEST_F(UtestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { TEST_F(UtestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) {
std::string caseDir = __FILE__; std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/"); std::size_t idx = caseDir.find_last_of("/");
@@ -3696,7 +3665,8 @@ TEST_F(UtestTensorflowParser, tensorflow_Pb2Json_OneField2Json_test)
ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc);
field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM); field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM);
mess2Op.ParseField(reflection, node_def, field, depth, ops); mess2Op.ParseField(reflection, node_def, field, depth, ops);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 1);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 5);
delete field; delete field;
} }




Loading…
Cancel
Save