@@ -3,6 +3,11 @@ approvers: | |||||
- wqtshg | - wqtshg | ||||
- ljl0711 | - ljl0711 | ||||
- liu-jisheng | - liu-jisheng | ||||
- zhangfan_hq | |||||
- lipeiyang3699 | |||||
reviewers: | reviewers: | ||||
- xchu42 | - xchu42 | ||||
- sheng-nan | - sheng-nan | ||||
- tangqunzhang | |||||
- wangxiaotian22 | |||||
- stevenaw |
@@ -1 +1 @@ | |||||
Subproject commit 0a2335712484f85cd44a0f2402eac6932b22b40a | |||||
Subproject commit 8fb59a00c6291207f3491fee0c4064efff94d79f |
@@ -94,7 +94,7 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l | |||||
const ge::ParserContext &ctx = GetParserContext(); | const ge::ParserContext &ctx = GetParserContext(); | ||||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | ||||
string name = layer->name(); | string name = layer->name(); | ||||
auto search = input_dims.find(name); | |||||
std::map<std::string, std::vector<int64_t>>::const_iterator search = input_dims.find(name); | |||||
if (search == input_dims.end()) { | if (search == input_dims.end()) { | ||||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()})); | REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()})); | ||||
GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " | GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " | ||||
@@ -139,7 +139,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete | |||||
const ge::ParserContext &ctx = GetParserContext(); | const ge::ParserContext &ctx = GetParserContext(); | ||||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | ||||
string name = layer->name(); | string name = layer->name(); | ||||
auto search = input_dims.find(name); | |||||
std::map<std::string, std::vector<int64_t>>::const_iterator search = input_dims.find(name); | |||||
if (search == input_dims.end()) { | if (search == input_dims.end()) { | ||||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()})); | REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()})); | ||||
GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " | GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " | ||||
@@ -19,6 +19,7 @@ | |||||
#include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
#include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
#include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
#include "graph/def_types.h" | |||||
using namespace ge::parser; | using namespace ge::parser; | ||||
using domi::caffe::BlobProto; | using domi::caffe::BlobProto; | ||||
@@ -107,7 +108,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||||
for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
buf[i] = proto.double_data(i); | buf[i] = proto.double_data(i); | ||||
} | } | ||||
GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast<uint8_t *>(buf.get()), size * sizeof(float)) != ge::GRAPH_SUCCESS, | |||||
GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<float, uint8_t>(buf.get()), size * sizeof(float)) != ge::GRAPH_SUCCESS, | |||||
GELOGW("SetData failed for GeTensor.");); // no need to return | GELOGW("SetData failed for GeTensor.");); // no need to return | ||||
} else if (proto.int8_data().length() > 0) { | } else if (proto.int8_data().length() > 0) { | ||||
if (size != static_cast<int>(proto.int8_data().length())) { | if (size != static_cast<int>(proto.int8_data().length())) { | ||||
@@ -121,7 +122,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||||
const char *data_ptr = proto.int8_data().data(); | const char *data_ptr = proto.int8_data().data(); | ||||
GE_CHECK_NOTNULL(data_ptr); | GE_CHECK_NOTNULL(data_ptr); | ||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
weight->SetData(reinterpret_cast<const uint8_t *>(data_ptr), size * sizeof(int8_t)) != ge::GRAPH_SUCCESS, | |||||
weight->SetData(PtrToPtr<const char, const uint8_t>(data_ptr), size * sizeof(int8_t)) != ge::GRAPH_SUCCESS, | |||||
GELOGW("SetData failed for GeTensor.");); // no need to return | GELOGW("SetData failed for GeTensor.");); // no need to return | ||||
dtype = ge::DT_INT8; | dtype = ge::DT_INT8; | ||||
} else if (proto.int32_data_size() > 0) { | } else if (proto.int32_data_size() > 0) { | ||||
@@ -139,7 +140,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||||
int32_weight_buf[i] = proto.int32_data(i); | int32_weight_buf[i] = proto.int32_data(i); | ||||
} | } | ||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
weight->SetData(reinterpret_cast<uint8_t *>(int32_weight_buf.get()), size * sizeof(int32_t)) != ge::GRAPH_SUCCESS, | |||||
weight->SetData(PtrToPtr<int32_t, uint8_t>(int32_weight_buf.get()), size * sizeof(int32_t)) != ge::GRAPH_SUCCESS, | |||||
GELOGW("SetData failed for GeTensor.");); // no need to return | GELOGW("SetData failed for GeTensor.");); // no need to return | ||||
dtype = ge::DT_INT32; | dtype = ge::DT_INT32; | ||||
} else if (proto.uint64_data_size() > 0) { | } else if (proto.uint64_data_size() > 0) { | ||||
@@ -156,7 +157,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||||
for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
uint64_weight_buf[i] = proto.uint64_data(i); | uint64_weight_buf[i] = proto.uint64_data(i); | ||||
} | } | ||||
GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast<uint8_t *>(uint64_weight_buf.get()), size * sizeof(uint64_t)) != | |||||
GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<uint64_t, uint8_t>(uint64_weight_buf.get()), size * sizeof(uint64_t)) != | |||||
ge::GRAPH_SUCCESS, | ge::GRAPH_SUCCESS, | ||||
GELOGW("SetData failed for GeTensor.");); // no need to return | GELOGW("SetData failed for GeTensor.");); // no need to return | ||||
dtype = ge::DT_UINT64; | dtype = ge::DT_UINT64; | ||||
@@ -173,7 +174,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||||
const float *data_ptr = proto.data().data(); | const float *data_ptr = proto.data().data(); | ||||
GE_CHECK_NOTNULL(data_ptr); | GE_CHECK_NOTNULL(data_ptr); | ||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
weight->SetData(reinterpret_cast<const uint8_t *>(data_ptr), size * sizeof(float)) != ge::GRAPH_SUCCESS, | |||||
weight->SetData(PtrToPtr<const float, const uint8_t>(data_ptr), size * sizeof(float)) != ge::GRAPH_SUCCESS, | |||||
GELOGW("SetData failed for GeTensor.");); // no need to return | GELOGW("SetData failed for GeTensor.");); // no need to return | ||||
} | } | ||||
ge::GeTensorDesc weight_desc = ge::GeTensorDesc(); | ge::GeTensorDesc weight_desc = ge::GeTensorDesc(); | ||||
@@ -45,7 +45,6 @@ | |||||
#include "parser/caffe/caffe_custom_parser_adapter.h" | #include "parser/caffe/caffe_custom_parser_adapter.h" | ||||
#include "parser/caffe/caffe_op_parser.h" | #include "parser/caffe/caffe_op_parser.h" | ||||
#include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
#include "parser/common/pre_checker.h" | |||||
#include "parser/common/prototype_pass_manager.h" | #include "parser/common/prototype_pass_manager.h" | ||||
#include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
#include "parser/common/model_saver.h" | #include "parser/common/model_saver.h" | ||||
@@ -61,13 +60,7 @@ using domi::caffe::InnerProductParameter; | |||||
using domi::caffe::LayerParameter; | using domi::caffe::LayerParameter; | ||||
using domi::caffe::NetParameter; | using domi::caffe::NetParameter; | ||||
using domi::ParseParamByOpFunc; | using domi::ParseParamByOpFunc; | ||||
using ge::caffe_op_map; | |||||
using ge::CaffeOpParser; | |||||
using ge::parser::ModelSaver; | using ge::parser::ModelSaver; | ||||
using ge::OpParser; | |||||
using ge::OpParserFactory; | |||||
using ge::Pb2Json; | |||||
using ge::PreChecker; | |||||
using std::ifstream; | using std::ifstream; | ||||
#define CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(val, errormsg) \ | #define CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(val, errormsg) \ | ||||
@@ -299,16 +292,17 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo | |||||
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); | GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
int input_dim_size = proto_message.input_dim_size(); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((input_dim_size / proto_message.input_size() != parser::DIM_DEFAULT_SIZE || | |||||
input_dim_size % proto_message.input_size() != 0), | |||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E11003", {"input_dim_size", "input_size"}, | |||||
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); | |||||
return FAILED, | |||||
"[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", | |||||
input_dim_size, proto_message.input_size()) | |||||
const int32_t input_dim_size = proto_message.input_dim_size(); | |||||
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || | |||||
((input_dim_size % proto_message.input_size()) != 0)); | |||||
if (is_input_invalid) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, | |||||
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); | |||||
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", | |||||
input_dim_size, proto_message.input_size()); | |||||
return FAILED; | |||||
} | |||||
for (int i = 0; i < proto_message.input_size(); i++) { | for (int i = 0; i < proto_message.input_size(); i++) { | ||||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | domi::caffe::LayerParameter *layer = proto_message.add_layer(); | ||||
@@ -329,12 +323,14 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo | |||||
input_data_flag = true; | input_data_flag = true; | ||||
} | } | ||||
} else if (proto_message.input_shape_size() > 0) { | } else if (proto_message.input_shape_size() > 0) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto_message.input_shape_size() != proto_message.input_size(), | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, | |||||
{std::to_string(proto_message.input_shape_size()), | |||||
std::to_string(proto_message.input_size())}); | |||||
return FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", | |||||
proto_message.input_shape_size(), proto_message.input_size()); | |||||
if (proto_message.input_shape_size() != proto_message.input_size()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, | |||||
{std::to_string(proto_message.input_shape_size()), | |||||
std::to_string(proto_message.input_size())}); | |||||
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", | |||||
proto_message.input_shape_size(), proto_message.input_size()); | |||||
return FAILED; | |||||
} | |||||
for (int i = 0; i < proto_message.input_size(); i++) { | for (int i = 0; i < proto_message.input_size(); i++) { | ||||
int dim_size = proto_message.input_shape(i).dim_size(); | int dim_size = proto_message.input_shape(i).dim_size(); | ||||
@@ -755,7 +751,8 @@ Status CaffeModelParser::GetCustomOp(const domi::caffe::LayerParameter &layer, v | |||||
} | } | ||||
if (is_search_built_in_layer) { | if (is_search_built_in_layer) { | ||||
const google::protobuf::Message *layer_message = reinterpret_cast<const google::protobuf::Message *>(&layer); | |||||
const google::protobuf::Message *layer_message = PtrToPtr<const domi::caffe::LayerParameter, | |||||
const google::protobuf::Message>(&layer); | |||||
Status status = CreateCustomOperator(op_name, op_type, layer_message, 0, operators); | Status status = CreateCustomOperator(op_name, op_type, layer_message, 0, operators); | ||||
if (status != SUCCESS || operators.empty()) { | if (status != SUCCESS || operators.empty()) { | ||||
GELOGE(status, "[Create][CustomOperator] failed, name: %s, type: %s.", op_name.c_str(), op_type.c_str()); | GELOGE(status, "[Create][CustomOperator] failed, name: %s, type: %s.", op_name.c_str(), op_type.c_str()); | ||||
@@ -838,11 +835,11 @@ Status CaffeModelParser::AddNode(const domi::caffe::LayerParameter &layer, ge::C | |||||
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::CAFFE); | std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::CAFFE); | ||||
GE_CHECK_NOTNULL(factory); | GE_CHECK_NOTNULL(factory); | ||||
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); | std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_parser == nullptr, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11009", {"opname", "optype"}, | |||||
{layer.name(), op_type}); | |||||
return FAILED, "op_parser is null, op_type: %s.", | |||||
op_type.c_str()); | |||||
if (op_parser == nullptr) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11009", {"opname", "optype"}, {layer.name(), op_type}); | |||||
GELOGE(FAILED, "op_parser is null, op_type: %s.", op_type.c_str()); | |||||
return FAILED; | |||||
} | |||||
ge::OpDescPtr op; | ge::OpDescPtr op; | ||||
// Process change of tensordesc initialization of opdesc, | // Process change of tensordesc initialization of opdesc, | ||||
@@ -994,7 +991,7 @@ Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const | |||||
GELOGI("op [%s], type[%s], update output(%d) with name %s %s", | GELOGI("op [%s], type[%s], update output(%d) with name %s %s", | ||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), | op_desc->GetName().c_str(), op_desc->GetType().c_str(), | ||||
i, op_desc->GetOutputNameByIndex(i).c_str(), | i, op_desc->GetOutputNameByIndex(i).c_str(), | ||||
ret == ge::GRAPH_SUCCESS ? "success" : "failed"); | |||||
ret == ge::GRAPH_SUCCESS ? "success" : "not success"); | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -1025,7 +1022,8 @@ Status CaffeModelParser::AddEdges(ge::ComputeGraphPtr &graph) { | |||||
// Find the layer for this output | // Find the layer for this output | ||||
auto top_node_iter = node_map.find(top_blob_layer_pair.first); | auto top_node_iter = node_map.find(top_blob_layer_pair.first); | ||||
// Find the layer for this input | // Find the layer for this input | ||||
auto bottom_node_iter = node_map.find(bottom_blob_layer_pair.first); | |||||
std::map<std::string, ge::NodePtr>::const_iterator bottom_node_iter = | |||||
node_map.find(bottom_blob_layer_pair.first); | |||||
if (top_node_iter != node_map.end() && bottom_node_iter != node_map.end()) { | if (top_node_iter != node_map.end() && bottom_node_iter != node_map.end()) { | ||||
// Output node top_node_iter->second, | // Output node top_node_iter->second, | ||||
// Output index top_blob_layer_pair.second | // Output index top_blob_layer_pair.second | ||||
@@ -1057,7 +1055,7 @@ Status CaffeModelParser::AddEdges(ge::ComputeGraphPtr &graph) { | |||||
{top_blob_layer_pair.first}); | {top_blob_layer_pair.first}); | ||||
GELOGE(INTERNAL_ERROR, "[Find][TopLayer] %s failed.", top_blob_layer_pair.first.c_str()); | GELOGE(INTERNAL_ERROR, "[Find][TopLayer] %s failed.", top_blob_layer_pair.first.c_str()); | ||||
return ge::FAILED;) | return ge::FAILED;) | ||||
GE_IF_BOOL_EXEC(top_node_iter == node_map.end(), | |||||
GE_IF_BOOL_EXEC(bottom_node_iter == node_map.end(), | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11015", {"opname"}, | ErrorManager::GetInstance().ATCReportErrMessage("E11015", {"opname"}, | ||||
{bottom_blob_layer_pair.first}); | {bottom_blob_layer_pair.first}); | ||||
GELOGE(INTERNAL_ERROR, "[Find][BottomLayer] %s failed.", bottom_blob_layer_pair.first.c_str()); | GELOGE(INTERNAL_ERROR, "[Find][BottomLayer] %s failed.", bottom_blob_layer_pair.first.c_str()); | ||||
@@ -1095,7 +1093,7 @@ Status CaffeModelParser::AddUserOutNodesTop() { | |||||
const std::vector<std::pair<std::string, int32_t>> &user_out_nodes = ge::GetParserContext().user_out_nodes; | const std::vector<std::pair<std::string, int32_t>> &user_out_nodes = ge::GetParserContext().user_out_nodes; | ||||
int net_output_num = user_out_nodes.size(); | int net_output_num = user_out_nodes.size(); | ||||
for (const auto &out_pair : user_out_nodes) { | for (const auto &out_pair : user_out_nodes) { | ||||
auto layer_iter = layer_tops_map_.find(out_pair.first); | |||||
std::map<std::string, std::vector<std::string>>::const_iterator layer_iter = layer_tops_map_.find(out_pair.first); | |||||
GELOGI("Add to output, node name: %s", out_pair.first.c_str()); | GELOGI("Add to output, node name: %s", out_pair.first.c_str()); | ||||
if (layer_iter != layer_tops_map_.end()) { | if (layer_iter != layer_tops_map_.end()) { | ||||
if (static_cast<uint32_t>(out_pair.second) >= (layer_iter->second).size()) { | if (static_cast<uint32_t>(out_pair.second) >= (layer_iter->second).size()) { | ||||
@@ -1110,7 +1108,7 @@ Status CaffeModelParser::AddUserOutNodesTop() { | |||||
} | } | ||||
string top_name = layer_iter->second[out_pair.second]; | string top_name = layer_iter->second[out_pair.second]; | ||||
auto top_node_iter = node_map.find(out_pair.first); | |||||
std::map<std::string, ge::NodePtr>::const_iterator top_node_iter = node_map.find(out_pair.first); | |||||
if (top_node_iter != node_map.end()) { | if (top_node_iter != node_map.end()) { | ||||
ge::GetParserContext().out_tensor_names.push_back(top_name); | ge::GetParserContext().out_tensor_names.push_back(top_name); | ||||
GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str()); | GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str()); | ||||
@@ -1142,7 +1140,8 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes | |||||
top = RemapTopNameByLayer(layer, top, i); | top = RemapTopNameByLayer(layer, top, i); | ||||
} | } | ||||
auto t_iter = top_blobs_map_.find(top); | |||||
std::map<std::string, std::vector<std::pair<std::string, int32_t>>>::const_iterator t_iter = | |||||
top_blobs_map_.find(top); | |||||
GE_RETURN_WITH_LOG_IF_FALSE(t_iter != top_blobs_map_.end(), | GE_RETURN_WITH_LOG_IF_FALSE(t_iter != top_blobs_map_.end(), | ||||
"[Check][Param]Failed to find top: %s, layer name:%s", top.c_str(), | "[Check][Param]Failed to find top: %s, layer name:%s", top.c_str(), | ||||
@@ -1156,7 +1155,7 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes | |||||
// If not found, add to the output side of the output | // If not found, add to the output side of the output | ||||
// Find the layer for this output | // Find the layer for this output | ||||
auto top_node_iter = node_map.find(layer.name()); | |||||
std::map<std::string, ge::NodePtr>::const_iterator top_node_iter = node_map.find(layer.name()); | |||||
GELOGI("output in top_blob: %s", layer.name().c_str()); | GELOGI("output in top_blob: %s", layer.name().c_str()); | ||||
if (top_node_iter != node_map.end()) { | if (top_node_iter != node_map.end()) { | ||||
ge::GetParserContext().out_tensor_names.push_back(top_origin); | ge::GetParserContext().out_tensor_names.push_back(top_origin); | ||||
@@ -1226,12 +1225,14 @@ Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) { | |||||
GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&layer, layer.name(), layer.type()), | GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&layer, layer.name(), layer.type()), | ||||
"[Invoke][AddOp]Add layer to PreChecker failed, layer name: %s.", | "[Invoke][AddOp]Add layer to PreChecker failed, layer name: %s.", | ||||
layer.name().c_str()); | layer.name().c_str()); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckName(&layer) != SUCCESS, return FAILED, | |||||
"[Invoke][CheckName]Check op[%s] failed, name repeat in caffe prototxt.", | |||||
layer.name().c_str()); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckType(&layer) != SUCCESS, return FAILED, | |||||
"[Invoke][CheckType]Check op[%s]'s optype failed, type is not supported.", | |||||
layer.name().c_str()); | |||||
if (PreChecker::Instance().CheckName(&layer) != SUCCESS) { | |||||
GELOGE(FAILED, "[Invoke][CheckName]Check op[%s] failed, name repeat in caffe prototxt.", layer.name().c_str()); | |||||
return FAILED; | |||||
} | |||||
if (PreChecker::Instance().CheckType(&layer) != SUCCESS) { | |||||
GELOGE(FAILED, "[Invoke][CheckType]Check op[%s]'s optype failed, type is not supported.", layer.name().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -1290,9 +1291,11 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||||
for (int32_t layer_index = 0; layer_index < layer_count; ++layer_index) { | for (int32_t layer_index = 0; layer_index < layer_count; ++layer_index) { | ||||
domi::caffe::LayerParameter &layer = const_cast<domi::caffe::LayerParameter &>(proto_message.layer(layer_index)); | domi::caffe::LayerParameter &layer = const_cast<domi::caffe::LayerParameter &>(proto_message.layer(layer_index)); | ||||
GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, | |||||
"[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.", | |||||
layer.name().c_str(), layer.type().c_str()); | |||||
if (!CheckValidLayer(layer)) { | |||||
GELOGI("[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.", | |||||
layer.name().c_str(), layer.type().c_str()); | |||||
continue; | |||||
} | |||||
CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && input_data_flag), has_error = true; | CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && input_data_flag), has_error = true; | ||||
REPORT_INNER_ERROR("E19999", "net %s has input and data layer simultaneously, check invalid." | REPORT_INNER_ERROR("E19999", "net %s has input and data layer simultaneously, check invalid." | ||||
@@ -1392,7 +1395,7 @@ void CaffeModelParser::SaveOrigionLayerTops(domi::caffe::LayerParameter &layer) | |||||
for (auto top : layer.top()) { | for (auto top : layer.top()) { | ||||
tops.push_back(top); | tops.push_back(top); | ||||
} | } | ||||
auto it = layer_tops_map_.find(name); | |||||
std::map<std::string, std::vector<std::string>>::const_iterator it = layer_tops_map_.find(name); | |||||
if (it == layer_tops_map_.end()) { | if (it == layer_tops_map_.end()) { | ||||
layer_tops_map_[name] = tops; | layer_tops_map_[name] = tops; | ||||
} | } | ||||
@@ -1431,11 +1434,23 @@ Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &la | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeModelParser::ReportLayerInvalid(const domi::caffe::NetParameter &proto, const std::string &path) const { | |||||
if (proto.layers_size() > 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11021", {"realpath"}, {path}); | |||||
GELOGE(FAILED, "[Check][Size]The model file[%s] is consisted of layers-structure which is deprecated in Caffe " | |||||
"and unsupported in ATC. The \"layers\" should be changed to \"layer\".", path.c_str()); | |||||
} else { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11022"); | |||||
GELOGE(FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid."); | |||||
} | |||||
return FAILED; | |||||
} | |||||
Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &graph) { | Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &graph) { | ||||
bool has_error = false; | bool has_error = false; | ||||
GE_CHECK_NOTNULL(model_path); | GE_CHECK_NOTNULL(model_path); | ||||
GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
GELOGI("Caffe Parse model file %s", model_path); | |||||
GELOGI("Caffe Parse model file [%s]", model_path); | |||||
PreChecker::Instance().Clear(); | PreChecker::Instance().Clear(); | ||||
@@ -1450,22 +1465,12 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||||
// parse network model by custom proto and get custom operators | // parse network model by custom proto and get custom operators | ||||
string custom_proto_path = ge::GetParserContext().custom_proto_path + "custom.proto"; | string custom_proto_path = ge::GetParserContext().custom_proto_path + "custom.proto"; | ||||
string caffe_proto_path = ge::GetParserContext().caffe_proto_path + "caffe.proto"; | string caffe_proto_path = ge::GetParserContext().caffe_proto_path + "caffe.proto"; | ||||
Status result = CustomProtoParse(model_path, custom_proto_path, caffe_proto_path, custom_operator_); | |||||
if (result != SUCCESS) { | |||||
GELOGE(FAILED, "[Parse][Model] by custom proto failed, model path: %s.", model_path); | |||||
return FAILED; | |||||
} | |||||
GE_CHK_STATUS(CustomProtoParse(model_path, custom_proto_path, caffe_proto_path, custom_operator_), | |||||
"[Parse][Model] by custom proto failed, model path: %s.", model_path); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
proto_message.layer_size() == 0 && proto_message.layers_size() > 0, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11021", {"realpath"}, {model_path}); | |||||
return FAILED, | |||||
"[Check][Size]The model file[%s] is consisted of layers-structure which is deprecated in Caffe " | |||||
"and unsupported in ATC. The \"layers\" should be changed to \"layer\".", | |||||
model_path); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto_message.layer_size() == 0), | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11022"); | |||||
return FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid."); | |||||
if (proto_message.layer_size() == 0) { | |||||
return ReportLayerInvalid(proto_message, model_path); | |||||
} | |||||
GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), | GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), | ||||
"Run ProtoType Pass Failed"); | "Run ProtoType Pass Failed"); | ||||
@@ -1476,8 +1481,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||||
GE_RETURN_IF_ERROR(PreCheck(proto_message)); | GE_RETURN_IF_ERROR(PreCheck(proto_message)); | ||||
if (PreChecker::Instance().HasError()) { | if (PreChecker::Instance().HasError()) { | ||||
REPORT_INNER_ERROR("E19999", "Precheck failed. Please read check report."); | |||||
GELOGE(INTERNAL_ERROR, "[Has][Error]Precheck failed. Please read check report."); | |||||
REPORT_INNER_ERROR("E19999", "Precheck failed. a report of json format will be create, Please read it."); | |||||
GELOGE(INTERNAL_ERROR, "[Has][Error]Precheck failed. a report of json format will be create, Please read it."); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -1512,9 +1517,11 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||||
for (int32_t layer_index = 0; layer_index < layer_count; ++layer_index) { | for (int32_t layer_index = 0; layer_index < layer_count; ++layer_index) { | ||||
domi::caffe::LayerParameter &layer = const_cast<domi::caffe::LayerParameter &>(proto_message.layer(layer_index)); | domi::caffe::LayerParameter &layer = const_cast<domi::caffe::LayerParameter &>(proto_message.layer(layer_index)); | ||||
SaveOrigionLayerTops(layer); | SaveOrigionLayerTops(layer); | ||||
GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, | |||||
"[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.", | |||||
layer.name().c_str(), layer.type().c_str()); | |||||
if (!CheckValidLayer(layer)) { | |||||
GELOGI("[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.", | |||||
layer.name().c_str(), layer.type().c_str()); | |||||
continue; | |||||
} | |||||
CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && input_data_flag), has_error = true; | CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && input_data_flag), has_error = true; | ||||
GELOGE(FAILED, "[Check][Layer]net %s has input and data layer simultaneously, check invalid." | GELOGE(FAILED, "[Check][Layer]net %s has input and data layer simultaneously, check invalid." | ||||
@@ -1679,7 +1686,7 @@ Status CaffeWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge:: | |||||
// Resolve proto file to netparameter | // Resolve proto file to netparameter | ||||
NetParameter proto; | NetParameter proto; | ||||
bool success = ge::parser::ReadProtoFromArray(reinterpret_cast<const char *>(data), static_cast<int>(size), &proto); | |||||
bool success = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &proto); | |||||
if (!success) { | if (!success) { | ||||
REPORT_CALL_ERROR("E19999", "ReadProtoFromArray failed."); | REPORT_CALL_ERROR("E19999", "ReadProtoFromArray failed."); | ||||
GELOGE(domi::PARSE_WEIGHTS_FAILED, "[Read][Proto] from Memory fail"); | GELOGE(domi::PARSE_WEIGHTS_FAILED, "[Read][Proto] from Memory fail"); | ||||
@@ -1920,7 +1927,7 @@ Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *r | |||||
const google::protobuf::FieldDescriptor *field, | const google::protobuf::FieldDescriptor *field, | ||||
google::protobuf::Message *layer) { | google::protobuf::Message *layer) { | ||||
GELOGD("Start to parse field: %s.", field->name().c_str()); | GELOGD("Start to parse field: %s.", field->name().c_str()); | ||||
domi::caffe::LayerParameter *layer_proto = reinterpret_cast<domi::caffe::LayerParameter *>(layer); | |||||
domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer); | |||||
string filed_name = field->name(); | string filed_name = field->name(); | ||||
#define CASE_FIELD_NAME(kName, method) \ | #define CASE_FIELD_NAME(kName, method) \ | ||||
if (filed_name == kField##kName) { \ | if (filed_name == kField##kName) { \ | ||||
@@ -1975,8 +1982,7 @@ Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *me | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
blobs_reflection->ListFields(*message, &field_desc); | blobs_reflection->ListFields(*message, &field_desc); | ||||
domi::caffe::BlobProto *blobs_proto = reinterpret_cast<domi::caffe::BlobProto *>(blobs); | |||||
domi::caffe::BlobProto *blobs_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobProto>(blobs); | |||||
for (auto &field : field_desc) { | for (auto &field : field_desc) { | ||||
GE_CHECK_NOTNULL(field); | GE_CHECK_NOTNULL(field); | ||||
@@ -2025,7 +2031,7 @@ Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message | |||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
reflection->ListFields(*message, &field_desc); | reflection->ListFields(*message, &field_desc); | ||||
domi::caffe::BlobShape *shape_proto = reinterpret_cast<domi::caffe::BlobShape *>(dest_message); | |||||
domi::caffe::BlobShape *shape_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobShape>(dest_message); | |||||
for (auto &field : field_desc) { | for (auto &field : field_desc) { | ||||
if (field->name() != kFieldDim) { | if (field->name() != kFieldDim) { | ||||
@@ -2048,7 +2054,7 @@ Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message | |||||
reflection->ListFields(*message, &field_desc); | reflection->ListFields(*message, &field_desc); | ||||
domi::caffe::ConvolutionParameter *conv_param_proto = | domi::caffe::ConvolutionParameter *conv_param_proto = | ||||
reinterpret_cast<domi::caffe::ConvolutionParameter *>(dest_message); | |||||
PtrToPtr<google::protobuf::Message, domi::caffe::ConvolutionParameter>(dest_message); | |||||
for (auto &field : field_desc) { | for (auto &field : field_desc) { | ||||
if (field->name() != kFieldBiasTerm) { | if (field->name() != kFieldBiasTerm) { | ||||
@@ -2068,7 +2074,7 @@ Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Mess | |||||
reflection->ListFields(*message, &field_desc); | reflection->ListFields(*message, &field_desc); | ||||
domi::caffe::InnerProductParameter *inner_product_proto = | domi::caffe::InnerProductParameter *inner_product_proto = | ||||
reinterpret_cast<domi::caffe::InnerProductParameter *>(dest_message); | |||||
PtrToPtr<google::protobuf::Message, domi::caffe::InnerProductParameter>(dest_message); | |||||
for (auto &field : field_desc) { | for (auto &field : field_desc) { | ||||
if (field->name() != kFieldBiasTerm) { | if (field->name() != kFieldBiasTerm) { | ||||
@@ -2110,22 +2116,25 @@ Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *mess | |||||
} | } | ||||
} | } | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(num_layer == 0 && num_layers > 0, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11023"); | |||||
return FAILED, | |||||
"[Check][Param]The weight file is consisted of layers-structure which is deprecated " | |||||
"in Caffe and unsupported in ATC. The \"layers\" should be changed to \"layer\"."); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((num_layer == 0), ErrorManager::GetInstance().ATCReportErrMessage("E11024"); | |||||
return FAILED, | |||||
"[Check][Param] Weight layer num is zero, weight file may be invalid."); | |||||
if (num_layer == 0 && num_layers > 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11023"); | |||||
GELOGE(FAILED, "[Check][Param]The weight file is consisted of layers-structure which is deprecated " | |||||
"in Caffe and unsupported in ATC. The \"layers\" should be changed to \"layer\"."); | |||||
return FAILED; | |||||
} | |||||
if (num_layer == 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11024"); | |||||
GELOGE(FAILED, "[Check][Param] Weight layer num is zero, weight file may be invalid."); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message *layer_message, | Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message *layer_message, | ||||
ge::ComputeGraphPtr &graph) { | ge::ComputeGraphPtr &graph) { | ||||
vector<string> need_share_layers; | vector<string> need_share_layers; | ||||
const domi::caffe::LayerParameter *layer = reinterpret_cast<const domi::caffe::LayerParameter *>(layer_message); | |||||
const domi::caffe::LayerParameter *layer = | |||||
PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer_message); | |||||
const string &shared_layer_name = layer->name(); | const string &shared_layer_name = layer->name(); | ||||
const string &layer_type = layer->type(); | const string &layer_type = layer->type(); | ||||
for (auto p_iter = params_share_map.begin(); p_iter != params_share_map.end(); ++p_iter) { | for (auto p_iter = params_share_map.begin(); p_iter != params_share_map.end(); ++p_iter) { | ||||
@@ -2159,7 +2168,7 @@ Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message | |||||
} | } | ||||
// The weight processing also needs to judge the duplicate operator, which is reserved here and processed later. | // The weight processing also needs to judge the duplicate operator, which is reserved here and processed later. | ||||
auto iter = caffe_op_map.find(layer_type); | |||||
std::map<std::string, std::string>::const_iterator iter = caffe_op_map.find(layer_type); | |||||
if (iter == caffe_op_map.end()) { | if (iter == caffe_op_map.end()) { | ||||
GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer_type.c_str(), layer_name.c_str()); | GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer_type.c_str(), layer_name.c_str()); | ||||
continue; | continue; | ||||
@@ -2172,20 +2181,20 @@ Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message | |||||
GE_CHECK_NOTNULL(factory); | GE_CHECK_NOTNULL(factory); | ||||
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); | std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
(op_parser.get() == nullptr), | |||||
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}), | |||||
std::vector<std::string>({layer_name, op_type})); | |||||
return FAILED, | |||||
"[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str()); | |||||
if (op_parser.get() == nullptr) { | |||||
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}), | |||||
std::vector<std::string>({layer_name, op_type})); | |||||
GELOGE(FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str()); | |||||
return FAILED; | |||||
} | |||||
// Parsing weight information through op parser | // Parsing weight information through op parser | ||||
Status status = op_parser->ParseWeights(layer_message, node); | Status status = op_parser->ParseWeights(layer_message, node); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
(status != SUCCESS), | |||||
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str()); | |||||
return status, | |||||
"[Parse][Weights] for op[%s] failed", layer_name.c_str()); | |||||
if (status != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str()); | |||||
GELOGE(FAILED, "[Parse][Weights] for op[%s] failed", layer_name.c_str()); | |||||
return status; | |||||
} | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -2233,13 +2242,18 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter ¶m, ge::Co | |||||
// Operator name and occurrence map, handle duplicate operators | // Operator name and occurrence map, handle duplicate operators | ||||
std::map<std::string, int32_t> layer_name_map; | std::map<std::string, int32_t> layer_name_map; | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(num_layer == 0 && num_layers > 0, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11023"); | |||||
return FAILED, "[Check][Param] The weight file is consisted of layers-structure " | |||||
"which is deprecated in Caffe and unsupported in ATC. " | |||||
"The \"layers\" should be changed to \"layer\"."); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((num_layer == 0), ErrorManager::GetInstance().ATCReportErrMessage("E11024"); | |||||
return FAILED, "weight layer num is zero, weight file may be invalid."); | |||||
if (num_layer == 0 && num_layers > 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11023"); | |||||
GELOGE(FAILED, "[Check][Param] The weight file is consisted of layers-structure " | |||||
"which is deprecated in Caffe and unsupported in ATC. " | |||||
"The \"layers\" should be changed to \"layer\"."); | |||||
return FAILED; | |||||
} | |||||
if (num_layer == 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11024"); | |||||
GELOGE(FAILED, "weight layer num is zero, weight file may be invalid."); | |||||
return FAILED; | |||||
} | |||||
for (int i = 0; i < num_layer; ++i) { | for (int i = 0; i < num_layer; ++i) { | ||||
const LayerParameter &layer = param.layer(i); | const LayerParameter &layer = param.layer(i); | ||||
@@ -2285,7 +2299,7 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter ¶m, ge::Co | |||||
} | } | ||||
// The weight processing also needs to judge the duplicate operator, which is reserved here and processed later. | // The weight processing also needs to judge the duplicate operator, which is reserved here and processed later. | ||||
auto iter = caffe_op_map.find(layer.type()); | |||||
std::map<std::string, std::string>::const_iterator iter = caffe_op_map.find(layer.type()); | |||||
if (iter == caffe_op_map.end()) { | if (iter == caffe_op_map.end()) { | ||||
GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer.type().c_str(), layer_name.c_str()); | GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer.type().c_str(), layer_name.c_str()); | ||||
continue; | continue; | ||||
@@ -2298,18 +2312,20 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter ¶m, ge::Co | |||||
GE_CHECK_NOTNULL(factory); | GE_CHECK_NOTNULL(factory); | ||||
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); | std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
(op_parser.get() == nullptr), | |||||
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}), | |||||
std::vector<std::string>({layer_name, op_type})); | |||||
return FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str()); | |||||
if (op_parser.get() == nullptr) { | |||||
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}), | |||||
std::vector<std::string>({layer_name, op_type})); | |||||
GELOGE(FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str()); | |||||
return FAILED; | |||||
} | |||||
// Parsing weight information through op parser | // Parsing weight information through op parser | ||||
Status status = op_parser->ParseWeights(&layer, node); | Status status = op_parser->ParseWeights(&layer, node); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
(status != SUCCESS), | |||||
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str()); | |||||
return status, "[Parse][Weights] for op[%s] failed", layer_name.c_str()); | |||||
if (status != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str()); | |||||
GELOGE(FAILED, "[Parse][Weights] for op[%s] failed", layer_name.c_str()); | |||||
return status; | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -40,6 +40,7 @@ | |||||
#include "omg/parser/op_parser.h" | #include "omg/parser/op_parser.h" | ||||
#include "omg/parser/model_parser.h" | #include "omg/parser/model_parser.h" | ||||
#include "omg/parser/weights_parser.h" | #include "omg/parser/weights_parser.h" | ||||
#include "common/pre_checker.h" | |||||
#include "proto/caffe/caffe.pb.h" | #include "proto/caffe/caffe.pb.h" | ||||
#include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
@@ -123,6 +124,17 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
} | } | ||||
bool HasError() override { | |||||
return PreChecker::Instance().HasError(); | |||||
} | |||||
Status Save(const string &file) override { | |||||
return PreChecker::Instance().Save(file); | |||||
} | |||||
void Clear() override { | |||||
PreChecker::Instance().Clear(); | |||||
} | |||||
private: | private: | ||||
Status Parse(const char *model_path, ge::ComputeGraphPtr &graph); | Status Parse(const char *model_path, ge::ComputeGraphPtr &graph); | ||||
@@ -313,6 +325,8 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
Status SaveDataLayerTops(const domi::caffe::LayerParameter &layer); | Status SaveDataLayerTops(const domi::caffe::LayerParameter &layer); | ||||
Status ReportLayerInvalid(const domi::caffe::NetParameter &proto, const std::string &path) const; | |||||
std::map<std::string, ge::NodePtr> node_map; | std::map<std::string, ge::NodePtr> node_map; | ||||
// key: blob name, value: layer name and index | // key: blob name, value: layer name and index | ||||
@@ -344,6 +358,18 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser { | |||||
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | ||||
bool HasError() override { | |||||
return PreChecker::Instance().HasError(); | |||||
} | |||||
Status Save(const string &file) override { | |||||
return PreChecker::Instance().Save(file); | |||||
} | |||||
void Clear() override { | |||||
PreChecker::Instance().Clear(); | |||||
} | |||||
private: | private: | ||||
Status CheckNodes(ge::ComputeGraphPtr &graph); | Status CheckNodes(ge::ComputeGraphPtr &graph); | ||||
/** | /** | ||||
@@ -128,7 +128,7 @@ Status CaffeReshapeParser::AddConstInput(ge::NodePtr &node) { | |||||
data[i] = attr_shape[i]; | data[i] = attr_shape[i]; | ||||
} | } | ||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
constTensor->SetData(reinterpret_cast<uint8_t *>(data.get()), dims_size * sizeof(int64_t)) != ge::GRAPH_SUCCESS, | |||||
constTensor->SetData(PtrToPtr<int64_t, uint8_t>(data.get()), dims_size * sizeof(int64_t)) != ge::GRAPH_SUCCESS, | |||||
GELOGW("SetData failed for GeTensor.");); // no need to return | GELOGW("SetData failed for GeTensor.");); // no need to return | ||||
// construct const node and add edge | // construct const node and add edge | ||||
@@ -492,7 +492,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, | |||||
} | } | ||||
domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | ||||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) { | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) { | |||||
std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes; | std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes; | ||||
if (!default_out_nodes.empty()) { | if (!default_out_nodes.empty()) { | ||||
for (size_t i = 0; i < default_out_nodes.size(); ++i) { | for (size_t i = 0; i < default_out_nodes.size(); ++i) { | ||||
@@ -613,24 +613,27 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin | |||||
ge::GetParserContext().out_tensor_names.clear(); | ge::GetParserContext().out_tensor_names.clear(); | ||||
ge::GetParserContext().data_tensor_names.clear(); | ge::GetParserContext().data_tensor_names.clear(); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS, | |||||
return PARAM_INVALID, "[Check][Options] Parse paragrams invalid, graph:%s.", | |||||
graph_name.c_str()); | |||||
if (CheckOptions(parser_params) != SUCCESS) { | |||||
GELOGE(FAILED, "[Check][Options] Parse paragrams invalid, graph:%s.", graph_name.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
// support paragrams: out_nodes, is_output_adjust_hw_layout, output, enable_scope_fusion_passes | // support paragrams: out_nodes, is_output_adjust_hw_layout, output, enable_scope_fusion_passes | ||||
SetDefaultFormat(); | SetDefaultFormat(); | ||||
string out_nodes; | string out_nodes; | ||||
GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes); | GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputNodes(out_nodes) != SUCCESS, | |||||
return PARAM_INVALID, | |||||
"[Invoke][ParseAclOutputNodes] Parse out_nodes failed, graph:%s.", graph_name.c_str()); | |||||
if (ParseAclOutputNodes(out_nodes) != SUCCESS) { | |||||
GELOGE(FAILED, "[Invoke][ParseAclOutputNodes] Parse out_nodes failed, graph:%s.", graph_name.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
string is_output_adjust_hw_layout; | string is_output_adjust_hw_layout; | ||||
GetAclParams(parser_params, ge::ir_option::IS_OUTPUT_ADJUST_HW_LAYOUT, is_output_adjust_hw_layout); | GetAclParams(parser_params, ge::ir_option::IS_OUTPUT_ADJUST_HW_LAYOUT, is_output_adjust_hw_layout); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS, | |||||
return PARAM_INVALID, | |||||
"[Invoke][ParseAclOutputFp16NodesFormat] Parse is_output_adjust_hw_layout failed, " | |||||
"graph:%s.", graph_name.c_str()); | |||||
if (ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS) { | |||||
GELOGE(FAILED, "[Invoke][ParseAclOutputFp16NodesFormat] Parse is_output_adjust_hw_layout failed, graph:%s.", | |||||
graph_name.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
string tmp_name; | string tmp_name; | ||||
GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name); | GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name); | ||||
@@ -638,10 +641,11 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin | |||||
string enable_scope_fusion_passes; | string enable_scope_fusion_passes; | ||||
GetAclParams(parser_params, ge::ir_option::ENABLE_SCOPE_FUSION_PASSES, enable_scope_fusion_passes); | GetAclParams(parser_params, ge::ir_option::ENABLE_SCOPE_FUSION_PASSES, enable_scope_fusion_passes); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclEnableScope(enable_scope_fusion_passes) != SUCCESS, | |||||
return PARAM_INVALID, | |||||
"[Invoke][ParseAclEnableScope] Parse enable_scope_fusion_passes failed, graph:%s.", | |||||
graph_name.c_str()); | |||||
if (ParseAclEnableScope(enable_scope_fusion_passes) != SUCCESS) { | |||||
GELOGE(FAILED, "[Invoke][ParseAclEnableScope] Parse enable_scope_fusion_passes failed, graph:%s.", | |||||
graph_name.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -657,10 +661,11 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | |||||
string is_input_adjust_hw_layout; | string is_input_adjust_hw_layout; | ||||
GetAclParams(parser_params, ge::ir_option::IS_INPUT_ADJUST_HW_LAYOUT, is_input_adjust_hw_layout); | GetAclParams(parser_params, ge::ir_option::IS_INPUT_ADJUST_HW_LAYOUT, is_input_adjust_hw_layout); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS, | |||||
return PARAM_INVALID, "[Invoke][ParseAclInputFp16Nodes] Parse input_fp16_nodes failed, graph:%s", | |||||
compute_graph->GetName().c_str()); | |||||
if (ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS) { | |||||
GELOGE(FAILED, "[Invoke][ParseAclInputFp16Nodes] Parse input_fp16_nodes failed, graph:%s", | |||||
compute_graph->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -689,30 +694,35 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char | |||||
// Get file length | // Get file length | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(), | |||||
REPORT_INNER_ERROR("E19999", "input_file path is null, check invalid."); | |||||
return -1, "[Check][Param] input_file path is null."); | |||||
if (input_file.empty()) { | |||||
REPORT_INNER_ERROR("E19999", "input_file path is null, check invalid."); | |||||
GELOGE(FAILED, "[Check][Param] input_file path is null."); | |||||
return -1; | |||||
} | |||||
std::string real_path = RealPath(input_file.c_str()); | std::string real_path = RealPath(input_file.c_str()); | ||||
char_t err_buf[kMaxErrStrLen + 1U] = {}; | char_t err_buf[kMaxErrStrLen + 1U] = {}; | ||||
const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen); | const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||||
REPORT_INPUT_ERROR("E19000", std::vector<std::string>({"path", "errmsg"}), | |||||
std::vector<std::string>({real_path, err_msg})); | |||||
return -1, "[Get][Path] input_file path '%s' not valid", input_file.c_str()); | |||||
if (real_path.empty()) { | |||||
REPORT_INPUT_ERROR("E19000", std::vector<std::string>({"path", "errmsg"}), | |||||
std::vector<std::string>({real_path, err_msg})); | |||||
GELOGE(FAILED, "[Get][Path] input_file path '%s' not valid", input_file.c_str()); | |||||
return -1; | |||||
} | |||||
unsigned long long file_length = 0; | unsigned long long file_length = 0; | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, | |||||
{input_file, err_msg}); | |||||
return -1, "[Open][File] [%s] failed. %s", input_file.c_str(), err_msg); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0 || file_length > kMaxFileSizeLimit), | |||||
REPORT_INPUT_ERROR( | |||||
"E19015", std::vector<std::string>({"file", "size", "maxsize"}), | |||||
std::vector<std::string>({input_file, std::to_string(file_length), | |||||
std::to_string(kMaxFileSizeLimit)})); | |||||
return -1, "[Check][Param] File[%s] size %lld is out of range(0,%d).", | |||||
input_file.c_str(), file_length, kMaxFileSizeLimit); | |||||
if (mmGetFileSize(input_file.c_str(), &file_length) != EN_OK) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, err_msg}); | |||||
GELOGE(FAILED, "[Open][File] [%s] failed. %s", input_file.c_str(), err_msg); | |||||
return -1; | |||||
} | |||||
if ((file_length == 0) || (file_length > kMaxFileSizeLimit)) { | |||||
REPORT_INPUT_ERROR("E19015", std::vector<std::string>({ "file", "size", "maxsize" }), | |||||
std::vector<std::string>({ input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit) })); | |||||
GELOGE(FAILED, "[Check][Param] File[%s] size %lld is out of range(0,%d).", | |||||
input_file.c_str(), file_length, kMaxFileSizeLimit); | |||||
return -1; | |||||
} | |||||
return static_cast<long>(file_length); | return static_cast<long>(file_length); | ||||
} | } | ||||
@@ -725,9 +735,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() | |||||
} | } | ||||
static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { | static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr, | |||||
REPORT_INNER_ERROR("E19999", "param proto is nullptr, check invalid"); | |||||
return false, "[Check][Param] incorrect parameter. nullptr == proto"); | |||||
if (proto == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "param proto is nullptr, check invalid"); | |||||
GELOGE(FAILED, "[Check][Param] incorrect parameter. nullptr == proto"); | |||||
return false; | |||||
} | |||||
coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit); | coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit); | ||||
return proto->ParseFromCodedStream(&coded_stream); | return proto->ParseFromCodedStream(&coded_stream); | ||||
@@ -743,17 +755,23 @@ static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Messag | |||||
*/ | */ | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, | ||||
int &length) { | int &length) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), | |||||
REPORT_INNER_ERROR("E19999", "param file_name is nullptr, check invalid"); | |||||
return false, "[Check][Param] incorrect parameter. file is nullptr"); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), | |||||
REPORT_INNER_ERROR("E19999", "param buffer is nullptr, check invalid"); | |||||
return false, "[Check][Param] incorrect parameter. buffer is nullptr"); | |||||
if (file_name == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "param file_name is nullptr, check invalid"); | |||||
GELOGE(FAILED, "[Check][Param] incorrect parameter. file is nullptr"); | |||||
return false; | |||||
} | |||||
if (buffer == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "param buffer is nullptr, check invalid"); | |||||
GELOGE(FAILED, "[Check][Param] incorrect parameter. buffer is nullptr"); | |||||
return false; | |||||
} | |||||
std::string real_path = RealPath(file_name); | std::string real_path = RealPath(file_name); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||||
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file_name); | |||||
return false, "[Check][Param]file path '%s' not valid, realpath failed", file_name); | |||||
if (real_path.empty()) { | |||||
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file_name); | |||||
GELOGE(FAILED, "[Check][Param]file path '%s' not valid, realpath failed", file_name); | |||||
return false; | |||||
} | |||||
std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate); | std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate); | ||||
if (!file.is_open()) { | if (!file.is_open()) { | ||||
@@ -763,16 +781,22 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co | |||||
} | } | ||||
length = static_cast<int>(file.tellg()); | length = static_cast<int>(file.tellg()); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((length <= 0), file.close(); REPORT_INNER_ERROR("E19999", "file length <= 0"); | |||||
return false, "[Check][Param] file length <= 0"); | |||||
if ((length <= 0)) { | |||||
file.close(); | |||||
REPORT_INNER_ERROR("E19999", "file length <= 0"); | |||||
GELOGE(FAILED, "[Check][Param] file length <= 0"); | |||||
return false; | |||||
} | |||||
file.seekg(0, std::ios::beg); | file.seekg(0, std::ios::beg); | ||||
*buffer = new(std::nothrow) char[length](); | *buffer = new(std::nothrow) char[length](); | ||||
GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, file.close(); | |||||
REPORT_CALL_ERROR("E19999", "new an object failed."), | |||||
"[Create][Buffer] new an object failed."); | |||||
if (*buffer == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "[Create][Buffer] new an object failed, length=%d.", length); | |||||
GELOGE(FAILED, "[Create][Buffer] new an object failed, length=%d.", length); | |||||
file.close(); | |||||
return false; | |||||
} | |||||
file.read(*buffer, length); | file.read(*buffer, length); | ||||
file.close(); | file.close(); | ||||
@@ -780,16 +804,23 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co | |||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr), | |||||
REPORT_INNER_ERROR("E19999", "param file or proto is nullptr, check invalid"); | |||||
return false, "[Check][Param] Input parameter file or proto is nullptr!"); | |||||
if ((file == nullptr) || (proto == nullptr)) { | |||||
REPORT_INNER_ERROR("E19999", "param file or proto is nullptr, check invalid"); | |||||
GELOGE(FAILED, "[Check][Param] Input parameter file or proto is nullptr!"); | |||||
return false; | |||||
} | |||||
std::string real_path = RealPath(file); | std::string real_path = RealPath(file); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||||
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file); | |||||
return false, "[Check][Param]pb file path '%s' not valid, realpath failed", file); | |||||
if (real_path.empty()) { | |||||
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file); | |||||
GELOGE(FAILED, "[Check][Param]pb file path '%s' not valid, realpath failed", file); | |||||
return false; | |||||
} | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "[Get][FileLength]file size not valid."); | |||||
if (GetFileLength(real_path) == -1) { | |||||
GELOGE(FAILED, "[Get][FileLength]file size not valid."); | |||||
return false; | |||||
} | |||||
std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); | std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); | ||||
if (!fs.is_open()) { | if (!fs.is_open()) { | ||||
@@ -815,32 +846,37 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co | |||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), | |||||
REPORT_INNER_ERROR("E19999", "param proto or data is nullptr " | |||||
"or size is 0, check invalid"); return false, | |||||
"[Check][Param]incorrect parameter. proto is nullptr || data is nullptr || size is 0"); | |||||
if ((proto == nullptr) || (data == nullptr) || (size == 0)) { | |||||
REPORT_INNER_ERROR("E19999", "param proto or data is nullptr or size is 0, check invalid"); | |||||
GELOGE(FAILED, "[Check][Param]incorrect parameter. proto is nullptr || data is nullptr || size is 0"); | |||||
return false; | |||||
} | |||||
google::protobuf::io::CodedInputStream coded_stream(reinterpret_cast<uint8_t *>(const_cast<void *>(data)), size); | |||||
google::protobuf::io::CodedInputStream coded_stream(PtrToPtr<void, uint8_t>(const_cast<void *>(data)), size); | |||||
return ReadProtoFromCodedInputStream(coded_stream, proto); | return ReadProtoFromCodedInputStream(coded_stream, proto); | ||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file, | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file, | ||||
google::protobuf::Message *message) { | google::protobuf::Message *message) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), | |||||
REPORT_INNER_ERROR("E19999", "param file or message is nullptr, check invalid"); | |||||
return false, | |||||
"[Check][Param]incorrect parameter. nullptr == file || nullptr == message"); | |||||
if ((file == nullptr) || (message == nullptr)) { | |||||
REPORT_INNER_ERROR("E19999", "param file or message is nullptr, check invalid"); | |||||
GELOGE(FAILED, "[Check][Param]incorrect parameter. nullptr == file || nullptr == message"); | |||||
return false; | |||||
} | |||||
std::string real_path = RealPath(file); | std::string real_path = RealPath(file); | ||||
char_t err_buf[kMaxErrStrLen + 1U] = {}; | char_t err_buf[kMaxErrStrLen + 1U] = {}; | ||||
const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen); | const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, | |||||
{file, err_msg}); | |||||
return false, "[Check][Param]Path[%s]'s realpath is empty, errmsg[%s]", file, | |||||
err_msg); | |||||
if (real_path.empty()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {file, err_msg}); | |||||
GELOGE(FAILED, "[Check][Param]Path[%s]'s realpath is empty, errmsg[%s]", file, err_msg); | |||||
return false; | |||||
} | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "[Check][Param] file size not valid."); | |||||
if (GetFileLength(real_path) == -1) { | |||||
GELOGE(FAILED, "[Check][Param] file size not valid."); | |||||
return false; | |||||
} | |||||
std::ifstream fs(real_path.c_str(), std::ifstream::in); | std::ifstream fs(real_path.c_str(), std::ifstream::in); | ||||
@@ -863,10 +899,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size, | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size, | ||||
google::protobuf::Message *message) { | google::protobuf::Message *message) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), | |||||
REPORT_INNER_ERROR("E19999", "param data or message is nullptr,check invalid"); | |||||
return false, | |||||
"[Check][Param] incorrect parameter. data is nullptr || message is nullptr"); | |||||
if ((data == nullptr) || (message == nullptr)) { | |||||
REPORT_INNER_ERROR("E19999", "param data or message is nullptr,check invalid"); | |||||
GELOGE(FAILED, "[Check][Param] incorrect parameter. data is nullptr || message is nullptr"); | |||||
return false; | |||||
} | |||||
std::string str(data, static_cast<size_t>(size)); | std::string str(data, static_cast<size_t>(size)); | ||||
std::istringstream fs(str); | std::istringstream fs(str); | ||||
@@ -901,7 +938,7 @@ Status GetOriginalType(const ge::NodePtr &node, string &type) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) { | |||||
FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &filePath, const std::string &mode) { | |||||
char ebuff[kMaxBuffSize]; | char ebuff[kMaxBuffSize]; | ||||
regex_t reg; | regex_t reg; | ||||
int cflags = REG_EXTENDED | REG_NOSUB; | int cflags = REG_EXTENDED | REG_NOSUB; | ||||
@@ -913,7 +950,7 @@ FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::str | |||||
return true; | return true; | ||||
} | } | ||||
ret = regexec(®, str.c_str(), 0, nullptr, 0); | |||||
ret = regexec(®, filePath.c_str(), 0, nullptr, 0); | |||||
if (ret) { | if (ret) { | ||||
regerror(ret, ®, ebuff, kMaxBuffSize); | regerror(ret, ®, ebuff, kMaxBuffSize); | ||||
GELOGE(ge::PARAM_INVALID, "[Invoke][RegExec] failed, reason: %s", ebuff); | GELOGE(ge::PARAM_INVALID, "[Invoke][RegExec] failed, reason: %s", ebuff); | ||||
@@ -21,11 +21,11 @@ | |||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/debug/ge_util.h" | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/util.h" | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
@@ -31,11 +31,17 @@ using std::string; | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
const int kSignificantDigits = 10; | const int kSignificantDigits = 10; | ||||
const int kMaxParseDepth = 20; | |||||
} | } | ||||
// JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields | // JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message, | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message, | ||||
const set<string> &black_fields, Json &json, | const set<string> &black_fields, Json &json, | ||||
bool enum2str) { | |||||
bool enum2str, int depth) { | |||||
if (depth > kMaxParseDepth) { | |||||
REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); | |||||
GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); | |||||
return; | |||||
} | |||||
auto descriptor = message.GetDescriptor(); | auto descriptor = message.GetDescriptor(); | ||||
auto reflection = message.GetReflection(); | auto reflection = message.GetReflection(); | ||||
if (descriptor == nullptr || reflection == nullptr) { | if (descriptor == nullptr || reflection == nullptr) { | ||||
@@ -57,7 +63,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons | |||||
if (field->is_repeated()) { | if (field->is_repeated()) { | ||||
if (reflection->FieldSize(message, field) > 0) { | if (reflection->FieldSize(message, field) > 0) { | ||||
RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str); | |||||
RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str, depth); | |||||
} | } | ||||
continue; | continue; | ||||
} | } | ||||
@@ -66,18 +72,18 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons | |||||
continue; | continue; | ||||
} | } | ||||
OneField2Json(message, field, reflection, black_fields, json, enum2str); | |||||
OneField2Json(message, field, reflection, black_fields, json, enum2str, depth); | |||||
} | } | ||||
} | } | ||||
void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | ||||
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | ||||
bool enum2str) { | |||||
bool enum2str, int depth) { | |||||
switch (field->type()) { | switch (field->type()) { | ||||
case ProtobufFieldDescriptor::TYPE_MESSAGE: { | case ProtobufFieldDescriptor::TYPE_MESSAGE: { | ||||
const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); | const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); | ||||
if (0UL != tmp_message.ByteSizeLong()) { | if (0UL != tmp_message.ByteSizeLong()) { | ||||
Message2Json(tmp_message, black_fields, json[field->name()], enum2str); | |||||
Message2Json(tmp_message, black_fields, json[field->name()], enum2str, depth + 1); | |||||
} | } | ||||
break; | break; | ||||
} | } | ||||
@@ -163,9 +169,9 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { | |||||
void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | ||||
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | ||||
bool enum2str) { | |||||
bool enum2str, int depth) { | |||||
if ((field == nullptr) || (reflection == nullptr)) { | if ((field == nullptr) || (reflection == nullptr)) { | ||||
Message2Json(message, black_fields, json, enum2str); | |||||
Message2Json(message, black_fields, json, enum2str, depth + 1); | |||||
return; | return; | ||||
} | } | ||||
@@ -175,7 +181,7 @@ void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFie | |||||
case ProtobufFieldDescriptor::TYPE_MESSAGE: { | case ProtobufFieldDescriptor::TYPE_MESSAGE: { | ||||
const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); | const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); | ||||
if (0UL != tmp_message.ByteSizeLong()) { | if (0UL != tmp_message.ByteSizeLong()) { | ||||
Message2Json(tmp_message, black_fields, tmp_json, enum2str); | |||||
Message2Json(tmp_message, black_fields, tmp_json, enum2str, depth + 1); | |||||
} | } | ||||
} break; | } break; | ||||
@@ -45,11 +45,11 @@ class Pb2Json { | |||||
* @author | * @author | ||||
*/ | */ | ||||
static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | ||||
bool enum2str = false); | |||||
bool enum2str = false, int depth = 0); | |||||
static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | ||||
const ProtobufReflection *reflection, const std::set<std::string> &black_fields, | const ProtobufReflection *reflection, const std::set<std::string> &black_fields, | ||||
Json &json, bool enum2str); | |||||
Json &json, bool enum2str, int depth = 0); | |||||
protected: | protected: | ||||
static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, | static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, | ||||
@@ -59,7 +59,7 @@ class Pb2Json { | |||||
static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | ||||
const ProtobufReflection *reflection, const std::set<std::string> &black_fields, Json &json, | const ProtobufReflection *reflection, const std::set<std::string> &black_fields, Json &json, | ||||
bool enum2str); | |||||
bool enum2str, int depth); | |||||
static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes); | static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes); | ||||
}; | }; | ||||
@@ -60,7 +60,7 @@ class DataOpParser { | |||||
* @param [in] 4D shape information (dimensions) | * @param [in] 4D shape information (dimensions) | ||||
* @param [out] Save converted shap information | * @param [out] Save converted shap information | ||||
*/ | */ | ||||
static Status Init5DInputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &tensorDesc); | |||||
static Status Init5DInputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &tensor_desc); | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -98,7 +98,7 @@ class DataOpParser { | |||||
* @return SUCCESS Convert success | * @return SUCCESS Convert success | ||||
* @return FAILED Convert failed | * @return FAILED Convert failed | ||||
*/ | */ | ||||
static Status InitNDTensor(const std::vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &desc); | |||||
static Status InitNDTensor(const std::vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &tensor_desc); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -55,9 +55,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi | |||||
} | } | ||||
char real_path[PATH_MAX] = {0}; | char real_path[PATH_MAX] = {0}; | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path) >= PATH_MAX, | |||||
REPORT_INNER_ERROR("E19999", "file path %s is too long!", file_path); | |||||
return FAILED, "[Check][Param] file path %s is too long!", file_path); | |||||
if (strlen(file_path) >= PATH_MAX) { | |||||
REPORT_INNER_ERROR("E19999", "file path %s is too long!", file_path); | |||||
GELOGE(FAILED, "[Check][Param] file path %s is too long!", file_path); | |||||
return FAILED; | |||||
} | |||||
if (realpath(file_path, real_path) == nullptr) { | if (realpath(file_path, real_path) == nullptr) { | ||||
GELOGI("File %s does not exit, it will be created.", file_path); | GELOGI("File %s does not exit, it will be created.", file_path); | ||||
} | } | ||||
@@ -32,9 +32,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOpera | |||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::DType(ge::DataType t) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::DType(ge::DataType t) { | ||||
Attr(VAR_ATTR_DTYPE, (int64_t)t); | |||||
Attr(VAR_ATTR_DTYPE, static_cast<int64_t>(t)); | |||||
return *this; | return *this; | ||||
} | } | ||||
ge::DataType ConstantOperator::GetDType() const { return (ge::DataType)GetIntAttr(VAR_ATTR_DTYPE); } | |||||
ge::DataType ConstantOperator::GetDType() const { return static_cast<ge::DataType>(GetIntAttr(VAR_ATTR_DTYPE)); } | |||||
} // namespace ge | } // namespace ge |
@@ -32,7 +32,7 @@ static void ConvertList(const std::pair<std::string, OpAttribute> &op_attr_pair, | |||||
vector<int64_t> v_i; | vector<int64_t> v_i; | ||||
for (int32_t i = 0; i < a_list.i_size(); i++) { | for (int32_t i = 0; i < a_list.i_size(); i++) { | ||||
v_i.push_back((int64_t)a_list.i(i)); | |||||
v_i.push_back(static_cast<int64_t>(a_list.i(i))); | |||||
} | } | ||||
if (v_i.size() > 0) { | if (v_i.size() > 0) { | ||||
(void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_i); | (void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_i); | ||||
@@ -56,7 +56,7 @@ static void ConvertList(const std::pair<std::string, OpAttribute> &op_attr_pair, | |||||
} | } | ||||
vector<int32_t> v_u; | vector<int32_t> v_u; | ||||
for (int32_t i = 0; i < a_list.u_size(); i++) { | for (int32_t i = 0; i < a_list.u_size(); i++) { | ||||
v_u.push_back((int32_t)a_list.u(i)); | |||||
v_u.push_back(static_cast<int32_t>(a_list.u(i))); | |||||
} | } | ||||
if (v_u.size() > 0) { | if (v_u.size() > 0) { | ||||
(void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_u); | (void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_u); | ||||
@@ -114,7 +114,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertToOpDesc(co | |||||
if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kBt) { | if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kBt) { | ||||
auto &buffer = op_attr_pair.second.value_.bt(); | auto &buffer = op_attr_pair.second.value_.bt(); | ||||
(void)ge::AttrUtils::SetZeroCopyBytes(op_def, op_attr_pair.first, | (void)ge::AttrUtils::SetZeroCopyBytes(op_def, op_attr_pair.first, | ||||
ge::Buffer::CopyFrom(reinterpret_cast<uint8_t *>(const_cast<char *>(buffer.data())), buffer.size())); | |||||
ge::Buffer::CopyFrom(PtrToPtr<void, uint8_t>(const_cast<char *>(buffer.data())), buffer.size())); | |||||
} | } | ||||
if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kS) { | if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kS) { | ||||
@@ -23,7 +23,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::RefSwitchOpe | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::~RefSwitchOperator() {} | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::~RefSwitchOperator() {} | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator &RefSwitchOperator::T(ge::DataType t) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator &RefSwitchOperator::T(ge::DataType t) { | ||||
Attr("T", (int64_t)t); | |||||
Attr("T", static_cast<int64_t>(t)); | |||||
return *this; | return *this; | ||||
} | } | ||||
} // namespace ge AUTO GEN PLEASE DO NOT MODIFY IT | } // namespace ge AUTO GEN PLEASE DO NOT MODIFY IT |
@@ -32,20 +32,20 @@ FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::N(int64_t n) { | |||||
FMK_FUNC_HOST_VISIBILITY int64_t ShapeNOperator::GetN() const { return GetIntAttr(SHAPEN_ATTR_N); } | FMK_FUNC_HOST_VISIBILITY int64_t ShapeNOperator::GetN() const { return GetIntAttr(SHAPEN_ATTR_N); } | ||||
FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::InType(ge::DataType t) { | FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::InType(ge::DataType t) { | ||||
Attr(SHAPEN_ATTR_IN_TYPE, (int64_t)t); | |||||
Attr(SHAPEN_ATTR_IN_TYPE, static_cast<int64_t>(t)); | |||||
return *this; | return *this; | ||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetInType() const { | FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetInType() const { | ||||
return (ge::DataType)GetIntAttr(SHAPEN_ATTR_IN_TYPE); | |||||
return static_cast<ge::DataType>(GetIntAttr(SHAPEN_ATTR_IN_TYPE)); | |||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::OutType(ge::DataType t) { | FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::OutType(ge::DataType t) { | ||||
Attr(SHAPEN_ATTR_OUT_TYPE, (int64_t)t); | |||||
Attr(SHAPEN_ATTR_OUT_TYPE, static_cast<int64_t>(t)); | |||||
return *this; | return *this; | ||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetOutType() const { | FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetOutType() const { | ||||
return (ge::DataType)GetIntAttr(SHAPEN_ATTR_OUT_TYPE); | |||||
return static_cast<ge::DataType>(GetIntAttr(SHAPEN_ATTR_OUT_TYPE)); | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -50,7 +50,7 @@ FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParserFactory> OpParserFactory::Insta | |||||
// Instances cannot be a member of a class because they may be used before initialization, resulting in a run error. | // Instances cannot be a member of a class because they may be used before initialization, resulting in a run error. | ||||
static std::map<domi::FrameworkType, std::shared_ptr<OpParserFactory>> instances; | static std::map<domi::FrameworkType, std::shared_ptr<OpParserFactory>> instances; | ||||
auto iter = instances.find(framework); | |||||
std::map<domi::FrameworkType, std::shared_ptr<OpParserFactory>>::const_iterator iter = instances.find(framework); | |||||
if (iter == instances.end()) { | if (iter == instances.end()) { | ||||
std::shared_ptr<OpParserFactory> instance(new (std::nothrow) OpParserFactory()); | std::shared_ptr<OpParserFactory> instance(new (std::nothrow) OpParserFactory()); | ||||
if (instance == nullptr) { | if (instance == nullptr) { | ||||
@@ -67,7 +67,7 @@ FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParserFactory> OpParserFactory::Insta | |||||
FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateOpParser(const std::string &op_type) { | FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateOpParser(const std::string &op_type) { | ||||
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. | // First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. | ||||
auto iter = op_parser_creator_map_.find(op_type); | |||||
std::map<std::string, CREATOR_FUN>::const_iterator iter = op_parser_creator_map_.find(op_type); | |||||
if (iter != op_parser_creator_map_.end()) { | if (iter != op_parser_creator_map_.end()) { | ||||
return iter->second(); | return iter->second(); | ||||
} | } | ||||
@@ -78,7 +78,7 @@ FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateOpPars | |||||
FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateFusionOpParser(const std::string &op_type) { | FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateFusionOpParser(const std::string &op_type) { | ||||
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. | // First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. | ||||
auto iter = fusion_op_parser_creator_map_.find(op_type); | |||||
std::map<std::string, CREATOR_FUN>::const_iterator iter = fusion_op_parser_creator_map_.find(op_type); | |||||
if (iter != fusion_op_parser_creator_map_.end()) { | if (iter != fusion_op_parser_creator_map_.end()) { | ||||
return iter->second(); | return iter->second(); | ||||
} | } | ||||
@@ -102,12 +102,12 @@ FMK_FUNC_HOST_VISIBILITY void OpParserFactory::RegisterCreator(const std::string | |||||
FMK_FUNC_HOST_VISIBILITY bool OpParserFactory::OpParserIsRegistered(const std::string &op_type, bool is_fusion_op) { | FMK_FUNC_HOST_VISIBILITY bool OpParserFactory::OpParserIsRegistered(const std::string &op_type, bool is_fusion_op) { | ||||
if (is_fusion_op) { | if (is_fusion_op) { | ||||
auto iter = fusion_op_parser_creator_map_.find(op_type); | |||||
std::map<std::string, CREATOR_FUN>::const_iterator iter = fusion_op_parser_creator_map_.find(op_type); | |||||
if (iter != fusion_op_parser_creator_map_.end()) { | if (iter != fusion_op_parser_creator_map_.end()) { | ||||
return true; | return true; | ||||
} | } | ||||
} else { | } else { | ||||
auto iter = op_parser_creator_map_.find(op_type); | |||||
std::map<std::string, CREATOR_FUN>::const_iterator iter = op_parser_creator_map_.find(op_type); | |||||
if (iter != op_parser_creator_map_.end()) { | if (iter != op_parser_creator_map_.end()) { | ||||
return true; | return true; | ||||
} | } | ||||
@@ -1,61 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef PARSER_COMMON_OP_TYPES_H_ | |||||
#define PARSER_COMMON_OP_TYPES_H_ | |||||
#include <set> | |||||
#include <string> | |||||
namespace ge { | |||||
class GE_FUNC_VISIBILITY OpTypeContainer { | |||||
public: | |||||
static OpTypeContainer *Instance() { | |||||
static OpTypeContainer instance; | |||||
return &instance; | |||||
} | |||||
~OpTypeContainer() = default; | |||||
void Register(const std::string &op_type) { op_type_list_.insert(op_type); } | |||||
bool IsExisting(const std::string &op_type) { | |||||
return op_type_list_.count(op_type) > 0UL; | |||||
} | |||||
protected: | |||||
OpTypeContainer() {} | |||||
private: | |||||
std::set<std::string> op_type_list_; | |||||
}; | |||||
class GE_FUNC_VISIBILITY OpTypeRegistrar { | |||||
public: | |||||
explicit OpTypeRegistrar(const std::string &op_type) { OpTypeContainer::Instance()->Register(op_type); } | |||||
~OpTypeRegistrar() {} | |||||
}; | |||||
#define REGISTER_OPTYPE_DECLARE(var_name, str_name) \ | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *var_name; | |||||
#define REGISTER_OPTYPE_DEFINE(var_name, str_name) \ | |||||
const char *var_name = str_name; \ | |||||
const OpTypeRegistrar g_##var_name##_reg(str_name); | |||||
#define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) | |||||
} // namespace ge | |||||
#endif // PARSER_COMMON_OP_TYPES_H_ |
@@ -16,6 +16,7 @@ | |||||
#include "omg/parser/parser_factory.h" | #include "omg/parser/parser_factory.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "common/register_tbe.h" | |||||
namespace domi { | namespace domi { | ||||
FMK_FUNC_HOST_VISIBILITY WeightsParserFactory *WeightsParserFactory::Instance() { | FMK_FUNC_HOST_VISIBILITY WeightsParserFactory *WeightsParserFactory::Instance() { | ||||
@@ -77,4 +78,13 @@ FMK_FUNC_HOST_VISIBILITY void ModelParserFactory::RegisterCreator(const domi::Fr | |||||
ModelParserFactory::~ModelParserFactory() { | ModelParserFactory::~ModelParserFactory() { | ||||
creator_map_.clear(); | creator_map_.clear(); | ||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY OpRegTbeParserFactory *OpRegTbeParserFactory::Instance() { | |||||
static OpRegTbeParserFactory instance; | |||||
return &instance; | |||||
} | |||||
void OpRegTbeParserFactory::Finalize(const domi::OpRegistrationData ®_data) { | |||||
(void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
} | |||||
} // namespace domi | } // namespace domi |
@@ -17,15 +17,16 @@ | |||||
#include "parser/common/parser_fp16_t.h" | #include "parser/common/parser_fp16_t.h" | ||||
#include "external/register/register_types.h" | #include "external/register/register_types.h" | ||||
#include "graph/def_types.h" | |||||
namespace { | namespace { | ||||
constexpr uint16_t kManBitLength = 11; | |||||
constexpr uint16_t kManBitLength = 11U; | |||||
} | } | ||||
namespace ge { | namespace ge { | ||||
namespace parser { | namespace parser { | ||||
/// @ingroup fp16_t global filed | /// @ingroup fp16_t global filed | ||||
/// @brief round mode of last valid digital | /// @brief round mode of last valid digital | ||||
enum TagFp16RoundMode g_round_mode = TagFp16RoundMode::kRoundToNearest; | |||||
const TagFp16RoundMode g_round_mode = TagFp16RoundMode::kRoundToNearest; | |||||
void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m) { | void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m) { | ||||
// 1.Extract | // 1.Extract | ||||
@@ -99,12 +100,12 @@ static float Fp16ToFloat(const uint16_t &fp_val) { | |||||
e_ret = 0; | e_ret = 0; | ||||
m_ret = 0; | m_ret = 0; | ||||
} else { | } else { | ||||
e_ret = hf_exp - kFp16ExpBias + kFp32ExpBias; | |||||
e_ret = (static_cast<uint32_t>(hf_exp) - static_cast<uint32_t>(kFp16ExpBias)) + static_cast<uint32_t>(kFp32ExpBias); | |||||
m_ret = hf_man & kFp16ManMask; | m_ret = hf_man & kFp16ManMask; | ||||
m_ret = m_ret << (kFp32ManLen - kFp16ManLen); | m_ret = m_ret << (kFp32ManLen - kFp16ManLen); | ||||
} | } | ||||
uint32_t f_val = FP32_CONSTRUCTOR(s_ret, e_ret, m_ret); | uint32_t f_val = FP32_CONSTRUCTOR(s_ret, e_ret, m_ret); | ||||
auto p_ret_v = reinterpret_cast<float *>(&f_val); | |||||
auto p_ret_v = ge::PtrToPtr<uint32_t, float>(&f_val); | |||||
return *p_ret_v; | return *p_ret_v; | ||||
} | } | ||||
@@ -131,12 +132,12 @@ static double Fp16ToDouble(const uint16_t &fp_val) { | |||||
e_ret = 0; | e_ret = 0; | ||||
m_ret = 0; | m_ret = 0; | ||||
} else { | } else { | ||||
e_ret = hf_exp - kFp16ExpBias + kFp64ExpBias; | |||||
e_ret = (static_cast<uint64_t>(hf_exp) - static_cast<uint64_t>(kFp16ExpBias)) + static_cast<uint64_t>(kFp64ExpBias); | |||||
m_ret = hf_man & kFp16ManMask; | m_ret = hf_man & kFp16ManMask; | ||||
m_ret = m_ret << (kFp64ManLen - kFp16ManLen); | m_ret = m_ret << (kFp64ManLen - kFp16ManLen); | ||||
} | } | ||||
uint64_t f_val = (s_ret << kFp64SignIndex) | (e_ret << kFp64ManLen) | (m_ret); | uint64_t f_val = (s_ret << kFp64SignIndex) | (e_ret << kFp64ManLen) | (m_ret); | ||||
auto p_ret_v = reinterpret_cast<double *>(&f_val); | |||||
auto p_ret_v = ge::PtrToPtr<uint64_t, double>(&f_val); | |||||
return *p_ret_v; | return *p_ret_v; | ||||
} | } | ||||
@@ -154,13 +155,13 @@ static uint8_t GetUint8ValByMan(uint8_t s_ret, const uint64_t &long_int_m, const | |||||
if (need_round) { | if (need_round) { | ||||
m_ret++; | m_ret++; | ||||
} | } | ||||
if (s_ret) { | |||||
m_ret = (~m_ret) + 1; | |||||
if (static_cast<bool>(s_ret)) { | |||||
m_ret = (~m_ret) + 1U; | |||||
} | } | ||||
if (m_ret == 0) { | if (m_ret == 0) { | ||||
s_ret = 0; | s_ret = 0; | ||||
} | } | ||||
return static_cast<uint8_t>((s_ret << kBitShift7) | (m_ret)); | |||||
return static_cast<uint8_t>((s_ret << static_cast<uint8_t>(kBitShift7)) | (m_ret)); | |||||
} | } | ||||
/// @ingroup fp16_t math conversion static method | /// @ingroup fp16_t math conversion static method | ||||
@@ -178,7 +179,7 @@ static int8_t Fp16ToInt8(const uint16_t &fp_val) { | |||||
if (FP16_IS_DENORM(fp_val)) { // Denormalized number | if (FP16_IS_DENORM(fp_val)) { // Denormalized number | ||||
ret_v = 0; | ret_v = 0; | ||||
ret = *(reinterpret_cast<uint8_t *>(&ret_v)); | |||||
ret = *(ge::PtrToPtr<uint8_t, uint8_t>(&ret_v)); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -207,14 +208,14 @@ static int8_t Fp16ToInt8(const uint16_t &fp_val) { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
if (overflow_flag) { | |||||
if (static_cast<bool>(overflow_flag)) { | |||||
ret_v = kInt8Max + s_ret; | ret_v = kInt8Max + s_ret; | ||||
} else { | } else { | ||||
// Generate final result | // Generate final result | ||||
ret_v = GetUint8ValByMan(s_ret, long_int_m, shift_out); | ret_v = GetUint8ValByMan(s_ret, long_int_m, shift_out); | ||||
} | } | ||||
ret = *(reinterpret_cast<uint8_t *>(&ret_v)); | |||||
ret = *(ge::PtrToPtr<uint8_t, uint8_t>(&ret_v)); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -283,8 +284,8 @@ static uint16_t GetUint16ValByMan(uint16_t s_ret, const uint64_t &long_int_m, co | |||||
if (need_round && m_ret < kInt16Max) { | if (need_round && m_ret < kInt16Max) { | ||||
m_ret++; | m_ret++; | ||||
} | } | ||||
if (s_ret) { | |||||
m_ret = (~m_ret) + 1; | |||||
if (static_cast<bool>(s_ret)) { | |||||
m_ret = (~m_ret) + 1U; | |||||
} | } | ||||
if (m_ret == 0) { | if (m_ret == 0) { | ||||
s_ret = 0; | s_ret = 0; | ||||
@@ -307,7 +308,7 @@ static int16_t Fp16ToInt16(const uint16_t &fp_val) { | |||||
if (FP16_IS_DENORM(fp_val)) { // Denormalized number | if (FP16_IS_DENORM(fp_val)) { // Denormalized number | ||||
ret_v = 0; | ret_v = 0; | ||||
ret = *(reinterpret_cast<uint8_t *>(&ret_v)); | |||||
ret = *(ge::PtrToPtr<uint16_t, uint8_t>(&ret_v)); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -336,13 +337,13 @@ static int16_t Fp16ToInt16(const uint16_t &fp_val) { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
if (overflow_flag) { | |||||
if (static_cast<bool>(overflow_flag)) { | |||||
ret_v = kInt16Max + s_ret; | ret_v = kInt16Max + s_ret; | ||||
} else { | } else { | ||||
// Generate final result | // Generate final result | ||||
ret_v = GetUint16ValByMan(s_ret, long_int_m, shift_out); | ret_v = GetUint16ValByMan(s_ret, long_int_m, shift_out); | ||||
} | } | ||||
ret = *(reinterpret_cast<int16_t *>(&ret_v)); | |||||
ret = *(ge::PtrToPtr<uint16_t, uint16_t>(&ret_v)); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -433,7 +434,7 @@ static int32_t Fp16ToInt32(const uint16_t &fp_val) { | |||||
ret_v = (s_ret << kBitShift31) | (m_ret); | ret_v = (s_ret << kBitShift31) | (m_ret); | ||||
} | } | ||||
return *(reinterpret_cast<int32_t *>(&ret_v)); | |||||
return *(ge::PtrToPtr<uint32_t, uint32_t>(&ret_v)); | |||||
} | } | ||||
/// @ingroup fp16_t math conversion static method | /// @ingroup fp16_t math conversion static method | ||||
@@ -498,8 +499,8 @@ static uint16_t Fp16AddCalVal(uint16_t s_ret, int16_t e_ret, uint16_t m_ret, uin | |||||
} | } | ||||
bool b_last_bit = ((m_ret & 1) > 0); | bool b_last_bit = ((m_ret & 1) > 0); | ||||
bool b_trunc_high = 0; | |||||
bool b_trunc_left = 0; | |||||
bool b_trunc_high = false; | |||||
bool b_trunc_left = false; | |||||
b_trunc_high = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32SignMask) > 0); | b_trunc_high = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32SignMask) > 0); | ||||
b_trunc_left = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32AbsMax) > 0); | b_trunc_left = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32AbsMax) > 0); | ||||
m_ret = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_ret, shift_out); | m_ret = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_ret, shift_out); | ||||
@@ -561,7 +562,7 @@ static uint16_t Fp16Add(uint16_t v_1, uint16_t v_2) { | |||||
int16_t e_ret = std::max(e_a, e_b); | int16_t e_ret = std::max(e_a, e_b); | ||||
int16_t e_tmp = std::abs(e_a - e_b); | int16_t e_tmp = std::abs(e_a - e_b); | ||||
if (e_a > e_b) { | if (e_a > e_b) { | ||||
m_trunc = (m_b << (kBitShift32 - static_cast<uint16_t>(e_tmp))); | |||||
m_trunc = (m_b << (static_cast<uint16_t>(kBitShift32) - static_cast<uint16_t>(e_tmp))); | |||||
m_b = RightShift(m_b, e_tmp); | m_b = RightShift(m_b, e_tmp); | ||||
} else if (e_a < e_b) { | } else if (e_a < e_b) { | ||||
m_trunc = (m_a << (kBitShift32 - static_cast<uint16_t>(e_tmp))); | m_trunc = (m_a << (kBitShift32 - static_cast<uint16_t>(e_tmp))); | ||||
@@ -602,7 +603,7 @@ static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) { | |||||
m_a = m_a_tmp; | m_a = m_a_tmp; | ||||
m_b = m_b_tmp; | m_b = m_b_tmp; | ||||
e_ret = e_a + e_b - kFp16ExpBias - kDim10; | |||||
e_ret = ((e_a + e_b) - kFp16ExpBias) - kDim10; | |||||
mul_m = m_a * m_b; | mul_m = m_a * m_b; | ||||
s_ret = s_a ^ s_b; | s_ret = s_a ^ s_b; | ||||
@@ -621,8 +622,8 @@ static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) { | |||||
e_ret = e_ret + 1; | e_ret = e_ret + 1; | ||||
} | } | ||||
bool b_last_bit = ((mul_m & 1) > 0); | bool b_last_bit = ((mul_m & 1) > 0); | ||||
bool b_trunc_high = 0; | |||||
bool b_trunc_left = 0; | |||||
bool b_trunc_high = false; | |||||
bool b_trunc_left = false; | |||||
b_trunc_high = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32SignMask) > 0); | b_trunc_high = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32SignMask) > 0); | ||||
b_trunc_left = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32AbsMax) > 0); | b_trunc_left = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32AbsMax) > 0); | ||||
mul_m = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, mul_m); | mul_m = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, mul_m); | ||||
@@ -675,14 +676,14 @@ static uint16_t Fp16Div(uint16_t v_1, uint16_t v_2) { | |||||
uint64_t m_tmp; | uint64_t m_tmp; | ||||
if (e_a > e_b) { | if (e_a > e_b) { | ||||
m_tmp = m_a; | m_tmp = m_a; | ||||
uint16_t tmp = e_a - e_b; | |||||
uint16_t tmp = static_cast<uint16_t>(e_a - e_b); | |||||
for (int i = 0; i < tmp; i++) { | for (int i = 0; i < tmp; i++) { | ||||
m_tmp = m_tmp << 1; | m_tmp = m_tmp << 1; | ||||
} | } | ||||
m_a = m_tmp; | m_a = m_tmp; | ||||
} else if (e_a < e_b) { | } else if (e_a < e_b) { | ||||
m_tmp = m_b; | m_tmp = m_b; | ||||
uint16_t tmp = e_b - e_a; | |||||
uint16_t tmp = static_cast<uint16_t>(e_b - e_a); | |||||
for (int i = 0; i < tmp; i++) { | for (int i = 0; i < tmp; i++) { | ||||
m_tmp = m_tmp << 1; | m_tmp = m_tmp << 1; | ||||
} | } | ||||
@@ -853,7 +854,7 @@ fp16_t &fp16_t::operator=(const float &f_val) { | |||||
uint16_t s_ret, m_ret; | uint16_t s_ret, m_ret; | ||||
int16_t e_ret; | int16_t e_ret; | ||||
uint32_t e_f, m_f; | uint32_t e_f, m_f; | ||||
const uint32_t ui32_v = *(reinterpret_cast<const uint32_t *>(&f_val)); // 1:8:23bit sign:exp:man | |||||
const uint32_t ui32_v = *(ge::PtrToPtr<const float, const uint32_t>(&f_val)); // 1:8:23bit sign:exp:man | |||||
uint32_t m_len_delta; | uint32_t m_len_delta; | ||||
s_ret = static_cast<uint16_t>((ui32_v & kFp32SignMask) >> kFp32SignIndex); // 4Byte->2Byte | s_ret = static_cast<uint16_t>((ui32_v & kFp32SignMask) >> kFp32SignIndex); // 4Byte->2Byte | ||||
@@ -891,7 +892,7 @@ fp16_t &fp16_t::operator=(const float &f_val) { | |||||
if (need_round) { | if (need_round) { | ||||
m_ret++; | m_ret++; | ||||
} | } | ||||
if (m_ret & kFp16ManHideBit) { | |||||
if (static_cast<bool>(m_ret & kFp16ManHideBit)) { | |||||
e_ret++; | e_ret++; | ||||
} | } | ||||
} | } | ||||
@@ -910,14 +911,14 @@ fp16_t &fp16_t::operator=(const int8_t &i_val) { | |||||
if (m_ret == 0) { | if (m_ret == 0) { | ||||
e_ret = 0; | e_ret = 0; | ||||
} else { | } else { | ||||
if (s_ret) { // negative number(<0) | |||||
if (static_cast<bool>(s_ret)) { // negative number(<0) | |||||
m_ret = static_cast<uint16_t>(std::abs(i_val)); // complement | m_ret = static_cast<uint16_t>(std::abs(i_val)); // complement | ||||
} | } | ||||
e_ret = kFp16ManLen; | e_ret = kFp16ManLen; | ||||
while ((m_ret & kFp16ManHideBit) == 0) { | while ((m_ret & kFp16ManHideBit) == 0) { | ||||
m_ret = m_ret << 1; | m_ret = m_ret << 1; | ||||
e_ret = e_ret - 1; | |||||
e_ret = e_ret - 1U; | |||||
} | } | ||||
e_ret = e_ret + kFp16ExpBias; | e_ret = e_ret + kFp16ExpBias; | ||||
} | } | ||||
@@ -931,11 +932,11 @@ fp16_t &fp16_t::operator=(const uint8_t &ui_val) { | |||||
s_ret = 0; | s_ret = 0; | ||||
e_ret = 0; | e_ret = 0; | ||||
m_ret = ui_val; | m_ret = ui_val; | ||||
if (m_ret) { | |||||
if (static_cast<bool>(m_ret)) { | |||||
e_ret = kFp16ManLen; | e_ret = kFp16ManLen; | ||||
while ((m_ret & kFp16ManHideBit) == 0) { | while ((m_ret & kFp16ManHideBit) == 0) { | ||||
m_ret = m_ret << 1; | m_ret = m_ret << 1; | ||||
e_ret = e_ret - 1; | |||||
e_ret = e_ret - 1U; | |||||
} | } | ||||
e_ret = e_ret + kFp16ExpBias; | e_ret = e_ret + kFp16ExpBias; | ||||
} | } | ||||
@@ -949,11 +950,11 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u | |||||
uint16_t m_min = kFp16ManHideBit; | uint16_t m_min = kFp16ManHideBit; | ||||
uint16_t m_max = m_min << 1; | uint16_t m_max = m_min << 1; | ||||
uint16_t len = static_cast<uint16_t>(GetManBitLength(m_tmp)); | uint16_t len = static_cast<uint16_t>(GetManBitLength(m_tmp)); | ||||
if (m_tmp) { | |||||
if (static_cast<bool>(m_tmp)) { | |||||
int16_t e_ret; | int16_t e_ret; | ||||
if (len > kDim11) { | if (len > kDim11) { | ||||
e_ret = kFp16ExpBias + kFp16ManLen; | e_ret = kFp16ExpBias + kFp16ManLen; | ||||
uint16_t e_tmp = len - kDim11; | |||||
uint16_t e_tmp = len - static_cast<uint16_t>(kDim11); | |||||
uint32_t trunc_mask = 1; | uint32_t trunc_mask = 1; | ||||
for (int i = 1; i < e_tmp; i++) { | for (int i = 1; i < e_tmp; i++) { | ||||
trunc_mask = (trunc_mask << 1) + 1; | trunc_mask = (trunc_mask << 1) + 1; | ||||
@@ -964,8 +965,8 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u | |||||
e_ret = e_ret + 1; | e_ret = e_ret + 1; | ||||
} | } | ||||
bool b_last_bit = ((m_tmp & 1) > 0); | bool b_last_bit = ((m_tmp & 1) > 0); | ||||
bool b_trunc_high = 0; | |||||
bool b_trunc_left = 0; | |||||
bool b_trunc_high = false; | |||||
bool b_trunc_left = false; | |||||
if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc | if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc | ||||
b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | ||||
b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | ||||
@@ -976,7 +977,7 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u | |||||
e_ret = e_ret + 1; | e_ret = e_ret + 1; | ||||
} | } | ||||
} else { | } else { | ||||
e_ret = kFp16ExpBias; | |||||
e_ret = static_cast<int16_t>(kFp16ExpBias); | |||||
m_tmp = m_tmp << (kManBitLength - len); | m_tmp = m_tmp << (kManBitLength - len); | ||||
e_ret = e_ret + (len - 1); | e_ret = e_ret + (len - 1); | ||||
} | } | ||||
@@ -989,11 +990,11 @@ fp16_t &fp16_t::operator=(const int16_t &i_val) { | |||||
if (i_val == 0) { | if (i_val == 0) { | ||||
val = 0; | val = 0; | ||||
} else { | } else { | ||||
uint16_t ui_val = *(reinterpret_cast<const uint16_t *>(&i_val)); | |||||
uint16_t ui_val = *(ge::PtrToPtr<const int16_t, const int16_t>(&i_val)); | |||||
auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift15); | auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift15); | ||||
if (s_ret) { | |||||
if (static_cast<bool>(s_ret)) { | |||||
int16_t iValM = -i_val; | int16_t iValM = -i_val; | ||||
ui_val = *(reinterpret_cast<uint16_t *>(&iValM)); | |||||
ui_val = *(ge::PtrToPtr<int16_t, uint16_t>(&iValM)); | |||||
} | } | ||||
SetValByUint16Val(ui_val, s_ret, val); | SetValByUint16Val(ui_val, s_ret, val); | ||||
} | } | ||||
@@ -1023,8 +1024,8 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) { | |||||
e_ret = e_ret + 1; | e_ret = e_ret + 1; | ||||
} | } | ||||
bool b_last_bit = ((m_ret & 1) > 0); | bool b_last_bit = ((m_ret & 1) > 0); | ||||
bool b_trunc_high = 0; | |||||
bool b_trunc_left = 0; | |||||
bool b_trunc_high = false; | |||||
bool b_trunc_left = false; | |||||
if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc | if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc | ||||
b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | ||||
b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | ||||
@@ -1038,7 +1039,7 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) { | |||||
val = kFp16Max; | val = kFp16Max; | ||||
} | } | ||||
} else { | } else { | ||||
e_ret = kFp16ExpBias; | |||||
e_ret = static_cast<int16_t>(kFp16ExpBias); | |||||
m_ret = m_ret << (kDim11 - len); | m_ret = m_ret << (kDim11 - len); | ||||
e_ret = e_ret + (len - 1); | e_ret = e_ret + (len - 1); | ||||
} | } | ||||
@@ -1057,7 +1058,7 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u | |||||
e_ret = kFp16ExpBias + kFp16ManLen; | e_ret = kFp16ExpBias + kFp16ManLen; | ||||
uint32_t m_trunc = 0; | uint32_t m_trunc = 0; | ||||
uint32_t trunc_mask = 1; | uint32_t trunc_mask = 1; | ||||
uint16_t e_tmp = len - kDim11; | |||||
uint16_t e_tmp = len - static_cast<uint16_t>(kDim11); | |||||
for (int i = 1; i < e_tmp; i++) { | for (int i = 1; i < e_tmp; i++) { | ||||
trunc_mask = (trunc_mask << 1) + 1; | trunc_mask = (trunc_mask << 1) + 1; | ||||
} | } | ||||
@@ -1067,8 +1068,8 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u | |||||
e_ret = e_ret + 1; | e_ret = e_ret + 1; | ||||
} | } | ||||
bool b_last_bit = ((m_tmp & 1) > 0); | bool b_last_bit = ((m_tmp & 1) > 0); | ||||
bool b_trunc_high = 0; | |||||
bool b_trunc_left = 0; | |||||
bool b_trunc_high = false; | |||||
bool b_trunc_left = false; | |||||
if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc | if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc | ||||
b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | ||||
b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | ||||
@@ -1083,7 +1084,7 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u | |||||
m_tmp = kFp16MaxMan; | m_tmp = kFp16MaxMan; | ||||
} | } | ||||
} else { | } else { | ||||
e_ret = kFp16ExpBias; | |||||
e_ret = static_cast<int16_t>(kFp16ExpBias); | |||||
m_tmp = m_tmp << (kDim11 - len); | m_tmp = m_tmp << (kDim11 - len); | ||||
e_ret = e_ret + (len - 1); | e_ret = e_ret + (len - 1); | ||||
} | } | ||||
@@ -1095,11 +1096,11 @@ fp16_t &fp16_t::operator=(const int32_t &i_val) { | |||||
if (i_val == 0) { | if (i_val == 0) { | ||||
val = 0; | val = 0; | ||||
} else { | } else { | ||||
uint32_t ui_val = *(reinterpret_cast<const uint32_t *>(&i_val)); | |||||
uint32_t ui_val = *(ge::PtrToPtr<const int32_t, const uint32_t>(&i_val)); | |||||
auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift31); | auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift31); | ||||
if (s_ret) { | |||||
if (static_cast<bool>(s_ret)) { | |||||
int32_t iValM = -i_val; | int32_t iValM = -i_val; | ||||
ui_val = *(reinterpret_cast<uint32_t *>(&iValM)); | |||||
ui_val = *(ge::PtrToPtr<int32_t, uint32_t>(&iValM)); | |||||
} | } | ||||
SetValByUint32Val(ui_val, s_ret, val); | SetValByUint32Val(ui_val, s_ret, val); | ||||
} | } | ||||
@@ -1119,7 +1120,7 @@ fp16_t &fp16_t::operator=(const uint32_t &ui_val) { | |||||
e_ret = kFp16ExpBias + kFp16ManLen; | e_ret = kFp16ExpBias + kFp16ManLen; | ||||
uint32_t m_trunc = 0; | uint32_t m_trunc = 0; | ||||
uint32_t trunc_mask = 1; | uint32_t trunc_mask = 1; | ||||
uint16_t e_tmp = len - kDim11; | |||||
uint16_t e_tmp = len - static_cast<uint16_t>(kDim11); | |||||
for (int i = 1; i < e_tmp; i++) { | for (int i = 1; i < e_tmp; i++) { | ||||
trunc_mask = (trunc_mask << 1) + 1; | trunc_mask = (trunc_mask << 1) + 1; | ||||
} | } | ||||
@@ -1145,7 +1146,7 @@ fp16_t &fp16_t::operator=(const uint32_t &ui_val) { | |||||
m_tmp = kFp16MaxMan; | m_tmp = kFp16MaxMan; | ||||
} | } | ||||
} else { | } else { | ||||
e_ret = kFp16ExpBias; | |||||
e_ret = static_cast<int16_t>(kFp16ExpBias); | |||||
m_tmp = m_tmp << (kDim11 - len); | m_tmp = m_tmp << (kDim11 - len); | ||||
e_ret = e_ret + (len - 1); | e_ret = e_ret + (len - 1); | ||||
} | } | ||||
@@ -1161,7 +1162,7 @@ fp16_t &fp16_t::operator=(const double &d_val) { | |||||
int16_t e_ret; | int16_t e_ret; | ||||
uint64_t e_d; | uint64_t e_d; | ||||
uint64_t m_d; | uint64_t m_d; | ||||
uint64_t ui64_v = *(reinterpret_cast<const uint64_t *>(&d_val)); // 1:11:52bit sign:exp:man | |||||
uint64_t ui64_v = *(ge::PtrToPtr<const double, const uint64_t>(&d_val)); // 1:11:52bit sign:exp:man | |||||
uint32_t m_len_delta; | uint32_t m_len_delta; | ||||
s_ret = static_cast<uint16_t>((ui64_v & kFp64SignMask) >> kFp64SignIndex); // 4Byte | s_ret = static_cast<uint16_t>((ui64_v & kFp64SignMask) >> kFp64SignIndex); // 4Byte | ||||
@@ -1204,7 +1205,7 @@ fp16_t &fp16_t::operator=(const double &d_val) { | |||||
if (need_round) { | if (need_round) { | ||||
m_ret++; | m_ret++; | ||||
} | } | ||||
if (m_ret & kFp16ManHideBit) { | |||||
if (static_cast<bool>(m_ret & kFp16ManHideBit)) { | |||||
e_ret++; | e_ret++; | ||||
} | } | ||||
} | } | ||||
@@ -1239,7 +1240,7 @@ fp16_t::operator uint64_t() const { return 0; } | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() const { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() const { | ||||
if ((val & kFp16AbsMax) == kFp16ExpMask) { | if ((val & kFp16AbsMax) == kFp16ExpMask) { | ||||
if (val & kFp16SignMask) { | |||||
if (static_cast<bool>(val & kFp16SignMask)) { | |||||
return -1; | return -1; | ||||
} else { | } else { | ||||
return 1; | return 1; | ||||
@@ -91,16 +91,16 @@ using BitShift = enum { | |||||
}; | }; | ||||
/// @ingroup fp16 basic parameter | /// @ingroup fp16 basic parameter | ||||
/// @brief fp16 exponent bias | /// @brief fp16 exponent bias | ||||
constexpr uint16_t kFp16ExpBias = 15; | |||||
constexpr uint16_t kFp16ExpBias = 15U; | |||||
/// @ingroup fp16 basic parameter | /// @ingroup fp16 basic parameter | ||||
/// @brief the exponent bit length of fp16 is 5 | /// @brief the exponent bit length of fp16 is 5 | ||||
constexpr uint16_t kFp16ExpLen = 5; | |||||
constexpr uint16_t kFp16ExpLen = 5U; | |||||
/// @ingroup fp16 basic parameter | /// @ingroup fp16 basic parameter | ||||
/// @brief the mantissa bit length of fp16 is 10 | /// @brief the mantissa bit length of fp16 is 10 | ||||
constexpr uint16_t kFp16ManLen = 10; | |||||
constexpr uint16_t kFp16ManLen = 10U; | |||||
/// @ingroup fp16 basic parameter | /// @ingroup fp16 basic parameter | ||||
/// @brief bit index of sign in fp16 | /// @brief bit index of sign in fp16 | ||||
constexpr uint16_t kFp16SignIndex = 15; | |||||
constexpr uint16_t kFp16SignIndex = 15U; | |||||
/// @ingroup fp16 basic parameter | /// @ingroup fp16 basic parameter | ||||
/// @brief sign mask of fp16 (1 00000 00000 00000) | /// @brief sign mask of fp16 (1 00000 00000 00000) | ||||
constexpr uint16_t kFp16SignMask = 0x8000; | constexpr uint16_t kFp16SignMask = 0x8000; | ||||
@@ -164,16 +164,16 @@ constexpr uint16_t kFp16MinNormal = 1.0f / (2 << 14); | |||||
#define FP16_IS_INVALID(x) (((x) & kFp16ExpMask) == kFp16ExpMask) | #define FP16_IS_INVALID(x) (((x) & kFp16ExpMask) == kFp16ExpMask) | ||||
/// @ingroup fp32 basic parameter | /// @ingroup fp32 basic parameter | ||||
/// @brief fp32 exponent bias | /// @brief fp32 exponent bias | ||||
constexpr uint16_t kFp32ExpBias = 127; | |||||
constexpr uint16_t kFp32ExpBias = 127U; | |||||
/// @ingroup fp32 basic parameter | /// @ingroup fp32 basic parameter | ||||
/// @brief the exponent bit length of float/fp32 is 8 | /// @brief the exponent bit length of float/fp32 is 8 | ||||
constexpr uint16_t kFp32ExpLen = 8; | |||||
constexpr uint16_t kFp32ExpLen = 8U; | |||||
/// @ingroup fp32 basic parameter | /// @ingroup fp32 basic parameter | ||||
/// @brief the mantissa bit length of float/fp32 is 23 | /// @brief the mantissa bit length of float/fp32 is 23 | ||||
constexpr uint16_t kFp32ManLen = 23; | |||||
constexpr uint16_t kFp32ManLen = 23U; | |||||
/// @ingroup fp32 basic parameter | /// @ingroup fp32 basic parameter | ||||
/// @brief bit index of sign in float/fp32 | /// @brief bit index of sign in float/fp32 | ||||
constexpr uint16_t kFp32SignIndex = 31; | |||||
constexpr uint16_t kFp32SignIndex = 31U; | |||||
/// @ingroup fp32 basic parameter | /// @ingroup fp32 basic parameter | ||||
/// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000) | /// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000) | ||||
constexpr uint32_t kFp32SignMask = 0x80000000u; | constexpr uint32_t kFp32SignMask = 0x80000000u; | ||||
@@ -191,10 +191,10 @@ constexpr uint32_t kFp32ManHideBit = 0x00800000u; | |||||
constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFu; | constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFu; | ||||
/// @ingroup fp32 basic parameter | /// @ingroup fp32 basic parameter | ||||
/// @brief maximum exponent value of fp32 is 255(1111 1111) | /// @brief maximum exponent value of fp32 is 255(1111 1111) | ||||
constexpr uint32_t kFp32MaxExp = 0xFF; | |||||
constexpr uint32_t kFp32MaxExp = 0xFFU; | |||||
/// @ingroup fp32 basic parameter | /// @ingroup fp32 basic parameter | ||||
/// @brief maximum mantissa value of fp32 (1111 1111 1111 1111 1111 111) | /// @brief maximum mantissa value of fp32 (1111 1111 1111 1111 1111 111) | ||||
constexpr uint32_t kFp32MaxMan = 0x7FFFFF; | |||||
constexpr uint32_t kFp32MaxMan = 0x7FFFFFU; | |||||
/// @ingroup fp32 special value judgment | /// @ingroup fp32 special value judgment | ||||
/// @brief whether a fp32 is NaN | /// @brief whether a fp32 is NaN | ||||
#define FP32_IS_NAN(x) ((((x) & kFp32ExpMask) == kFp32ExpMask) && ((x) & kFp32ManMask)) | #define FP32_IS_NAN(x) ((((x) & kFp32ExpMask) == kFp32ExpMask) && ((x) & kFp32ManMask)) | ||||
@@ -218,16 +218,16 @@ constexpr uint32_t kFp32MaxMan = 0x7FFFFF; | |||||
#define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m) & kFp32MaxMan)) | #define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m) & kFp32MaxMan)) | ||||
/// @ingroup fp64 basic parameter | /// @ingroup fp64 basic parameter | ||||
/// @brief fp64 exponent bias | /// @brief fp64 exponent bias | ||||
constexpr uint16_t kFp64ExpBias = 1023; | |||||
constexpr uint16_t kFp64ExpBias = 1023U; | |||||
/// @ingroup fp64 basic parameter | /// @ingroup fp64 basic parameter | ||||
/// @brief the exponent bit length of double/fp64 is 11 | /// @brief the exponent bit length of double/fp64 is 11 | ||||
constexpr uint16_t kFp64ExpLen = 11; | |||||
constexpr uint16_t kFp64ExpLen = 11U; | |||||
/// @ingroup fp64 basic parameter | /// @ingroup fp64 basic parameter | ||||
/// @brief the mantissa bit length of double/fp64 is 52 | /// @brief the mantissa bit length of double/fp64 is 52 | ||||
constexpr uint16_t kFp64ManLen = 52; | |||||
constexpr uint16_t kFp64ManLen = 52U; | |||||
/// @ingroup fp64 basic parameter | /// @ingroup fp64 basic parameter | ||||
/// @brief bit index of sign in double/fp64 is 63 | /// @brief bit index of sign in double/fp64 is 63 | ||||
constexpr uint16_t kFp64SignIndex = 63; | |||||
constexpr uint16_t kFp64SignIndex = 63U; | |||||
/// @ingroup fp64 basic parameter | /// @ingroup fp64 basic parameter | ||||
/// @brief sign mask of fp64 (1 000 (total 63bits 0)) | /// @brief sign mask of fp64 (1 000 (total 63bits 0)) | ||||
constexpr uint64_t kFp64SignMask = 0x8000000000000000LLu; | constexpr uint64_t kFp64SignMask = 0x8000000000000000LLu; | ||||
@@ -269,14 +269,14 @@ constexpr int16_t kInt16Max = 0x7FFF; | |||||
constexpr uint16_t kBitLen16Max = 0xFFFF; | constexpr uint16_t kBitLen16Max = 0xFFFF; | ||||
/// @ingroup integer special value judgment | /// @ingroup integer special value judgment | ||||
/// @brief maximum positive value of int32_t (0111 1111 1111 1111 1111 1111 1111 1111) | /// @brief maximum positive value of int32_t (0111 1111 1111 1111 1111 1111 1111 1111) | ||||
constexpr int32_t kInt32Max = 0x7FFFFFFFu; | |||||
constexpr int32_t kInt32Max = 0x7FFFFFFF; | |||||
/// @ingroup integer special value judgment | /// @ingroup integer special value judgment | ||||
/// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111) | /// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111) | ||||
constexpr uint32_t kBitLen32Max = 0xFFFFFFFFu; | constexpr uint32_t kBitLen32Max = 0xFFFFFFFFu; | ||||
/// @ingroup integer special value judgment | /// @ingroup integer special value judgment | ||||
/// @brief maximum positive value of int64_t | /// @brief maximum positive value of int64_t | ||||
/// (0111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) | /// (0111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) | ||||
constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFFu; | |||||
constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFF; | |||||
/// @ingroup integer special value judgment | /// @ingroup integer special value judgment | ||||
/// @brief maximum value of a data with 64 bits length | /// @brief maximum value of a data with 64 bits length | ||||
/// (1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) | /// (1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) | ||||
@@ -62,7 +62,7 @@ public: | |||||
/// @return others optimized failed | /// @return others optimized failed | ||||
/// @author | /// @author | ||||
/// | /// | ||||
static Status Run(const ge::ComputeGraphPtr &graph, std::vector<std::pair<std::string, GraphPass *>> &passes); | |||||
static Status Run(const ge::ComputeGraphPtr &graph, std::vector<std::pair<std::string, GraphPass *>> &names_to_passes); | |||||
~PassManager(); | ~PassManager(); | ||||
@@ -99,9 +99,9 @@ Status PreChecker::CheckName(OpId id) { | |||||
if (id != v.first && info.name == v.second.name) { | if (id != v.first && info.name == v.second.name) { | ||||
Cause cause; | Cause cause; | ||||
cause.code = ErrorCode::NAME_REPEATED; | cause.code = ErrorCode::NAME_REPEATED; | ||||
cause.message = "The name is repeated."; | |||||
cause.message = "The name is repeated in the graph."; | |||||
GELOGI("Name %s repeated.", info.name.c_str()); | |||||
GELOGE(FAILED, "opname %s repeated, same name op in the graph", info.name.c_str()); | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19009", {"opname"}, {info.name}); | ErrorManager::GetInstance().ATCReportErrMessage("E19009", {"opname"}, {info.name}); | ||||
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed."); | GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed."); | ||||
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(v.first, cause), "[Add][Cause] failed."); | GE_RETURN_WITH_LOG_IF_ERROR(AddCause(v.first, cause), "[Add][Cause] failed."); | ||||
@@ -200,7 +200,7 @@ FMK_FUNC_HOST_VISIBILITY bool PreChecker::HasError() { | |||||
return false; | return false; | ||||
} | } | ||||
Status PreChecker::Save(string file) { | |||||
Status PreChecker::Save(const string &file) { | |||||
uint32_t fail_num = 0; | uint32_t fail_num = 0; | ||||
for (auto id : ops_) { | for (auto id : ops_) { | ||||
if (HasError(id)) { | if (HasError(id)) { | ||||
@@ -250,7 +250,7 @@ Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string | |||||
Cause cause; | Cause cause; | ||||
cause.code = ErrorCode::TYPE_UNSUPPORTED; | cause.code = ErrorCode::TYPE_UNSUPPORTED; | ||||
cause.message = "The type is not supported."; | cause.message = "The type is not supported."; | ||||
GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); | |||||
GELOGI("Check op[%s]'s type[%s], The type is not supported.", name.c_str(), type.c_str()); | |||||
if (!is_tensorflow) { | if (!is_tensorflow) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); | ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); | ||||
} | } | ||||
@@ -265,9 +265,11 @@ Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string | |||||
cause.code = ErrorCode::TYPE_UNSUPPORTED; | cause.code = ErrorCode::TYPE_UNSUPPORTED; | ||||
cause.message = "The type is not supported."; | cause.message = "The type is not supported."; | ||||
GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); | |||||
if (!is_tensorflow) { | if (!is_tensorflow) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); | ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); | ||||
GELOGE(FAILED, "Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); | |||||
} else { | |||||
GELOGI("Check op[%s]'s type[%s] is not supported.", name.c_str(), type.c_str()); | |||||
} | } | ||||
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed."); | GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed."); | ||||
} | } | ||||
@@ -142,7 +142,7 @@ class PreChecker { | |||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
* @brief Save inspection results(JSON) | * @brief Save inspection results(JSON) | ||||
*/ | */ | ||||
Status Save(string file); | |||||
Status Save(const string &file); | |||||
private: | private: | ||||
/** | /** | ||||
@@ -425,7 +425,8 @@ Status ProtoFileParser::FindConflictLine(const char *proto_file, int identifier, | |||||
void ProtoFileParser::CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file, | void ProtoFileParser::CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file, | ||||
std::map<std::string, std::pair<int, string>> &caffe_op_identifier_map, | std::map<std::string, std::pair<int, string>> &caffe_op_identifier_map, | ||||
std::map<std::string, std::pair<int, string>> &custom_op_identifier_map) { | std::map<std::string, std::pair<int, string>> &custom_op_identifier_map) { | ||||
for (auto iter = custom_op_identifier_map.begin(); iter != custom_op_identifier_map.end(); ++iter) { | |||||
std::map<std::string, std::pair<int, string>>::const_iterator iter = custom_op_identifier_map.begin(); | |||||
for (; iter != custom_op_identifier_map.end(); ++iter) { | |||||
if (caffe_op_identifier_map.count(iter->first) > 0) { | if (caffe_op_identifier_map.count(iter->first) > 0) { | ||||
string message_name = iter->first; | string message_name = iter->first; | ||||
auto caffe_pair = caffe_op_identifier_map[iter->first]; | auto caffe_pair = caffe_op_identifier_map[iter->first]; | ||||
@@ -452,7 +453,8 @@ void ProtoFileParser::CheckConflictOp(const char *caffe_proto_file, const char * | |||||
void ProtoFileParser::CheckConflictIdentifier(const char *caffe_proto_file, const char *custom_proto_file, | void ProtoFileParser::CheckConflictIdentifier(const char *caffe_proto_file, const char *custom_proto_file, | ||||
std::map<int, std::pair<string, string>> caffe_identifier_op_map, | std::map<int, std::pair<string, string>> caffe_identifier_op_map, | ||||
std::map<int, std::pair<string, string>> custom_identifier_op_map) { | std::map<int, std::pair<string, string>> custom_identifier_op_map) { | ||||
for (auto iter = custom_identifier_op_map.begin(); iter != custom_identifier_op_map.end(); ++iter) { | |||||
std::map<int, std::pair<string, string>>::const_iterator iter = custom_identifier_op_map.begin(); | |||||
for (; iter != custom_identifier_op_map.end(); ++iter) { | |||||
if (caffe_identifier_op_map.count(iter->first) > 0) { | if (caffe_identifier_op_map.count(iter->first) > 0) { | ||||
int identifier = iter->first; | int identifier = iter->first; | ||||
auto caffe_pair = caffe_identifier_op_map[iter->first]; | auto caffe_pair = caffe_identifier_op_map[iter->first]; | ||||
@@ -97,7 +97,8 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { | |||||
return false; | return false; | ||||
} | } | ||||
OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( | OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( | ||||
domi::TENSORFLOW, GetOmOptype(reg_data), [=]() -> std::shared_ptr<OpParser> { return tf_parser_adapter; }); | |||||
domi::TENSORFLOW, GetOmOptype(reg_data), [tf_parser_adapter]() -> std::shared_ptr<OpParser> | |||||
{ return tf_parser_adapter; }); | |||||
} | } | ||||
if (reg_data.GetFusionParseParamFn() != nullptr || reg_data.GetFusionParseParamByOpFn() != nullptr) { | if (reg_data.GetFusionParseParamFn() != nullptr || reg_data.GetFusionParseParamByOpFn() != nullptr) { | ||||
bool is_registed = factory->OpParserIsRegistered(GetOmOptype(reg_data), true); | bool is_registed = factory->OpParserIsRegistered(GetOmOptype(reg_data), true); | ||||
@@ -115,7 +116,7 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { | |||||
} | } | ||||
OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( | OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( | ||||
domi::TENSORFLOW, GetOmOptype(reg_data), | domi::TENSORFLOW, GetOmOptype(reg_data), | ||||
[=]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; }, true); | |||||
[tf_fusion_parser_adapter]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; }, true); | |||||
} | } | ||||
} else { | } else { | ||||
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(reg_data.GetFrameworkType()); | std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(reg_data.GetFrameworkType()); | ||||
@@ -105,7 +105,7 @@ void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) { | |||||
GELOGI("Enter get custom op path schedule"); | GELOGI("Enter get custom op path schedule"); | ||||
std::string fmk_type; | std::string fmk_type; | ||||
domi::FrameworkType type = domi::TENSORFLOW; | domi::FrameworkType type = domi::TENSORFLOW; | ||||
auto it = options_.find(FRAMEWORK_TYPE); | |||||
std::map<string, string>::const_iterator it = options_.find(FRAMEWORK_TYPE); | |||||
if (it != options_.end()) { | if (it != options_.end()) { | ||||
type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10)); | type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10)); | ||||
} | } | ||||
@@ -227,9 +227,9 @@ def convert_subgraphs(graph_def, filename): | |||||
print(graph_def_library.graph_def[i]) | print(graph_def_library.graph_def[i]) | ||||
# Write to prototxt | # Write to prototxt | ||||
graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename))) | |||||
print("graph_def_file: ", graph_def_file) | |||||
try: | try: | ||||
graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename))) | |||||
print("graph_def_file: ", graph_def_file) | |||||
with open(graph_def_file, "w") as f: | with open(graph_def_file, "w") as f: | ||||
print(graph_def_library, file=f) | print(graph_def_library, file=f) | ||||
except IOError: | except IOError: | ||||
@@ -261,18 +261,18 @@ if __name__ == '__main__': | |||||
model = '' | model = '' | ||||
try: | try: | ||||
opts, args = getopt.getopt(sys.argv[1:], '-v-h-m:', ['version', 'help', 'model=']) | opts, args = getopt.getopt(sys.argv[1:], '-v-h-m:', ['version', 'help', 'model=']) | ||||
for opt_name, opt_value in opts: | |||||
if opt_name in ('-m', '--model'): | |||||
model = opt_value | |||||
print("INFO: Input model file is", model) | |||||
convert_graphs(model) | |||||
elif opt_name in ('-h', '--help'): | |||||
usage() | |||||
break | |||||
elif opt_name in ('-v', '--version'): | |||||
print("version 1.0.0") | |||||
break | |||||
except getopt.GetoptError: | except getopt.GetoptError: | ||||
print("ERROR: Input parameters is invalid, use '--help' to view the help.") | print("ERROR: Input parameters is invalid, use '--help' to view the help.") | ||||
for opt_name, opt_value in opts: | |||||
if opt_name in ('-m', '--model'): | |||||
model = opt_value | |||||
print("INFO: Input model file is", model) | |||||
convert_graphs(model) | |||||
elif opt_name in ('-h', '--help'): | |||||
usage() | |||||
break | |||||
elif opt_name in ('-v', '--version'): | |||||
print("version 1.0.0") | |||||
break | |||||
if len(sys.argv) == 1: | if len(sys.argv) == 1: | ||||
print("INFO: Please specify the input parameters, and use '--help' to view the help.") | print("INFO: Please specify the input parameters, and use '--help' to view the help.") |
@@ -71,7 +71,7 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_ | |||||
}; | }; | ||||
int32_t datatype_val_size = 0; | int32_t datatype_val_size = 0; | ||||
auto iter = datatype_val_size_map.find(data_type); | |||||
std::map<uint32_t, int32_t>::const_iterator iter = datatype_val_size_map.find(data_type); | |||||
if (iter != datatype_val_size_map.end()) { | if (iter != datatype_val_size_map.end()) { | ||||
datatype_val_size = iter->second; | datatype_val_size = iter->second; | ||||
} else { | } else { | ||||
@@ -91,7 +91,7 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_ | |||||
if (data_type == OnnxDataType::STRING) { | if (data_type == OnnxDataType::STRING) { | ||||
tensor.SetData(tensor_proto.raw_data().c_str()); | tensor.SetData(tensor_proto.raw_data().c_str()); | ||||
} else { | } else { | ||||
tensor.SetData(reinterpret_cast<const uint8_t *>(tensor_proto.raw_data().c_str()), | |||||
tensor.SetData(PtrToPtr<const char_t, const uint8_t>(tensor_proto.raw_data().c_str()), | |||||
tensor_proto.raw_data().size()); | tensor_proto.raw_data().size()); | ||||
} | } | ||||
GELOGD("Raw data size is : %zu", tensor_proto.raw_data().size()); | GELOGD("Raw data size is : %zu", tensor_proto.raw_data().size()); | ||||
@@ -65,7 +65,7 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { | |||||
*(addr_trans.get() + i) = static_cast<bool>( | *(addr_trans.get() + i) = static_cast<bool>( | ||||
std::fabs(*((addr).get() + i)) > std::numeric_limits<T>::epsilon()); | std::fabs(*((addr).get() + i)) > std::numeric_limits<T>::epsilon()); | ||||
} | } | ||||
(tensor).SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), (count) * sizeof(bool)); | |||||
(tensor).SetData(PtrToPtr<bool, uint8_t>(addr_trans.get()), (count) * sizeof(bool)); | |||||
break; | break; | ||||
} | } | ||||
#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ | #define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ | ||||
@@ -32,7 +32,6 @@ | |||||
#include "onnx_op_parser.h" | #include "onnx_op_parser.h" | ||||
#include "onnx_util.h" | #include "onnx_util.h" | ||||
#include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
#include "parser/common/pre_checker.h" | |||||
#include "parser/common/acl_graph_parser_util.h" | #include "parser/common/acl_graph_parser_util.h" | ||||
#include "parser/common/model_saver.h" | #include "parser/common/model_saver.h" | ||||
#include "parser/common/parser_utils.h" | #include "parser/common/parser_utils.h" | ||||
@@ -240,7 +239,7 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra | |||||
if (node->GetOpDesc() == nullptr) { | if (node->GetOpDesc() == nullptr) { | ||||
continue; | continue; | ||||
} | } | ||||
node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName()); | |||||
node->GetOpDesc()->SetName(OnnxUtil::GenUniqueNodeName(sub_graph->GetName(), node->GetName())); | |||||
} | } | ||||
auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph); | auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph); | ||||
@@ -384,7 +383,7 @@ Status OnnxModelParser::ConstructOriType(const ge::onnx::NodeProto *node_proto, | |||||
std::string domain = node_proto->domain(); | std::string domain = node_proto->domain(); | ||||
int64_t version = 0; | int64_t version = 0; | ||||
if (!domain.empty()) { | if (!domain.empty()) { | ||||
auto it = domain_verseion_.find(domain); | |||||
std::map<std::string, int64_t>::const_iterator it = domain_verseion_.find(domain); | |||||
if (it != domain_verseion_.end()) { | if (it != domain_verseion_.end()) { | ||||
version = it->second; | version = it->second; | ||||
} else { | } else { | ||||
@@ -493,14 +492,14 @@ Status OnnxModelParser::SetOperatorInputs() { | |||||
std::vector<std::pair<std::string, int>> &output_node_indexs = out_iter->second; | std::vector<std::pair<std::string, int>> &output_node_indexs = out_iter->second; | ||||
for (auto input_node_index : input_node_indexs) { | for (auto input_node_index : input_node_indexs) { | ||||
for (auto out_node_index : output_node_indexs) { | for (auto out_node_index : output_node_indexs) { | ||||
auto input_op_iter = name_operator_.find(input_node_index.first); | |||||
std::map<std::string, ge::Operator>::const_iterator input_op_iter = name_operator_.find(input_node_index.first); | |||||
if (input_op_iter == name_operator_.end()) { | if (input_op_iter == name_operator_.end()) { | ||||
REPORT_INNER_ERROR("E19999", "Node: %s can not find in name_operator map.", input_node_index.first.c_str()); | REPORT_INNER_ERROR("E19999", "Node: %s can not find in name_operator map.", input_node_index.first.c_str()); | ||||
GELOGE(INTERNAL_ERROR, "[Check][Param] Node: %s can not find in name_operator map.", | GELOGE(INTERNAL_ERROR, "[Check][Param] Node: %s can not find in name_operator map.", | ||||
input_node_index.first.c_str()); | input_node_index.first.c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
auto output_op_iter = name_operator_.find(out_node_index.first); | |||||
std::map<std::string, ge::Operator>::const_iterator output_op_iter = name_operator_.find(out_node_index.first); | |||||
if (output_op_iter == name_operator_.end()) { | if (output_op_iter == name_operator_.end()) { | ||||
REPORT_INNER_ERROR("E19999", "Node: %s can not find in name_operator map.", out_node_index.first.c_str()); | REPORT_INNER_ERROR("E19999", "Node: %s can not find in name_operator map.", out_node_index.first.c_str()); | ||||
GELOGE(INTERNAL_ERROR, "[Check][Param] Node: %s can not find in name_operator map.", | GELOGE(INTERNAL_ERROR, "[Check][Param] Node: %s can not find in name_operator map.", | ||||
@@ -594,6 +593,7 @@ Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge:: | |||||
} | } | ||||
Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { | Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { | ||||
bool has_error = false; | |||||
for (int i = 0; i < onnx_graph.node_size(); i++) { | for (int i = 0; i < onnx_graph.node_size(); i++) { | ||||
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | ||||
std::string node_name = node_proto->name(); | std::string node_name = node_proto->name(); | ||||
@@ -605,7 +605,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
GELOGE(status, "[Adapt][OpType] Adapter op type for ori type %s failed.", ori_type.c_str()); | GELOGE(status, "[Adapt][OpType] Adapter op type for ori type %s failed.", ori_type.c_str()); | ||||
REPORT_CALL_ERROR("E19999", "Adapter op type for ori type %s failed.", ori_type.c_str()); | REPORT_CALL_ERROR("E19999", "Adapter op type for ori type %s failed.", ori_type.c_str()); | ||||
return status; | |||||
has_error = true; | |||||
continue; | |||||
} | } | ||||
node_proto->set_op_type(ori_type); | node_proto->set_op_type(ori_type); | ||||
@@ -616,7 +617,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
GELOGE(status, "[Trans][Node] Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); | GELOGE(status, "[Trans][Node] Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); | ||||
REPORT_CALL_ERROR("E19999", "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); | REPORT_CALL_ERROR("E19999", "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); | ||||
return status; | |||||
has_error = true; | |||||
continue; | |||||
} | } | ||||
// 7. op parser | // 7. op parser | ||||
@@ -627,7 +629,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
status = ParseOpParam(node_proto, op, op_parser); | status = ParseOpParam(node_proto, op, op_parser); | ||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); | GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); | ||||
return status; | |||||
has_error = true; | |||||
continue; | |||||
} | } | ||||
GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", | GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", | ||||
@@ -638,7 +641,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
if (graph_status != ge::GRAPH_SUCCESS) { | if (graph_status != ge::GRAPH_SUCCESS) { | ||||
GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str()); | GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str()); | ||||
REPORT_CALL_ERROR("E19999", "Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str()); | REPORT_CALL_ERROR("E19999", "Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str()); | ||||
return FAILED; | |||||
has_error = true; | |||||
continue; | |||||
} | } | ||||
name_operator_[ParserUtils::GetOperatorName(op)] = op; | name_operator_[ParserUtils::GetOperatorName(op)] = op; | ||||
@@ -647,11 +651,12 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
REPORT_INNER_ERROR("E19999", "ConstructInputOutputContext failed."); | REPORT_INNER_ERROR("E19999", "ConstructInputOutputContext failed."); | ||||
GELOGE(status, "[Construct][RelationMap] to input and output failed."); | GELOGE(status, "[Construct][RelationMap] to input and output failed."); | ||||
return status; | |||||
has_error = true; | |||||
continue; | |||||
} | } | ||||
} | } | ||||
GELOGI("Parse all node proto success."); | |||||
return SUCCESS; | |||||
GELOGI("Parse all node proto end."); | |||||
return has_error ? FAILED : SUCCESS; | |||||
} | } | ||||
Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { | Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { | ||||
@@ -665,7 +670,7 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve | |||||
} | } | ||||
} | } | ||||
for (auto in_name : input_node_names_) { | for (auto in_name : input_node_names_) { | ||||
auto in_op = name_operator_.find(in_name); | |||||
std::map<std::string, ge::Operator>::const_iterator in_op = name_operator_.find(in_name); | |||||
if (in_op == name_operator_.end()) { | if (in_op == name_operator_.end()) { | ||||
GELOGE(PARAM_INVALID, "[Get][Inputs] Model assigned input node name: %s can not find in graph.", | GELOGE(PARAM_INVALID, "[Get][Inputs] Model assigned input node name: %s can not find in graph.", | ||||
in_name.c_str()); | in_name.c_str()); | ||||
@@ -682,7 +687,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve | |||||
Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops, | Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops, | ||||
ParserUtils::OutputMapping &out_tensor_to_nodes) { | ParserUtils::OutputMapping &out_tensor_to_nodes) { | ||||
for (auto output_name : output_node_names_) { | for (auto output_name : output_node_names_) { | ||||
auto itr = outputs_map_.find(output_name); | |||||
std::map<std::string, std::vector<std::pair<std::string, int>>>::const_iterator itr = | |||||
outputs_map_.find(output_name); | |||||
if (itr == outputs_map_.end()) { | if (itr == outputs_map_.end()) { | ||||
GELOGE(PARAM_INVALID, "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | GELOGE(PARAM_INVALID, "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | ||||
REPORT_INNER_ERROR("E19999", "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | REPORT_INNER_ERROR("E19999", "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | ||||
@@ -692,7 +698,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vec | |||||
std::vector<std::pair<std::string, int>> node_names_indexes = itr->second; | std::vector<std::pair<std::string, int>> node_names_indexes = itr->second; | ||||
for (const auto &node_name_index : node_names_indexes) { | for (const auto &node_name_index : node_names_indexes) { | ||||
auto node_name = node_name_index.first; | auto node_name = node_name_index.first; | ||||
auto out_op_itr = name_operator_.find(node_name); | |||||
std::map<std::string, ge::Operator>::const_iterator out_op_itr = name_operator_.find(node_name); | |||||
if (out_op_itr == name_operator_.end()) { | if (out_op_itr == name_operator_.end()) { | ||||
GELOGE(PARAM_INVALID, "[Get][Operator] Can not find operator: %s in graph.", node_name.c_str()); | GELOGE(PARAM_INVALID, "[Get][Operator] Can not find operator: %s in graph.", node_name.c_str()); | ||||
REPORT_INNER_ERROR("E19999", "Can not find operator: %s in graph.", node_name.c_str()); | REPORT_INNER_ERROR("E19999", "Can not find operator: %s in graph.", node_name.c_str()); | ||||
@@ -749,6 +755,14 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph( | |||||
while (!onnx_graph_tasks.empty()) { | while (!onnx_graph_tasks.empty()) { | ||||
ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front(); | ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front(); | ||||
onnx_graph_tasks.pop(); | onnx_graph_tasks.pop(); | ||||
std::string graph_name; | |||||
for (const auto &graph_iter : name_to_onnx_graph) { | |||||
if (graph_iter.second == onnx_graph) { | |||||
graph_name = graph_iter.first; | |||||
break; | |||||
} | |||||
} | |||||
for (int i = 0; i < onnx_graph->node_size(); i++) { | for (int i = 0; i < onnx_graph->node_size(); i++) { | ||||
ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i); | ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i); | ||||
if (node_proto->name().empty()) { | if (node_proto->name().empty()) { | ||||
@@ -766,7 +780,8 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph( | |||||
} | } | ||||
std::vector<ge::onnx::GraphProto *> onnx_graphs; | std::vector<ge::onnx::GraphProto *> onnx_graphs; | ||||
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_subgraph; | std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_subgraph; | ||||
if (subgraph_adapter->AdaptAndFindAllSubgraphs(node_proto, onnx_graphs, name_to_onnx_subgraph) != SUCCESS) { | |||||
if (subgraph_adapter->AdaptAndFindAllSubgraphs( | |||||
node_proto, onnx_graphs, name_to_onnx_subgraph, graph_name) != SUCCESS) { | |||||
GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str()); | GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str()); | ||||
REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str()); | REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str()); | ||||
return FAILED; | return FAILED; | ||||
@@ -815,7 +830,7 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||||
bool is_subgraph = (arg.parent_node != nullptr) ? true : false; | bool is_subgraph = (arg.parent_node != nullptr) ? true : false; | ||||
if (arg.onnx_graph == nullptr) { | if (arg.onnx_graph == nullptr) { | ||||
auto itr = name_to_onnx_graph.find(arg.graph_name); | |||||
std::map<std::string, ge::onnx::GraphProto *>::const_iterator itr = name_to_onnx_graph.find(arg.graph_name); | |||||
if (itr == name_to_onnx_graph.end()) { | if (itr == name_to_onnx_graph.end()) { | ||||
GELOGE(FAILED, "[Find][OnnxGraph] Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); | GELOGE(FAILED, "[Find][OnnxGraph] Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); | ||||
REPORT_INNER_ERROR("E19999", "Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); | REPORT_INNER_ERROR("E19999", "Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); | ||||
@@ -38,6 +38,7 @@ | |||||
#include "omg/parser/op_parser.h" | #include "omg/parser/op_parser.h" | ||||
#include "omg/parser/weights_parser.h" | #include "omg/parser/weights_parser.h" | ||||
#include "common/parser_utils.h" | #include "common/parser_utils.h" | ||||
#include "common/pre_checker.h" | |||||
#include "proto/onnx/ge_onnx.pb.h" | #include "proto/onnx/ge_onnx.pb.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -81,6 +82,18 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
} | } | ||||
bool HasError() override { | |||||
return PreChecker::Instance().HasError(); | |||||
} | |||||
Status Save(const string &file) override { | |||||
return PreChecker::Instance().Save(file); | |||||
} | |||||
void Clear() override { | |||||
PreChecker::Instance().Clear(); | |||||
} | |||||
private: | private: | ||||
Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); | Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); | ||||
@@ -96,7 +109,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); | Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); | ||||
Status AdapterOpType(const ge::onnx::NodeProto *node_proto, std::string &ori_type, std::string &om_type); | |||||
Status AdapterOpType(const ge::onnx::NodeProto *node_proto, std::string &ori_type, std::string &op_type); | |||||
Status TransNodeToOperator(const ge::onnx::NodeProto *node_proto, ge::Operator &op, const string &op_type) const; | Status TransNodeToOperator(const ge::onnx::NodeProto *node_proto, ge::Operator &op, const string &op_type) const; | ||||
@@ -106,7 +119,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
Status GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops); | Status GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops); | ||||
Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &outputs, | |||||
Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops, | |||||
ParserUtils::OutputMapping &out_tensor_to_nodes); | ParserUtils::OutputMapping &out_tensor_to_nodes); | ||||
Status Prechecker(ge::onnx::GraphProto &onnx_graph); | Status Prechecker(ge::onnx::GraphProto &onnx_graph); | ||||
@@ -115,7 +128,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const; | Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const; | ||||
Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph); | |||||
Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph); | |||||
Status ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); | Status ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); | ||||
@@ -161,6 +174,18 @@ class PARSER_FUNC_VISIBILITY OnnxWeightsParser : public domi::WeightsParser { | |||||
(void)graph; | (void)graph; | ||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
} | } | ||||
bool HasError() override { | |||||
return PreChecker::Instance().HasError(); | |||||
} | |||||
Status Save(const string &file) override { | |||||
return PreChecker::Instance().Save(file); | |||||
} | |||||
void Clear() override { | |||||
PreChecker::Instance().Clear(); | |||||
} | |||||
}; | }; | ||||
} // namespace domi | } // namespace domi | ||||
#endif // PARSER_ONNX_ONNX_PARSER_H_ | #endif // PARSER_ONNX_ONNX_PARSER_H_ |
@@ -45,4 +45,8 @@ void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &orig | |||||
const std::string &parent_node_name, std::string &unique_subgraph_name) { | const std::string &parent_node_name, std::string &unique_subgraph_name) { | ||||
unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; | unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; | ||||
} | } | ||||
std::string OnnxUtil::GenUniqueNodeName(const std::string &graph_name, const std::string &node_name) { | |||||
return graph_name + "/" + node_name; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -54,6 +54,7 @@ class OnnxUtil { | |||||
static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); | static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); | ||||
static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, | static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, | ||||
const std::string &parent_node_name, std::string &unique_subgraph_name); | const std::string &parent_node_name, std::string &unique_subgraph_name); | ||||
static std::string GenUniqueNodeName(const std::string &graph_name, const std::string &node_name); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -15,6 +15,7 @@ | |||||
*/ | */ | ||||
#include "if_subgraph_adapter.h" | #include "if_subgraph_adapter.h" | ||||
#include <unordered_set> | |||||
#include "subgraph_adapter_factory.h" | #include "subgraph_adapter_factory.h" | ||||
#include "common/util.h" | #include "common/util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -27,12 +28,12 @@ const int kIfNodeAttrSize = 2; | |||||
} // namespace | } // namespace | ||||
domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | ||||
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) { | |||||
GE_CHECK_NOTNULL(parent_node); | GE_CHECK_NOTNULL(parent_node); | ||||
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), | GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), | ||||
parent_node->op_type().c_str()); | parent_node->op_type().c_str()); | ||||
auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph); | |||||
auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "[Parse][Node] Parse if node failed."); | GELOGE(ret, "[Parse][Node] Parse if node failed."); | ||||
REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); | REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); | ||||
@@ -44,7 +45,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||||
domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | ||||
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) { | |||||
if (parent_node->attribute_size() != kIfNodeAttrSize) { | if (parent_node->attribute_size() != kIfNodeAttrSize) { | ||||
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | ||||
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | ||||
@@ -67,7 +68,11 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
std::string unique_subgraph_name; | std::string unique_subgraph_name; | ||||
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, parent_node->name(), unique_subgraph_name); | |||||
std::string node_name = parent_node->name(); | |||||
if (!parent_graph_name.empty()) { | |||||
node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name); | |||||
} | |||||
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, node_name, unique_subgraph_name); | |||||
GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str()); | GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str()); | ||||
ge::onnx::GraphProto *onnx_graph = attribute->mutable_g(); | ge::onnx::GraphProto *onnx_graph = attribute->mutable_g(); | ||||
name_to_onnx_graph[unique_subgraph_name] = onnx_graph; | name_to_onnx_graph[unique_subgraph_name] = onnx_graph; | ||||
@@ -91,8 +96,8 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||||
domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, | domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, | ||||
std::set<std::string> &all_inputs) const { | std::set<std::string> &all_inputs) const { | ||||
std::set<std::string> graph_inputs; | |||||
std::set<std::string> graph_outputs; | |||||
std::unordered_set<std::string> graph_inputs; | |||||
std::unordered_set<std::string> graph_outputs; | |||||
for (int i = 0; i < onnx_graph.node_size(); i++) { | for (int i = 0; i < onnx_graph.node_size(); i++) { | ||||
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | ||||
for (int j = 0; j < node_proto->input_size(); j++) { | for (int j = 0; j < node_proto->input_size(); j++) { | ||||
@@ -102,10 +107,12 @@ domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx | |||||
graph_outputs.emplace(node_proto->output(j)); | graph_outputs.emplace(node_proto->output(j)); | ||||
} | } | ||||
} | } | ||||
std::unordered_set<std::string> graph_initializer_tensors; | |||||
for (int32_t i = 0; i < onnx_graph.initializer_size(); i++) { | |||||
graph_initializer_tensors.emplace(onnx_graph.initializer(i).name()); | |||||
} | |||||
for (const auto &input : graph_inputs) { | for (const auto &input : graph_inputs) { | ||||
auto out_iter = graph_outputs.find(input); | |||||
if (out_iter == graph_outputs.end()) { | |||||
if (graph_outputs.count(input) == 0 && graph_initializer_tensors.count(input) == 0) { | |||||
// Record input node need to be constructed | // Record input node need to be constructed | ||||
all_inputs.emplace(input); | all_inputs.emplace(input); | ||||
} | } | ||||
@@ -24,13 +24,15 @@ | |||||
namespace ge { | namespace ge { | ||||
class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | ||||
public: | public: | ||||
domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, | |||||
domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_node, | |||||
std::vector<ge::onnx::GraphProto *> &onnx_graphs, | std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) override; | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | |||||
const std::string &parent_graph_name = "") override; | |||||
private: | private: | ||||
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph); | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | |||||
const std::string &parent_graph_name); | |||||
domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; | domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; | ||||
void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph) const; | void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph) const; | ||||
void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node) const; | void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node) const; | ||||
@@ -49,10 +49,12 @@ class PARSER_FUNC_VISIBILITY SubgraphAdapter { | |||||
/// @return FAILED Parse failed | /// @return FAILED Parse failed | ||||
virtual domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, | virtual domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, | ||||
std::vector<ge::onnx::GraphProto *> &onnx_graphs, | std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | |||||
const std::string &parent_graph_name = "") { | |||||
(void)parent_op; | (void)parent_op; | ||||
(void)onnx_graphs; | (void)onnx_graphs; | ||||
(void)name_to_onnx_graph; | (void)name_to_onnx_graph; | ||||
(void)parent_graph_name; | |||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
} | } | ||||
}; | }; | ||||
@@ -26,7 +26,7 @@ SubgraphAdapterFactory* SubgraphAdapterFactory::Instance() { | |||||
std::shared_ptr<SubgraphAdapter> SubgraphAdapterFactory::CreateSubgraphAdapter( | std::shared_ptr<SubgraphAdapter> SubgraphAdapterFactory::CreateSubgraphAdapter( | ||||
const std::string &op_type) { | const std::string &op_type) { | ||||
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create SubgraphAdapter. | // First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create SubgraphAdapter. | ||||
auto iter = subgraph_adapter_creator_map_.find(op_type); | |||||
std::map<std::string, CREATOR_FUN>::const_iterator iter = subgraph_adapter_creator_map_.find(op_type); | |||||
if (iter != subgraph_adapter_creator_map_.end()) { | if (iter != subgraph_adapter_creator_map_.end()) { | ||||
return iter->second(); | return iter->second(); | ||||
} | } | ||||
@@ -161,7 +161,6 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp | |||||
if (node_def->input(i).find("^") != string::npos) { | if (node_def->input(i).find("^") != string::npos) { | ||||
// Control input | // Control input | ||||
const string normalized = node_names.Renormalize(node_def->input(i).substr(1)); | const string normalized = node_names.Renormalize(node_def->input(i).substr(1)); | ||||
GE_IF_BOOL_EXEC(normalized.empty(), | GE_IF_BOOL_EXEC(normalized.empty(), | ||||
REPORT_INNER_ERROR("E19999", "Could not remap control input %s of node %s in function %s", | REPORT_INNER_ERROR("E19999", "Could not remap control input %s of node %s in function %s", | ||||
node_def->input(i).c_str(), node_def->name().c_str(), name.c_str()); | node_def->input(i).c_str(), node_def->name().c_str(), name.c_str()); | ||||
@@ -172,7 +171,6 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp | |||||
*node_def->mutable_input(i) = "^" + normalized; | *node_def->mutable_input(i) = "^" + normalized; | ||||
} else { | } else { | ||||
const auto iter = tensor_renaming.find(node_def->input(i)); | const auto iter = tensor_renaming.find(node_def->input(i)); | ||||
GE_IF_BOOL_EXEC(iter == tensor_renaming.end(), | GE_IF_BOOL_EXEC(iter == tensor_renaming.end(), | ||||
REPORT_INNER_ERROR("E19999", "Could not remap input %s of node %s in function %s", | REPORT_INNER_ERROR("E19999", "Could not remap input %s of node %s in function %s", | ||||
node_def->input(i).c_str(), node_def->name().c_str(), name.c_str()); | node_def->input(i).c_str(), node_def->name().c_str(), name.c_str()); | ||||
@@ -188,14 +186,12 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp | |||||
// Remap return values. | // Remap return values. | ||||
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { | for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { | ||||
const string &ret_name = fdef->signature().output_arg(r).name(); | const string &ret_name = fdef->signature().output_arg(r).name(); | ||||
GE_IF_BOOL_EXEC(ret_name.empty(), | GE_IF_BOOL_EXEC(ret_name.empty(), | ||||
REPORT_INNER_ERROR("E19999", "Missing output %d to function %s", r, name.c_str()); | REPORT_INNER_ERROR("E19999", "Missing output %d to function %s", r, name.c_str()); | ||||
GELOGE(domi::INTERNAL_ERROR, "Missing output %d to function %s .", r, name.c_str()); | GELOGE(domi::INTERNAL_ERROR, "Missing output %d to function %s .", r, name.c_str()); | ||||
return domi::INTERNAL_ERROR); | return domi::INTERNAL_ERROR); | ||||
const string &return_value = return_values[ret_name]; | const string &return_value = return_values[ret_name]; | ||||
GE_IF_BOOL_EXEC(return_value.empty(), | GE_IF_BOOL_EXEC(return_value.empty(), | ||||
REPORT_INNER_ERROR("E19999", "Could not remap return value %d ,%s of %s in function %s", r, | REPORT_INNER_ERROR("E19999", "Could not remap return value %d ,%s of %s in function %s", r, | ||||
ret_name.c_str(), return_value.c_str(), name.c_str()); | ret_name.c_str(), return_value.c_str(), name.c_str()); | ||||
@@ -204,12 +200,11 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp | |||||
return domi::INTERNAL_ERROR); | return domi::INTERNAL_ERROR); | ||||
const auto iter = tensor_renaming.find(return_value); | const auto iter = tensor_renaming.find(return_value); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iter == tensor_renaming.end(), | |||||
REPORT_INNER_ERROR("E19999", "can not find value[%s] in tensor_renaming map", | |||||
return_value.c_str()); | |||||
return domi::INTERNAL_ERROR, | |||||
"can not find value[%s] in tensor_renaming map.", return_value.c_str()); | |||||
if (iter == tensor_renaming.end()) { | |||||
REPORT_INNER_ERROR("E19999", "can not find value[%s] in tensor_renaming map", return_value.c_str()); | |||||
GELOGE(FAILED, "can not find value[%s] in tensor_renaming map.", return_value.c_str()); | |||||
return domi::INTERNAL_ERROR; | |||||
} | |||||
(*fdef->mutable_ret())[ret_name] = iter->second; | (*fdef->mutable_ret())[ret_name] = iter->second; | ||||
} | } | ||||
@@ -227,7 +222,7 @@ domi::Status GraphToFunctionDef::RecordResult(ge::ComputeGraphPtr graph, | |||||
GE_CHECK_NOTNULL(anchor); | GE_CHECK_NOTNULL(anchor); | ||||
GE_CHECK_NOTNULL(anchor->GetOwnerNode()->GetOpDesc()); | GE_CHECK_NOTNULL(anchor->GetOwnerNode()->GetOpDesc()); | ||||
int32_t type = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(anchor->GetIdx()).GetDataType(); | int32_t type = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(anchor->GetIdx()).GetDataType(); | ||||
auto iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type); | |||||
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type); | |||||
GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | ||||
REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported", | REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported", | ||||
type, anchor->GetIdx(), anchor->GetOwnerNode()->GetName().c_str(), | type, anchor->GetIdx(), anchor->GetOwnerNode()->GetName().c_str(), | ||||
@@ -304,7 +299,7 @@ domi::Status GraphToFunctionDef::RecordArg(ge::ComputeGraphPtr graph, const vect | |||||
GE_CHECK_NOTNULL_EXEC(tensor_desc_ptr, return domi::FAILED); | GE_CHECK_NOTNULL_EXEC(tensor_desc_ptr, return domi::FAILED); | ||||
int32_t type = tensor_desc_ptr->GetDataType(); | int32_t type = tensor_desc_ptr->GetDataType(); | ||||
auto iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type); | |||||
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type); | |||||
GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | ||||
REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported", | REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported", | ||||
type, anchor->GetIdx(), anchor->GetOwnerNode()->GetName().c_str(), | type, anchor->GetIdx(), anchor->GetOwnerNode()->GetName().c_str(), | ||||
@@ -325,8 +320,8 @@ domi::Status GraphToFunctionDef::RecordArg(ge::ComputeGraphPtr graph, const vect | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
(void)ge::AttrUtils::SetInt(op, "T", (int32_t)dtype); | |||||
(void)ge::AttrUtils::SetInt(op, "arg_index", (int32_t)index); | |||||
(void)ge::AttrUtils::SetInt(op, "T", static_cast<int32_t>(dtype)); | |||||
(void)ge::AttrUtils::SetInt(op, "arg_index", static_cast<int32_t>(index)); | |||||
ge::NodePtr arg_node = graph->AddNode(op); | ge::NodePtr arg_node = graph->AddNode(op); | ||||
GE_CHECK_NOTNULL(arg_node); | GE_CHECK_NOTNULL(arg_node); | ||||
bool node_exists = false; | bool node_exists = false; | ||||
@@ -378,7 +373,6 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
if (node->GetOpDesc()->GetType() == ge::parser::DATA) { | if (node->GetOpDesc()->GetType() == ge::parser::DATA) { | ||||
int64_t index = 0; | int64_t index = 0; | ||||
int64_t type = 1; | int64_t type = 1; | ||||
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID, | GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID, | ||||
"Get type attr failed"); | "Get type attr failed"); | ||||
@@ -400,7 +394,6 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph | |||||
if (node->GetOpDesc()->GetType() == ge::parser::NETOUTPUT) { | if (node->GetOpDesc()->GetType() == ge::parser::NETOUTPUT) { | ||||
int64_t index = 0; | int64_t index = 0; | ||||
int64_t type = 1; | int64_t type = 1; | ||||
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID, | GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID, | ||||
"Get type attr failed"); | "Get type attr failed"); | ||||
@@ -589,7 +582,10 @@ bool GraphToFunctionDef::FindAttrValue(const domi::tensorflow::NodeDef *node_def | |||||
void GraphToFunctionDef::AddNodeAttr(const string &attr_name, const domi::tensorflow::AttrValue &value, | void GraphToFunctionDef::AddNodeAttr(const string &attr_name, const domi::tensorflow::AttrValue &value, | ||||
domi::tensorflow::NodeDef *node_def) { | domi::tensorflow::NodeDef *node_def) { | ||||
GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null."); | |||||
if (node_def == nullptr) { | |||||
GELOGI("input parameter is null."); | |||||
return; | |||||
} | |||||
node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); | node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -52,7 +52,7 @@ class GraphToFunctionDef { | |||||
const string &name, FunctionDef *fdef); | const string &name, FunctionDef *fdef); | ||||
static domi::Status BuildFunctionDef(ge::ComputeGraphPtr &graph, | static domi::Status BuildFunctionDef(ge::ComputeGraphPtr &graph, | ||||
const string &nme_in, | |||||
const string &name_in, | |||||
FunctionDefLibrary *library, | FunctionDefLibrary *library, | ||||
NodeDef *call_node_def, | NodeDef *call_node_def, | ||||
vector<ge::InDataAnchorPtr> &in_anchor, | vector<ge::InDataAnchorPtr> &in_anchor, | ||||
@@ -15,7 +15,7 @@ | |||||
*/ | */ | ||||
#include "graph_optimizer.h" | #include "graph_optimizer.h" | ||||
#include "common/op_types.h" | |||||
#include "graph/op_types.h" | |||||
#include "common/types_map.h" | #include "common/types_map.h" | ||||
#include "common/util.h" | #include "common/util.h" | ||||
#include "framework/omg/parser/parser_inner_ctx.h" | #include "framework/omg/parser/parser_inner_ctx.h" | ||||
@@ -93,6 +93,7 @@ Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, const boo | |||||
GE_CHECK_NOTNULL(graph_); | GE_CHECK_NOTNULL(graph_); | ||||
for (auto node : graph_->GetDirectNode()) { | for (auto node : graph_->GetDirectNode()) { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) | GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) | ||||
string type; | string type; | ||||
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); | GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); | ||||
@@ -178,7 +179,7 @@ Status CollectNodeFuncs(vector<ge::NodePtr> &nodes, FunctionDefLibrary *library) | |||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib; | AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib; | ||||
GE_CHECK_NOTNULL(funcDefBytes.GetData()); | GE_CHECK_NOTNULL(funcDefBytes.GetData()); | ||||
string str(reinterpret_cast<char *>(funcDefBytes.GetData()), funcDefBytes.GetSize()); | |||||
string str(PtrToPtr<uint8_t, char_t>(funcDefBytes.GetData()), funcDefBytes.GetSize()); | |||||
GELOGI("FUNCDEF: Get function -> %s.", str.c_str()); GE_IF_BOOL_EXEC( | GELOGI("FUNCDEF: Get function -> %s.", str.c_str()); GE_IF_BOOL_EXEC( | ||||
funcLib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()), library->MergeFrom(funcLib))); | funcLib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()), library->MergeFrom(funcLib))); | ||||
} | } | ||||
@@ -206,9 +207,11 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) { | |||||
std::unique_ptr<FunctionDefLibrary> func_def_lib(new (std::nothrow) FunctionDefLibrary()); | std::unique_ptr<FunctionDefLibrary> func_def_lib(new (std::nothrow) FunctionDefLibrary()); | ||||
GE_CHECK_NOTNULL(func_def_lib); | GE_CHECK_NOTNULL(func_def_lib); | ||||
// convert graph to FunctionDef | // convert graph to FunctionDef | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(nodes.size() == 0, | |||||
REPORT_INNER_ERROR("E19999", "Param nodes size must greater than 0"); | |||||
return PARAM_INVALID, "node size must greater than 0 ."); | |||||
if (nodes.size() == 0) { | |||||
REPORT_INNER_ERROR("E19999", "Param nodes size must greater than 0"); | |||||
GELOGE(FAILED, "node size must greater than 0 ."); | |||||
return PARAM_INVALID; | |||||
} | |||||
GE_CHK_STATUS_RET(CollectNodeFuncs(nodes, func_def_lib.get()), "Collect functionDef in nodes failed."); | GE_CHK_STATUS_RET(CollectNodeFuncs(nodes, func_def_lib.get()), "Collect functionDef in nodes failed."); | ||||
GE_CHK_STATUS_RET(GraphToFunctionDef::BuildFunctionDef(sub_graph, nodes[0]->GetName(), func_def_lib.get(), | GE_CHK_STATUS_RET(GraphToFunctionDef::BuildFunctionDef(sub_graph, nodes[0]->GetName(), func_def_lib.get(), | ||||
node_def.get(), input_anchors, output_anchors), | node_def.get(), input_anchors, output_anchors), | ||||
@@ -226,7 +229,10 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) { | |||||
GELOGE(PARAM_INVALID, "Serialize func_def to string failed."); | GELOGE(PARAM_INVALID, "Serialize func_def to string failed."); | ||||
return PARAM_INVALID); | return PARAM_INVALID); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(nodes.size() == 0, return PARAM_INVALID, "nodes is empty."); | |||||
if (nodes.size() == 0) { | |||||
GELOGE(FAILED, "nodes is empty."); | |||||
return PARAM_INVALID; | |||||
} | |||||
std::string fusion_op_name; | std::string fusion_op_name; | ||||
for (auto node : nodes) { | for (auto node : nodes) { | ||||
@@ -250,10 +256,10 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) { | |||||
(void)AttrUtils::SetZeroCopyBytes( | (void)AttrUtils::SetZeroCopyBytes( | ||||
fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, | fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, | ||||
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(funcdefStr.data()), funcdefStr.length())); | |||||
Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(funcdefStr.data()), funcdefStr.length())); | |||||
(void)AttrUtils::SetZeroCopyBytes( | (void)AttrUtils::SetZeroCopyBytes( | ||||
fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | ||||
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length())); | |||||
Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(nodefStr.data()), nodefStr.length())); | |||||
(void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, ge::GetParserContext().type); | (void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, ge::GetParserContext().type); | ||||
@@ -284,6 +290,7 @@ Status ParserGraphOptimizer::InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
OpDescPtr op_def = node->GetOpDesc(); | OpDescPtr op_def = node->GetOpDesc(); | ||||
NodePtr new_node = sub_graph->AddNode(op_def); | NodePtr new_node = sub_graph->AddNode(op_def); | ||||
GE_CHECK_NOTNULL(new_node); | |||||
node_map[node->GetName()] = new_node; | node_map[node->GetName()] = new_node; | ||||
// Input | // Input | ||||
@@ -381,7 +388,8 @@ Status ParserGraphOptimizer::RebuildOutputAnchors(vector<ge::OutDataAnchorPtr> & | |||||
GE_CHK_BOOL_EXEC(fusion_op_desc->AddOutputDesc(src_out_desc) == ge::GRAPH_SUCCESS, return FAILED); | GE_CHK_BOOL_EXEC(fusion_op_desc->AddOutputDesc(src_out_desc) == ge::GRAPH_SUCCESS, return FAILED); | ||||
ge::DataType data_type = src_out_desc.GetDataType(); | ge::DataType data_type = src_out_desc.GetDataType(); | ||||
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find((int32_t)data_type); | |||||
const std::map<int32_t, int32_t>::const_iterator iter = | |||||
GE_TENSORFLOW_DATA_TYPE_MAP.find(static_cast<int32_t>(data_type)); | |||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | ||||
REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported", | REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported", | ||||
@@ -390,7 +398,7 @@ Status ParserGraphOptimizer::RebuildOutputAnchors(vector<ge::OutDataAnchorPtr> & | |||||
return PARAM_INVALID); | return PARAM_INVALID); | ||||
int32_t dtype = iter->second; | int32_t dtype = iter->second; | ||||
output_list.push_back((int64_t)dtype); | |||||
output_list.push_back(static_cast<int64_t>(dtype)); | |||||
GELOGI("FUNCDEF: output_list push_back %d.", dtype); | GELOGI("FUNCDEF: output_list push_back %d.", dtype); | ||||
} | } | ||||
GE_IF_BOOL_EXEC(!output_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_OUT_DATATYPE, output_list)); | GE_IF_BOOL_EXEC(!output_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_OUT_DATATYPE, output_list)); | ||||
@@ -410,14 +418,15 @@ Status ParserGraphOptimizer::RebuildInputAnchors(vector<ge::InDataAnchorPtr> &in | |||||
auto tensorDescPtr = dst_node->GetOpDesc()->GetInputDescPtr(in_anchor->GetIdx()); | auto tensorDescPtr = dst_node->GetOpDesc()->GetInputDescPtr(in_anchor->GetIdx()); | ||||
GE_CHECK_NOTNULL_EXEC(tensorDescPtr, return domi::FAILED); | GE_CHECK_NOTNULL_EXEC(tensorDescPtr, return domi::FAILED); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((fusion_op_desc->AddInputDesc(*tensorDescPtr)) != GRAPH_SUCCESS, | |||||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||||
fusion_op_desc->GetName().c_str(), | |||||
fusion_op_desc->GetType().c_str()); | |||||
return FAILED, | |||||
"Add fusion_op_desc AddInputDesc failed"); | |||||
if (fusion_op_desc->AddInputDesc(*tensorDescPtr) != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||||
fusion_op_desc->GetName().c_str(), fusion_op_desc->GetType().c_str()); | |||||
GELOGE(FAILED, "Add fusion_op_desc AddInputDesc failed"); | |||||
return FAILED; | |||||
} | |||||
ge::DataType data_type = tensorDescPtr->GetDataType(); | ge::DataType data_type = tensorDescPtr->GetDataType(); | ||||
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find((int32_t)data_type); | |||||
const std::map<int32_t, int32_t>::const_iterator iter = | |||||
GE_TENSORFLOW_DATA_TYPE_MAP.find(static_cast<int32_t>(data_type)); | |||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | ||||
REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported", | REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported", | ||||
@@ -426,7 +435,7 @@ Status ParserGraphOptimizer::RebuildInputAnchors(vector<ge::InDataAnchorPtr> &in | |||||
return PARAM_INVALID); | return PARAM_INVALID); | ||||
int32_t dtype = iter->second; | int32_t dtype = iter->second; | ||||
input_list.push_back((int64_t)dtype); | |||||
input_list.push_back(static_cast<int64_t>(dtype)); | |||||
GELOGI("FUNCDEF: input_list push_back %d.", dtype); | GELOGI("FUNCDEF: input_list push_back %d.", dtype); | ||||
} | } | ||||
GE_IF_BOOL_EXEC(!input_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_IN_DATATYPE, input_list)); | GE_IF_BOOL_EXEC(!input_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_IN_DATATYPE, input_list)); | ||||
@@ -440,6 +449,7 @@ Status ParserGraphOptimizer::RebuildFusionNode(vector<ge::InDataAnchorPtr> &inpu | |||||
vector<ge::InControlAnchorPtr> &input_control_anchors, | vector<ge::InControlAnchorPtr> &input_control_anchors, | ||||
vector<ge::OutControlAnchorPtr> &output_control_anchors, | vector<ge::OutControlAnchorPtr> &output_control_anchors, | ||||
ge::NodePtr fusion_node) { | ge::NodePtr fusion_node) { | ||||
GE_CHECK_NOTNULL(fusion_node); | |||||
int32_t src_index = 0; | int32_t src_index = 0; | ||||
for (auto out_anchor : output_anchors) { | for (auto out_anchor : output_anchors) { | ||||
@@ -32,7 +32,7 @@ const char *const kSerializeFormat = "serialize_format"; | |||||
Status ParseParams(const Message *op_src, ArgOpOperator *const op) { | Status ParseParams(const Message *op_src, ArgOpOperator *const op) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||||
const domi::tensorflow::NodeDef *node = reinterpret_cast<const domi::tensorflow::NodeDef *>(op_src); | |||||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | ||||
domi::tensorflow::AttrValue output_attr_value; | domi::tensorflow::AttrValue output_attr_value; | ||||
if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { | if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { | ||||
@@ -44,7 +44,7 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge | |||||
GELOGE(PARAM_INVALID, "Op src is null"); | GELOGE(PARAM_INVALID, "Op src is null"); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||||
const domi::tensorflow::NodeDef *node = PtrToPtr<const Message, const domi::tensorflow::NodeDef>(op_src); | |||||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | ||||
if (op_dest == nullptr) { | if (op_dest == nullptr) { | ||||
REPORT_INNER_ERROR("E19999", "Param op_dest is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param op_dest is nullptr, check invalid"); | ||||
@@ -109,8 +109,11 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
} | } | ||||
const auto out_desc = op_dest->MutableOutputDesc(0); | |||||
GE_CHECK_NOTNULL(out_desc); | |||||
out_desc->SetDataType(out_type); | |||||
std::shared_ptr<NodeDef> pkg_node = ge::parser::MakeShared<NodeDef>(); | |||||
std::shared_ptr<domi::tensorflow::NodeDef> pkg_node = ge::parser::MakeShared<domi::tensorflow::NodeDef>(); | |||||
GE_CHECK_NOTNULL(pkg_node); | GE_CHECK_NOTNULL(pkg_node); | ||||
pkg_node->CopyFrom(*node); | pkg_node->CopyFrom(*node); | ||||
@@ -130,7 +133,7 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge | |||||
(void)AttrUtils::SetZeroCopyBytes( | (void)AttrUtils::SetZeroCopyBytes( | ||||
op_dest, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | op_dest, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | ||||
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(serialized_node.data()), serialized_node.length())); | |||||
Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(serialized_node.data()), serialized_node.length())); | |||||
GELOGI("node_def of %s is %s.", op_dest->GetName().c_str(), serialized_node.c_str()); | GELOGI("node_def of %s is %s.", op_dest->GetName().c_str(), serialized_node.c_str()); | ||||
} | } | ||||
@@ -27,7 +27,7 @@ using domi::ParseParamByOpFunc; | |||||
namespace ge { | namespace ge { | ||||
Status TensorFlowCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | Status TensorFlowCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
const NodeDef *node_src = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src); | |||||
const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||||
GE_CHECK_NOTNULL(node_src); | GE_CHECK_NOTNULL(node_src); | ||||
GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); | ||||
GE_CHECK_NOTNULL(op_dest); | GE_CHECK_NOTNULL(op_dest); | ||||
@@ -93,7 +93,7 @@ Status TensorFlowDataParser::ParseInputFromModel(const Message *op_src, const ge | |||||
GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_SHAPE), | GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_SHAPE), | ||||
"check Attr %s failed", TENSORFLOW_ATTR_SHAPE.c_str()); | "check Attr %s failed", TENSORFLOW_ATTR_SHAPE.c_str()); | ||||
const TensorShapeProto &data_shape = attr_value.shape(); | |||||
const domi::tensorflow::TensorShapeProto &data_shape = attr_value.shape(); | |||||
for (auto i = 0; i < data_shape.dim_size(); i++) { | for (auto i = 0; i < data_shape.dim_size(); i++) { | ||||
model_input_dims_v.push_back(data_shape.dim(i).size()); | model_input_dims_v.push_back(data_shape.dim(i).size()); | ||||
} | } | ||||
@@ -110,7 +110,7 @@ Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge: | |||||
std::string name = op_def->GetName(); | std::string name = op_def->GetName(); | ||||
if (input_dims.count(name) == 0) { | if (input_dims.count(name) == 0) { | ||||
GELOGI("input shape of node %s is not designated ,need parse from model", name.c_str()); | GELOGI("input shape of node %s is not designated ,need parse from model", name.c_str()); | ||||
for (uint32_t i = 0; i < model_input_dims_v.size(); i++) { | |||||
for (size_t i = 0; i < model_input_dims_v.size(); ++i) { | |||||
user_input_dims_v.push_back(model_input_dims_v[i]); | user_input_dims_v.push_back(model_input_dims_v[i]); | ||||
} | } | ||||
@@ -138,7 +138,7 @@ Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge: | |||||
} | } | ||||
Status TensorFlowDataParser::CheckInputShape(const std::string &name) { | Status TensorFlowDataParser::CheckInputShape(const std::string &name) { | ||||
for (uint32_t i = 0; i < user_input_dims_v.size(); i++) { | |||||
for (size_t i = 0; i < user_input_dims_v.size(); ++i) { | |||||
// if input_shape has some placeholders, user should designate them. | // if input_shape has some placeholders, user should designate them. | ||||
// dim i = 0, means empty tensor. | // dim i = 0, means empty tensor. | ||||
// dim i = -1 or -2, means unknown shape. | // dim i = -1 or -2, means unknown shape. | ||||
@@ -31,7 +31,7 @@ Status TensorFlowEnterParser::ParseParams(const Message *op_src, ge::OpDescPtr & | |||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
const std::string name = op_desc->GetName(); | const std::string name = op_desc->GetName(); | ||||
const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||||
const domi::tensorflow::NodeDef *node = PtrToPtr<const Message, const domi::tensorflow::NodeDef>(op_src); | |||||
domi::tensorflow::AttrValue str_attr; | domi::tensorflow::AttrValue str_attr; | ||||
if (!TensorFlowUtil::FindAttrValue(node, ENTER_ATTR_FRAME_NAME, str_attr)) { | if (!TensorFlowUtil::FindAttrValue(node, ENTER_ATTR_FRAME_NAME, str_attr)) { | ||||
REPORT_CALL_ERROR("E19999", "In NodeDef:%s attr:%s not exist, check invalid", | REPORT_CALL_ERROR("E19999", "In NodeDef:%s attr:%s not exist, check invalid", | ||||
@@ -38,7 +38,7 @@ node { | |||||
} | } | ||||
} | } | ||||
*/ | */ | ||||
domi::Status ParseParams(const NodeDef *node, FillOperator *op) { | |||||
domi::Status ParseParams(const domi::tensorflow::NodeDef *node, FillOperator *op) { | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
op->Name(node->name()); | op->Name(node->name()); | ||||
@@ -31,7 +31,7 @@ namespace ge { | |||||
Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { | Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||||
const domi::tensorflow::NodeDef *node = reinterpret_cast<const domi::tensorflow::NodeDef *>(op_src); | |||||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | ||||
string type = node->op(); | string type = node->op(); | ||||
@@ -64,7 +64,7 @@ Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { | |||||
GE_IF_BOOL_EXEC(((type == "_Retval") && (TensorFlowUtil::FindAttrValue(node, ATTR_NAME_INDEX, index_attr_value))), | GE_IF_BOOL_EXEC(((type == "_Retval") && (TensorFlowUtil::FindAttrValue(node, ATTR_NAME_INDEX, index_attr_value))), | ||||
op->Index(index_attr_value.i())); | op->Index(index_attr_value.i())); | ||||
NodeDef *pkg_node = new (std::nothrow) NodeDef(); | |||||
domi::tensorflow::NodeDef *pkg_node = new (std::nothrow) domi::tensorflow::NodeDef(); | |||||
GE_CHECK_NOTNULL(pkg_node); | GE_CHECK_NOTNULL(pkg_node); | ||||
pkg_node->CopyFrom(*node); | pkg_node->CopyFrom(*node); | ||||
@@ -44,7 +44,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowFusionOpParser : public TensorFlowOpParse | |||||
* @return SUCCESS Parsing success | * @return SUCCESS Parsing success | ||||
* @return FAILED Parsing failed | * @return FAILED Parsing failed | ||||
*/ | */ | ||||
virtual Status ParseParams(const std::vector<const NodeDef *> &v_input_const, ge::NodePtr &node) const; | |||||
virtual Status ParseParams(const std::vector<const NodeDef *> &v_input_const, ge::NodePtr &op_dest) const; | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -31,7 +31,7 @@ Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr & | |||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||||
const domi::tensorflow::NodeDef *node = PtrToPtr<const Message, const domi::tensorflow::NodeDef>(op_src); | |||||
domi::tensorflow::AttrValue attr_num; | domi::tensorflow::AttrValue attr_num; | ||||
if (!(TensorFlowUtil::FindAttrValue(node, ATTR_NAME_N, attr_num))) { | if (!(TensorFlowUtil::FindAttrValue(node, ATTR_NAME_N, attr_num))) { | ||||
GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_desc->GetName().c_str(), ATTR_NAME_N.c_str()); | GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_desc->GetName().c_str(), ATTR_NAME_N.c_str()); | ||||
@@ -26,7 +26,7 @@ using namespace ge::parser; | |||||
namespace ge { | namespace ge { | ||||
Status TensorFlowNoOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | Status TensorFlowNoOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | ||||
const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src); | |||||
const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | ||||
NoOpOperator op; | NoOpOperator op; | ||||
@@ -42,25 +42,6 @@ | |||||
#include "proto/tensorflow/graph.pb.h" | #include "proto/tensorflow/graph.pb.h" | ||||
#include "proto/tensorflow/node_def.pb.h" | #include "proto/tensorflow/node_def.pb.h" | ||||
using domi::tensorflow::NodeDef; | |||||
using domi::tensorflow::TensorProto; | |||||
using google::protobuf::int32; | |||||
using google::protobuf::int64; | |||||
using google::protobuf::Message; | |||||
using std::string; | |||||
using std::vector; | |||||
using Status = domi::Status; | |||||
using domi::tensorflow::AttrValue; | |||||
using domi::tensorflow::DataType; | |||||
using domi::tensorflow::DT_BOOL; | |||||
using domi::tensorflow::DT_FLOAT; | |||||
using domi::tensorflow::DT_INT32; | |||||
using domi::tensorflow::DT_INT64; | |||||
using domi::tensorflow::DT_INVALID; | |||||
using domi::tensorflow::TensorShapeProto; | |||||
using domi::tensorflow::TensorShapeProto_Dim; | |||||
namespace ge { | namespace ge { | ||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -40,7 +40,6 @@ | |||||
#include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
#include "parser/common/parser_fp16_t.h" | #include "parser/common/parser_fp16_t.h" | ||||
#include "parser/common/pass_manager.h" | #include "parser/common/pass_manager.h" | ||||
#include "parser/common/pre_checker.h" | |||||
#include "parser/common/prototype_pass_manager.h" | #include "parser/common/prototype_pass_manager.h" | ||||
#include "parser/common/thread_pool.h" | #include "parser/common/thread_pool.h" | ||||
#include "parser/common/parser_utils.h" | #include "parser/common/parser_utils.h" | ||||
@@ -54,6 +53,7 @@ | |||||
#include "register/register_utils.h" | #include "register/register_utils.h" | ||||
#include "register/scope/scope_pass_registry_impl.h" | #include "register/scope/scope_pass_registry_impl.h" | ||||
#include "parser/common/auto_mapping_subgraph_io_index_func.h" | #include "parser/common/auto_mapping_subgraph_io_index_func.h" | ||||
#include "graph/def_types.h" | |||||
using ge::OpParserFactory; | using ge::OpParserFactory; | ||||
using ge::Pb2Json; | using ge::Pb2Json; | ||||
@@ -507,8 +507,11 @@ Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def, | |||||
ge::NodePtr node = nullptr; | ge::NodePtr node = nullptr; | ||||
node = graph->AddNode(op_desc); | node = graph->AddNode(op_desc); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node == nullptr), DeleteFuisonNodeDef(); return INTERNAL_ERROR, "add node failed."); | |||||
if (node == nullptr) { | |||||
DeleteFuisonNodeDef(); | |||||
GELOGE(FAILED, "add node failed."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
node_map_[node_name] = node; | node_map_[node_name] = node; | ||||
@@ -545,7 +548,11 @@ Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def, | |||||
// checkout op input number with IR | // checkout op input number with IR | ||||
GE_RETURN_IF_ERROR(CheckoutInputNum(op, node_def)); | GE_RETURN_IF_ERROR(CheckoutInputNum(op, node_def)); | ||||
ge::NodePtr node = graph->AddNode(op); | ge::NodePtr node = graph->AddNode(op); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node == nullptr), DeleteFuisonNodeDef(); return INTERNAL_ERROR, "add node failed."); | |||||
if (node == nullptr) { | |||||
DeleteFuisonNodeDef(); | |||||
GELOGE(FAILED, "add node failed."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
node_map_[node_name] = node; | node_map_[node_name] = node; | ||||
@@ -794,22 +801,24 @@ Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) { | |||||
GE_CHECK_NOTNULL(out_archor_ptr); | GE_CHECK_NOTNULL(out_archor_ptr); | ||||
ge::InDataAnchorPtr in_archor_ptr = dest->GetInDataAnchor(outputpair.second); | ge::InDataAnchorPtr in_archor_ptr = dest->GetInDataAnchor(outputpair.second); | ||||
GE_CHECK_NOTNULL(in_archor_ptr); | GE_CHECK_NOTNULL(in_archor_ptr); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||||
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed", | |||||
src->GetName().c_str(), dest->GetName().c_str()); | |||||
return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", | |||||
src->GetName().c_str(), dest->GetName().c_str()); | |||||
if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) { | |||||
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed", | |||||
src->GetName().c_str(), dest->GetName().c_str()); | |||||
GELOGE(FAILED, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), dest->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} else { | } else { | ||||
GELOGD("Start add contorl edge: from %s to %s.", src->GetName().c_str(), dest->GetName().c_str()); | GELOGD("Start add contorl edge: from %s to %s.", src->GetName().c_str(), dest->GetName().c_str()); | ||||
ge::InControlAnchorPtr in_archor_ptr = dest->GetInControlAnchor(); | ge::InControlAnchorPtr in_archor_ptr = dest->GetInControlAnchor(); | ||||
GE_CHECK_NOTNULL(in_archor_ptr); | GE_CHECK_NOTNULL(in_archor_ptr); | ||||
ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor(); | ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor(); | ||||
GE_CHECK_NOTNULL(out_archor_ptr); | GE_CHECK_NOTNULL(out_archor_ptr); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||||
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed", | |||||
src->GetName().c_str(), dest->GetName().c_str()); | |||||
return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", | |||||
src->GetName().c_str(), dest->GetName().c_str()); | |||||
if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) { | |||||
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed", | |||||
src->GetName().c_str(), dest->GetName().c_str()); | |||||
GELOGE(FAILED, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), dest->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} | } | ||||
} | } | ||||
dest_input_map.erase(input_iter); | dest_input_map.erase(input_iter); | ||||
@@ -921,10 +930,11 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co | |||||
} | } | ||||
std::map<std::string, std::string>::const_iterator iterator = parser->adaptedOpTypeMap_.find(node_name); | std::map<std::string, std::string>::const_iterator iterator = parser->adaptedOpTypeMap_.find(node_name); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
iterator == parser->adaptedOpTypeMap_.end(), | |||||
REPORT_INNER_ERROR("E19999", "get adapted op type failed, node name = %s", node_name.c_str()); | |||||
return FAILED, "get adapted op type failed, node name = %s", node_name.c_str()); | |||||
if (iterator == parser->adaptedOpTypeMap_.cend()) { | |||||
REPORT_INNER_ERROR("E19999", "get adapted op type failed, node name = %s", node_name.c_str()); | |||||
GELOGE(FAILED, "get adapted op type failed, node name = %s", node_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
string op_type = iterator->second; | string op_type = iterator->second; | ||||
// Log printing for determining operator type | // Log printing for determining operator type | ||||
@@ -1017,10 +1027,12 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co | |||||
node = graph->AddNode(op); | node = graph->AddNode(op); | ||||
} | } | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
(node == nullptr), REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op->GetName().c_str(), | |||||
op->GetType().c_str(), graph->GetName().c_str()); | |||||
return INTERNAL_ERROR, "add node failed."); | |||||
if (node == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op->GetName().c_str(), | |||||
op->GetType().c_str(), graph->GetName().c_str()); | |||||
GELOGE(FAILED, "add node failed."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
if (needFusion) { | if (needFusion) { | ||||
shared_ptr<OpParser> fusion_op_parser = factory->CreateFusionOpParser(op_type); | shared_ptr<OpParser> fusion_op_parser = factory->CreateFusionOpParser(op_type); | ||||
@@ -1116,10 +1128,11 @@ Status TensorFlowModelParser::AddNodeToGraphAndMarkFormat(ge::ComputeGraphPtr &g | |||||
for (size_t j = 0; j < op_node_list_size; j++) { | for (size_t j = 0; j < op_node_list_size; j++) { | ||||
const string op_node_name = op_node_name_list[j]; | const string op_node_name = op_node_name_list[j]; | ||||
auto iterator = node_map_.find(op_node_name); | auto iterator = node_map_.find(op_node_name); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
(iterator == node_map_.end()), | |||||
REPORT_INNER_ERROR("E19999", "node:%s can't find in node_map_, check invalid", op_node_name.c_str()); | |||||
return INTERNAL_ERROR, "add node failed."); | |||||
if (iterator == node_map_.end()) { | |||||
REPORT_INNER_ERROR("E19999", "node:%s can't find in node_map_, check invalid", op_node_name.c_str()); | |||||
GELOGE(FAILED, "add node failed."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
GE_CHECK_NOTNULL(iterator->second); | GE_CHECK_NOTNULL(iterator->second); | ||||
GE_CHK_STATUS_RET(iterator->second->SetOwnerComputeGraph(graph), "set owner compute graph failed"); | GE_CHK_STATUS_RET(iterator->second->SetOwnerComputeGraph(graph), "set owner compute graph failed"); | ||||
graph->AddNode(iterator->second); | graph->AddNode(iterator->second); | ||||
@@ -1178,15 +1191,22 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g | |||||
domi::tensorflow::GraphDef OriDef; | domi::tensorflow::GraphDef OriDef; | ||||
bool read = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &OriDef); | bool read = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &OriDef); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, REPORT_INNER_ERROR("E19999", "read graph proto from binary failed"); | |||||
return INTERNAL_ERROR, "read_proto_from_binary failed."); | |||||
if (!read) { | |||||
REPORT_INNER_ERROR("E19999", "read graph proto from binary failed"); | |||||
GELOGE(FAILED, "read_proto_from_binary failed."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
domi::tensorflow::GraphDef graph_def; | domi::tensorflow::GraphDef graph_def; | ||||
if (ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) { | |||||
const bool is_empty_input = GetParserContext().input_dims.empty() && GetParserContext().out_nodes_map.empty(); | |||||
if (is_empty_input) { | |||||
graph_def = OriDef; | graph_def = OriDef; | ||||
} else { | } else { | ||||
GELOGI("Before Trim, the Graph Node size is:%d", OriDef.node_size()); | GELOGI("Before Trim, the Graph Node size is:%d", OriDef.node_size()); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(TrimGraph(OriDef, &graph_def), return INTERNAL_ERROR, "Trim Graph fail."); | |||||
if (static_cast<bool>(TrimGraph(OriDef, &graph_def))) { | |||||
GELOGE(FAILED, "Trim Graph fail."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size()); | GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size()); | ||||
} | } | ||||
@@ -1236,9 +1256,6 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g | |||||
// This function call affects the return value of prechecker::instance().Haserror() | // This function call affects the return value of prechecker::instance().Haserror() | ||||
GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list)); | GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list)); | ||||
// Check the input validity of the node, the input attribute must have a corresponding node | |||||
GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def)); | |||||
// Building input and input relationships for all OP nodes | // Building input and input relationships for all OP nodes | ||||
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def)); | GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def)); | ||||
GELOGD("[TF ParseFromMemory] get op nodes context from graph success"); | GELOGD("[TF ParseFromMemory] get op nodes context from graph success"); | ||||
@@ -1269,9 +1286,12 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
const string &node_op = node_def->op(); | const string &node_op = node_def->op(); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((tensorflow_op_map.find(node_op) == tensorflow_op_map.end()), DeleteFuisonNodeDef(); | |||||
REPORT_INNER_ERROR("E19999", "Op type %s unsupport", node_op.c_str()); | |||||
return INTERNAL_ERROR, "Unsupport op type %s", node_op.c_str()); | |||||
if (tensorflow_op_map.find(node_op) == tensorflow_op_map.cend()) { | |||||
DeleteFuisonNodeDef(); | |||||
REPORT_INNER_ERROR("E19999", "Op type %s unsupport", node_op.c_str()); | |||||
GELOGE(FAILED, "Unsupport op type %s", node_op.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
ret = AddNode(node_def, graph, scope_graph); | ret = AddNode(node_def, graph, scope_graph); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
@@ -1339,7 +1359,10 @@ Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr | |||||
// Store objects parsed from pb files | // Store objects parsed from pb files | ||||
domi::tensorflow::GraphDef ori_def; | domi::tensorflow::GraphDef ori_def; | ||||
bool read = ge::parser::ReadProtoFromBinaryFile(model_path, &ori_def); | bool read = ge::parser::ReadProtoFromBinaryFile(model_path, &ori_def); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, return INTERNAL_ERROR, "read_proto_from_binary failed."); | |||||
if (!read) { | |||||
GELOGE(FAILED, "read tensorflow file failed when the inupt param value of --framework is 3."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
// Trim graph by user input and output. | // Trim graph by user input and output. | ||||
domi::tensorflow::GraphDef graph_def; | domi::tensorflow::GraphDef graph_def; | ||||
@@ -1347,7 +1370,10 @@ Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr | |||||
graph_def = ori_def; | graph_def = ori_def; | ||||
} else { | } else { | ||||
GELOGI("Before Trim, the Graph Node size is:%d", ori_def.node_size()); | GELOGI("Before Trim, the Graph Node size is:%d", ori_def.node_size()); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(TrimGraph(ori_def, &graph_def), return INTERNAL_ERROR, "Trim Graph fail."); | |||||
if (static_cast<bool>(TrimGraph(ori_def, &graph_def))) { | |||||
GELOGE(FAILED, "Trim Graph fail."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
GELOGI("After Trim, The graph_def.node size is:%d", graph_def.node_size()); | GELOGI("After Trim, The graph_def.node size is:%d", graph_def.node_size()); | ||||
} | } | ||||
@@ -1375,7 +1401,7 @@ Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr | |||||
} | } | ||||
} | } | ||||
std::map<std::string, domi::tensorflow::GraphDef>::const_iterator | |||||
const std::map<std::string, domi::tensorflow::GraphDef>::const_iterator | |||||
iter = function_name_to_graphdef.find(arg.function_name); | iter = function_name_to_graphdef.find(arg.function_name); | ||||
if (iter == function_name_to_graphdef.end()) { | if (iter == function_name_to_graphdef.end()) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E12013", {"functionname"}, {arg.function_name}); | ErrorManager::GetInstance().ATCReportErrMessage("E12013", {"functionname"}, {arg.function_name}); | ||||
@@ -1415,7 +1441,8 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||||
GE_CHECK_NOTNULL(proto); | GE_CHECK_NOTNULL(proto); | ||||
GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
const domi::tensorflow::GraphDef *ori_graph = reinterpret_cast<const domi::tensorflow::GraphDef *>(proto); | |||||
const domi::tensorflow::GraphDef *ori_graph = | |||||
ge::PtrToPtr<google::protobuf::Message, domi::tensorflow::GraphDef>(proto); | |||||
// Make a copy for operation without modifying the original graph def. | // Make a copy for operation without modifying the original graph def. | ||||
domi::tensorflow::GraphDef graph_def = *ori_graph; | domi::tensorflow::GraphDef graph_def = *ori_graph; | ||||
@@ -1441,12 +1468,16 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||||
GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()), | GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()), | ||||
"Add node_def to PreChecker failed, node name: %s.", node.name().c_str()); | "Add node_def to PreChecker failed, node name: %s.", node.name().c_str()); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckName(&node) != SUCCESS, return FAILED, | |||||
"Check op[%s] failed, name repeat in tensorflow pb file.", node.name().c_str()); | |||||
GE_CHK_BOOL_EXEC_NOLOG( | |||||
node.op() == TENSORFLOWF_NODE_OP_IDENTITY, | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckType(&node, true) != SUCCESS, return FAILED, | |||||
"Check op[%s]'s optype failed, type is not supported.", node.name().c_str());) | |||||
if (PreChecker::Instance().CheckName(&node) != SUCCESS) { | |||||
GELOGE(FAILED, "Check op[%s] failed, name repeat in tensorflow pb file.", node.name().c_str()); | |||||
return FAILED; | |||||
} | |||||
if (node.op() != TENSORFLOWF_NODE_OP_IDENTITY) { | |||||
if (PreChecker::Instance().CheckType(&node, true) != SUCCESS) { | |||||
GELOGE(FAILED, "Check op[%s]'s optype failed, type is not supported.", node.name().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
} | } | ||||
bool has_error = false; | bool has_error = false; | ||||
@@ -1471,10 +1502,6 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||||
// This function call affects the return value of prechecker::instance().Haserror() | // This function call affects the return value of prechecker::instance().Haserror() | ||||
GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list)); | GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list)); | ||||
// Check the input validity of the node, the input attribute must have a corresponding node | |||||
GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def)); | |||||
GELOGD("[TF Parse] check graph success"); | |||||
// Building input and input relationships for all OP nodes | // Building input and input relationships for all OP nodes | ||||
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def)); | GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def)); | ||||
GELOGD("[TF Parse] get op nodes context from graph success"); | GELOGD("[TF Parse] get op nodes context from graph success"); | ||||
@@ -1547,37 +1574,6 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status TensorFlowModelParser::CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const { | |||||
// Number of data nodes | |||||
uint32_t data_node_count = 0; | |||||
for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) { | |||||
// Check that all input is valid | |||||
for (const string &node_name : node_def.input()) { | |||||
string tmp_node_name; | |||||
GE_RETURN_IF_ERROR(CheckInputNodeName(node_name, &tmp_node_name, nullptr, nullptr)); | |||||
if (nodedef_map_.find(tmp_node_name) == nodedef_map_.end()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E12009", {"opname", "inputopname"}, | |||||
{node_def.name(), node_name}); | |||||
GELOGE(INTERNAL_ERROR, "Op[%s]'s input op[%s] is not exist in the graph_def.", node_def.name().c_str(), | |||||
node_name.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
} | |||||
if (node_def.op() == TENSORFLOWF_NODE_OP_PLACEHOLDER || node_def.op() == ge::parser::ARG) { | |||||
data_node_count++; | |||||
} | |||||
} | |||||
if (data_node_count == 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E12010"); | |||||
GELOGE(INTERNAL_ERROR, "Model has no Placeholder node."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status TensorFlowModelParser::GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def) { | Status TensorFlowModelParser::GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def) { | ||||
// Build the input relationship first | // Build the input relationship first | ||||
for (auto &iter : op_node_context_map_) { | for (auto &iter : op_node_context_map_) { | ||||
@@ -1868,7 +1864,7 @@ Status TensorFlowModelParser::UpdateAllNodeOpContext(shared_ptr<ge::ScopeGraph> | |||||
ge::ScopeFusionOpInfo info; | ge::ScopeFusionOpInfo info; | ||||
if (IsFusionOpChild(op_node_name, &info) && nodedef_map_[op_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) { | if (IsFusionOpChild(op_node_name, &info) && nodedef_map_[op_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) { | ||||
// This node is a fusion operator | // This node is a fusion operator | ||||
std::map<std::string, OpNodeContext>::const_iterator | |||||
const std::map<std::string, OpNodeContext>::const_iterator | |||||
fusion_iter = tmp_fusion_op_node_context_map.find(info.fusion_node_name); | fusion_iter = tmp_fusion_op_node_context_map.find(info.fusion_node_name); | ||||
if (fusion_iter == tmp_fusion_op_node_context_map.end()) { | if (fusion_iter == tmp_fusion_op_node_context_map.end()) { | ||||
OpNodeContext op_node_context; | OpNodeContext op_node_context; | ||||
@@ -2108,10 +2104,10 @@ Status TensorFlowModelParser::NormalizeInputOrOutputMap( | |||||
std::set<std::string> compare_set; | std::set<std::string> compare_set; | ||||
for (auto &pair : pairs) { | for (auto &pair : pairs) { | ||||
bool is_fusion_child = (fusion_op_children_.find(node_name) != fusion_op_children_.end()) || | |||||
(fusion_op_children_.find(iter->first) != fusion_op_children_.end()); | |||||
bool is_fusion_op = (fusion_op_type_map_.find(node_name) != fusion_op_type_map_.end()) || | |||||
(fusion_op_type_map_.find(iter->first) != fusion_op_type_map_.end()); | |||||
bool is_fusion_child = (fusion_op_children_.find(node_name) != fusion_op_children_.cend()) || | |||||
(fusion_op_children_.find(iter->first) != fusion_op_children_.cend()); | |||||
bool is_fusion_op = (fusion_op_type_map_.find(node_name) != fusion_op_type_map_.cend()) || | |||||
(fusion_op_type_map_.find(iter->first) != fusion_op_type_map_.cend()); | |||||
if (((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) && | if (((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) && | ||||
(is_fusion_child || is_fusion_op)) { | (is_fusion_child || is_fusion_op)) { | ||||
// The edge will be cut off at the back, ignoring | // The edge will be cut off at the back, ignoring | ||||
@@ -2119,7 +2115,7 @@ Status TensorFlowModelParser::NormalizeInputOrOutputMap( | |||||
} | } | ||||
string name = to_string(pair.first) + ":" + to_string(pair.second); | string name = to_string(pair.first) + ":" + to_string(pair.second); | ||||
std::set<std::string>::const_iterator compare_iter = compare_set.find(name); | |||||
const std::set<std::string>::const_iterator compare_iter = compare_set.find(name); | |||||
if (compare_iter != compare_set.end()) { | if (compare_iter != compare_set.end()) { | ||||
// pair<from,to> repeat, ignore | // pair<from,to> repeat, ignore | ||||
continue; | continue; | ||||
@@ -2158,7 +2154,7 @@ void TensorFlowModelParser::SaveEdgesControlInfo(const string &node_name, const | |||||
} | } | ||||
void TensorFlowModelParser::UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo &info) { | void TensorFlowModelParser::UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo &info) { | ||||
std::map<std::string, std::vector<int32_t>>::const_iterator iter = edges_control_map.find(info.node_name); | |||||
const std::map<std::string, std::vector<int32_t>>::const_iterator iter = edges_control_map.find(info.node_name); | |||||
if (iter != edges_control_map.end()) { | if (iter != edges_control_map.end()) { | ||||
// Delete the original fusion operator node information and add the fusion operator control edge information | // Delete the original fusion operator node information and add the fusion operator control edge information | ||||
edges_control_map.erase(iter); | edges_control_map.erase(iter); | ||||
@@ -2228,7 +2224,8 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||||
GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
ge::GetParserContext().train_flag = true; | ge::GetParserContext().train_flag = true; | ||||
const domi::tensorflow::GraphDef *graph_def_in = reinterpret_cast<const domi::tensorflow::GraphDef *>(proto); | |||||
const domi::tensorflow::GraphDef *graph_def_in = | |||||
ge::PtrToPtr<google::protobuf::Message, domi::tensorflow::GraphDef>(proto); | |||||
// Make a copy for operation without modifying the original graph def. | // Make a copy for operation without modifying the original graph def. | ||||
domi::tensorflow::GraphDef graph_def_operation = *graph_def_in; | domi::tensorflow::GraphDef graph_def_operation = *graph_def_in; | ||||
domi::tensorflow::GraphDef *graph_def = &graph_def_operation; | domi::tensorflow::GraphDef *graph_def = &graph_def_operation; | ||||
@@ -2415,7 +2412,7 @@ Status TensorFlowModelParser::ParseProto(const std::string &serialized_proto, ge | |||||
GELOGE(FAILED, "Proto object GraphDef parse serialized proto failed"); | GELOGE(FAILED, "Proto object GraphDef parse serialized proto failed"); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
return ParseProto(reinterpret_cast<const google::protobuf::Message *>(&graph_def), graph); | |||||
return ParseProto(ge::PtrToPtr<domi::tensorflow::GraphDef, const google::protobuf::Message>(&graph_def), graph); | |||||
} | } | ||||
Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback, | Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback, | ||||
@@ -2472,95 +2469,18 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_pro | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
// For the identity operator whose output is "_retval", optimize it. | |||||
Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, | |||||
const string &curr_node_name, bool &clear_input_flag) { | |||||
auto context_iter = op_node_context_map_.find(curr_node_name); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
(context_iter == op_node_context_map_.end()), | |||||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str()); | |||||
return INTERNAL_ERROR, "Can't find op node context."); | |||||
OpNodeContext op_node_context = context_iter->second; | |||||
std::map<std::string, NodeDef *>::const_iterator node_def_iter = nodedef_map.find(curr_node_name); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
(node_def_iter == nodedef_map.end()), | |||||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in nodedef_map, check invalid", curr_node_name.c_str()); | |||||
return INTERNAL_ERROR, "Can't find nodedef"); | |||||
domi::tensorflow::NodeDef *curr_node_def = node_def_iter->second; | |||||
GE_CHECK_NOTNULL(curr_node_def); | |||||
bool has_out_retval = false; | |||||
// For the identity operator whose output is "_retval", optimize it | |||||
std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map; | |||||
for (auto output_iter = output_map.begin(); output_iter != output_map.end(); ++output_iter) { | |||||
const string &output_node_name = output_iter->first; | |||||
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; | |||||
GE_CHECK_NOTNULL(output_node_def); | |||||
if (output_node_def->op() == "_Retval") { | |||||
GELOGD("_Retval Identity need optimize."); | |||||
output_node_def->set_input(0, curr_node_def->input(0).c_str()); | |||||
has_out_retval = true; | |||||
GELOGD("op %s set input(0):%s.", output_node_def->name().c_str(), curr_node_def->input(0).c_str()); | |||||
} | |||||
} | |||||
// Deal with non _Retval output operator of Identity. | |||||
if (has_out_retval) { | |||||
for (auto output_iter = output_map.begin(); output_iter != output_map.end(); ++output_iter) { | |||||
const string &output_node_name = output_iter->first; | |||||
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; | |||||
GE_CHECK_NOTNULL(output_node_def); | |||||
GE_IF_BOOL_EXEC(output_node_def->op() == "_Retval", continue); | |||||
for (int k = 0; k < output_node_def->input_size(); ++k) { | |||||
GE_IF_BOOL_EXEC( | |||||
output_node_def->input(k) == curr_node_name, output_node_def->set_input(k, curr_node_def->input(0).c_str()); | |||||
GELOGD("%s op set input(%d):%s.", output_node_def->name().c_str(), k, curr_node_def->input(0).c_str());) | |||||
} | |||||
} | |||||
clear_input_flag = true; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status TensorFlowModelParser::GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def, | |||||
map<string, NodeDef *> &nodedef_map, | |||||
const vector<NodeDef *> &nodedef_to_optimize) { | |||||
GE_CHECK_NOTNULL(graph_def); | |||||
if (!nodedef_to_optimize.empty()) { | |||||
// Building input and input relationships for all OP nodes | |||||
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def)); | |||||
} else { | |||||
return SUCCESS; | |||||
} | |||||
for (auto &curr_node_def : nodedef_to_optimize) { | |||||
GE_CHECK_NOTNULL(curr_node_def); | |||||
bool clear_input_flag = false; | |||||
const string &curr_node_name = curr_node_def->name(); | |||||
GE_RETURN_IF_ERROR(OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag)); | |||||
if (clear_input_flag) { | |||||
curr_node_def->clear_input(); | |||||
} | |||||
} | |||||
GELOGI("GraphDefOptimizeIdentity success."); | |||||
return SUCCESS; | |||||
} | |||||
Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, | Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, | ||||
map<string, NodeDef *> &nodedef_map, | map<string, NodeDef *> &nodedef_map, | ||||
const std::pair<string, int> &input_data, | const std::pair<string, int> &input_data, | ||||
const std::vector<string> &control_list) { | const std::vector<string> &control_list) { | ||||
GE_CHECK_NOTNULL(curr_mode_def); | GE_CHECK_NOTNULL(curr_mode_def); | ||||
if (curr_mode_def == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "Param curr_mode_def is nullptr, check invalid"); | |||||
GELOGE(FAILED, "input param is nullptr."); | |||||
return PARAM_INVALID; | |||||
} | |||||
string curr_node_name = curr_mode_def->name(); | string curr_node_name = curr_mode_def->name(); | ||||
auto context_iter = op_node_context_map_.find(curr_node_name); | auto context_iter = op_node_context_map_.find(curr_node_name); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
(context_iter == op_node_context_map_.end()), | |||||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str()); | |||||
return INTERNAL_ERROR, "Can't find op node context."); | |||||
if (context_iter == op_node_context_map_.end()) { | |||||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str()); | |||||
GELOGE(FAILED, "Can't find op node context."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
OpNodeContext op_node_context = context_iter->second; | OpNodeContext op_node_context = context_iter->second; | ||||
std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map; | std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map; | ||||
@@ -2773,7 +2693,7 @@ struct DelTransposeInfo { | |||||
int inputIdx; | int inputIdx; | ||||
}; | }; | ||||
Status GetTransposeInfo(GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo, | |||||
Status GetTransposeInfo(domi::tensorflow::GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo, | |||||
std::map<std::string, DelTransposeInfo> &transposeInfo) { | std::map<std::string, DelTransposeInfo> &transposeInfo) { | ||||
GE_CHECK_NOTNULL(graph_def); | GE_CHECK_NOTNULL(graph_def); | ||||
for (int i = 0; i < graph_def->node_size(); ++i) { | for (int i = 0; i < graph_def->node_size(); ++i) { | ||||
@@ -2826,7 +2746,7 @@ Status EraseTransposeNode(std::map<std::string, std::string> &softmaxInfo, | |||||
itTranspose->second.node_def->input(0).c_str()); | itTranspose->second.node_def->input(0).c_str()); | ||||
itTranspose = transposeInfo.erase(itTranspose); | itTranspose = transposeInfo.erase(itTranspose); | ||||
} else { | } else { | ||||
itTranspose++; | |||||
++itTranspose; | |||||
} | } | ||||
} | } | ||||
@@ -2847,7 +2767,7 @@ void TensorFlowModelParser::OptimizeTranspose(std::map<std::string, DelTranspose | |||||
} | } | ||||
} | } | ||||
void TensorFlowModelParser::SoftmaxAddAttr(GraphDef *const graph_def) { | |||||
void TensorFlowModelParser::SoftmaxAddAttr(domi::tensorflow::GraphDef *const graph_def) { | |||||
// The caller guarantees that the pointer is not null | // The caller guarantees that the pointer is not null | ||||
for (int i = 0; i < graph_def->node_size(); ++i) { | for (int i = 0; i < graph_def->node_size(); ++i) { | ||||
auto node_def = graph_def->mutable_node(i); | auto node_def = graph_def->mutable_node(i); | ||||
@@ -2864,8 +2784,6 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph | |||||
GE_CHECK_NOTNULL(graph_def); | GE_CHECK_NOTNULL(graph_def); | ||||
map<string, NodeDef *> nodedef_map; | map<string, NodeDef *> nodedef_map; | ||||
vector<string> op_node_name_list; | vector<string> op_node_name_list; | ||||
// Save Identity and ReadVariableOp | |||||
vector<NodeDef *> identity_to_optimize; | |||||
// Save Snapshot | // Save Snapshot | ||||
vector<NodeDef *> snapshot_to_optimize; | vector<NodeDef *> snapshot_to_optimize; | ||||
@@ -2875,16 +2793,12 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph | |||||
const string &node_name = node_def->name(); | const string &node_name = node_def->name(); | ||||
Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list); | Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list); | ||||
GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); | GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); | ||||
if (node_def->op() == ge::parser::IDENTITY || node_def->op() == ge::parser::READVARIABLEOP) { | |||||
identity_to_optimize.push_back(node_def); | |||||
} else if (node_def->op() == ge::parser::SNAPSHOT) { | |||||
if (node_def->op() == ge::parser::SNAPSHOT) { | |||||
snapshot_to_optimize.push_back(node_def); | snapshot_to_optimize.push_back(node_def); | ||||
} | } | ||||
nodedef_map[node_name] = node_def; | nodedef_map[node_name] = node_def; | ||||
} | } | ||||
// Optimize for Identity/ReadVariableOp | |||||
GE_RETURN_IF_ERROR(GraphDefOptimizeIdentity(graph_def, nodedef_map, identity_to_optimize)); | |||||
// Optimize for Snapshot | // Optimize for Snapshot | ||||
GE_RETURN_IF_ERROR(GraphDefOptimizeSnapShot(graph_def, nodedef_map, snapshot_to_optimize)); | GE_RETURN_IF_ERROR(GraphDefOptimizeSnapShot(graph_def, nodedef_map, snapshot_to_optimize)); | ||||
@@ -3055,7 +2969,7 @@ Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node, | |||||
GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TENSOR) != SUCCESS, | GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TENSOR) != SUCCESS, | ||||
return FAILED); | return FAILED); | ||||
const TensorProto &tensor = attr_value.tensor(); | const TensorProto &tensor = attr_value.tensor(); | ||||
const TensorShapeProto &tensor_shape = tensor.tensor_shape(); | |||||
const domi::tensorflow::TensorShapeProto &tensor_shape = tensor.tensor_shape(); | |||||
GE_IF_BOOL_EXEC(tensor_shape.dim_size() != 1 || tensor_shape.dim(0).size() != parser::DIM_DEFAULT_SIZE, | GE_IF_BOOL_EXEC(tensor_shape.dim_size() != 1 || tensor_shape.dim(0).size() != parser::DIM_DEFAULT_SIZE, | ||||
return SUCCESS); | return SUCCESS); | ||||
GE_IF_BOOL_EXEC(tensor.tensor_content().empty(), return SUCCESS); | GE_IF_BOOL_EXEC(tensor.tensor_content().empty(), return SUCCESS); | ||||
@@ -3110,10 +3024,10 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef | |||||
std::set<string> next_inputs; | std::set<string> next_inputs; | ||||
for (const string ¤t_input : current_inputs) { | for (const string ¤t_input : current_inputs) { | ||||
delete_nodes.insert(current_input); | delete_nodes.insert(current_input); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input), | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||||
{"input_shape", current_input}); | |||||
return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); | |||||
GE_CHK_BOOL_EXEC(node_lookup.count(current_input) > 0U, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||||
{"input_shape", current_input}); | |||||
return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); | |||||
const NodeDef *current_node = node_lookup[current_input]; | const NodeDef *current_node = node_lookup[current_input]; | ||||
GE_CHECK_NOTNULL(current_node); | GE_CHECK_NOTNULL(current_node); | ||||
for (const string &input_name : current_node->input()) { | for (const string &input_name : current_node->input()) { | ||||
@@ -3128,7 +3042,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef | |||||
domi::tensorflow::GraphDef filtered_graph_def; | domi::tensorflow::GraphDef filtered_graph_def; | ||||
filtered_graph_def.mutable_node()->Clear(); | filtered_graph_def.mutable_node()->Clear(); | ||||
for (const NodeDef &node : input_graph_def.node()) { | for (const NodeDef &node : input_graph_def.node()) { | ||||
if (input_nodes.count(node.name())) { | |||||
if (static_cast<bool>(input_nodes.count(node.name()))) { | |||||
*(filtered_graph_def.mutable_node()->Add()) = node; | *(filtered_graph_def.mutable_node()->Add()) = node; | ||||
} | } | ||||
if (!delete_nodes.count(node.name())) { | if (!delete_nodes.count(node.name())) { | ||||
@@ -3137,12 +3051,12 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef | |||||
} | } | ||||
output_graph_def->Clear(); | output_graph_def->Clear(); | ||||
for (const NodeDef &node : filtered_graph_def.node()) { | for (const NodeDef &node : filtered_graph_def.node()) { | ||||
if (input_nodes.count(node.name())) { | |||||
if (static_cast<bool>(input_nodes.count(node.name()))) { | |||||
NodeDef placeholder_node = node; | NodeDef placeholder_node = node; | ||||
placeholder_node.clear_input(); | placeholder_node.clear_input(); | ||||
GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); | GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); | ||||
domi::tensorflow::AttrValue attr_value; | domi::tensorflow::AttrValue attr_value; | ||||
TensorShapeProto *data_shape = attr_value.mutable_shape(); | |||||
domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape(); | |||||
GE_CHECK_NOTNULL(data_shape); | GE_CHECK_NOTNULL(data_shape); | ||||
const ge::ParserContext &ctx = ge::GetParserContext(); | const ge::ParserContext &ctx = ge::GetParserContext(); | ||||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | ||||
@@ -3185,11 +3099,11 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef | |||||
std::set<string> next_inputs; | std::set<string> next_inputs; | ||||
for (const string ¤t_input : current_inputs) { | for (const string ¤t_input : current_inputs) { | ||||
required_nodes.insert(current_input); | required_nodes.insert(current_input); | ||||
GE_IF_BOOL_EXEC(input_nodes.count(current_input), continue); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input), | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||||
{"out_nodes", current_input}); | |||||
return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); | |||||
GE_IF_BOOL_EXEC(static_cast<bool>(input_nodes.count(current_input)), continue); | |||||
GE_CHK_BOOL_EXEC(node_lookup.count(current_input) > 0U, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||||
{"out_nodes", current_input}); | |||||
return FAILED, "op[%s] not found in graph.", current_input.c_str()); | |||||
const NodeDef *current_node = node_lookup[current_input]; | const NodeDef *current_node = node_lookup[current_input]; | ||||
GE_CHECK_NOTNULL(current_node); | GE_CHECK_NOTNULL(current_node); | ||||
for (const string &input_name : current_node->input()) { | for (const string &input_name : current_node->input()) { | ||||
@@ -3204,18 +3118,18 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef | |||||
domi::tensorflow::GraphDef filtered_graph_def; | domi::tensorflow::GraphDef filtered_graph_def; | ||||
filtered_graph_def.mutable_node()->Clear(); | filtered_graph_def.mutable_node()->Clear(); | ||||
for (const NodeDef &node : input_graph_def.node()) { | for (const NodeDef &node : input_graph_def.node()) { | ||||
if (required_nodes.count(node.name())) { | |||||
if (static_cast<bool>(required_nodes.count(node.name()))) { | |||||
*(filtered_graph_def.mutable_node()->Add()) = node; | *(filtered_graph_def.mutable_node()->Add()) = node; | ||||
} | } | ||||
} | } | ||||
output_graph_def->Clear(); | output_graph_def->Clear(); | ||||
for (const NodeDef &node : filtered_graph_def.node()) { | for (const NodeDef &node : filtered_graph_def.node()) { | ||||
if (input_nodes.count(node.name())) { | |||||
if (static_cast<bool>(input_nodes.count(node.name()))) { | |||||
NodeDef placeholder_node = node; | NodeDef placeholder_node = node; | ||||
placeholder_node.clear_input(); | placeholder_node.clear_input(); | ||||
GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); | GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); | ||||
domi::tensorflow::AttrValue attr_value; | domi::tensorflow::AttrValue attr_value; | ||||
TensorShapeProto *data_shape = attr_value.mutable_shape(); | |||||
domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape(); | |||||
GE_CHECK_NOTNULL(data_shape); | GE_CHECK_NOTNULL(data_shape); | ||||
const ge::ParserContext &ctx = ge::GetParserContext(); | const ge::ParserContext &ctx = ge::GetParserContext(); | ||||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | ||||
@@ -3265,11 +3179,12 @@ Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_par | |||||
// Find all children of the fusion operator | // Find all children of the fusion operator | ||||
auto iter = fusion_op_nodedef_map_.find(node_def->name()); | auto iter = fusion_op_nodedef_map_.find(node_def->name()); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
iter == fusion_op_nodedef_map_.end(), | |||||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in fusion_op_nodedef_map_, check invalid", | |||||
node_def->name().c_str()); | |||||
return INTERNAL_ERROR, "FusionOp node %s has no children node!", node_def->name().c_str()); | |||||
if (iter == fusion_op_nodedef_map_.end()) { | |||||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in fusion_op_nodedef_map_, check invalid", | |||||
node_def->name().c_str()); | |||||
GELOGE(FAILED, "FusionOp node %s has no children node!", node_def->name().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
(void)ge::AttrUtils::SetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, node_def->op()); | (void)ge::AttrUtils::SetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, node_def->op()); | ||||
vector<const domi::tensorflow::NodeDef *> node_def_v = iter->second; | vector<const domi::tensorflow::NodeDef *> node_def_v = iter->second; | ||||
@@ -3325,7 +3240,7 @@ Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_par | |||||
* @return false optimize failed | * @return false optimize failed | ||||
* | * | ||||
*/ | */ | ||||
Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) { | |||||
Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) const { | |||||
GE_CHECK_NOTNULL(graph_def); | GE_CHECK_NOTNULL(graph_def); | ||||
// 1. find all the nodes in the graph and save them to all_nodedef_map | // 1. find all the nodes in the graph and save them to all_nodedef_map | ||||
map<string, NodeDef *> all_nodedef_map; | map<string, NodeDef *> all_nodedef_map; | ||||
@@ -3336,7 +3251,7 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||||
string node_name = current_node->name(); | string node_name = current_node->name(); | ||||
all_nodedef_map[node_name] = current_node; | all_nodedef_map[node_name] = current_node; | ||||
} | } | ||||
GE_CHK_BOOL_EXEC_INFO(!all_nodedef_map.empty(), return SUCCESS, "all_nodedef_map is empty"); | |||||
GELOGD("node size is: %zu", all_nodedef_map.size()); | |||||
// 2. move input to attr. | // 2. move input to attr. | ||||
for (auto &it_node_map : all_nodedef_map) { | for (auto &it_node_map : all_nodedef_map) { | ||||
@@ -3347,14 +3262,14 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||||
// 2.1. check whether the current op is register for move to attr. | // 2.1. check whether the current op is register for move to attr. | ||||
const std::vector<domi::RemoveInputConfigure> &move_input_vec = | const std::vector<domi::RemoveInputConfigure> &move_input_vec = | ||||
domi::OpRegistry::Instance()->GetRemoveInputConfigure(current_op_name); | domi::OpRegistry::Instance()->GetRemoveInputConfigure(current_op_name); | ||||
GE_CHK_BOOL_EXEC_NOLOG(!move_input_vec.empty(), continue); | |||||
GELOGD("Current op %s is registered for remove input.", current_op_name.c_str()); | |||||
// 2.2 check whether the current op is a TVM op. | // 2.2 check whether the current op is a TVM op. | ||||
GE_CHK_BOOL_EXEC_INFO( | |||||
domi::OpRegistry::Instance()->GetImplyTypeByOriOpType(current_op_name) == domi::ImplyType::TVM, continue, | |||||
"op %s is not TVM op", current_op_name.c_str()); | |||||
GELOGD("handle tvm op %s", current_op_name.c_str()); | |||||
const bool is_unknown_custom_op = move_input_vec.empty() || | |||||
(domi::OpRegistry::Instance()->GetImplyTypeByOriOpType(current_op_name) != domi::ImplyType::TVM); | |||||
if (is_unknown_custom_op) { | |||||
GELOGI("op %s is not TVM op, move input size: %zu", current_op_name.c_str(), move_input_vec.size()); | |||||
continue; | |||||
} | |||||
GELOGD("Current op %s is registered for remove input and tvm op", current_op_name.c_str()); | |||||
// 2.3 copy input to attr | // 2.3 copy input to attr | ||||
set<uint32_t> unused_inputs; | set<uint32_t> unused_inputs; | ||||
@@ -3401,7 +3316,8 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||||
} | } | ||||
for (size_t i = 0; i < it.input_order.size(); ++i) { | for (size_t i = 0; i < it.input_order.size(); ++i) { | ||||
int new_index = it.input_order[i]; | int new_index = it.input_order[i]; | ||||
if (new_index < 0 || new_index >= inputs.size()) { | |||||
const bool is_input_invalid = (new_index < 0) || (new_index >= inputs.size()); | |||||
if (is_input_invalid) { | |||||
REPORT_INNER_ERROR("E19999", "New order of %s has invalid index %d, out of range(0, %d)", | REPORT_INNER_ERROR("E19999", "New order of %s has invalid index %d, out of range(0, %d)", | ||||
it_node_map.first.c_str(), new_index, inputs.size()); | it_node_map.first.c_str(), new_index, inputs.size()); | ||||
GELOGE(INTERNAL_ERROR, "New order of %s has invalid index %d.", it_node_map.first.c_str(), new_index); | GELOGE(INTERNAL_ERROR, "New order of %s has invalid index %d.", it_node_map.first.c_str(), new_index); | ||||
@@ -3443,7 +3359,10 @@ Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow:: | |||||
if (input_node_def->op() == parser::SWITCH || input_node_def->op() == parser::REFSWITCH) { | if (input_node_def->op() == parser::SWITCH || input_node_def->op() == parser::REFSWITCH) { | ||||
NodeDef *identity_node_def = graph_def->add_node(); | NodeDef *identity_node_def = graph_def->add_node(); | ||||
GE_CHECK_NOTNULL(identity_node_def); | GE_CHECK_NOTNULL(identity_node_def); | ||||
input_node_name = input_node_name + "identity"; | |||||
std::string remove_input_name = remove_input; | |||||
remove_input_name = remove_input_name.find(":") == std::string::npos ? | |||||
input_node_name : (remove_input_name.replace(remove_input_name.find(":"), 1, "_")); | |||||
input_node_name = remove_input_name + "_identity"; | |||||
identity_node_def->set_name(input_node_name); | identity_node_def->set_name(input_node_name); | ||||
identity_node_def->set_op(parser::IDENTITY); | identity_node_def->set_op(parser::IDENTITY); | ||||
identity_node_def->add_input(remove_input); | identity_node_def->add_input(remove_input); | ||||
@@ -3465,7 +3384,7 @@ Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow:: | |||||
*/ | */ | ||||
Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *node_def, | Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *node_def, | ||||
const set<uint32_t> &remove_index_set, | const set<uint32_t> &remove_index_set, | ||||
const map<string, NodeDef *> &all_node_map) { | |||||
const map<string, NodeDef *> &all_node_map) const { | |||||
GE_CHECK_NOTNULL(node_def); | GE_CHECK_NOTNULL(node_def); | ||||
if (remove_index_set.empty()) { | if (remove_index_set.empty()) { | ||||
GELOGI("The size of remove_index_set is zero."); | GELOGI("The size of remove_index_set is zero."); | ||||
@@ -3662,7 +3581,7 @@ Status TensorFlowModelParser::RecordFusionResult(const std::shared_ptr<ge::Scope | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status TensorFlowModelParser::SetOriginNodeContext(NodeDef *node_def, OpNodeContext &op_node_context, | |||||
Status TensorFlowModelParser::SetOriginNodeContext(const NodeDef *node_def, OpNodeContext &op_node_context, | |||||
const std::vector<std::pair<std::string, int32_t>> &inputs, | const std::vector<std::pair<std::string, int32_t>> &inputs, | ||||
const std::vector<std::pair<std::string, int32_t>> &outputs) { | const std::vector<std::pair<std::string, int32_t>> &outputs) { | ||||
int32_t in_index = 0; | int32_t in_index = 0; | ||||
@@ -3752,7 +3671,7 @@ void TensorFlowModelParser::UpdateInnerInputMap(const string &fusion_op_name, Op | |||||
++iter; | ++iter; | ||||
} | } | ||||
} | } | ||||
op_node_context.input_map.insert(tmp_input_map.begin(), tmp_input_map.end()); | |||||
op_node_context.input_map.insert(tmp_input_map.cbegin(), tmp_input_map.cend()); | |||||
// update output map of pre node | // update output map of pre node | ||||
for (const auto &in_iter : op_node_context.input_map) { | for (const auto &in_iter : op_node_context.input_map) { | ||||
auto src_iter = op_node_context_map_.find(in_iter.first); | auto src_iter = op_node_context_map_.find(in_iter.first); | ||||
@@ -3801,7 +3720,7 @@ void TensorFlowModelParser::UpdateInnerOutputMap(const string &fusion_op_name, O | |||||
++iter; | ++iter; | ||||
} | } | ||||
} | } | ||||
op_node_context.output_map.insert(tmp_output_map.begin(), tmp_output_map.end()); | |||||
op_node_context.output_map.insert(tmp_output_map.cbegin(), tmp_output_map.cend()); | |||||
// update input map of pre node | // update input map of pre node | ||||
for (const auto &out_iter : op_node_context.output_map) { | for (const auto &out_iter : op_node_context.output_map) { | ||||
auto dst_iter = op_node_context_map_.find(out_iter.first); | auto dst_iter = op_node_context_map_.find(out_iter.first); | ||||
@@ -3902,7 +3821,7 @@ Status TensorFlowModelParser::AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope | |||||
DumpAllNodeContext("BeforeAddFusionNodeDef"); | DumpAllNodeContext("BeforeAddFusionNodeDef"); | ||||
for (size_t i = 0; i < op_node_list_size; ++i) { | for (size_t i = 0; i < op_node_list_size; ++i) { | ||||
const string op_node_name = node_name_list[i]; | const string op_node_name = node_name_list[i]; | ||||
auto iter = fusion_op_nodedef_map_.find(op_node_name); | |||||
std::map<string, vector<const NodeDef *>>::const_iterator iter = fusion_op_nodedef_map_.find(op_node_name); | |||||
if (iter != fusion_op_nodedef_map_.end()) { | if (iter != fusion_op_nodedef_map_.end()) { | ||||
vector<string> fusion_op_info = fusion_op_type_map_[op_node_name]; | vector<string> fusion_op_info = fusion_op_type_map_[op_node_name]; | ||||
if (fusion_op_info[0] != ge::kScopeToMultiNodes) { | if (fusion_op_info[0] != ge::kScopeToMultiNodes) { | ||||
@@ -3943,7 +3862,8 @@ Status TensorFlowModelParser::AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope | |||||
} | } | ||||
Status TensorFlowModelParser::AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, | Status TensorFlowModelParser::AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, | ||||
std::mutex *graph_mutex, const domi::tensorflow::NodeDef *node_def) { | |||||
std::mutex *const graph_mutex, | |||||
const domi::tensorflow::NodeDef *node_def) { | |||||
// This is an internal function. The pointer input parameter is not empty when this function is invoked. | // This is an internal function. The pointer input parameter is not empty when this function is invoked. | ||||
string node_name = node_def->name(); | string node_name = node_def->name(); | ||||
string node_op = node_def->op(); | string node_op = node_def->op(); | ||||
@@ -4059,7 +3979,7 @@ Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph | |||||
std::string model_data; | std::string model_data; | ||||
if (AttrUtils::GetStr(node->GetOpDesc(), kExternalModel, model_data) && !model_data.empty()) { | if (AttrUtils::GetStr(node->GetOpDesc(), kExternalModel, model_data) && !model_data.empty()) { | ||||
ge::Model model; | ge::Model model; | ||||
auto load_ret = ge::Model::Load(reinterpret_cast<const uint8_t *>(model_data.data()), model_data.size(), model); | |||||
auto load_ret = ge::Model::Load(ge::PtrToPtr<char_t, const uint8_t>(model_data.data()), model_data.size(), model); | |||||
if (load_ret != GRAPH_SUCCESS) { | if (load_ret != GRAPH_SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "[Parse][ExternalModel]Node:%s.", node->GetName().c_str()); | GELOGE(INTERNAL_ERROR, "[Parse][ExternalModel]Node:%s.", node->GetName().c_str()); | ||||
REPORT_CALL_ERROR("E19999", "Failed to parse external model, node:%s.", node->GetName().c_str()); | REPORT_CALL_ERROR("E19999", "Failed to parse external model, node:%s.", node->GetName().c_str()); | ||||
@@ -35,6 +35,7 @@ | |||||
#include "omg/parser/model_parser.h" | #include "omg/parser/model_parser.h" | ||||
#include "omg/parser/op_parser.h" | #include "omg/parser/op_parser.h" | ||||
#include "omg/parser/weights_parser.h" | #include "omg/parser/weights_parser.h" | ||||
#include "common/pre_checker.h" | |||||
#include "parser/tensorflow/tensorflow_fusion_op_parser.h" | #include "parser/tensorflow/tensorflow_fusion_op_parser.h" | ||||
#include "parser/tensorflow/tensorflow_util.h" | #include "parser/tensorflow/tensorflow_util.h" | ||||
#include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
@@ -46,15 +47,6 @@ | |||||
#include "scope/scope_pass_manager.h" | #include "scope/scope_pass_manager.h" | ||||
#include "common/parser_utils.h" | #include "common/parser_utils.h" | ||||
using ge::ScopePassManager; | |||||
using domi::tensorflow::GraphDef; | |||||
using domi::tensorflow::DT_HALF; | |||||
using domi::tensorflow::NodeDef; | |||||
using domi::tensorflow::GraphDef; | |||||
using domi::tensorflow::AttrValue; | |||||
using domi::tensorflow::DataType; | |||||
using ge::OpParser; | |||||
namespace ge { | namespace ge { | ||||
using std::string; | using std::string; | ||||
using std::vector; | using std::vector; | ||||
@@ -130,7 +122,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, | Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, | ||||
domi::GetGraphCallback callback, | domi::GetGraphCallback callback, | ||||
ge::ComputeGraphPtr &graph) override; | |||||
ge::ComputeGraphPtr &root_graph) override; | |||||
/* | /* | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -163,6 +155,18 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
*/ | */ | ||||
Status ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback, | Status ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback, | ||||
ge::ComputeGraphPtr &root_graph) override; | ge::ComputeGraphPtr &root_graph) override; | ||||
bool HasError() override { | |||||
return PreChecker::Instance().HasError(); | |||||
} | |||||
Status Save(const string &file) override { | |||||
return PreChecker::Instance().Save(file); | |||||
} | |||||
void Clear() override { | |||||
PreChecker::Instance().Clear(); | |||||
} | |||||
private: | private: | ||||
Status Parse(const char *model_path, ge::ComputeGraphPtr &root_graph); | Status Parse(const char *model_path, ge::ComputeGraphPtr &root_graph); | ||||
@@ -241,15 +245,6 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
* @brief Verifying the validity of graphdef object parsed by pb | |||||
* @param [in] graph_def Parsed tensorflow:: graphdef object | |||||
* @return SUCCESS check successfully | |||||
* @return FAILED check failed | |||||
*/ | |||||
Status CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const; | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief whether const OP need to update context | * @brief whether const OP need to update context | ||||
* @param const op name | * @param const op name | ||||
* @return true or false | * @return true or false | ||||
@@ -433,28 +428,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
* @brief Delete the connection relationship of the identity operator connecting the Arg node in graphdef | * @brief Delete the connection relationship of the identity operator connecting the Arg node in graphdef | ||||
*/ | */ | ||||
Status GraphDefOptimize(domi::tensorflow::GraphDef *graph_def); | Status GraphDefOptimize(domi::tensorflow::GraphDef *graph_def); | ||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Optimize for Identity/ReadVariableOp operator | |||||
* @param [in] graph_def GraphDef to be optimized | |||||
* @param [in] nodedef_map Map of all nodes in graph | |||||
* @param [in] nodedef_to_optimize vector of NodeDef to be optimized | |||||
* @return SUCCESS optimize successfully | |||||
* @return others failed | |||||
*/ | |||||
Status GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map, | |||||
const vector<NodeDef *> &nodedef_to_optimize); | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief For the identity operator whose output is "_retval", optimize it. | |||||
* @param [in] nodedef_map Map of all nodes in graph | |||||
* @param [in] curr_node_name Name of node to be optimized | |||||
* @param [in] clear_input_flag Flag of whether to clear the input of the current node | |||||
* @return SUCCESS optimize successfully | |||||
* @return others failed | |||||
*/ | |||||
Status OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, const string &curr_node_name, | |||||
bool &clear_input_flag); | |||||
Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map, | Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map, | ||||
const vector<NodeDef *> &nodedef_to_optimize); | const vector<NodeDef *> &nodedef_to_optimize); | ||||
Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, | Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, | ||||
@@ -469,7 +443,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
void OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *const graph_def, | void OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *const graph_def, | ||||
domi::tensorflow::NodeDef *const nodeCurrent, bool &clearInputFlag) const; | domi::tensorflow::NodeDef *const nodeCurrent, bool &clearInputFlag) const; | ||||
static void OptimizeTranspose(std::map<std::string, DelTransposeInfo> &transposeInfo); | static void OptimizeTranspose(std::map<std::string, DelTransposeInfo> &transposeInfo); | ||||
static void SoftmaxAddAttr(GraphDef *const graph_def); | |||||
static void SoftmaxAddAttr(domi::tensorflow::GraphDef *const graph_def); | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -551,7 +525,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
* @return false optimize failed | * @return false optimize failed | ||||
* | * | ||||
*/ | */ | ||||
Status OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def); | |||||
Status OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) const; | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -565,7 +539,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
Status RemoveInputs(domi::tensorflow::GraphDef *graph_def, | Status RemoveInputs(domi::tensorflow::GraphDef *graph_def, | ||||
domi::tensorflow::NodeDef *node_def, | domi::tensorflow::NodeDef *node_def, | ||||
const set<uint32_t> &remove_index_set, | const set<uint32_t> &remove_index_set, | ||||
const map<string, NodeDef *> &all_node_map); | |||||
const map<string, NodeDef *> &all_node_map) const; | |||||
Status AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def, | Status AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def, | ||||
domi::tensorflow::NodeDef *node_def, | domi::tensorflow::NodeDef *node_def, | ||||
@@ -611,7 +585,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
static Status GetFunctionProto(const string &file, domi::tensorflow::GraphDefLibrary &graph_def_library); | static Status GetFunctionProto(const string &file, domi::tensorflow::GraphDefLibrary &graph_def_library); | ||||
Status SetOriginNodeContext(NodeDef *node_def, OpNodeContext &op_node_context, | |||||
Status SetOriginNodeContext(const NodeDef *node_def, OpNodeContext &op_node_context, | |||||
const std::vector<std::pair<std::string, int32_t>> &inputs, | const std::vector<std::pair<std::string, int32_t>> &inputs, | ||||
const std::vector<std::pair<std::string, int32_t>> &outputs); | const std::vector<std::pair<std::string, int32_t>> &outputs); | ||||
@@ -642,7 +616,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
Status AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph, vector<string> &node_name_list); | Status AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph, vector<string> &node_name_list); | ||||
static Status AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, | static Status AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, | ||||
std::mutex *graph_mutex, const domi::tensorflow::NodeDef *node_def); | |||||
std::mutex *const graph_mutex, const domi::tensorflow::NodeDef *node_def); | |||||
static void DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase); | static void DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase); | ||||
void DumpAllNodeContext(const string &phase) const; | void DumpAllNodeContext(const string &phase) const; | ||||
@@ -725,6 +699,18 @@ class PARSER_FUNC_VISIBILITY TensorFlowWeightsParser : public domi::WeightsParse | |||||
Status Parse(const char *file, ge::Graph &graph) override; | Status Parse(const char *file, ge::Graph &graph) override; | ||||
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | ||||
bool HasError() override { | |||||
return PreChecker::Instance().HasError(); | |||||
} | |||||
Status Save(const string &file) override { | |||||
return PreChecker::Instance().Save(file); | |||||
} | |||||
void Clear() override { | |||||
PreChecker::Instance().Clear(); | |||||
} | |||||
}; | }; | ||||
} // namespace domi | } // namespace domi | ||||
#endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_ | #endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_ |
@@ -30,8 +30,6 @@ | |||||
#include "parser/tensorflow/tensorflow_op_parser.h" | #include "parser/tensorflow/tensorflow_op_parser.h" | ||||
#include "proto/tensorflow/node_def.pb.h" | #include "proto/tensorflow/node_def.pb.h" | ||||
using domi::tensorflow::NodeDef; | |||||
namespace ge { | namespace ge { | ||||
class PARSER_FUNC_VISIBILITY TensorflowFinalizeable { | class PARSER_FUNC_VISIBILITY TensorflowFinalizeable { | ||||
public: | public: | ||||
@@ -20,8 +20,6 @@ | |||||
#include "common/op_def/ref_switch_op.h" | #include "common/op_def/ref_switch_op.h" | ||||
#include "parser/tensorflow/tensorflow_op_parser.h" | #include "parser/tensorflow/tensorflow_op_parser.h" | ||||
using domi::tensorflow::NodeDef; | |||||
namespace ge { | namespace ge { | ||||
class PARSER_FUNC_VISIBILITY TensorFlowRefSwitchParser : public TensorFlowOpParser { | class PARSER_FUNC_VISIBILITY TensorFlowRefSwitchParser : public TensorFlowOpParser { | ||||
// AUTO GEN PLEASE DO NOT MODIFY IT | // AUTO GEN PLEASE DO NOT MODIFY IT | ||||
@@ -61,7 +61,7 @@ Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
const NodeDef *node_src = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src); | |||||
const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||||
GE_CHECK_NOTNULL(node_src); | GE_CHECK_NOTNULL(node_src); | ||||
GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); | ||||
domi::tensorflow::AttrValue input_attr_value; | domi::tensorflow::AttrValue input_attr_value; | ||||
@@ -34,7 +34,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowReshapeParser : public TensorFlowOpParser | |||||
* @return FAILED parse failed | * @return FAILED parse failed | ||||
* @author | * @author | ||||
*/ | */ | ||||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | |||||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -94,7 +94,7 @@ Status TensorFlowShapeNParser::ParseN(const domi::tensorflow::NodeDef *node, Sha | |||||
Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | ||||
GE_CHECK_NOTNULL(op_dest); | GE_CHECK_NOTNULL(op_dest); | ||||
const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src); | |||||
const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
ShapeNOperator op; | ShapeNOperator op; | ||||
op.Name(node->name()); | op.Name(node->name()); | ||||
@@ -154,13 +154,13 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
} | } | ||||
// AUTO GEN PLEASE DO NOT MODIFY IT | // AUTO GEN PLEASE DO NOT MODIFY IT | ||||
Status TensorFlowShapeNParser::PreParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) { | |||||
Status TensorFlowShapeNParser::PreParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op) { | |||||
(void)node; | (void)node; | ||||
(void)op; | (void)op; | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status TensorFlowShapeNParser::PostParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) { | |||||
Status TensorFlowShapeNParser::PostParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op) { | |||||
(void)node; | (void)node; | ||||
(void)op; | (void)op; | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -20,8 +20,6 @@ | |||||
#include "common/op_def/shape_n_op.h" | #include "common/op_def/shape_n_op.h" | ||||
#include "parser/tensorflow/tensorflow_op_parser.h" | #include "parser/tensorflow/tensorflow_op_parser.h" | ||||
using domi::tensorflow::NodeDef; | |||||
namespace ge { | namespace ge { | ||||
class PARSER_FUNC_VISIBILITY TensorFlowShapeNParser : public TensorFlowOpParser { | class PARSER_FUNC_VISIBILITY TensorFlowShapeNParser : public TensorFlowOpParser { | ||||
// AUTO GEN PLEASE DO NOT MODIFY IT | // AUTO GEN PLEASE DO NOT MODIFY IT | ||||
@@ -29,8 +27,8 @@ class PARSER_FUNC_VISIBILITY TensorFlowShapeNParser : public TensorFlowOpParser | |||||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | ||||
protected: | protected: | ||||
Status PreParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||||
Status PostParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||||
Status PreParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op); | |||||
Status PostParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op); | |||||
static Status ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | static Status ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | ||||
static Status ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | static Status ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | ||||
@@ -66,7 +66,7 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src); | |||||
const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | ||||
bool has_axis = true; | bool has_axis = true; | ||||
@@ -22,7 +22,7 @@ | |||||
namespace ge { | namespace ge { | ||||
class PARSER_FUNC_VISIBILITY TensorFlowSqueezeParser : public TensorFlowOpParser { | class PARSER_FUNC_VISIBILITY TensorFlowSqueezeParser : public TensorFlowOpParser { | ||||
public: | public: | ||||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | |||||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; | |||||
private: | private: | ||||
static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); | static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); | ||||
@@ -207,7 +207,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Ch | |||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType( | ||||
const NodeDef *node_src, const std::string &attr_src, domi::tensorflow::DataType &data_type) { | |||||
const domi::tensorflow::NodeDef *node_src, const std::string &attr_src, domi::tensorflow::DataType &data_type) { | |||||
GE_CHECK_NOTNULL(node_src); | GE_CHECK_NOTNULL(node_src); | ||||
std::string node_name = node_src->name(); | std::string node_name = node_src->name(); | ||||
@@ -301,7 +301,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr | |||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr( | ||||
const std::string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *const node_def) { | const std::string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *const node_def) { | ||||
GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null."); | |||||
if (node_def == nullptr) { | |||||
GELOGI("input parameter is null."); | |||||
return; | |||||
} | |||||
node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); | node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -18,14 +18,11 @@ | |||||
#define OMG_PARSER_TENSORFLOW_TENSORFLOW_UTIL_H_ | #define OMG_PARSER_TENSORFLOW_TENSORFLOW_UTIL_H_ | ||||
#include <map> | #include <map> | ||||
#include <set> | |||||
#include <string> | #include <string> | ||||
#include <unordered_map> | #include <unordered_map> | ||||
#include <vector> | |||||
#include "parser/common/op_def/operator.h" | #include "parser/common/op_def/operator.h" | ||||
#include "external/graph/attr_value.h" | #include "external/graph/attr_value.h" | ||||
#include "external/graph/graph.h" | #include "external/graph/graph.h" | ||||
#include "external/graph/operator.h" | |||||
#include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
#include "framework/omg/omg_inner_types.h" | #include "framework/omg/omg_inner_types.h" | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
@@ -37,11 +34,6 @@ | |||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "proto/tensorflow/graph.pb.h" | #include "proto/tensorflow/graph.pb.h" | ||||
using domi::tensorflow::NodeDef; | |||||
using domi::tensorflow::FunctionDef; | |||||
using domi::tensorflow::AttrValue_ListValue; | |||||
using domi::tensorflow::FunctionDefLibrary; | |||||
namespace ge { | namespace ge { | ||||
/***************************TensorFlow attribute type, constant definition*******************************************/ | /***************************TensorFlow attribute type, constant definition*******************************************/ | ||||
extern const std::string TENSORFLOW_ATTR_TYPE_STRING; | extern const std::string TENSORFLOW_ATTR_TYPE_STRING; | ||||
@@ -167,7 +159,7 @@ class TensorFlowUtil { | |||||
* @return FAILED parsing failed | * @return FAILED parsing failed | ||||
* | * | ||||
*/ | */ | ||||
static domi::Status ParseDataType(const NodeDef *node_src, | |||||
static domi::Status ParseDataType(const domi::tensorflow::NodeDef *node_src, | |||||
const std::string &attr_src, | const std::string &attr_src, | ||||
domi::tensorflow::DataType &data_type); | domi::tensorflow::DataType &data_type); | ||||
@@ -25,7 +25,7 @@ using namespace ge::parser; | |||||
namespace ge { | namespace ge { | ||||
Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *const op) { | Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *const op) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||||
const domi::tensorflow::NodeDef *node = ge::PtrToPtr<Message, domi::tensorflow::NodeDef>(op_src); | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | ||||
op->Name(node->name()); | op->Name(node->name()); | ||||
@@ -19,7 +19,6 @@ | |||||
#include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "graph/operator.h" | |||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "parser/common/op_def/variable_op.h" | #include "parser/common/op_def/variable_op.h" | ||||
@@ -253,7 +252,7 @@ static void ParseMemType(const domi::tensorflow::NodeDef *node, VariableOperator | |||||
Status ParseParams(const Message *op_src, VariableOperator *op) { | Status ParseParams(const Message *op_src, VariableOperator *op) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||||
const NodeDef *node = ge::PtrToPtr<Message, NodeDef>(op_src); | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | ||||
string node_op = node->op(); | string node_op = node->op(); | ||||
@@ -121,7 +121,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc( | |||||
GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), | GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), | ||||
REPORT_CALL_ERROR("E19999", "UnserializeOpDesc failed"); | REPORT_CALL_ERROR("E19999", "UnserializeOpDesc failed"); | ||||
return op_desc, "[Call][UnserializeOpDesc] op_desc unserialize failed"); | return op_desc, "[Call][UnserializeOpDesc] op_desc unserialize failed"); | ||||
op_desc->extAttrs_ = org_op_desc->extAttrs_; | |||||
op_desc->ext_attrs_ = org_op_desc->ext_attrs_; | |||||
// This function may be called by some passes of fusion engine, in this condition, do not need these attribute | // This function may be called by some passes of fusion engine, in this condition, do not need these attribute | ||||
if (op_desc->impl_ == nullptr) { | if (op_desc->impl_ == nullptr) { | ||||
@@ -164,7 +164,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
op_desc->extAttrs_ = org_op_desc->extAttrs_; | |||||
op_desc->ext_attrs_ = org_op_desc->ext_attrs_; | |||||
if (op_desc->impl_ == nullptr) { | if (op_desc->impl_ == nullptr) { | ||||
REPORT_INNER_ERROR("E19999", "op desc impl is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "op desc impl is nullptr, check invalid"); | ||||
@@ -158,6 +158,17 @@ mmTimespec mmGetTickCount() { | |||||
return rts; | return rts; | ||||
} | } | ||||
INT32 mmGetSystemTime(mmSystemTime_t *sysTime) { | |||||
// Beijing olympics | |||||
sysTime->wYear = 2008; | |||||
sysTime->wMonth = 8; | |||||
sysTime->wDay = 8; | |||||
sysTime->wHour = 20; | |||||
sysTime->wMinute = 8; | |||||
sysTime->wSecond = 0; | |||||
return 0; | |||||
} | |||||
INT32 mmGetTid() { | INT32 mmGetTid() { | ||||
INT32 ret = (INT32)syscall(SYS_gettid); | INT32 ret = (INT32)syscall(SYS_gettid); | ||||
@@ -61,7 +61,7 @@ target_link_libraries(st_parser_proto PRIVATE | |||||
################################################################################ | ################################################################################ | ||||
set(DUPLICATE_PROTO_LIST | set(DUPLICATE_PROTO_LIST | ||||
"${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto" | |||||
"${PARSER_DIR}/metadef/proto/onnx/ge_onnx.proto" | |||||
) | ) | ||||
protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST}) | protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST}) | ||||
@@ -118,7 +118,10 @@ set(MATEDEF_SRC_FILES | |||||
"${PARSER_DIR}/metadef/graph/resource_context_mgr.cc" | "${PARSER_DIR}/metadef/graph/resource_context_mgr.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/constant_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/constant_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/file_utils.cc" | |||||
"${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/connection_matrix.cc" | |||||
"${PARSER_DIR}/metadef/graph/utils/cycle_detector.cc" | |||||
"${PARSER_DIR}/metadef/graph/utils/graph_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/graph_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/node_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/node_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc" | ||||
@@ -126,6 +129,8 @@ set(MATEDEF_SRC_FILES | |||||
"${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/type_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/type_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/trace/trace_manager.cc" | |||||
"${PARSER_DIR}/metadef/graph/common/large_bm.cc" | |||||
"${PARSER_DIR}/metadef/ops/op_imp.cpp" | "${PARSER_DIR}/metadef/ops/op_imp.cpp" | ||||
"${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc" | "${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc" | ||||
"${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" | "${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" | ||||
@@ -21,7 +21,9 @@ | |||||
#include <google/protobuf/io/zero_copy_stream_impl.h> | #include <google/protobuf/io/zero_copy_stream_impl.h> | ||||
#include <google/protobuf/text_format.h> | #include <google/protobuf/text_format.h> | ||||
#include <fstream> | #include <fstream> | ||||
#include <sys/types.h> | |||||
#include <sys/stat.h> | |||||
#include <fcntl.h> | |||||
namespace ge { | namespace ge { | ||||
void ParerSTestsUtils::ClearParserInnerCtx() { | void ParerSTestsUtils::ClearParserInnerCtx() { | ||||
@@ -131,4 +133,14 @@ void ParerSTestsUtils::WriteProtoToBinaryFile(const google::protobuf::Message &p | |||||
out.close(); | out.close(); | ||||
delete[] buf; | delete[] buf; | ||||
} | } | ||||
void ParerSTestsUtils::WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename) { | |||||
const int32_t fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 384U); | |||||
if (fd >= 0) { | |||||
google::protobuf::io::FileOutputStream output(fd); | |||||
google::protobuf::TextFormat::Print(proto, &output); | |||||
output.Close(); | |||||
close(fd); | |||||
} | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -31,6 +31,7 @@ class ParerSTestsUtils { | |||||
static MemBuffer* MemBufferFromFile(const char *path); | static MemBuffer* MemBufferFromFile(const char *path); | ||||
static bool ReadProtoFromText(const char *file, google::protobuf::Message *message); | static bool ReadProtoFromText(const char *file, google::protobuf::Message *message); | ||||
static void WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename); | static void WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename); | ||||
static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -36,6 +36,7 @@ | |||||
#include "parser/caffe/caffe_op_parser.h" | #include "parser/caffe/caffe_op_parser.h" | ||||
#include "graph/operator_reg.h" | #include "graph/operator_reg.h" | ||||
#include "parser/common/acl_graph_parser_util.h" | #include "parser/common/acl_graph_parser_util.h" | ||||
#include "common/op_map.h" | |||||
#undef protected | #undef protected | ||||
#undef private | #undef private | ||||
@@ -173,7 +174,7 @@ void STestCaffeParser::RegisterCustomOp() { | |||||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | ||||
for (auto reg_data : reg_datas) { | for (auto reg_data : reg_datas) { | ||||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||||
domi::OpRegistry::Instance()->Register(reg_data); | domi::OpRegistry::Instance()->Register(reg_data); | ||||
} | } | ||||
domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
@@ -223,6 +224,29 @@ TEST_F(STestCaffeParser, acl_caffe_parser) { | |||||
EXPECT_EQ(ret, GRAPH_FAILED); | EXPECT_EQ(ret, GRAPH_FAILED); | ||||
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), graph); | ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), graph); | ||||
EXPECT_EQ(ret, GRAPH_FAILED); | EXPECT_EQ(ret, GRAPH_FAILED); | ||||
caffe_op_map.clear(); | |||||
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), parser_params, graph); | |||||
EXPECT_EQ(ret, GRAPH_FAILED); | |||||
{ | |||||
proto.set_name("empty_layer"); | |||||
auto &layers = *proto.add_layers(); | |||||
layers.set_name("layers"); | |||||
proto.clear_layer(); | |||||
const std::string empty_layer = case_dir + "/origin_models/empty_layer.pbtxt"; | |||||
ParerSTestsUtils::WriteProtoToTextFile(proto, empty_layer.c_str()); | |||||
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layer.c_str(), weight_file.c_str(), parser_params, graph), FAILED); | |||||
proto.clear_layers(); | |||||
const std::string empty_layers = case_dir + "/origin_models/empty_layers.pbtxt"; | |||||
ParerSTestsUtils::WriteProtoToTextFile(proto, empty_layers.c_str()); | |||||
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layers.c_str(), weight_file.c_str(), parser_params, graph), FAILED); | |||||
unlink(empty_layer.c_str()); | |||||
unlink(empty_layers.c_str()); | |||||
} | |||||
} | } | ||||
TEST_F(STestCaffeParser, modelparser_parsefrommemory_success) | TEST_F(STestCaffeParser, modelparser_parsefrommemory_success) | ||||
@@ -24,6 +24,7 @@ | |||||
#include "st/parser_st_utils.h" | #include "st/parser_st_utils.h" | ||||
#include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
#include "tests/depends/ops_stub/ops_stub.h" | #include "tests/depends/ops_stub/ops_stub.h" | ||||
#include "framework/omg/parser/parser_factory.h" | |||||
#include "parser/onnx/onnx_parser.h" | #include "parser/onnx/onnx_parser.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -96,7 +97,7 @@ void STestOnnxParser::RegisterCustomOp() { | |||||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | ||||
for (auto reg_data : reg_datas) { | for (auto reg_data : reg_datas) { | ||||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||||
domi::OpRegistry::Instance()->Register(reg_data); | domi::OpRegistry::Instance()->Register(reg_data); | ||||
} | } | ||||
domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
@@ -64,6 +64,7 @@ | |||||
#include "parser/common/data_op_parser.h" | #include "parser/common/data_op_parser.h" | ||||
#include "parser/common/model_saver.h" | #include "parser/common/model_saver.h" | ||||
#include "framework/omg/parser/parser_api.h" | #include "framework/omg/parser/parser_api.h" | ||||
#include "framework/omg/parser/parser_factory.h" | |||||
#include "parser/common/parser_fp16_t.h" | #include "parser/common/parser_fp16_t.h" | ||||
#include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
#include "parser/common/prototype_pass_manager.h" | #include "parser/common/prototype_pass_manager.h" | ||||
@@ -151,7 +152,7 @@ void STestTensorflowParser::RegisterCustomOp() { | |||||
.ParseParamsFn(ParseParams); | .ParseParamsFn(ParseParams); | ||||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | ||||
for (auto reg_data : reg_datas) { | for (auto reg_data : reg_datas) { | ||||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||||
domi::OpRegistry::Instance()->Register(reg_data); | domi::OpRegistry::Instance()->Register(reg_data); | ||||
} | } | ||||
domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
@@ -584,7 +585,7 @@ namespace { | |||||
void register_tbe_op() { | void register_tbe_op() { | ||||
std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas; | ||||
for (OpRegistrationData reg_data : registrationDatas) { | for (OpRegistrationData reg_data : registrationDatas) { | ||||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||||
OpRegistry::Instance()->Register(reg_data); | OpRegistry::Instance()->Register(reg_data); | ||||
} | } | ||||
OpRegistry::Instance()->registrationDatas.clear(); | OpRegistry::Instance()->registrationDatas.clear(); | ||||
@@ -1124,7 +1125,7 @@ TEST_F(STestTensorflowParser, tensorflow_parserfrommemory_failed) | |||||
ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); | ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); | ||||
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | ||||
ret = modelParser.ParseFromMemory(data, size, compute_graph); | ret = modelParser.ParseFromMemory(data, size, compute_graph); | ||||
EXPECT_EQ(ret, INTERNAL_ERROR); | |||||
EXPECT_NE(ret, SUCCESS); | |||||
} | } | ||||
TEST_F(STestTensorflowParser, modelparser_parsefrommemory_success) | TEST_F(STestTensorflowParser, modelparser_parsefrommemory_success) | ||||
@@ -1259,7 +1260,7 @@ TEST_F(STestTensorflowParser, tensorflow_parserAllGraph_failed) | |||||
ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); | ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); | ||||
TensorFlowModelParser tensorflow_parser; | TensorFlowModelParser tensorflow_parser; | ||||
ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); | ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); | ||||
EXPECT_EQ(INTERNAL_ERROR, ret); | |||||
ASSERT_NE(ret, SUCCESS); | |||||
} | } | ||||
TEST_F(STestTensorflowParser, test_parse_acl_output_nodes) | TEST_F(STestTensorflowParser, test_parse_acl_output_nodes) | ||||
@@ -1913,6 +1914,7 @@ TEST_F(STestTensorflowParser, tensorflow_auto_mapping_parser_adapter_test) | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
op_dest->SetType(ge::parser::SHAPE); | op_dest->SetType(ge::parser::SHAPE); | ||||
op_dest->AddOutputDesc(GeTensorDesc()); | |||||
ret = autoMappingParser.ParseParams(node_def, op_dest); | ret = autoMappingParser.ParseParams(node_def, op_dest); | ||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | } | ||||
@@ -2648,29 +2650,6 @@ TEST_F(STestTensorflowParser, tensorflow_UpdateEdgesControlInfo_test) | |||||
model_parser.UpdateEdgesControlInfo(info); | model_parser.UpdateEdgesControlInfo(info); | ||||
} | } | ||||
TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test) | |||||
{ | |||||
TensorFlowModelParser model_parser; | |||||
NodeDef *node_def = new NodeDef(); | |||||
node_def->set_name("Placeholder"); | |||||
node_def->set_op("Placeholder_0"); | |||||
std::map<string, NodeDef *> nodedef_map; | |||||
nodedef_map.emplace("Placeholder", node_def); | |||||
std::string curr_node_name = "Placeholder"; | |||||
bool clear_input_flag = true; | |||||
Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||||
EXPECT_EQ(ret, INTERNAL_ERROR); | |||||
GraphDef graph; | |||||
curr_node_name = "pre_node_a"; | |||||
nodedef_map.emplace("pre_node_a", node_def); | |||||
node_def->set_op("pre_node_a"); | |||||
GenOriginContext(&model_parser, curr_node_name); | |||||
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
delete node_def; | |||||
} | |||||
TEST_F(STestTensorflowParser, tensorflow_OptimizeSnapShot_test) | TEST_F(STestTensorflowParser, tensorflow_OptimizeSnapShot_test) | ||||
{ | { | ||||
TensorFlowModelParser model_parser; | TensorFlowModelParser model_parser; | ||||
@@ -2831,27 +2810,17 @@ TEST_F(STestTensorflowParser, tensorflow_AddControlEdgeAfterRemoveInputs_test) | |||||
removed_inputs_vec.emplace_back("Add0"); | removed_inputs_vec.emplace_back("Add0"); | ||||
Status ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_def, all_node_map, removed_inputs_vec); | Status ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_def, all_node_map, removed_inputs_vec); | ||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
tensorflow::NodeDef *node_swith = initNodeDef(); | |||||
node_swith->set_name("switch_op"); | |||||
node_swith->set_op(parser::SWITCH); | |||||
all_node_map.emplace("switch_op", node_swith); | |||||
removed_inputs_vec.clear(); | |||||
removed_inputs_vec.emplace_back("switch_op"); | |||||
ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_swith, all_node_map, removed_inputs_vec); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | } | ||||
TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test) | |||||
{ | |||||
tensorflow::GraphDef graph_def; | |||||
TensorFlowModelParser tensorflow_parser; | |||||
tensorflow::NodeDef *node_def = initNodeDef(); | |||||
node_def->set_name("post_node_d"); | |||||
std::map<string, NodeDef *> nodedef_map; | |||||
nodedef_map.emplace("post_node_d", node_def); | |||||
nodedef_map.emplace("post_node_a", node_def); | |||||
nodedef_map.emplace("post_node_b", node_def); | |||||
std::vector<NodeDef *> nodedef_to_optimize; | |||||
nodedef_to_optimize.emplace_back(node_def); | |||||
std::string curr_node_name = "post_node_b"; | |||||
GenOriginContext(&tensorflow_parser, curr_node_name); | |||||
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); | |||||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||||
} | |||||
TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | ||||
std::string caseDir = __FILE__; | std::string caseDir = __FILE__; | ||||
std::size_t idx = caseDir.find_last_of("/"); | std::size_t idx = caseDir.find_last_of("/"); | ||||
@@ -3534,7 +3503,8 @@ TEST_F(STestTensorflowParser, tensorflow_Pb2Json_OneField2Json_test) | |||||
ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); | ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); | ||||
field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM); | field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM); | ||||
mess2Op.ParseField(reflection, node_def, field, depth, ops); | mess2Op.ParseField(reflection, node_def, field, depth, ops); | ||||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str); | |||||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 1); | |||||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 5); | |||||
delete field; | delete field; | ||||
} | } | ||||
@@ -62,7 +62,7 @@ target_link_libraries(ut_parser_proto PRIVATE | |||||
################################################################################ | ################################################################################ | ||||
set(DUPLICATE_PROTO_LIST | set(DUPLICATE_PROTO_LIST | ||||
"${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto" | |||||
"${PARSER_DIR}/metadef/proto/onnx/ge_onnx.proto" | |||||
) | ) | ||||
protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST}) | protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST}) | ||||
@@ -119,14 +119,19 @@ set(MATEDEF_SRC_FILES | |||||
"${PARSER_DIR}/metadef/graph/tensor.cc" | "${PARSER_DIR}/metadef/graph/tensor.cc" | ||||
"${PARSER_DIR}/metadef/graph/types.cc" | "${PARSER_DIR}/metadef/graph/types.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/file_utils.cc" | |||||
"${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/graph_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/graph_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/connection_matrix.cc" | |||||
"${PARSER_DIR}/metadef/graph/utils/cycle_detector.cc" | |||||
"${PARSER_DIR}/metadef/graph/utils/node_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/node_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/tensor_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/tensor_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/type_utils.cc" | "${PARSER_DIR}/metadef/graph/utils/type_utils.cc" | ||||
"${PARSER_DIR}/metadef/graph/utils/trace/trace_manager.cc" | |||||
"${PARSER_DIR}/metadef/graph/common/large_bm.cc" | |||||
"${PARSER_DIR}/metadef/ops/op_imp.cpp" | "${PARSER_DIR}/metadef/ops/op_imp.cpp" | ||||
"${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc" | "${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc" | ||||
"${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" | "${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" | ||||
@@ -21,6 +21,9 @@ | |||||
#include <google/protobuf/io/zero_copy_stream_impl.h> | #include <google/protobuf/io/zero_copy_stream_impl.h> | ||||
#include <google/protobuf/text_format.h> | #include <google/protobuf/text_format.h> | ||||
#include <limits.h> | #include <limits.h> | ||||
#include <sys/types.h> | |||||
#include <sys/stat.h> | |||||
#include <fcntl.h> | |||||
namespace ge { | namespace ge { | ||||
void ParerUTestsUtils::ClearParserInnerCtx() { | void ParerUTestsUtils::ClearParserInnerCtx() { | ||||
@@ -131,6 +134,16 @@ void ParerUTestsUtils::WriteProtoToBinaryFile(const google::protobuf::Message &p | |||||
delete[] buf; | delete[] buf; | ||||
} | } | ||||
void ParerUTestsUtils::WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename) { | |||||
const int32_t fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 384U); | |||||
if (fd >= 0) { | |||||
google::protobuf::io::FileOutputStream output(fd); | |||||
google::protobuf::TextFormat::Print(proto, &output); | |||||
output.Close(); | |||||
close(fd); | |||||
} | |||||
} | |||||
namespace ut { | namespace ut { | ||||
NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, | NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, | ||||
DataType data_type, std::vector<int64_t> shape) { | DataType data_type, std::vector<int64_t> shape) { | ||||
@@ -32,6 +32,7 @@ class ParerUTestsUtils { | |||||
static MemBuffer* MemBufferFromFile(const char *path); | static MemBuffer* MemBufferFromFile(const char *path); | ||||
static bool ReadProtoFromText(const char *file, google::protobuf::Message *message); | static bool ReadProtoFromText(const char *file, google::protobuf::Message *message); | ||||
static void WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename); | static void WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename); | ||||
static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename); | |||||
}; | }; | ||||
namespace ut { | namespace ut { | ||||
@@ -39,6 +39,7 @@ | |||||
#include "graph/operator_reg.h" | #include "graph/operator_reg.h" | ||||
#include "parser/common/acl_graph_parser_util.h" | #include "parser/common/acl_graph_parser_util.h" | ||||
#include "parser/caffe/caffe_reshape_parser.h" | #include "parser/caffe/caffe_reshape_parser.h" | ||||
#include "common/op_map.h" | |||||
#undef protected | #undef protected | ||||
#undef private | #undef private | ||||
@@ -162,7 +163,7 @@ static ge::NodePtr GenNodeFromOpDesc(ge::OpDescPtr opDesc){ | |||||
void UtestCaffeParser::RegisterCustomOp() { | void UtestCaffeParser::RegisterCustomOp() { | ||||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | ||||
for (auto reg_data : reg_datas) { | for (auto reg_data : reg_datas) { | ||||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||||
domi::OpRegistry::Instance()->Register(reg_data); | domi::OpRegistry::Instance()->Register(reg_data); | ||||
} | } | ||||
domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
@@ -266,6 +267,29 @@ TEST_F(UtestCaffeParser, acl_caffe_parser) { | |||||
EXPECT_EQ(ret, GRAPH_FAILED); | EXPECT_EQ(ret, GRAPH_FAILED); | ||||
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), graph); | ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), graph); | ||||
EXPECT_EQ(ret, GRAPH_FAILED); | EXPECT_EQ(ret, GRAPH_FAILED); | ||||
caffe_op_map.clear(); | |||||
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), parser_params, graph); | |||||
EXPECT_EQ(ret, GRAPH_FAILED); | |||||
{ | |||||
proto.set_name("empty_layer"); | |||||
auto &layers = *proto.add_layers(); | |||||
layers.set_name("layers"); | |||||
proto.clear_layer(); | |||||
const std::string empty_layer = case_dir + "/caffe_model/empty_layer.pbtxt"; | |||||
ParerUTestsUtils::WriteProtoToTextFile(proto, empty_layer.c_str()); | |||||
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layer.c_str(), weight_file.c_str(), parser_params, graph), FAILED); | |||||
proto.clear_layers(); | |||||
const std::string empty_layers = case_dir + "/caffe_model/empty_layers.pbtxt"; | |||||
ParerUTestsUtils::WriteProtoToTextFile(proto, empty_layers.c_str()); | |||||
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layers.c_str(), weight_file.c_str(), parser_params, graph), FAILED); | |||||
unlink(empty_layer.c_str()); | |||||
unlink(empty_layers.c_str()); | |||||
} | |||||
} | } | ||||
TEST_F(UtestCaffeParser, ParseFromMemory_success) | TEST_F(UtestCaffeParser, ParseFromMemory_success) | ||||
@@ -34,6 +34,7 @@ | |||||
#include "parser/common/pass_manager.h" | #include "parser/common/pass_manager.h" | ||||
#include "parser/common/tbe_plugin_loader.h" | #include "parser/common/tbe_plugin_loader.h" | ||||
#include "parser/common/parser_fp16_t.h" | #include "parser/common/parser_fp16_t.h" | ||||
#include "parser/common/pre_checker.h" | |||||
#undef protected | #undef protected | ||||
#undef private | #undef private | ||||
@@ -342,4 +343,15 @@ TEST_F(UtestAclGraphParser, test_operatoreq) | |||||
int8 = fp16; | int8 = fp16; | ||||
} | } | ||||
TEST_F(UtestAclGraphParser, test_pre_checker) { | |||||
PreChecker::Instance().fmk_op_types_ = nullptr; | |||||
const char* str = "iiii"; | |||||
PreChecker::OpId id = str; | |||||
std::string type("ddd"); | |||||
std::string name("lll"); | |||||
Status ret = PreChecker::Instance().CheckTypeSupported(id, type, name, false); | |||||
EXPECT_EQ(ret, FAILED); | |||||
ret = PreChecker::Instance().CheckTypeSupported(id, type, name, true); | |||||
EXPECT_EQ(ret, FAILED); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -24,6 +24,7 @@ | |||||
#include "external/parser/onnx_parser.h" | #include "external/parser/onnx_parser.h" | ||||
#include "ut/parser/parser_ut_utils.h" | #include "ut/parser/parser_ut_utils.h" | ||||
#include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
#include "framework/omg/parser/parser_factory.h" | |||||
#include "tests/depends/ops_stub/ops_stub.h" | #include "tests/depends/ops_stub/ops_stub.h" | ||||
#define protected public | #define protected public | ||||
@@ -103,7 +104,7 @@ void UtestOnnxParser::RegisterCustomOp() { | |||||
.ParseParamsFn(ParseParams); | .ParseParamsFn(ParseParams); | ||||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | ||||
for (auto reg_data : reg_datas) { | for (auto reg_data : reg_datas) { | ||||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||||
domi::OpRegistry::Instance()->Register(reg_data); | domi::OpRegistry::Instance()->Register(reg_data); | ||||
} | } | ||||
domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
@@ -176,7 +176,7 @@ void UtestTensorflowParser::RegisterCustomOp() { | |||||
.ParseParamsFn(ParseParams); | .ParseParamsFn(ParseParams); | ||||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | ||||
for (auto reg_data : reg_datas) { | for (auto reg_data : reg_datas) { | ||||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||||
domi::OpRegistry::Instance()->Register(reg_data); | domi::OpRegistry::Instance()->Register(reg_data); | ||||
} | } | ||||
domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
@@ -599,7 +599,7 @@ namespace { | |||||
void register_tbe_op() { | void register_tbe_op() { | ||||
std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas; | ||||
for (OpRegistrationData reg_data : registrationDatas) { | for (OpRegistrationData reg_data : registrationDatas) { | ||||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||||
OpRegistry::Instance()->Register(reg_data); | OpRegistry::Instance()->Register(reg_data); | ||||
} | } | ||||
OpRegistry::Instance()->registrationDatas.clear(); | OpRegistry::Instance()->registrationDatas.clear(); | ||||
@@ -1288,7 +1288,7 @@ TEST_F(UtestTensorflowParser, tensorflow_parserfrommemory_failed) | |||||
ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); | ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); | ||||
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | ||||
ret = modelParser.ParseFromMemory(data, size, compute_graph); | ret = modelParser.ParseFromMemory(data, size, compute_graph); | ||||
EXPECT_EQ(ret, INTERNAL_ERROR); | |||||
EXPECT_NE(ret, SUCCESS); | |||||
} | } | ||||
TEST_F(UtestTensorflowParser, modelparser_parsefrommemory_success) | TEST_F(UtestTensorflowParser, modelparser_parsefrommemory_success) | ||||
@@ -1419,7 +1419,7 @@ TEST_F(UtestTensorflowParser, tensorflow_parserAllGraph_failed) | |||||
ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); | ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); | ||||
TensorFlowModelParser tensorflow_parser; | TensorFlowModelParser tensorflow_parser; | ||||
ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); | ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); | ||||
EXPECT_EQ(INTERNAL_ERROR, ret); | |||||
ASSERT_NE(ret, SUCCESS); | |||||
} | } | ||||
TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes) | TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes) | ||||
@@ -2082,6 +2082,7 @@ TEST_F(UtestTensorflowParser, tensorflow_auto_mapping_parser_adapter_test) | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
op_dest->SetType(ge::parser::SHAPE); | op_dest->SetType(ge::parser::SHAPE); | ||||
op_dest->AddOutputDesc(GeTensorDesc()); | |||||
ret = autoMappingParser.ParseParams(node_def, op_dest); | ret = autoMappingParser.ParseParams(node_def, op_dest); | ||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | } | ||||
@@ -2824,29 +2825,6 @@ TEST_F(UtestTensorflowParser, tensorflow_UpdateEdgesControlInfo_test) | |||||
model_parser.UpdateEdgesControlInfo(info); | model_parser.UpdateEdgesControlInfo(info); | ||||
} | } | ||||
TEST_F(UtestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test) | |||||
{ | |||||
TensorFlowModelParser model_parser; | |||||
NodeDef *node_def = new NodeDef(); | |||||
node_def->set_name("Placeholder"); | |||||
node_def->set_op("Placeholder_0"); | |||||
std::map<string, NodeDef *> nodedef_map; | |||||
nodedef_map.emplace("Placeholder", node_def); | |||||
std::string curr_node_name = "Placeholder"; | |||||
bool clear_input_flag = true; | |||||
Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||||
EXPECT_EQ(ret, INTERNAL_ERROR); | |||||
GraphDef graph; | |||||
curr_node_name = "pre_node_a"; | |||||
nodedef_map.emplace("pre_node_a", node_def); | |||||
node_def->set_op("pre_node_a"); | |||||
GenOriginContext(&model_parser, curr_node_name); | |||||
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
delete node_def; | |||||
} | |||||
TEST_F(UtestTensorflowParser, tensorflow_OptimizeSnapShot_test) | TEST_F(UtestTensorflowParser, tensorflow_OptimizeSnapShot_test) | ||||
{ | { | ||||
TensorFlowModelParser model_parser; | TensorFlowModelParser model_parser; | ||||
@@ -3007,27 +2985,18 @@ TEST_F(UtestTensorflowParser, tensorflow_AddControlEdgeAfterRemoveInputs_test) | |||||
removed_inputs_vec.emplace_back("Add0"); | removed_inputs_vec.emplace_back("Add0"); | ||||
Status ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_def, all_node_map, removed_inputs_vec); | Status ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_def, all_node_map, removed_inputs_vec); | ||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | |||||
TEST_F(UtestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test) | |||||
{ | |||||
tensorflow::GraphDef graph_def; | |||||
TensorFlowModelParser tensorflow_parser; | |||||
tensorflow::NodeDef *node_def = initNodeDef(); | |||||
node_def->set_name("post_node_d"); | |||||
tensorflow::NodeDef *node_swith = initNodeDef(); | |||||
node_swith->set_name("switch_op"); | |||||
node_swith->set_op(parser::SWITCH); | |||||
all_node_map.emplace("switch_op", node_swith); | |||||
removed_inputs_vec.clear(); | |||||
removed_inputs_vec.emplace_back("switch_op"); | |||||
ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_swith, all_node_map, removed_inputs_vec); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
std::map<string, NodeDef *> nodedef_map; | |||||
nodedef_map.emplace("post_node_d", node_def); | |||||
nodedef_map.emplace("post_node_a", node_def); | |||||
nodedef_map.emplace("post_node_b", node_def); | |||||
std::vector<NodeDef *> nodedef_to_optimize; | |||||
nodedef_to_optimize.emplace_back(node_def); | |||||
std::string curr_node_name = "post_node_b"; | |||||
GenOriginContext(&tensorflow_parser, curr_node_name); | |||||
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); | |||||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||||
} | |||||
TEST_F(UtestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | TEST_F(UtestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | ||||
std::string caseDir = __FILE__; | std::string caseDir = __FILE__; | ||||
std::size_t idx = caseDir.find_last_of("/"); | std::size_t idx = caseDir.find_last_of("/"); | ||||
@@ -3696,7 +3665,8 @@ TEST_F(UtestTensorflowParser, tensorflow_Pb2Json_OneField2Json_test) | |||||
ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); | ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); | ||||
field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM); | field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM); | ||||
mess2Op.ParseField(reflection, node_def, field, depth, ops); | mess2Op.ParseField(reflection, node_def, field, depth, ops); | ||||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str); | |||||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 1); | |||||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 5); | |||||
delete field; | delete field; | ||||
} | } | ||||