@@ -3,6 +3,11 @@ approvers: | |||
- wqtshg | |||
- ljl0711 | |||
- liu-jisheng | |||
- zhangfan_hq | |||
- lipeiyang3699 | |||
reviewers: | |||
- xchu42 | |||
- sheng-nan | |||
- tangqunzhang | |||
- wangxiaotian22 | |||
- stevenaw |
@@ -1 +1 @@ | |||
Subproject commit 0a2335712484f85cd44a0f2402eac6932b22b40a | |||
Subproject commit 8fb59a00c6291207f3491fee0c4064efff94d79f |
@@ -94,7 +94,7 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l | |||
const ge::ParserContext &ctx = GetParserContext(); | |||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||
string name = layer->name(); | |||
auto search = input_dims.find(name); | |||
std::map<std::string, std::vector<int64_t>>::const_iterator search = input_dims.find(name); | |||
if (search == input_dims.end()) { | |||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()})); | |||
GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " | |||
@@ -139,7 +139,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete | |||
const ge::ParserContext &ctx = GetParserContext(); | |||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||
string name = layer->name(); | |||
auto search = input_dims.find(name); | |||
std::map<std::string, std::vector<int64_t>>::const_iterator search = input_dims.find(name); | |||
if (search == input_dims.end()) { | |||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()})); | |||
GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " | |||
@@ -19,6 +19,7 @@ | |||
#include "parser/common/op_parser_factory.h" | |||
#include "common/util/error_manager/error_manager.h" | |||
#include "framework/omg/parser/parser_types.h" | |||
#include "graph/def_types.h" | |||
using namespace ge::parser; | |||
using domi::caffe::BlobProto; | |||
@@ -107,7 +108,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||
for (int i = 0; i < size; ++i) { | |||
buf[i] = proto.double_data(i); | |||
} | |||
GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast<uint8_t *>(buf.get()), size * sizeof(float)) != ge::GRAPH_SUCCESS, | |||
GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<float, uint8_t>(buf.get()), size * sizeof(float)) != ge::GRAPH_SUCCESS, | |||
GELOGW("SetData failed for GeTensor.");); // no need to return | |||
} else if (proto.int8_data().length() > 0) { | |||
if (size != static_cast<int>(proto.int8_data().length())) { | |||
@@ -121,7 +122,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||
const char *data_ptr = proto.int8_data().data(); | |||
GE_CHECK_NOTNULL(data_ptr); | |||
GE_IF_BOOL_EXEC( | |||
weight->SetData(reinterpret_cast<const uint8_t *>(data_ptr), size * sizeof(int8_t)) != ge::GRAPH_SUCCESS, | |||
weight->SetData(PtrToPtr<const char, const uint8_t>(data_ptr), size * sizeof(int8_t)) != ge::GRAPH_SUCCESS, | |||
GELOGW("SetData failed for GeTensor.");); // no need to return | |||
dtype = ge::DT_INT8; | |||
} else if (proto.int32_data_size() > 0) { | |||
@@ -139,7 +140,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||
int32_weight_buf[i] = proto.int32_data(i); | |||
} | |||
GE_IF_BOOL_EXEC( | |||
weight->SetData(reinterpret_cast<uint8_t *>(int32_weight_buf.get()), size * sizeof(int32_t)) != ge::GRAPH_SUCCESS, | |||
weight->SetData(PtrToPtr<int32_t, uint8_t>(int32_weight_buf.get()), size * sizeof(int32_t)) != ge::GRAPH_SUCCESS, | |||
GELOGW("SetData failed for GeTensor.");); // no need to return | |||
dtype = ge::DT_INT32; | |||
} else if (proto.uint64_data_size() > 0) { | |||
@@ -156,7 +157,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||
for (int i = 0; i < size; ++i) { | |||
uint64_weight_buf[i] = proto.uint64_data(i); | |||
} | |||
GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast<uint8_t *>(uint64_weight_buf.get()), size * sizeof(uint64_t)) != | |||
GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<uint64_t, uint8_t>(uint64_weight_buf.get()), size * sizeof(uint64_t)) != | |||
ge::GRAPH_SUCCESS, | |||
GELOGW("SetData failed for GeTensor.");); // no need to return | |||
dtype = ge::DT_UINT64; | |||
@@ -173,7 +174,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||
const float *data_ptr = proto.data().data(); | |||
GE_CHECK_NOTNULL(data_ptr); | |||
GE_IF_BOOL_EXEC( | |||
weight->SetData(reinterpret_cast<const uint8_t *>(data_ptr), size * sizeof(float)) != ge::GRAPH_SUCCESS, | |||
weight->SetData(PtrToPtr<const float, const uint8_t>(data_ptr), size * sizeof(float)) != ge::GRAPH_SUCCESS, | |||
GELOGW("SetData failed for GeTensor.");); // no need to return | |||
} | |||
ge::GeTensorDesc weight_desc = ge::GeTensorDesc(); | |||
@@ -45,7 +45,6 @@ | |||
#include "parser/caffe/caffe_custom_parser_adapter.h" | |||
#include "parser/caffe/caffe_op_parser.h" | |||
#include "parser/common/op_parser_factory.h" | |||
#include "parser/common/pre_checker.h" | |||
#include "parser/common/prototype_pass_manager.h" | |||
#include "framework/omg/parser/parser_types.h" | |||
#include "parser/common/model_saver.h" | |||
@@ -61,13 +60,7 @@ using domi::caffe::InnerProductParameter; | |||
using domi::caffe::LayerParameter; | |||
using domi::caffe::NetParameter; | |||
using domi::ParseParamByOpFunc; | |||
using ge::caffe_op_map; | |||
using ge::CaffeOpParser; | |||
using ge::parser::ModelSaver; | |||
using ge::OpParser; | |||
using ge::OpParserFactory; | |||
using ge::Pb2Json; | |||
using ge::PreChecker; | |||
using std::ifstream; | |||
#define CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(val, errormsg) \ | |||
@@ -299,16 +292,17 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo | |||
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); | |||
return FAILED; | |||
} | |||
int input_dim_size = proto_message.input_dim_size(); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((input_dim_size / proto_message.input_size() != parser::DIM_DEFAULT_SIZE || | |||
input_dim_size % proto_message.input_size() != 0), | |||
ErrorManager::GetInstance().ATCReportErrMessage( | |||
"E11003", {"input_dim_size", "input_size"}, | |||
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); | |||
return FAILED, | |||
"[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", | |||
input_dim_size, proto_message.input_size()) | |||
const int32_t input_dim_size = proto_message.input_dim_size(); | |||
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || | |||
((input_dim_size % proto_message.input_size()) != 0)); | |||
if (is_input_invalid) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, | |||
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); | |||
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", | |||
input_dim_size, proto_message.input_size()); | |||
return FAILED; | |||
} | |||
for (int i = 0; i < proto_message.input_size(); i++) { | |||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
@@ -329,12 +323,14 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo | |||
input_data_flag = true; | |||
} | |||
} else if (proto_message.input_shape_size() > 0) { | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto_message.input_shape_size() != proto_message.input_size(), | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, | |||
{std::to_string(proto_message.input_shape_size()), | |||
std::to_string(proto_message.input_size())}); | |||
return FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", | |||
proto_message.input_shape_size(), proto_message.input_size()); | |||
if (proto_message.input_shape_size() != proto_message.input_size()) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, | |||
{std::to_string(proto_message.input_shape_size()), | |||
std::to_string(proto_message.input_size())}); | |||
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", | |||
proto_message.input_shape_size(), proto_message.input_size()); | |||
return FAILED; | |||
} | |||
for (int i = 0; i < proto_message.input_size(); i++) { | |||
int dim_size = proto_message.input_shape(i).dim_size(); | |||
@@ -755,7 +751,8 @@ Status CaffeModelParser::GetCustomOp(const domi::caffe::LayerParameter &layer, v | |||
} | |||
if (is_search_built_in_layer) { | |||
const google::protobuf::Message *layer_message = reinterpret_cast<const google::protobuf::Message *>(&layer); | |||
const google::protobuf::Message *layer_message = PtrToPtr<const domi::caffe::LayerParameter, | |||
const google::protobuf::Message>(&layer); | |||
Status status = CreateCustomOperator(op_name, op_type, layer_message, 0, operators); | |||
if (status != SUCCESS || operators.empty()) { | |||
GELOGE(status, "[Create][CustomOperator] failed, name: %s, type: %s.", op_name.c_str(), op_type.c_str()); | |||
@@ -838,11 +835,11 @@ Status CaffeModelParser::AddNode(const domi::caffe::LayerParameter &layer, ge::C | |||
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::CAFFE); | |||
GE_CHECK_NOTNULL(factory); | |||
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_parser == nullptr, | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11009", {"opname", "optype"}, | |||
{layer.name(), op_type}); | |||
return FAILED, "op_parser is null, op_type: %s.", | |||
op_type.c_str()); | |||
if (op_parser == nullptr) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11009", {"opname", "optype"}, {layer.name(), op_type}); | |||
GELOGE(FAILED, "op_parser is null, op_type: %s.", op_type.c_str()); | |||
return FAILED; | |||
} | |||
ge::OpDescPtr op; | |||
// Process change of tensordesc initialization of opdesc, | |||
@@ -994,7 +991,7 @@ Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const | |||
GELOGI("op [%s], type[%s], update output(%d) with name %s %s", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||
i, op_desc->GetOutputNameByIndex(i).c_str(), | |||
ret == ge::GRAPH_SUCCESS ? "success" : "failed"); | |||
ret == ge::GRAPH_SUCCESS ? "success" : "not success"); | |||
} | |||
} | |||
return SUCCESS; | |||
@@ -1025,7 +1022,8 @@ Status CaffeModelParser::AddEdges(ge::ComputeGraphPtr &graph) { | |||
// Find the layer for this output | |||
auto top_node_iter = node_map.find(top_blob_layer_pair.first); | |||
// Find the layer for this input | |||
auto bottom_node_iter = node_map.find(bottom_blob_layer_pair.first); | |||
std::map<std::string, ge::NodePtr>::const_iterator bottom_node_iter = | |||
node_map.find(bottom_blob_layer_pair.first); | |||
if (top_node_iter != node_map.end() && bottom_node_iter != node_map.end()) { | |||
// Output node top_node_iter->second, | |||
// Output index top_blob_layer_pair.second | |||
@@ -1057,7 +1055,7 @@ Status CaffeModelParser::AddEdges(ge::ComputeGraphPtr &graph) { | |||
{top_blob_layer_pair.first}); | |||
GELOGE(INTERNAL_ERROR, "[Find][TopLayer] %s failed.", top_blob_layer_pair.first.c_str()); | |||
return ge::FAILED;) | |||
GE_IF_BOOL_EXEC(top_node_iter == node_map.end(), | |||
GE_IF_BOOL_EXEC(bottom_node_iter == node_map.end(), | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11015", {"opname"}, | |||
{bottom_blob_layer_pair.first}); | |||
GELOGE(INTERNAL_ERROR, "[Find][BottomLayer] %s failed.", bottom_blob_layer_pair.first.c_str()); | |||
@@ -1095,7 +1093,7 @@ Status CaffeModelParser::AddUserOutNodesTop() { | |||
const std::vector<std::pair<std::string, int32_t>> &user_out_nodes = ge::GetParserContext().user_out_nodes; | |||
int net_output_num = user_out_nodes.size(); | |||
for (const auto &out_pair : user_out_nodes) { | |||
auto layer_iter = layer_tops_map_.find(out_pair.first); | |||
std::map<std::string, std::vector<std::string>>::const_iterator layer_iter = layer_tops_map_.find(out_pair.first); | |||
GELOGI("Add to output, node name: %s", out_pair.first.c_str()); | |||
if (layer_iter != layer_tops_map_.end()) { | |||
if (static_cast<uint32_t>(out_pair.second) >= (layer_iter->second).size()) { | |||
@@ -1110,7 +1108,7 @@ Status CaffeModelParser::AddUserOutNodesTop() { | |||
} | |||
string top_name = layer_iter->second[out_pair.second]; | |||
auto top_node_iter = node_map.find(out_pair.first); | |||
std::map<std::string, ge::NodePtr>::const_iterator top_node_iter = node_map.find(out_pair.first); | |||
if (top_node_iter != node_map.end()) { | |||
ge::GetParserContext().out_tensor_names.push_back(top_name); | |||
GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str()); | |||
@@ -1142,7 +1140,8 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes | |||
top = RemapTopNameByLayer(layer, top, i); | |||
} | |||
auto t_iter = top_blobs_map_.find(top); | |||
std::map<std::string, std::vector<std::pair<std::string, int32_t>>>::const_iterator t_iter = | |||
top_blobs_map_.find(top); | |||
GE_RETURN_WITH_LOG_IF_FALSE(t_iter != top_blobs_map_.end(), | |||
"[Check][Param]Failed to find top: %s, layer name:%s", top.c_str(), | |||
@@ -1156,7 +1155,7 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes | |||
// If not found, add to the output side of the output | |||
// Find the layer for this output | |||
auto top_node_iter = node_map.find(layer.name()); | |||
std::map<std::string, ge::NodePtr>::const_iterator top_node_iter = node_map.find(layer.name()); | |||
GELOGI("output in top_blob: %s", layer.name().c_str()); | |||
if (top_node_iter != node_map.end()) { | |||
ge::GetParserContext().out_tensor_names.push_back(top_origin); | |||
@@ -1226,12 +1225,14 @@ Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) { | |||
GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&layer, layer.name(), layer.type()), | |||
"[Invoke][AddOp]Add layer to PreChecker failed, layer name: %s.", | |||
layer.name().c_str()); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckName(&layer) != SUCCESS, return FAILED, | |||
"[Invoke][CheckName]Check op[%s] failed, name repeat in caffe prototxt.", | |||
layer.name().c_str()); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckType(&layer) != SUCCESS, return FAILED, | |||
"[Invoke][CheckType]Check op[%s]'s optype failed, type is not supported.", | |||
layer.name().c_str()); | |||
if (PreChecker::Instance().CheckName(&layer) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][CheckName]Check op[%s] failed, name repeat in caffe prototxt.", layer.name().c_str()); | |||
return FAILED; | |||
} | |||
if (PreChecker::Instance().CheckType(&layer) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][CheckType]Check op[%s]'s optype failed, type is not supported.", layer.name().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
return SUCCESS; | |||
@@ -1290,9 +1291,11 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||
for (int32_t layer_index = 0; layer_index < layer_count; ++layer_index) { | |||
domi::caffe::LayerParameter &layer = const_cast<domi::caffe::LayerParameter &>(proto_message.layer(layer_index)); | |||
GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, | |||
"[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.", | |||
layer.name().c_str(), layer.type().c_str()); | |||
if (!CheckValidLayer(layer)) { | |||
GELOGI("[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.", | |||
layer.name().c_str(), layer.type().c_str()); | |||
continue; | |||
} | |||
CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && input_data_flag), has_error = true; | |||
REPORT_INNER_ERROR("E19999", "net %s has input and data layer simultaneously, check invalid." | |||
@@ -1392,7 +1395,7 @@ void CaffeModelParser::SaveOrigionLayerTops(domi::caffe::LayerParameter &layer) | |||
for (auto top : layer.top()) { | |||
tops.push_back(top); | |||
} | |||
auto it = layer_tops_map_.find(name); | |||
std::map<std::string, std::vector<std::string>>::const_iterator it = layer_tops_map_.find(name); | |||
if (it == layer_tops_map_.end()) { | |||
layer_tops_map_[name] = tops; | |||
} | |||
@@ -1431,11 +1434,23 @@ Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &la | |||
return SUCCESS; | |||
} | |||
Status CaffeModelParser::ReportLayerInvalid(const domi::caffe::NetParameter &proto, const std::string &path) const { | |||
if (proto.layers_size() > 0) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11021", {"realpath"}, {path}); | |||
GELOGE(FAILED, "[Check][Size]The model file[%s] is consisted of layers-structure which is deprecated in Caffe " | |||
"and unsupported in ATC. The \"layers\" should be changed to \"layer\".", path.c_str()); | |||
} else { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11022"); | |||
GELOGE(FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid."); | |||
} | |||
return FAILED; | |||
} | |||
Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &graph) { | |||
bool has_error = false; | |||
GE_CHECK_NOTNULL(model_path); | |||
GE_CHECK_NOTNULL(graph); | |||
GELOGI("Caffe Parse model file %s", model_path); | |||
GELOGI("Caffe Parse model file [%s]", model_path); | |||
PreChecker::Instance().Clear(); | |||
@@ -1450,22 +1465,12 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||
// parse network model by custom proto and get custom operators | |||
string custom_proto_path = ge::GetParserContext().custom_proto_path + "custom.proto"; | |||
string caffe_proto_path = ge::GetParserContext().caffe_proto_path + "caffe.proto"; | |||
Status result = CustomProtoParse(model_path, custom_proto_path, caffe_proto_path, custom_operator_); | |||
if (result != SUCCESS) { | |||
GELOGE(FAILED, "[Parse][Model] by custom proto failed, model path: %s.", model_path); | |||
return FAILED; | |||
} | |||
GE_CHK_STATUS(CustomProtoParse(model_path, custom_proto_path, caffe_proto_path, custom_operator_), | |||
"[Parse][Model] by custom proto failed, model path: %s.", model_path); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
proto_message.layer_size() == 0 && proto_message.layers_size() > 0, | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11021", {"realpath"}, {model_path}); | |||
return FAILED, | |||
"[Check][Size]The model file[%s] is consisted of layers-structure which is deprecated in Caffe " | |||
"and unsupported in ATC. The \"layers\" should be changed to \"layer\".", | |||
model_path); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto_message.layer_size() == 0), | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11022"); | |||
return FAILED, "[Check][Size]net layer num is zero, prototxt file may be invalid."); | |||
if (proto_message.layer_size() == 0) { | |||
return ReportLayerInvalid(proto_message, model_path); | |||
} | |||
GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&proto_message, domi::CAFFE), | |||
"Run ProtoType Pass Failed"); | |||
@@ -1476,8 +1481,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||
GE_RETURN_IF_ERROR(PreCheck(proto_message)); | |||
if (PreChecker::Instance().HasError()) { | |||
REPORT_INNER_ERROR("E19999", "Precheck failed. Please read check report."); | |||
GELOGE(INTERNAL_ERROR, "[Has][Error]Precheck failed. Please read check report."); | |||
REPORT_INNER_ERROR("E19999", "Precheck failed. a report of json format will be create, Please read it."); | |||
GELOGE(INTERNAL_ERROR, "[Has][Error]Precheck failed. a report of json format will be create, Please read it."); | |||
return FAILED; | |||
} | |||
@@ -1512,9 +1517,11 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||
for (int32_t layer_index = 0; layer_index < layer_count; ++layer_index) { | |||
domi::caffe::LayerParameter &layer = const_cast<domi::caffe::LayerParameter &>(proto_message.layer(layer_index)); | |||
SaveOrigionLayerTops(layer); | |||
GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, | |||
"[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.", | |||
layer.name().c_str(), layer.type().c_str()); | |||
if (!CheckValidLayer(layer)) { | |||
GELOGI("[Check][Layer]layer phase is train, skip this layer, name:%s, type:%s.", | |||
layer.name().c_str(), layer.type().c_str()); | |||
continue; | |||
} | |||
CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && input_data_flag), has_error = true; | |||
GELOGE(FAILED, "[Check][Layer]net %s has input and data layer simultaneously, check invalid." | |||
@@ -1679,7 +1686,7 @@ Status CaffeWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge:: | |||
// Resolve proto file to netparameter | |||
NetParameter proto; | |||
bool success = ge::parser::ReadProtoFromArray(reinterpret_cast<const char *>(data), static_cast<int>(size), &proto); | |||
bool success = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &proto); | |||
if (!success) { | |||
REPORT_CALL_ERROR("E19999", "ReadProtoFromArray failed."); | |||
GELOGE(domi::PARSE_WEIGHTS_FAILED, "[Read][Proto] from Memory fail"); | |||
@@ -1920,7 +1927,7 @@ Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *r | |||
const google::protobuf::FieldDescriptor *field, | |||
google::protobuf::Message *layer) { | |||
GELOGD("Start to parse field: %s.", field->name().c_str()); | |||
domi::caffe::LayerParameter *layer_proto = reinterpret_cast<domi::caffe::LayerParameter *>(layer); | |||
domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer); | |||
string filed_name = field->name(); | |||
#define CASE_FIELD_NAME(kName, method) \ | |||
if (filed_name == kField##kName) { \ | |||
@@ -1975,8 +1982,7 @@ Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *me | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); | |||
vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
blobs_reflection->ListFields(*message, &field_desc); | |||
domi::caffe::BlobProto *blobs_proto = reinterpret_cast<domi::caffe::BlobProto *>(blobs); | |||
domi::caffe::BlobProto *blobs_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobProto>(blobs); | |||
for (auto &field : field_desc) { | |||
GE_CHECK_NOTNULL(field); | |||
@@ -2025,7 +2031,7 @@ Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message | |||
vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
reflection->ListFields(*message, &field_desc); | |||
domi::caffe::BlobShape *shape_proto = reinterpret_cast<domi::caffe::BlobShape *>(dest_message); | |||
domi::caffe::BlobShape *shape_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobShape>(dest_message); | |||
for (auto &field : field_desc) { | |||
if (field->name() != kFieldDim) { | |||
@@ -2048,7 +2054,7 @@ Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message | |||
reflection->ListFields(*message, &field_desc); | |||
domi::caffe::ConvolutionParameter *conv_param_proto = | |||
reinterpret_cast<domi::caffe::ConvolutionParameter *>(dest_message); | |||
PtrToPtr<google::protobuf::Message, domi::caffe::ConvolutionParameter>(dest_message); | |||
for (auto &field : field_desc) { | |||
if (field->name() != kFieldBiasTerm) { | |||
@@ -2068,7 +2074,7 @@ Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Mess | |||
reflection->ListFields(*message, &field_desc); | |||
domi::caffe::InnerProductParameter *inner_product_proto = | |||
reinterpret_cast<domi::caffe::InnerProductParameter *>(dest_message); | |||
PtrToPtr<google::protobuf::Message, domi::caffe::InnerProductParameter>(dest_message); | |||
for (auto &field : field_desc) { | |||
if (field->name() != kFieldBiasTerm) { | |||
@@ -2110,22 +2116,25 @@ Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *mess | |||
} | |||
} | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(num_layer == 0 && num_layers > 0, | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11023"); | |||
return FAILED, | |||
"[Check][Param]The weight file is consisted of layers-structure which is deprecated " | |||
"in Caffe and unsupported in ATC. The \"layers\" should be changed to \"layer\"."); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((num_layer == 0), ErrorManager::GetInstance().ATCReportErrMessage("E11024"); | |||
return FAILED, | |||
"[Check][Param] Weight layer num is zero, weight file may be invalid."); | |||
if (num_layer == 0 && num_layers > 0) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11023"); | |||
GELOGE(FAILED, "[Check][Param]The weight file is consisted of layers-structure which is deprecated " | |||
"in Caffe and unsupported in ATC. The \"layers\" should be changed to \"layer\"."); | |||
return FAILED; | |||
} | |||
if (num_layer == 0) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11024"); | |||
GELOGE(FAILED, "[Check][Param] Weight layer num is zero, weight file may be invalid."); | |||
return FAILED; | |||
} | |||
return SUCCESS; | |||
} | |||
Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message *layer_message, | |||
ge::ComputeGraphPtr &graph) { | |||
vector<string> need_share_layers; | |||
const domi::caffe::LayerParameter *layer = reinterpret_cast<const domi::caffe::LayerParameter *>(layer_message); | |||
const domi::caffe::LayerParameter *layer = | |||
PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer_message); | |||
const string &shared_layer_name = layer->name(); | |||
const string &layer_type = layer->type(); | |||
for (auto p_iter = params_share_map.begin(); p_iter != params_share_map.end(); ++p_iter) { | |||
@@ -2159,7 +2168,7 @@ Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message | |||
} | |||
// The weight processing also needs to judge the duplicate operator, which is reserved here and processed later. | |||
auto iter = caffe_op_map.find(layer_type); | |||
std::map<std::string, std::string>::const_iterator iter = caffe_op_map.find(layer_type); | |||
if (iter == caffe_op_map.end()) { | |||
GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer_type.c_str(), layer_name.c_str()); | |||
continue; | |||
@@ -2172,20 +2181,20 @@ Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message | |||
GE_CHECK_NOTNULL(factory); | |||
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
(op_parser.get() == nullptr), | |||
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}), | |||
std::vector<std::string>({layer_name, op_type})); | |||
return FAILED, | |||
"[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str()); | |||
if (op_parser.get() == nullptr) { | |||
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}), | |||
std::vector<std::string>({layer_name, op_type})); | |||
GELOGE(FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str()); | |||
return FAILED; | |||
} | |||
// Parsing weight information through op parser | |||
Status status = op_parser->ParseWeights(layer_message, node); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
(status != SUCCESS), | |||
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str()); | |||
return status, | |||
"[Parse][Weights] for op[%s] failed", layer_name.c_str()); | |||
if (status != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str()); | |||
GELOGE(FAILED, "[Parse][Weights] for op[%s] failed", layer_name.c_str()); | |||
return status; | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -2233,13 +2242,18 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter ¶m, ge::Co | |||
// Operator name and occurrence map, handle duplicate operators | |||
std::map<std::string, int32_t> layer_name_map; | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(num_layer == 0 && num_layers > 0, | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11023"); | |||
return FAILED, "[Check][Param] The weight file is consisted of layers-structure " | |||
"which is deprecated in Caffe and unsupported in ATC. " | |||
"The \"layers\" should be changed to \"layer\"."); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((num_layer == 0), ErrorManager::GetInstance().ATCReportErrMessage("E11024"); | |||
return FAILED, "weight layer num is zero, weight file may be invalid."); | |||
if (num_layer == 0 && num_layers > 0) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11023"); | |||
GELOGE(FAILED, "[Check][Param] The weight file is consisted of layers-structure " | |||
"which is deprecated in Caffe and unsupported in ATC. " | |||
"The \"layers\" should be changed to \"layer\"."); | |||
return FAILED; | |||
} | |||
if (num_layer == 0) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11024"); | |||
GELOGE(FAILED, "weight layer num is zero, weight file may be invalid."); | |||
return FAILED; | |||
} | |||
for (int i = 0; i < num_layer; ++i) { | |||
const LayerParameter &layer = param.layer(i); | |||
@@ -2285,7 +2299,7 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter ¶m, ge::Co | |||
} | |||
// The weight processing also needs to judge the duplicate operator, which is reserved here and processed later. | |||
auto iter = caffe_op_map.find(layer.type()); | |||
std::map<std::string, std::string>::const_iterator iter = caffe_op_map.find(layer.type()); | |||
if (iter == caffe_op_map.end()) { | |||
GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer.type().c_str(), layer_name.c_str()); | |||
continue; | |||
@@ -2298,18 +2312,20 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter ¶m, ge::Co | |||
GE_CHECK_NOTNULL(factory); | |||
std::shared_ptr<OpParser> op_parser = factory->CreateOpParser(op_type); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
(op_parser.get() == nullptr), | |||
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}), | |||
std::vector<std::string>({layer_name, op_type})); | |||
return FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str()); | |||
if (op_parser.get() == nullptr) { | |||
REPORT_INPUT_ERROR("E11009", std::vector<std::string>({"opname", "optype"}), | |||
std::vector<std::string>({layer_name, op_type})); | |||
GELOGE(FAILED, "[Create][OpParser] failed for Op[%s], optype is %s", layer_name.c_str(), op_type.c_str()); | |||
return FAILED; | |||
} | |||
// Parsing weight information through op parser | |||
Status status = op_parser->ParseWeights(&layer, node); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
(status != SUCCESS), | |||
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str()); | |||
return status, "[Parse][Weights] for op[%s] failed", layer_name.c_str()); | |||
if (status != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Parse weight for op:%s(%s) failed", layer_name.c_str(), op_type.c_str()); | |||
GELOGE(FAILED, "[Parse][Weights] for op[%s] failed", layer_name.c_str()); | |||
return status; | |||
} | |||
} | |||
} | |||
@@ -40,6 +40,7 @@ | |||
#include "omg/parser/op_parser.h" | |||
#include "omg/parser/model_parser.h" | |||
#include "omg/parser/weights_parser.h" | |||
#include "common/pre_checker.h" | |||
#include "proto/caffe/caffe.pb.h" | |||
#include "proto/om.pb.h" | |||
@@ -123,6 +124,17 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||
return domi::SUCCESS; | |||
} | |||
bool HasError() override { | |||
return PreChecker::Instance().HasError(); | |||
} | |||
Status Save(const string &file) override { | |||
return PreChecker::Instance().Save(file); | |||
} | |||
void Clear() override { | |||
PreChecker::Instance().Clear(); | |||
} | |||
private: | |||
Status Parse(const char *model_path, ge::ComputeGraphPtr &graph); | |||
@@ -313,6 +325,8 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||
Status SaveDataLayerTops(const domi::caffe::LayerParameter &layer); | |||
Status ReportLayerInvalid(const domi::caffe::NetParameter &proto, const std::string &path) const; | |||
std::map<std::string, ge::NodePtr> node_map; | |||
// key: blob name, value: layer name and index | |||
@@ -344,6 +358,18 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser { | |||
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | |||
bool HasError() override { | |||
return PreChecker::Instance().HasError(); | |||
} | |||
Status Save(const string &file) override { | |||
return PreChecker::Instance().Save(file); | |||
} | |||
void Clear() override { | |||
PreChecker::Instance().Clear(); | |||
} | |||
private: | |||
Status CheckNodes(ge::ComputeGraphPtr &graph); | |||
/** | |||
@@ -128,7 +128,7 @@ Status CaffeReshapeParser::AddConstInput(ge::NodePtr &node) { | |||
data[i] = attr_shape[i]; | |||
} | |||
GE_IF_BOOL_EXEC( | |||
constTensor->SetData(reinterpret_cast<uint8_t *>(data.get()), dims_size * sizeof(int64_t)) != ge::GRAPH_SUCCESS, | |||
constTensor->SetData(PtrToPtr<int64_t, uint8_t>(data.get()), dims_size * sizeof(int64_t)) != ge::GRAPH_SUCCESS, | |||
GELOGW("SetData failed for GeTensor.");); // no need to return | |||
// construct const node and add edge | |||
@@ -492,7 +492,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, | |||
} | |||
domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | |||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) { | |||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) { | |||
std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes; | |||
if (!default_out_nodes.empty()) { | |||
for (size_t i = 0; i < default_out_nodes.size(); ++i) { | |||
@@ -613,24 +613,27 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin | |||
ge::GetParserContext().out_tensor_names.clear(); | |||
ge::GetParserContext().data_tensor_names.clear(); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS, | |||
return PARAM_INVALID, "[Check][Options] Parse paragrams invalid, graph:%s.", | |||
graph_name.c_str()); | |||
if (CheckOptions(parser_params) != SUCCESS) { | |||
GELOGE(FAILED, "[Check][Options] Parse paragrams invalid, graph:%s.", graph_name.c_str()); | |||
return PARAM_INVALID; | |||
} | |||
// support paragrams: out_nodes, is_output_adjust_hw_layout, output, enable_scope_fusion_passes | |||
SetDefaultFormat(); | |||
string out_nodes; | |||
GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputNodes(out_nodes) != SUCCESS, | |||
return PARAM_INVALID, | |||
"[Invoke][ParseAclOutputNodes] Parse out_nodes failed, graph:%s.", graph_name.c_str()); | |||
if (ParseAclOutputNodes(out_nodes) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][ParseAclOutputNodes] Parse out_nodes failed, graph:%s.", graph_name.c_str()); | |||
return PARAM_INVALID; | |||
} | |||
string is_output_adjust_hw_layout; | |||
GetAclParams(parser_params, ge::ir_option::IS_OUTPUT_ADJUST_HW_LAYOUT, is_output_adjust_hw_layout); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS, | |||
return PARAM_INVALID, | |||
"[Invoke][ParseAclOutputFp16NodesFormat] Parse is_output_adjust_hw_layout failed, " | |||
"graph:%s.", graph_name.c_str()); | |||
if (ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][ParseAclOutputFp16NodesFormat] Parse is_output_adjust_hw_layout failed, graph:%s.", | |||
graph_name.c_str()); | |||
return PARAM_INVALID; | |||
} | |||
string tmp_name; | |||
GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name); | |||
@@ -638,10 +641,11 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin | |||
string enable_scope_fusion_passes; | |||
GetAclParams(parser_params, ge::ir_option::ENABLE_SCOPE_FUSION_PASSES, enable_scope_fusion_passes); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclEnableScope(enable_scope_fusion_passes) != SUCCESS, | |||
return PARAM_INVALID, | |||
"[Invoke][ParseAclEnableScope] Parse enable_scope_fusion_passes failed, graph:%s.", | |||
graph_name.c_str()); | |||
if (ParseAclEnableScope(enable_scope_fusion_passes) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][ParseAclEnableScope] Parse enable_scope_fusion_passes failed, graph:%s.", | |||
graph_name.c_str()); | |||
return PARAM_INVALID; | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -657,10 +661,11 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | |||
string is_input_adjust_hw_layout; | |||
GetAclParams(parser_params, ge::ir_option::IS_INPUT_ADJUST_HW_LAYOUT, is_input_adjust_hw_layout); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS, | |||
return PARAM_INVALID, "[Invoke][ParseAclInputFp16Nodes] Parse input_fp16_nodes failed, graph:%s", | |||
compute_graph->GetName().c_str()); | |||
if (ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][ParseAclInputFp16Nodes] Parse input_fp16_nodes failed, graph:%s", | |||
compute_graph->GetName().c_str()); | |||
return PARAM_INVALID; | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -689,30 +694,35 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char | |||
// Get file length | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) { | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(), | |||
REPORT_INNER_ERROR("E19999", "input_file path is null, check invalid."); | |||
return -1, "[Check][Param] input_file path is null."); | |||
if (input_file.empty()) { | |||
REPORT_INNER_ERROR("E19999", "input_file path is null, check invalid."); | |||
GELOGE(FAILED, "[Check][Param] input_file path is null."); | |||
return -1; | |||
} | |||
std::string real_path = RealPath(input_file.c_str()); | |||
char_t err_buf[kMaxErrStrLen + 1U] = {}; | |||
const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||
REPORT_INPUT_ERROR("E19000", std::vector<std::string>({"path", "errmsg"}), | |||
std::vector<std::string>({real_path, err_msg})); | |||
return -1, "[Get][Path] input_file path '%s' not valid", input_file.c_str()); | |||
if (real_path.empty()) { | |||
REPORT_INPUT_ERROR("E19000", std::vector<std::string>({"path", "errmsg"}), | |||
std::vector<std::string>({real_path, err_msg})); | |||
GELOGE(FAILED, "[Get][Path] input_file path '%s' not valid", input_file.c_str()); | |||
return -1; | |||
} | |||
unsigned long long file_length = 0; | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, | |||
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, | |||
{input_file, err_msg}); | |||
return -1, "[Open][File] [%s] failed. %s", input_file.c_str(), err_msg); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0 || file_length > kMaxFileSizeLimit), | |||
REPORT_INPUT_ERROR( | |||
"E19015", std::vector<std::string>({"file", "size", "maxsize"}), | |||
std::vector<std::string>({input_file, std::to_string(file_length), | |||
std::to_string(kMaxFileSizeLimit)})); | |||
return -1, "[Check][Param] File[%s] size %lld is out of range(0,%d).", | |||
input_file.c_str(), file_length, kMaxFileSizeLimit); | |||
if (mmGetFileSize(input_file.c_str(), &file_length) != EN_OK) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, err_msg}); | |||
GELOGE(FAILED, "[Open][File] [%s] failed. %s", input_file.c_str(), err_msg); | |||
return -1; | |||
} | |||
if ((file_length == 0) || (file_length > kMaxFileSizeLimit)) { | |||
REPORT_INPUT_ERROR("E19015", std::vector<std::string>({ "file", "size", "maxsize" }), | |||
std::vector<std::string>({ input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit) })); | |||
GELOGE(FAILED, "[Check][Param] File[%s] size %lld is out of range(0,%d).", | |||
input_file.c_str(), file_length, kMaxFileSizeLimit); | |||
return -1; | |||
} | |||
return static_cast<long>(file_length); | |||
} | |||
@@ -725,9 +735,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() | |||
} | |||
static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr, | |||
REPORT_INNER_ERROR("E19999", "param proto is nullptr, check invalid"); | |||
return false, "[Check][Param] incorrect parameter. nullptr == proto"); | |||
if (proto == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "param proto is nullptr, check invalid"); | |||
GELOGE(FAILED, "[Check][Param] incorrect parameter. nullptr == proto"); | |||
return false; | |||
} | |||
coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit); | |||
return proto->ParseFromCodedStream(&coded_stream); | |||
@@ -743,17 +755,23 @@ static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Messag | |||
*/ | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, | |||
int &length) { | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), | |||
REPORT_INNER_ERROR("E19999", "param file_name is nullptr, check invalid"); | |||
return false, "[Check][Param] incorrect parameter. file is nullptr"); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), | |||
REPORT_INNER_ERROR("E19999", "param buffer is nullptr, check invalid"); | |||
return false, "[Check][Param] incorrect parameter. buffer is nullptr"); | |||
if (file_name == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "param file_name is nullptr, check invalid"); | |||
GELOGE(FAILED, "[Check][Param] incorrect parameter. file is nullptr"); | |||
return false; | |||
} | |||
if (buffer == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "param buffer is nullptr, check invalid"); | |||
GELOGE(FAILED, "[Check][Param] incorrect parameter. buffer is nullptr"); | |||
return false; | |||
} | |||
std::string real_path = RealPath(file_name); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file_name); | |||
return false, "[Check][Param]file path '%s' not valid, realpath failed", file_name); | |||
if (real_path.empty()) { | |||
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file_name); | |||
GELOGE(FAILED, "[Check][Param]file path '%s' not valid, realpath failed", file_name); | |||
return false; | |||
} | |||
std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate); | |||
if (!file.is_open()) { | |||
@@ -763,16 +781,22 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co | |||
} | |||
length = static_cast<int>(file.tellg()); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((length <= 0), file.close(); REPORT_INNER_ERROR("E19999", "file length <= 0"); | |||
return false, "[Check][Param] file length <= 0"); | |||
if ((length <= 0)) { | |||
file.close(); | |||
REPORT_INNER_ERROR("E19999", "file length <= 0"); | |||
GELOGE(FAILED, "[Check][Param] file length <= 0"); | |||
return false; | |||
} | |||
file.seekg(0, std::ios::beg); | |||
*buffer = new(std::nothrow) char[length](); | |||
GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, file.close(); | |||
REPORT_CALL_ERROR("E19999", "new an object failed."), | |||
"[Create][Buffer] new an object failed."); | |||
if (*buffer == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "[Create][Buffer] new an object failed, length=%d.", length); | |||
GELOGE(FAILED, "[Create][Buffer] new an object failed, length=%d.", length); | |||
file.close(); | |||
return false; | |||
} | |||
file.read(*buffer, length); | |||
file.close(); | |||
@@ -780,16 +804,23 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co | |||
} | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) { | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr), | |||
REPORT_INNER_ERROR("E19999", "param file or proto is nullptr, check invalid"); | |||
return false, "[Check][Param] Input parameter file or proto is nullptr!"); | |||
if ((file == nullptr) || (proto == nullptr)) { | |||
REPORT_INNER_ERROR("E19999", "param file or proto is nullptr, check invalid"); | |||
GELOGE(FAILED, "[Check][Param] Input parameter file or proto is nullptr!"); | |||
return false; | |||
} | |||
std::string real_path = RealPath(file); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file); | |||
return false, "[Check][Param]pb file path '%s' not valid, realpath failed", file); | |||
if (real_path.empty()) { | |||
REPORT_INNER_ERROR("E19999", "file path '%s' not valid, realpath failed", file); | |||
GELOGE(FAILED, "[Check][Param]pb file path '%s' not valid, realpath failed", file); | |||
return false; | |||
} | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "[Get][FileLength]file size not valid."); | |||
if (GetFileLength(real_path) == -1) { | |||
GELOGE(FAILED, "[Get][FileLength]file size not valid."); | |||
return false; | |||
} | |||
std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); | |||
if (!fs.is_open()) { | |||
@@ -815,32 +846,37 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co | |||
} | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) { | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), | |||
REPORT_INNER_ERROR("E19999", "param proto or data is nullptr " | |||
"or size is 0, check invalid"); return false, | |||
"[Check][Param]incorrect parameter. proto is nullptr || data is nullptr || size is 0"); | |||
if ((proto == nullptr) || (data == nullptr) || (size == 0)) { | |||
REPORT_INNER_ERROR("E19999", "param proto or data is nullptr or size is 0, check invalid"); | |||
GELOGE(FAILED, "[Check][Param]incorrect parameter. proto is nullptr || data is nullptr || size is 0"); | |||
return false; | |||
} | |||
google::protobuf::io::CodedInputStream coded_stream(reinterpret_cast<uint8_t *>(const_cast<void *>(data)), size); | |||
google::protobuf::io::CodedInputStream coded_stream(PtrToPtr<void, uint8_t>(const_cast<void *>(data)), size); | |||
return ReadProtoFromCodedInputStream(coded_stream, proto); | |||
} | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file, | |||
google::protobuf::Message *message) { | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), | |||
REPORT_INNER_ERROR("E19999", "param file or message is nullptr, check invalid"); | |||
return false, | |||
"[Check][Param]incorrect parameter. nullptr == file || nullptr == message"); | |||
if ((file == nullptr) || (message == nullptr)) { | |||
REPORT_INNER_ERROR("E19999", "param file or message is nullptr, check invalid"); | |||
GELOGE(FAILED, "[Check][Param]incorrect parameter. nullptr == file || nullptr == message"); | |||
return false; | |||
} | |||
std::string real_path = RealPath(file); | |||
char_t err_buf[kMaxErrStrLen + 1U] = {}; | |||
const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, | |||
{file, err_msg}); | |||
return false, "[Check][Param]Path[%s]'s realpath is empty, errmsg[%s]", file, | |||
err_msg); | |||
if (real_path.empty()) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {file, err_msg}); | |||
GELOGE(FAILED, "[Check][Param]Path[%s]'s realpath is empty, errmsg[%s]", file, err_msg); | |||
return false; | |||
} | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "[Check][Param] file size not valid."); | |||
if (GetFileLength(real_path) == -1) { | |||
GELOGE(FAILED, "[Check][Param] file size not valid."); | |||
return false; | |||
} | |||
std::ifstream fs(real_path.c_str(), std::ifstream::in); | |||
@@ -863,10 +899,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size, | |||
google::protobuf::Message *message) { | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), | |||
REPORT_INNER_ERROR("E19999", "param data or message is nullptr,check invalid"); | |||
return false, | |||
"[Check][Param] incorrect parameter. data is nullptr || message is nullptr"); | |||
if ((data == nullptr) || (message == nullptr)) { | |||
REPORT_INNER_ERROR("E19999", "param data or message is nullptr,check invalid"); | |||
GELOGE(FAILED, "[Check][Param] incorrect parameter. data is nullptr || message is nullptr"); | |||
return false; | |||
} | |||
std::string str(data, static_cast<size_t>(size)); | |||
std::istringstream fs(str); | |||
@@ -901,7 +938,7 @@ Status GetOriginalType(const ge::NodePtr &node, string &type) { | |||
return SUCCESS; | |||
} | |||
FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) { | |||
FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &filePath, const std::string &mode) { | |||
char ebuff[kMaxBuffSize]; | |||
regex_t reg; | |||
int cflags = REG_EXTENDED | REG_NOSUB; | |||
@@ -913,7 +950,7 @@ FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::str | |||
return true; | |||
} | |||
ret = regexec(®, str.c_str(), 0, nullptr, 0); | |||
ret = regexec(®, filePath.c_str(), 0, nullptr, 0); | |||
if (ret) { | |||
regerror(ret, ®, ebuff, kMaxBuffSize); | |||
GELOGE(ge::PARAM_INVALID, "[Invoke][RegExec] failed, reason: %s", ebuff); | |||
@@ -21,11 +21,11 @@ | |||
#include "graph/op_desc.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/debug/ge_util.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "register/register_fmk_types.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "framework/common/util.h" | |||
namespace ge { | |||
namespace { | |||
@@ -31,11 +31,17 @@ using std::string; | |||
namespace ge { | |||
namespace { | |||
const int kSignificantDigits = 10; | |||
const int kMaxParseDepth = 20; | |||
} | |||
// JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message, | |||
const set<string> &black_fields, Json &json, | |||
bool enum2str) { | |||
bool enum2str, int depth) { | |||
if (depth > kMaxParseDepth) { | |||
REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); | |||
GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); | |||
return; | |||
} | |||
auto descriptor = message.GetDescriptor(); | |||
auto reflection = message.GetReflection(); | |||
if (descriptor == nullptr || reflection == nullptr) { | |||
@@ -57,7 +63,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons | |||
if (field->is_repeated()) { | |||
if (reflection->FieldSize(message, field) > 0) { | |||
RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str); | |||
RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str, depth); | |||
} | |||
continue; | |||
} | |||
@@ -66,18 +72,18 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons | |||
continue; | |||
} | |||
OneField2Json(message, field, reflection, black_fields, json, enum2str); | |||
OneField2Json(message, field, reflection, black_fields, json, enum2str, depth); | |||
} | |||
} | |||
void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | |||
bool enum2str) { | |||
bool enum2str, int depth) { | |||
switch (field->type()) { | |||
case ProtobufFieldDescriptor::TYPE_MESSAGE: { | |||
const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); | |||
if (0UL != tmp_message.ByteSizeLong()) { | |||
Message2Json(tmp_message, black_fields, json[field->name()], enum2str); | |||
Message2Json(tmp_message, black_fields, json[field->name()], enum2str, depth + 1); | |||
} | |||
break; | |||
} | |||
@@ -163,9 +169,9 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { | |||
void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | |||
bool enum2str) { | |||
bool enum2str, int depth) { | |||
if ((field == nullptr) || (reflection == nullptr)) { | |||
Message2Json(message, black_fields, json, enum2str); | |||
Message2Json(message, black_fields, json, enum2str, depth + 1); | |||
return; | |||
} | |||
@@ -175,7 +181,7 @@ void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFie | |||
case ProtobufFieldDescriptor::TYPE_MESSAGE: { | |||
const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); | |||
if (0UL != tmp_message.ByteSizeLong()) { | |||
Message2Json(tmp_message, black_fields, tmp_json, enum2str); | |||
Message2Json(tmp_message, black_fields, tmp_json, enum2str, depth + 1); | |||
} | |||
} break; | |||
@@ -45,11 +45,11 @@ class Pb2Json { | |||
* @author | |||
*/ | |||
static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | |||
bool enum2str = false); | |||
bool enum2str = false, int depth = 0); | |||
static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||
const ProtobufReflection *reflection, const std::set<std::string> &black_fields, | |||
Json &json, bool enum2str); | |||
Json &json, bool enum2str, int depth = 0); | |||
protected: | |||
static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, | |||
@@ -59,7 +59,7 @@ class Pb2Json { | |||
static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||
const ProtobufReflection *reflection, const std::set<std::string> &black_fields, Json &json, | |||
bool enum2str); | |||
bool enum2str, int depth); | |||
static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes); | |||
}; | |||
@@ -60,7 +60,7 @@ class DataOpParser { | |||
* @param [in] 4D shape information (dimensions) | |||
* @param [out] Save converted shap information | |||
*/ | |||
static Status Init5DInputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &tensorDesc); | |||
static Status Init5DInputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &tensor_desc); | |||
/** | |||
* @ingroup domi_omg | |||
@@ -98,7 +98,7 @@ class DataOpParser { | |||
* @return SUCCESS Convert success | |||
* @return FAILED Convert failed | |||
*/ | |||
static Status InitNDTensor(const std::vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &desc); | |||
static Status InitNDTensor(const std::vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &tensor_desc); | |||
}; | |||
} // namespace ge | |||
@@ -55,9 +55,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi | |||
} | |||
char real_path[PATH_MAX] = {0}; | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path) >= PATH_MAX, | |||
REPORT_INNER_ERROR("E19999", "file path %s is too long!", file_path); | |||
return FAILED, "[Check][Param] file path %s is too long!", file_path); | |||
if (strlen(file_path) >= PATH_MAX) { | |||
REPORT_INNER_ERROR("E19999", "file path %s is too long!", file_path); | |||
GELOGE(FAILED, "[Check][Param] file path %s is too long!", file_path); | |||
return FAILED; | |||
} | |||
if (realpath(file_path, real_path) == nullptr) { | |||
GELOGI("File %s does not exit, it will be created.", file_path); | |||
} | |||
@@ -32,9 +32,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOpera | |||
} | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::DType(ge::DataType t) { | |||
Attr(VAR_ATTR_DTYPE, (int64_t)t); | |||
Attr(VAR_ATTR_DTYPE, static_cast<int64_t>(t)); | |||
return *this; | |||
} | |||
ge::DataType ConstantOperator::GetDType() const { return (ge::DataType)GetIntAttr(VAR_ATTR_DTYPE); } | |||
ge::DataType ConstantOperator::GetDType() const { return static_cast<ge::DataType>(GetIntAttr(VAR_ATTR_DTYPE)); } | |||
} // namespace ge |
@@ -32,7 +32,7 @@ static void ConvertList(const std::pair<std::string, OpAttribute> &op_attr_pair, | |||
vector<int64_t> v_i; | |||
for (int32_t i = 0; i < a_list.i_size(); i++) { | |||
v_i.push_back((int64_t)a_list.i(i)); | |||
v_i.push_back(static_cast<int64_t>(a_list.i(i))); | |||
} | |||
if (v_i.size() > 0) { | |||
(void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_i); | |||
@@ -56,7 +56,7 @@ static void ConvertList(const std::pair<std::string, OpAttribute> &op_attr_pair, | |||
} | |||
vector<int32_t> v_u; | |||
for (int32_t i = 0; i < a_list.u_size(); i++) { | |||
v_u.push_back((int32_t)a_list.u(i)); | |||
v_u.push_back(static_cast<int32_t>(a_list.u(i))); | |||
} | |||
if (v_u.size() > 0) { | |||
(void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_u); | |||
@@ -114,7 +114,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertToOpDesc(co | |||
if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kBt) { | |||
auto &buffer = op_attr_pair.second.value_.bt(); | |||
(void)ge::AttrUtils::SetZeroCopyBytes(op_def, op_attr_pair.first, | |||
ge::Buffer::CopyFrom(reinterpret_cast<uint8_t *>(const_cast<char *>(buffer.data())), buffer.size())); | |||
ge::Buffer::CopyFrom(PtrToPtr<void, uint8_t>(const_cast<char *>(buffer.data())), buffer.size())); | |||
} | |||
if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kS) { | |||
@@ -23,7 +23,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::RefSwitchOpe | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::~RefSwitchOperator() {} | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator &RefSwitchOperator::T(ge::DataType t) { | |||
Attr("T", (int64_t)t); | |||
Attr("T", static_cast<int64_t>(t)); | |||
return *this; | |||
} | |||
} // namespace ge AUTO GEN PLEASE DO NOT MODIFY IT |
@@ -32,20 +32,20 @@ FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::N(int64_t n) { | |||
FMK_FUNC_HOST_VISIBILITY int64_t ShapeNOperator::GetN() const { return GetIntAttr(SHAPEN_ATTR_N); } | |||
FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::InType(ge::DataType t) { | |||
Attr(SHAPEN_ATTR_IN_TYPE, (int64_t)t); | |||
Attr(SHAPEN_ATTR_IN_TYPE, static_cast<int64_t>(t)); | |||
return *this; | |||
} | |||
FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetInType() const { | |||
return (ge::DataType)GetIntAttr(SHAPEN_ATTR_IN_TYPE); | |||
return static_cast<ge::DataType>(GetIntAttr(SHAPEN_ATTR_IN_TYPE)); | |||
} | |||
FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::OutType(ge::DataType t) { | |||
Attr(SHAPEN_ATTR_OUT_TYPE, (int64_t)t); | |||
Attr(SHAPEN_ATTR_OUT_TYPE, static_cast<int64_t>(t)); | |||
return *this; | |||
} | |||
FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetOutType() const { | |||
return (ge::DataType)GetIntAttr(SHAPEN_ATTR_OUT_TYPE); | |||
return static_cast<ge::DataType>(GetIntAttr(SHAPEN_ATTR_OUT_TYPE)); | |||
} | |||
} // namespace ge |
@@ -50,7 +50,7 @@ FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParserFactory> OpParserFactory::Insta | |||
// Instances cannot be a member of a class because they may be used before initialization, resulting in a run error. | |||
static std::map<domi::FrameworkType, std::shared_ptr<OpParserFactory>> instances; | |||
auto iter = instances.find(framework); | |||
std::map<domi::FrameworkType, std::shared_ptr<OpParserFactory>>::const_iterator iter = instances.find(framework); | |||
if (iter == instances.end()) { | |||
std::shared_ptr<OpParserFactory> instance(new (std::nothrow) OpParserFactory()); | |||
if (instance == nullptr) { | |||
@@ -67,7 +67,7 @@ FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParserFactory> OpParserFactory::Insta | |||
FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateOpParser(const std::string &op_type) { | |||
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. | |||
auto iter = op_parser_creator_map_.find(op_type); | |||
std::map<std::string, CREATOR_FUN>::const_iterator iter = op_parser_creator_map_.find(op_type); | |||
if (iter != op_parser_creator_map_.end()) { | |||
return iter->second(); | |||
} | |||
@@ -78,7 +78,7 @@ FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateOpPars | |||
FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateFusionOpParser(const std::string &op_type) { | |||
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. | |||
auto iter = fusion_op_parser_creator_map_.find(op_type); | |||
std::map<std::string, CREATOR_FUN>::const_iterator iter = fusion_op_parser_creator_map_.find(op_type); | |||
if (iter != fusion_op_parser_creator_map_.end()) { | |||
return iter->second(); | |||
} | |||
@@ -102,12 +102,12 @@ FMK_FUNC_HOST_VISIBILITY void OpParserFactory::RegisterCreator(const std::string | |||
FMK_FUNC_HOST_VISIBILITY bool OpParserFactory::OpParserIsRegistered(const std::string &op_type, bool is_fusion_op) { | |||
if (is_fusion_op) { | |||
auto iter = fusion_op_parser_creator_map_.find(op_type); | |||
std::map<std::string, CREATOR_FUN>::const_iterator iter = fusion_op_parser_creator_map_.find(op_type); | |||
if (iter != fusion_op_parser_creator_map_.end()) { | |||
return true; | |||
} | |||
} else { | |||
auto iter = op_parser_creator_map_.find(op_type); | |||
std::map<std::string, CREATOR_FUN>::const_iterator iter = op_parser_creator_map_.find(op_type); | |||
if (iter != op_parser_creator_map_.end()) { | |||
return true; | |||
} | |||
@@ -1,61 +0,0 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef PARSER_COMMON_OP_TYPES_H_ | |||
#define PARSER_COMMON_OP_TYPES_H_ | |||
#include <set> | |||
#include <string> | |||
namespace ge { | |||
class GE_FUNC_VISIBILITY OpTypeContainer { | |||
public: | |||
static OpTypeContainer *Instance() { | |||
static OpTypeContainer instance; | |||
return &instance; | |||
} | |||
~OpTypeContainer() = default; | |||
void Register(const std::string &op_type) { op_type_list_.insert(op_type); } | |||
bool IsExisting(const std::string &op_type) { | |||
return op_type_list_.count(op_type) > 0UL; | |||
} | |||
protected: | |||
OpTypeContainer() {} | |||
private: | |||
std::set<std::string> op_type_list_; | |||
}; | |||
class GE_FUNC_VISIBILITY OpTypeRegistrar { | |||
public: | |||
explicit OpTypeRegistrar(const std::string &op_type) { OpTypeContainer::Instance()->Register(op_type); } | |||
~OpTypeRegistrar() {} | |||
}; | |||
#define REGISTER_OPTYPE_DECLARE(var_name, str_name) \ | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *var_name; | |||
#define REGISTER_OPTYPE_DEFINE(var_name, str_name) \ | |||
const char *var_name = str_name; \ | |||
const OpTypeRegistrar g_##var_name##_reg(str_name); | |||
#define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) | |||
} // namespace ge | |||
#endif // PARSER_COMMON_OP_TYPES_H_ |
@@ -16,6 +16,7 @@ | |||
#include "omg/parser/parser_factory.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "common/register_tbe.h" | |||
namespace domi { | |||
FMK_FUNC_HOST_VISIBILITY WeightsParserFactory *WeightsParserFactory::Instance() { | |||
@@ -77,4 +78,13 @@ FMK_FUNC_HOST_VISIBILITY void ModelParserFactory::RegisterCreator(const domi::Fr | |||
ModelParserFactory::~ModelParserFactory() { | |||
creator_map_.clear(); | |||
} | |||
FMK_FUNC_HOST_VISIBILITY OpRegTbeParserFactory *OpRegTbeParserFactory::Instance() { | |||
static OpRegTbeParserFactory instance; | |||
return &instance; | |||
} | |||
void OpRegTbeParserFactory::Finalize(const domi::OpRegistrationData ®_data) { | |||
(void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
} | |||
} // namespace domi |
@@ -17,15 +17,16 @@ | |||
#include "parser/common/parser_fp16_t.h" | |||
#include "external/register/register_types.h" | |||
#include "graph/def_types.h" | |||
namespace { | |||
constexpr uint16_t kManBitLength = 11; | |||
constexpr uint16_t kManBitLength = 11U; | |||
} | |||
namespace ge { | |||
namespace parser { | |||
/// @ingroup fp16_t global filed | |||
/// @brief round mode of last valid digital | |||
enum TagFp16RoundMode g_round_mode = TagFp16RoundMode::kRoundToNearest; | |||
const TagFp16RoundMode g_round_mode = TagFp16RoundMode::kRoundToNearest; | |||
void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m) { | |||
// 1.Extract | |||
@@ -99,12 +100,12 @@ static float Fp16ToFloat(const uint16_t &fp_val) { | |||
e_ret = 0; | |||
m_ret = 0; | |||
} else { | |||
e_ret = hf_exp - kFp16ExpBias + kFp32ExpBias; | |||
e_ret = (static_cast<uint32_t>(hf_exp) - static_cast<uint32_t>(kFp16ExpBias)) + static_cast<uint32_t>(kFp32ExpBias); | |||
m_ret = hf_man & kFp16ManMask; | |||
m_ret = m_ret << (kFp32ManLen - kFp16ManLen); | |||
} | |||
uint32_t f_val = FP32_CONSTRUCTOR(s_ret, e_ret, m_ret); | |||
auto p_ret_v = reinterpret_cast<float *>(&f_val); | |||
auto p_ret_v = ge::PtrToPtr<uint32_t, float>(&f_val); | |||
return *p_ret_v; | |||
} | |||
@@ -131,12 +132,12 @@ static double Fp16ToDouble(const uint16_t &fp_val) { | |||
e_ret = 0; | |||
m_ret = 0; | |||
} else { | |||
e_ret = hf_exp - kFp16ExpBias + kFp64ExpBias; | |||
e_ret = (static_cast<uint64_t>(hf_exp) - static_cast<uint64_t>(kFp16ExpBias)) + static_cast<uint64_t>(kFp64ExpBias); | |||
m_ret = hf_man & kFp16ManMask; | |||
m_ret = m_ret << (kFp64ManLen - kFp16ManLen); | |||
} | |||
uint64_t f_val = (s_ret << kFp64SignIndex) | (e_ret << kFp64ManLen) | (m_ret); | |||
auto p_ret_v = reinterpret_cast<double *>(&f_val); | |||
auto p_ret_v = ge::PtrToPtr<uint64_t, double>(&f_val); | |||
return *p_ret_v; | |||
} | |||
@@ -154,13 +155,13 @@ static uint8_t GetUint8ValByMan(uint8_t s_ret, const uint64_t &long_int_m, const | |||
if (need_round) { | |||
m_ret++; | |||
} | |||
if (s_ret) { | |||
m_ret = (~m_ret) + 1; | |||
if (static_cast<bool>(s_ret)) { | |||
m_ret = (~m_ret) + 1U; | |||
} | |||
if (m_ret == 0) { | |||
s_ret = 0; | |||
} | |||
return static_cast<uint8_t>((s_ret << kBitShift7) | (m_ret)); | |||
return static_cast<uint8_t>((s_ret << static_cast<uint8_t>(kBitShift7)) | (m_ret)); | |||
} | |||
/// @ingroup fp16_t math conversion static method | |||
@@ -178,7 +179,7 @@ static int8_t Fp16ToInt8(const uint16_t &fp_val) { | |||
if (FP16_IS_DENORM(fp_val)) { // Denormalized number | |||
ret_v = 0; | |||
ret = *(reinterpret_cast<uint8_t *>(&ret_v)); | |||
ret = *(ge::PtrToPtr<uint8_t, uint8_t>(&ret_v)); | |||
return ret; | |||
} | |||
@@ -207,14 +208,14 @@ static int8_t Fp16ToInt8(const uint16_t &fp_val) { | |||
} | |||
} | |||
} | |||
if (overflow_flag) { | |||
if (static_cast<bool>(overflow_flag)) { | |||
ret_v = kInt8Max + s_ret; | |||
} else { | |||
// Generate final result | |||
ret_v = GetUint8ValByMan(s_ret, long_int_m, shift_out); | |||
} | |||
ret = *(reinterpret_cast<uint8_t *>(&ret_v)); | |||
ret = *(ge::PtrToPtr<uint8_t, uint8_t>(&ret_v)); | |||
return ret; | |||
} | |||
@@ -283,8 +284,8 @@ static uint16_t GetUint16ValByMan(uint16_t s_ret, const uint64_t &long_int_m, co | |||
if (need_round && m_ret < kInt16Max) { | |||
m_ret++; | |||
} | |||
if (s_ret) { | |||
m_ret = (~m_ret) + 1; | |||
if (static_cast<bool>(s_ret)) { | |||
m_ret = (~m_ret) + 1U; | |||
} | |||
if (m_ret == 0) { | |||
s_ret = 0; | |||
@@ -307,7 +308,7 @@ static int16_t Fp16ToInt16(const uint16_t &fp_val) { | |||
if (FP16_IS_DENORM(fp_val)) { // Denormalized number | |||
ret_v = 0; | |||
ret = *(reinterpret_cast<uint8_t *>(&ret_v)); | |||
ret = *(ge::PtrToPtr<uint16_t, uint8_t>(&ret_v)); | |||
return ret; | |||
} | |||
@@ -336,13 +337,13 @@ static int16_t Fp16ToInt16(const uint16_t &fp_val) { | |||
} | |||
} | |||
} | |||
if (overflow_flag) { | |||
if (static_cast<bool>(overflow_flag)) { | |||
ret_v = kInt16Max + s_ret; | |||
} else { | |||
// Generate final result | |||
ret_v = GetUint16ValByMan(s_ret, long_int_m, shift_out); | |||
} | |||
ret = *(reinterpret_cast<int16_t *>(&ret_v)); | |||
ret = *(ge::PtrToPtr<uint16_t, uint16_t>(&ret_v)); | |||
return ret; | |||
} | |||
@@ -433,7 +434,7 @@ static int32_t Fp16ToInt32(const uint16_t &fp_val) { | |||
ret_v = (s_ret << kBitShift31) | (m_ret); | |||
} | |||
return *(reinterpret_cast<int32_t *>(&ret_v)); | |||
return *(ge::PtrToPtr<uint32_t, uint32_t>(&ret_v)); | |||
} | |||
/// @ingroup fp16_t math conversion static method | |||
@@ -498,8 +499,8 @@ static uint16_t Fp16AddCalVal(uint16_t s_ret, int16_t e_ret, uint16_t m_ret, uin | |||
} | |||
bool b_last_bit = ((m_ret & 1) > 0); | |||
bool b_trunc_high = 0; | |||
bool b_trunc_left = 0; | |||
bool b_trunc_high = false; | |||
bool b_trunc_left = false; | |||
b_trunc_high = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32SignMask) > 0); | |||
b_trunc_left = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32AbsMax) > 0); | |||
m_ret = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_ret, shift_out); | |||
@@ -561,7 +562,7 @@ static uint16_t Fp16Add(uint16_t v_1, uint16_t v_2) { | |||
int16_t e_ret = std::max(e_a, e_b); | |||
int16_t e_tmp = std::abs(e_a - e_b); | |||
if (e_a > e_b) { | |||
m_trunc = (m_b << (kBitShift32 - static_cast<uint16_t>(e_tmp))); | |||
m_trunc = (m_b << (static_cast<uint16_t>(kBitShift32) - static_cast<uint16_t>(e_tmp))); | |||
m_b = RightShift(m_b, e_tmp); | |||
} else if (e_a < e_b) { | |||
m_trunc = (m_a << (kBitShift32 - static_cast<uint16_t>(e_tmp))); | |||
@@ -602,7 +603,7 @@ static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) { | |||
m_a = m_a_tmp; | |||
m_b = m_b_tmp; | |||
e_ret = e_a + e_b - kFp16ExpBias - kDim10; | |||
e_ret = ((e_a + e_b) - kFp16ExpBias) - kDim10; | |||
mul_m = m_a * m_b; | |||
s_ret = s_a ^ s_b; | |||
@@ -621,8 +622,8 @@ static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) { | |||
e_ret = e_ret + 1; | |||
} | |||
bool b_last_bit = ((mul_m & 1) > 0); | |||
bool b_trunc_high = 0; | |||
bool b_trunc_left = 0; | |||
bool b_trunc_high = false; | |||
bool b_trunc_left = false; | |||
b_trunc_high = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32SignMask) > 0); | |||
b_trunc_left = (g_round_mode == TagFp16RoundMode::kRoundToNearest) && ((m_trunc & kFp32AbsMax) > 0); | |||
mul_m = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, mul_m); | |||
@@ -675,14 +676,14 @@ static uint16_t Fp16Div(uint16_t v_1, uint16_t v_2) { | |||
uint64_t m_tmp; | |||
if (e_a > e_b) { | |||
m_tmp = m_a; | |||
uint16_t tmp = e_a - e_b; | |||
uint16_t tmp = static_cast<uint16_t>(e_a - e_b); | |||
for (int i = 0; i < tmp; i++) { | |||
m_tmp = m_tmp << 1; | |||
} | |||
m_a = m_tmp; | |||
} else if (e_a < e_b) { | |||
m_tmp = m_b; | |||
uint16_t tmp = e_b - e_a; | |||
uint16_t tmp = static_cast<uint16_t>(e_b - e_a); | |||
for (int i = 0; i < tmp; i++) { | |||
m_tmp = m_tmp << 1; | |||
} | |||
@@ -853,7 +854,7 @@ fp16_t &fp16_t::operator=(const float &f_val) { | |||
uint16_t s_ret, m_ret; | |||
int16_t e_ret; | |||
uint32_t e_f, m_f; | |||
const uint32_t ui32_v = *(reinterpret_cast<const uint32_t *>(&f_val)); // 1:8:23bit sign:exp:man | |||
const uint32_t ui32_v = *(ge::PtrToPtr<const float, const uint32_t>(&f_val)); // 1:8:23bit sign:exp:man | |||
uint32_t m_len_delta; | |||
s_ret = static_cast<uint16_t>((ui32_v & kFp32SignMask) >> kFp32SignIndex); // 4Byte->2Byte | |||
@@ -891,7 +892,7 @@ fp16_t &fp16_t::operator=(const float &f_val) { | |||
if (need_round) { | |||
m_ret++; | |||
} | |||
if (m_ret & kFp16ManHideBit) { | |||
if (static_cast<bool>(m_ret & kFp16ManHideBit)) { | |||
e_ret++; | |||
} | |||
} | |||
@@ -910,14 +911,14 @@ fp16_t &fp16_t::operator=(const int8_t &i_val) { | |||
if (m_ret == 0) { | |||
e_ret = 0; | |||
} else { | |||
if (s_ret) { // negative number(<0) | |||
if (static_cast<bool>(s_ret)) { // negative number(<0) | |||
m_ret = static_cast<uint16_t>(std::abs(i_val)); // complement | |||
} | |||
e_ret = kFp16ManLen; | |||
while ((m_ret & kFp16ManHideBit) == 0) { | |||
m_ret = m_ret << 1; | |||
e_ret = e_ret - 1; | |||
e_ret = e_ret - 1U; | |||
} | |||
e_ret = e_ret + kFp16ExpBias; | |||
} | |||
@@ -931,11 +932,11 @@ fp16_t &fp16_t::operator=(const uint8_t &ui_val) { | |||
s_ret = 0; | |||
e_ret = 0; | |||
m_ret = ui_val; | |||
if (m_ret) { | |||
if (static_cast<bool>(m_ret)) { | |||
e_ret = kFp16ManLen; | |||
while ((m_ret & kFp16ManHideBit) == 0) { | |||
m_ret = m_ret << 1; | |||
e_ret = e_ret - 1; | |||
e_ret = e_ret - 1U; | |||
} | |||
e_ret = e_ret + kFp16ExpBias; | |||
} | |||
@@ -949,11 +950,11 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u | |||
uint16_t m_min = kFp16ManHideBit; | |||
uint16_t m_max = m_min << 1; | |||
uint16_t len = static_cast<uint16_t>(GetManBitLength(m_tmp)); | |||
if (m_tmp) { | |||
if (static_cast<bool>(m_tmp)) { | |||
int16_t e_ret; | |||
if (len > kDim11) { | |||
e_ret = kFp16ExpBias + kFp16ManLen; | |||
uint16_t e_tmp = len - kDim11; | |||
uint16_t e_tmp = len - static_cast<uint16_t>(kDim11); | |||
uint32_t trunc_mask = 1; | |||
for (int i = 1; i < e_tmp; i++) { | |||
trunc_mask = (trunc_mask << 1) + 1; | |||
@@ -964,8 +965,8 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u | |||
e_ret = e_ret + 1; | |||
} | |||
bool b_last_bit = ((m_tmp & 1) > 0); | |||
bool b_trunc_high = 0; | |||
bool b_trunc_left = 0; | |||
bool b_trunc_high = false; | |||
bool b_trunc_left = false; | |||
if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc | |||
b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | |||
b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | |||
@@ -976,7 +977,7 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u | |||
e_ret = e_ret + 1; | |||
} | |||
} else { | |||
e_ret = kFp16ExpBias; | |||
e_ret = static_cast<int16_t>(kFp16ExpBias); | |||
m_tmp = m_tmp << (kManBitLength - len); | |||
e_ret = e_ret + (len - 1); | |||
} | |||
@@ -989,11 +990,11 @@ fp16_t &fp16_t::operator=(const int16_t &i_val) { | |||
if (i_val == 0) { | |||
val = 0; | |||
} else { | |||
uint16_t ui_val = *(reinterpret_cast<const uint16_t *>(&i_val)); | |||
uint16_t ui_val = *(ge::PtrToPtr<const int16_t, const int16_t>(&i_val)); | |||
auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift15); | |||
if (s_ret) { | |||
if (static_cast<bool>(s_ret)) { | |||
int16_t iValM = -i_val; | |||
ui_val = *(reinterpret_cast<uint16_t *>(&iValM)); | |||
ui_val = *(ge::PtrToPtr<int16_t, uint16_t>(&iValM)); | |||
} | |||
SetValByUint16Val(ui_val, s_ret, val); | |||
} | |||
@@ -1023,8 +1024,8 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) { | |||
e_ret = e_ret + 1; | |||
} | |||
bool b_last_bit = ((m_ret & 1) > 0); | |||
bool b_trunc_high = 0; | |||
bool b_trunc_left = 0; | |||
bool b_trunc_high = false; | |||
bool b_trunc_left = false; | |||
if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc | |||
b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | |||
b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | |||
@@ -1038,7 +1039,7 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) { | |||
val = kFp16Max; | |||
} | |||
} else { | |||
e_ret = kFp16ExpBias; | |||
e_ret = static_cast<int16_t>(kFp16ExpBias); | |||
m_ret = m_ret << (kDim11 - len); | |||
e_ret = e_ret + (len - 1); | |||
} | |||
@@ -1057,7 +1058,7 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u | |||
e_ret = kFp16ExpBias + kFp16ManLen; | |||
uint32_t m_trunc = 0; | |||
uint32_t trunc_mask = 1; | |||
uint16_t e_tmp = len - kDim11; | |||
uint16_t e_tmp = len - static_cast<uint16_t>(kDim11); | |||
for (int i = 1; i < e_tmp; i++) { | |||
trunc_mask = (trunc_mask << 1) + 1; | |||
} | |||
@@ -1067,8 +1068,8 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u | |||
e_ret = e_ret + 1; | |||
} | |||
bool b_last_bit = ((m_tmp & 1) > 0); | |||
bool b_trunc_high = 0; | |||
bool b_trunc_left = 0; | |||
bool b_trunc_high = false; | |||
bool b_trunc_left = false; | |||
if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc | |||
b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | |||
b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | |||
@@ -1083,7 +1084,7 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u | |||
m_tmp = kFp16MaxMan; | |||
} | |||
} else { | |||
e_ret = kFp16ExpBias; | |||
e_ret = static_cast<int16_t>(kFp16ExpBias); | |||
m_tmp = m_tmp << (kDim11 - len); | |||
e_ret = e_ret + (len - 1); | |||
} | |||
@@ -1095,11 +1096,11 @@ fp16_t &fp16_t::operator=(const int32_t &i_val) { | |||
if (i_val == 0) { | |||
val = 0; | |||
} else { | |||
uint32_t ui_val = *(reinterpret_cast<const uint32_t *>(&i_val)); | |||
uint32_t ui_val = *(ge::PtrToPtr<const int32_t, const uint32_t>(&i_val)); | |||
auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift31); | |||
if (s_ret) { | |||
if (static_cast<bool>(s_ret)) { | |||
int32_t iValM = -i_val; | |||
ui_val = *(reinterpret_cast<uint32_t *>(&iValM)); | |||
ui_val = *(ge::PtrToPtr<int32_t, uint32_t>(&iValM)); | |||
} | |||
SetValByUint32Val(ui_val, s_ret, val); | |||
} | |||
@@ -1119,7 +1120,7 @@ fp16_t &fp16_t::operator=(const uint32_t &ui_val) { | |||
e_ret = kFp16ExpBias + kFp16ManLen; | |||
uint32_t m_trunc = 0; | |||
uint32_t trunc_mask = 1; | |||
uint16_t e_tmp = len - kDim11; | |||
uint16_t e_tmp = len - static_cast<uint16_t>(kDim11); | |||
for (int i = 1; i < e_tmp; i++) { | |||
trunc_mask = (trunc_mask << 1) + 1; | |||
} | |||
@@ -1145,7 +1146,7 @@ fp16_t &fp16_t::operator=(const uint32_t &ui_val) { | |||
m_tmp = kFp16MaxMan; | |||
} | |||
} else { | |||
e_ret = kFp16ExpBias; | |||
e_ret = static_cast<int16_t>(kFp16ExpBias); | |||
m_tmp = m_tmp << (kDim11 - len); | |||
e_ret = e_ret + (len - 1); | |||
} | |||
@@ -1161,7 +1162,7 @@ fp16_t &fp16_t::operator=(const double &d_val) { | |||
int16_t e_ret; | |||
uint64_t e_d; | |||
uint64_t m_d; | |||
uint64_t ui64_v = *(reinterpret_cast<const uint64_t *>(&d_val)); // 1:11:52bit sign:exp:man | |||
uint64_t ui64_v = *(ge::PtrToPtr<const double, const uint64_t>(&d_val)); // 1:11:52bit sign:exp:man | |||
uint32_t m_len_delta; | |||
s_ret = static_cast<uint16_t>((ui64_v & kFp64SignMask) >> kFp64SignIndex); // 4Byte | |||
@@ -1204,7 +1205,7 @@ fp16_t &fp16_t::operator=(const double &d_val) { | |||
if (need_round) { | |||
m_ret++; | |||
} | |||
if (m_ret & kFp16ManHideBit) { | |||
if (static_cast<bool>(m_ret & kFp16ManHideBit)) { | |||
e_ret++; | |||
} | |||
} | |||
@@ -1239,7 +1240,7 @@ fp16_t::operator uint64_t() const { return 0; } | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() const { | |||
if ((val & kFp16AbsMax) == kFp16ExpMask) { | |||
if (val & kFp16SignMask) { | |||
if (static_cast<bool>(val & kFp16SignMask)) { | |||
return -1; | |||
} else { | |||
return 1; | |||
@@ -91,16 +91,16 @@ using BitShift = enum { | |||
}; | |||
/// @ingroup fp16 basic parameter | |||
/// @brief fp16 exponent bias | |||
constexpr uint16_t kFp16ExpBias = 15; | |||
constexpr uint16_t kFp16ExpBias = 15U; | |||
/// @ingroup fp16 basic parameter | |||
/// @brief the exponent bit length of fp16 is 5 | |||
constexpr uint16_t kFp16ExpLen = 5; | |||
constexpr uint16_t kFp16ExpLen = 5U; | |||
/// @ingroup fp16 basic parameter | |||
/// @brief the mantissa bit length of fp16 is 10 | |||
constexpr uint16_t kFp16ManLen = 10; | |||
constexpr uint16_t kFp16ManLen = 10U; | |||
/// @ingroup fp16 basic parameter | |||
/// @brief bit index of sign in fp16 | |||
constexpr uint16_t kFp16SignIndex = 15; | |||
constexpr uint16_t kFp16SignIndex = 15U; | |||
/// @ingroup fp16 basic parameter | |||
/// @brief sign mask of fp16 (1 00000 00000 00000) | |||
constexpr uint16_t kFp16SignMask = 0x8000; | |||
@@ -164,16 +164,16 @@ constexpr uint16_t kFp16MinNormal = 1.0f / (2 << 14); | |||
#define FP16_IS_INVALID(x) (((x) & kFp16ExpMask) == kFp16ExpMask) | |||
/// @ingroup fp32 basic parameter | |||
/// @brief fp32 exponent bias | |||
constexpr uint16_t kFp32ExpBias = 127; | |||
constexpr uint16_t kFp32ExpBias = 127U; | |||
/// @ingroup fp32 basic parameter | |||
/// @brief the exponent bit length of float/fp32 is 8 | |||
constexpr uint16_t kFp32ExpLen = 8; | |||
constexpr uint16_t kFp32ExpLen = 8U; | |||
/// @ingroup fp32 basic parameter | |||
/// @brief the mantissa bit length of float/fp32 is 23 | |||
constexpr uint16_t kFp32ManLen = 23; | |||
constexpr uint16_t kFp32ManLen = 23U; | |||
/// @ingroup fp32 basic parameter | |||
/// @brief bit index of sign in float/fp32 | |||
constexpr uint16_t kFp32SignIndex = 31; | |||
constexpr uint16_t kFp32SignIndex = 31U; | |||
/// @ingroup fp32 basic parameter | |||
/// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000) | |||
constexpr uint32_t kFp32SignMask = 0x80000000u; | |||
@@ -191,10 +191,10 @@ constexpr uint32_t kFp32ManHideBit = 0x00800000u; | |||
constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFu; | |||
/// @ingroup fp32 basic parameter | |||
/// @brief maximum exponent value of fp32 is 255(1111 1111) | |||
constexpr uint32_t kFp32MaxExp = 0xFF; | |||
constexpr uint32_t kFp32MaxExp = 0xFFU; | |||
/// @ingroup fp32 basic parameter | |||
/// @brief maximum mantissa value of fp32 (1111 1111 1111 1111 1111 111) | |||
constexpr uint32_t kFp32MaxMan = 0x7FFFFF; | |||
constexpr uint32_t kFp32MaxMan = 0x7FFFFFU; | |||
/// @ingroup fp32 special value judgment | |||
/// @brief whether a fp32 is NaN | |||
#define FP32_IS_NAN(x) ((((x) & kFp32ExpMask) == kFp32ExpMask) && ((x) & kFp32ManMask)) | |||
@@ -218,16 +218,16 @@ constexpr uint32_t kFp32MaxMan = 0x7FFFFF; | |||
#define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m) & kFp32MaxMan)) | |||
/// @ingroup fp64 basic parameter | |||
/// @brief fp64 exponent bias | |||
constexpr uint16_t kFp64ExpBias = 1023; | |||
constexpr uint16_t kFp64ExpBias = 1023U; | |||
/// @ingroup fp64 basic parameter | |||
/// @brief the exponent bit length of double/fp64 is 11 | |||
constexpr uint16_t kFp64ExpLen = 11; | |||
constexpr uint16_t kFp64ExpLen = 11U; | |||
/// @ingroup fp64 basic parameter | |||
/// @brief the mantissa bit length of double/fp64 is 52 | |||
constexpr uint16_t kFp64ManLen = 52; | |||
constexpr uint16_t kFp64ManLen = 52U; | |||
/// @ingroup fp64 basic parameter | |||
/// @brief bit index of sign in double/fp64 is 63 | |||
constexpr uint16_t kFp64SignIndex = 63; | |||
constexpr uint16_t kFp64SignIndex = 63U; | |||
/// @ingroup fp64 basic parameter | |||
/// @brief sign mask of fp64 (1 000 (total 63bits 0)) | |||
constexpr uint64_t kFp64SignMask = 0x8000000000000000LLu; | |||
@@ -269,14 +269,14 @@ constexpr int16_t kInt16Max = 0x7FFF; | |||
constexpr uint16_t kBitLen16Max = 0xFFFF; | |||
/// @ingroup integer special value judgment | |||
/// @brief maximum positive value of int32_t (0111 1111 1111 1111 1111 1111 1111 1111) | |||
constexpr int32_t kInt32Max = 0x7FFFFFFFu; | |||
constexpr int32_t kInt32Max = 0x7FFFFFFF; | |||
/// @ingroup integer special value judgment | |||
/// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111) | |||
constexpr uint32_t kBitLen32Max = 0xFFFFFFFFu; | |||
/// @ingroup integer special value judgment | |||
/// @brief maximum positive value of int64_t | |||
/// (0111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) | |||
constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFFu; | |||
constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFF; | |||
/// @ingroup integer special value judgment | |||
/// @brief maximum value of a data with 64 bits length | |||
/// (1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) | |||
@@ -62,7 +62,7 @@ public: | |||
/// @return others optimized failed | |||
/// @author | |||
/// | |||
static Status Run(const ge::ComputeGraphPtr &graph, std::vector<std::pair<std::string, GraphPass *>> &passes); | |||
static Status Run(const ge::ComputeGraphPtr &graph, std::vector<std::pair<std::string, GraphPass *>> &names_to_passes); | |||
~PassManager(); | |||
@@ -99,9 +99,9 @@ Status PreChecker::CheckName(OpId id) { | |||
if (id != v.first && info.name == v.second.name) { | |||
Cause cause; | |||
cause.code = ErrorCode::NAME_REPEATED; | |||
cause.message = "The name is repeated."; | |||
cause.message = "The name is repeated in the graph."; | |||
GELOGI("Name %s repeated.", info.name.c_str()); | |||
GELOGE(FAILED, "opname %s repeated, same name op in the graph", info.name.c_str()); | |||
ErrorManager::GetInstance().ATCReportErrMessage("E19009", {"opname"}, {info.name}); | |||
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed."); | |||
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(v.first, cause), "[Add][Cause] failed."); | |||
@@ -200,7 +200,7 @@ FMK_FUNC_HOST_VISIBILITY bool PreChecker::HasError() { | |||
return false; | |||
} | |||
Status PreChecker::Save(string file) { | |||
Status PreChecker::Save(const string &file) { | |||
uint32_t fail_num = 0; | |||
for (auto id : ops_) { | |||
if (HasError(id)) { | |||
@@ -250,7 +250,7 @@ Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string | |||
Cause cause; | |||
cause.code = ErrorCode::TYPE_UNSUPPORTED; | |||
cause.message = "The type is not supported."; | |||
GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); | |||
GELOGI("Check op[%s]'s type[%s], The type is not supported.", name.c_str(), type.c_str()); | |||
if (!is_tensorflow) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); | |||
} | |||
@@ -265,9 +265,11 @@ Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string | |||
cause.code = ErrorCode::TYPE_UNSUPPORTED; | |||
cause.message = "The type is not supported."; | |||
GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); | |||
if (!is_tensorflow) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); | |||
GELOGE(FAILED, "Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); | |||
} else { | |||
GELOGI("Check op[%s]'s type[%s] is not supported.", name.c_str(), type.c_str()); | |||
} | |||
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "[Add][Cause] failed."); | |||
} | |||
@@ -142,7 +142,7 @@ class PreChecker { | |||
* @ingroup domi_omg | |||
* @brief Save inspection results(JSON) | |||
*/ | |||
Status Save(string file); | |||
Status Save(const string &file); | |||
private: | |||
/** | |||
@@ -425,7 +425,8 @@ Status ProtoFileParser::FindConflictLine(const char *proto_file, int identifier, | |||
void ProtoFileParser::CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file, | |||
std::map<std::string, std::pair<int, string>> &caffe_op_identifier_map, | |||
std::map<std::string, std::pair<int, string>> &custom_op_identifier_map) { | |||
for (auto iter = custom_op_identifier_map.begin(); iter != custom_op_identifier_map.end(); ++iter) { | |||
std::map<std::string, std::pair<int, string>>::const_iterator iter = custom_op_identifier_map.begin(); | |||
for (; iter != custom_op_identifier_map.end(); ++iter) { | |||
if (caffe_op_identifier_map.count(iter->first) > 0) { | |||
string message_name = iter->first; | |||
auto caffe_pair = caffe_op_identifier_map[iter->first]; | |||
@@ -452,7 +453,8 @@ void ProtoFileParser::CheckConflictOp(const char *caffe_proto_file, const char * | |||
void ProtoFileParser::CheckConflictIdentifier(const char *caffe_proto_file, const char *custom_proto_file, | |||
std::map<int, std::pair<string, string>> caffe_identifier_op_map, | |||
std::map<int, std::pair<string, string>> custom_identifier_op_map) { | |||
for (auto iter = custom_identifier_op_map.begin(); iter != custom_identifier_op_map.end(); ++iter) { | |||
std::map<int, std::pair<string, string>>::const_iterator iter = custom_identifier_op_map.begin(); | |||
for (; iter != custom_identifier_op_map.end(); ++iter) { | |||
if (caffe_identifier_op_map.count(iter->first) > 0) { | |||
int identifier = iter->first; | |||
auto caffe_pair = caffe_identifier_op_map[iter->first]; | |||
@@ -97,7 +97,8 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { | |||
return false; | |||
} | |||
OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( | |||
domi::TENSORFLOW, GetOmOptype(reg_data), [=]() -> std::shared_ptr<OpParser> { return tf_parser_adapter; }); | |||
domi::TENSORFLOW, GetOmOptype(reg_data), [tf_parser_adapter]() -> std::shared_ptr<OpParser> | |||
{ return tf_parser_adapter; }); | |||
} | |||
if (reg_data.GetFusionParseParamFn() != nullptr || reg_data.GetFusionParseParamByOpFn() != nullptr) { | |||
bool is_registed = factory->OpParserIsRegistered(GetOmOptype(reg_data), true); | |||
@@ -115,7 +116,7 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { | |||
} | |||
OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( | |||
domi::TENSORFLOW, GetOmOptype(reg_data), | |||
[=]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; }, true); | |||
[tf_fusion_parser_adapter]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; }, true); | |||
} | |||
} else { | |||
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(reg_data.GetFrameworkType()); | |||
@@ -105,7 +105,7 @@ void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) { | |||
GELOGI("Enter get custom op path schedule"); | |||
std::string fmk_type; | |||
domi::FrameworkType type = domi::TENSORFLOW; | |||
auto it = options_.find(FRAMEWORK_TYPE); | |||
std::map<string, string>::const_iterator it = options_.find(FRAMEWORK_TYPE); | |||
if (it != options_.end()) { | |||
type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10)); | |||
} | |||
@@ -227,9 +227,9 @@ def convert_subgraphs(graph_def, filename): | |||
print(graph_def_library.graph_def[i]) | |||
# Write to prototxt | |||
graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename))) | |||
print("graph_def_file: ", graph_def_file) | |||
try: | |||
graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename))) | |||
print("graph_def_file: ", graph_def_file) | |||
with open(graph_def_file, "w") as f: | |||
print(graph_def_library, file=f) | |||
except IOError: | |||
@@ -261,18 +261,18 @@ if __name__ == '__main__': | |||
model = '' | |||
try: | |||
opts, args = getopt.getopt(sys.argv[1:], '-v-h-m:', ['version', 'help', 'model=']) | |||
for opt_name, opt_value in opts: | |||
if opt_name in ('-m', '--model'): | |||
model = opt_value | |||
print("INFO: Input model file is", model) | |||
convert_graphs(model) | |||
elif opt_name in ('-h', '--help'): | |||
usage() | |||
break | |||
elif opt_name in ('-v', '--version'): | |||
print("version 1.0.0") | |||
break | |||
except getopt.GetoptError: | |||
print("ERROR: Input parameters is invalid, use '--help' to view the help.") | |||
for opt_name, opt_value in opts: | |||
if opt_name in ('-m', '--model'): | |||
model = opt_value | |||
print("INFO: Input model file is", model) | |||
convert_graphs(model) | |||
elif opt_name in ('-h', '--help'): | |||
usage() | |||
break | |||
elif opt_name in ('-v', '--version'): | |||
print("version 1.0.0") | |||
break | |||
if len(sys.argv) == 1: | |||
print("INFO: Please specify the input parameters, and use '--help' to view the help.") |
@@ -71,7 +71,7 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_ | |||
}; | |||
int32_t datatype_val_size = 0; | |||
auto iter = datatype_val_size_map.find(data_type); | |||
std::map<uint32_t, int32_t>::const_iterator iter = datatype_val_size_map.find(data_type); | |||
if (iter != datatype_val_size_map.end()) { | |||
datatype_val_size = iter->second; | |||
} else { | |||
@@ -91,7 +91,7 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_ | |||
if (data_type == OnnxDataType::STRING) { | |||
tensor.SetData(tensor_proto.raw_data().c_str()); | |||
} else { | |||
tensor.SetData(reinterpret_cast<const uint8_t *>(tensor_proto.raw_data().c_str()), | |||
tensor.SetData(PtrToPtr<const char_t, const uint8_t>(tensor_proto.raw_data().c_str()), | |||
tensor_proto.raw_data().size()); | |||
} | |||
GELOGD("Raw data size is : %zu", tensor_proto.raw_data().size()); | |||
@@ -65,7 +65,7 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { | |||
*(addr_trans.get() + i) = static_cast<bool>( | |||
std::fabs(*((addr).get() + i)) > std::numeric_limits<T>::epsilon()); | |||
} | |||
(tensor).SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), (count) * sizeof(bool)); | |||
(tensor).SetData(PtrToPtr<bool, uint8_t>(addr_trans.get()), (count) * sizeof(bool)); | |||
break; | |||
} | |||
#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ | |||
@@ -32,7 +32,6 @@ | |||
#include "onnx_op_parser.h" | |||
#include "onnx_util.h" | |||
#include "parser/common/op_parser_factory.h" | |||
#include "parser/common/pre_checker.h" | |||
#include "parser/common/acl_graph_parser_util.h" | |||
#include "parser/common/model_saver.h" | |||
#include "parser/common/parser_utils.h" | |||
@@ -240,7 +239,7 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra | |||
if (node->GetOpDesc() == nullptr) { | |||
continue; | |||
} | |||
node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName()); | |||
node->GetOpDesc()->SetName(OnnxUtil::GenUniqueNodeName(sub_graph->GetName(), node->GetName())); | |||
} | |||
auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph); | |||
@@ -384,7 +383,7 @@ Status OnnxModelParser::ConstructOriType(const ge::onnx::NodeProto *node_proto, | |||
std::string domain = node_proto->domain(); | |||
int64_t version = 0; | |||
if (!domain.empty()) { | |||
auto it = domain_verseion_.find(domain); | |||
std::map<std::string, int64_t>::const_iterator it = domain_verseion_.find(domain); | |||
if (it != domain_verseion_.end()) { | |||
version = it->second; | |||
} else { | |||
@@ -493,14 +492,14 @@ Status OnnxModelParser::SetOperatorInputs() { | |||
std::vector<std::pair<std::string, int>> &output_node_indexs = out_iter->second; | |||
for (auto input_node_index : input_node_indexs) { | |||
for (auto out_node_index : output_node_indexs) { | |||
auto input_op_iter = name_operator_.find(input_node_index.first); | |||
std::map<std::string, ge::Operator>::const_iterator input_op_iter = name_operator_.find(input_node_index.first); | |||
if (input_op_iter == name_operator_.end()) { | |||
REPORT_INNER_ERROR("E19999", "Node: %s can not find in name_operator map.", input_node_index.first.c_str()); | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] Node: %s can not find in name_operator map.", | |||
input_node_index.first.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
auto output_op_iter = name_operator_.find(out_node_index.first); | |||
std::map<std::string, ge::Operator>::const_iterator output_op_iter = name_operator_.find(out_node_index.first); | |||
if (output_op_iter == name_operator_.end()) { | |||
REPORT_INNER_ERROR("E19999", "Node: %s can not find in name_operator map.", out_node_index.first.c_str()); | |||
GELOGE(INTERNAL_ERROR, "[Check][Param] Node: %s can not find in name_operator map.", | |||
@@ -594,6 +593,7 @@ Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge:: | |||
} | |||
Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { | |||
bool has_error = false; | |||
for (int i = 0; i < onnx_graph.node_size(); i++) { | |||
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | |||
std::string node_name = node_proto->name(); | |||
@@ -605,7 +605,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||
if (status != SUCCESS) { | |||
GELOGE(status, "[Adapt][OpType] Adapter op type for ori type %s failed.", ori_type.c_str()); | |||
REPORT_CALL_ERROR("E19999", "Adapter op type for ori type %s failed.", ori_type.c_str()); | |||
return status; | |||
has_error = true; | |||
continue; | |||
} | |||
node_proto->set_op_type(ori_type); | |||
@@ -616,7 +617,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||
if (status != SUCCESS) { | |||
GELOGE(status, "[Trans][Node] Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); | |||
REPORT_CALL_ERROR("E19999", "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); | |||
return status; | |||
has_error = true; | |||
continue; | |||
} | |||
// 7. op parser | |||
@@ -627,7 +629,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||
status = ParseOpParam(node_proto, op, op_parser); | |||
if (status != SUCCESS) { | |||
GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); | |||
return status; | |||
has_error = true; | |||
continue; | |||
} | |||
GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", | |||
@@ -638,7 +641,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||
if (graph_status != ge::GRAPH_SUCCESS) { | |||
GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str()); | |||
REPORT_CALL_ERROR("E19999", "Add op:%s to graph failed.", ParserUtils::GetOperatorName(op).c_str()); | |||
return FAILED; | |||
has_error = true; | |||
continue; | |||
} | |||
name_operator_[ParserUtils::GetOperatorName(op)] = op; | |||
@@ -647,11 +651,12 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||
if (status != SUCCESS) { | |||
REPORT_INNER_ERROR("E19999", "ConstructInputOutputContext failed."); | |||
GELOGE(status, "[Construct][RelationMap] to input and output failed."); | |||
return status; | |||
has_error = true; | |||
continue; | |||
} | |||
} | |||
GELOGI("Parse all node proto success."); | |||
return SUCCESS; | |||
GELOGI("Parse all node proto end."); | |||
return has_error ? FAILED : SUCCESS; | |||
} | |||
Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { | |||
@@ -665,7 +670,7 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve | |||
} | |||
} | |||
for (auto in_name : input_node_names_) { | |||
auto in_op = name_operator_.find(in_name); | |||
std::map<std::string, ge::Operator>::const_iterator in_op = name_operator_.find(in_name); | |||
if (in_op == name_operator_.end()) { | |||
GELOGE(PARAM_INVALID, "[Get][Inputs] Model assigned input node name: %s can not find in graph.", | |||
in_name.c_str()); | |||
@@ -682,7 +687,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve | |||
Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops, | |||
ParserUtils::OutputMapping &out_tensor_to_nodes) { | |||
for (auto output_name : output_node_names_) { | |||
auto itr = outputs_map_.find(output_name); | |||
std::map<std::string, std::vector<std::pair<std::string, int>>>::const_iterator itr = | |||
outputs_map_.find(output_name); | |||
if (itr == outputs_map_.end()) { | |||
GELOGE(PARAM_INVALID, "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | |||
REPORT_INNER_ERROR("E19999", "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | |||
@@ -692,7 +698,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vec | |||
std::vector<std::pair<std::string, int>> node_names_indexes = itr->second; | |||
for (const auto &node_name_index : node_names_indexes) { | |||
auto node_name = node_name_index.first; | |||
auto out_op_itr = name_operator_.find(node_name); | |||
std::map<std::string, ge::Operator>::const_iterator out_op_itr = name_operator_.find(node_name); | |||
if (out_op_itr == name_operator_.end()) { | |||
GELOGE(PARAM_INVALID, "[Get][Operator] Can not find operator: %s in graph.", node_name.c_str()); | |||
REPORT_INNER_ERROR("E19999", "Can not find operator: %s in graph.", node_name.c_str()); | |||
@@ -749,6 +755,14 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph( | |||
while (!onnx_graph_tasks.empty()) { | |||
ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front(); | |||
onnx_graph_tasks.pop(); | |||
std::string graph_name; | |||
for (const auto &graph_iter : name_to_onnx_graph) { | |||
if (graph_iter.second == onnx_graph) { | |||
graph_name = graph_iter.first; | |||
break; | |||
} | |||
} | |||
for (int i = 0; i < onnx_graph->node_size(); i++) { | |||
ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i); | |||
if (node_proto->name().empty()) { | |||
@@ -766,7 +780,8 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph( | |||
} | |||
std::vector<ge::onnx::GraphProto *> onnx_graphs; | |||
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_subgraph; | |||
if (subgraph_adapter->AdaptAndFindAllSubgraphs(node_proto, onnx_graphs, name_to_onnx_subgraph) != SUCCESS) { | |||
if (subgraph_adapter->AdaptAndFindAllSubgraphs( | |||
node_proto, onnx_graphs, name_to_onnx_subgraph, graph_name) != SUCCESS) { | |||
GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str()); | |||
REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str()); | |||
return FAILED; | |||
@@ -815,7 +830,7 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||
bool is_subgraph = (arg.parent_node != nullptr) ? true : false; | |||
if (arg.onnx_graph == nullptr) { | |||
auto itr = name_to_onnx_graph.find(arg.graph_name); | |||
std::map<std::string, ge::onnx::GraphProto *>::const_iterator itr = name_to_onnx_graph.find(arg.graph_name); | |||
if (itr == name_to_onnx_graph.end()) { | |||
GELOGE(FAILED, "[Find][OnnxGraph] Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); | |||
REPORT_INNER_ERROR("E19999", "Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); | |||
@@ -38,6 +38,7 @@ | |||
#include "omg/parser/op_parser.h" | |||
#include "omg/parser/weights_parser.h" | |||
#include "common/parser_utils.h" | |||
#include "common/pre_checker.h" | |||
#include "proto/onnx/ge_onnx.pb.h" | |||
namespace ge { | |||
@@ -81,6 +82,18 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||
return domi::SUCCESS; | |||
} | |||
bool HasError() override { | |||
return PreChecker::Instance().HasError(); | |||
} | |||
Status Save(const string &file) override { | |||
return PreChecker::Instance().Save(file); | |||
} | |||
void Clear() override { | |||
PreChecker::Instance().Clear(); | |||
} | |||
private: | |||
Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); | |||
@@ -96,7 +109,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||
Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); | |||
Status AdapterOpType(const ge::onnx::NodeProto *node_proto, std::string &ori_type, std::string &om_type); | |||
Status AdapterOpType(const ge::onnx::NodeProto *node_proto, std::string &ori_type, std::string &op_type); | |||
Status TransNodeToOperator(const ge::onnx::NodeProto *node_proto, ge::Operator &op, const string &op_type) const; | |||
@@ -106,7 +119,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||
Status GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops); | |||
Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &outputs, | |||
Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops, | |||
ParserUtils::OutputMapping &out_tensor_to_nodes); | |||
Status Prechecker(ge::onnx::GraphProto &onnx_graph); | |||
@@ -115,7 +128,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||
Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const; | |||
Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph); | |||
Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph); | |||
Status ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); | |||
@@ -161,6 +174,18 @@ class PARSER_FUNC_VISIBILITY OnnxWeightsParser : public domi::WeightsParser { | |||
(void)graph; | |||
return domi::SUCCESS; | |||
} | |||
bool HasError() override { | |||
return PreChecker::Instance().HasError(); | |||
} | |||
Status Save(const string &file) override { | |||
return PreChecker::Instance().Save(file); | |||
} | |||
void Clear() override { | |||
PreChecker::Instance().Clear(); | |||
} | |||
}; | |||
} // namespace domi | |||
#endif // PARSER_ONNX_ONNX_PARSER_H_ |
@@ -45,4 +45,8 @@ void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &orig | |||
const std::string &parent_node_name, std::string &unique_subgraph_name) { | |||
unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; | |||
} | |||
std::string OnnxUtil::GenUniqueNodeName(const std::string &graph_name, const std::string &node_name) { | |||
return graph_name + "/" + node_name; | |||
} | |||
} // namespace ge |
@@ -54,6 +54,7 @@ class OnnxUtil { | |||
static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); | |||
static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, | |||
const std::string &parent_node_name, std::string &unique_subgraph_name); | |||
static std::string GenUniqueNodeName(const std::string &graph_name, const std::string &node_name); | |||
}; | |||
} // namespace ge | |||
@@ -15,6 +15,7 @@ | |||
*/ | |||
#include "if_subgraph_adapter.h" | |||
#include <unordered_set> | |||
#include "subgraph_adapter_factory.h" | |||
#include "common/util.h" | |||
#include "framework/common/debug/ge_log.h" | |||
@@ -27,12 +28,12 @@ const int kIfNodeAttrSize = 2; | |||
} // namespace | |||
domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) { | |||
GE_CHECK_NOTNULL(parent_node); | |||
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), | |||
parent_node->op_type().c_str()); | |||
auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph); | |||
auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Parse][Node] Parse if node failed."); | |||
REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); | |||
@@ -44,7 +45,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||
domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) { | |||
if (parent_node->attribute_size() != kIfNodeAttrSize) { | |||
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | |||
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | |||
@@ -67,7 +68,11 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||
return FAILED; | |||
} | |||
std::string unique_subgraph_name; | |||
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, parent_node->name(), unique_subgraph_name); | |||
std::string node_name = parent_node->name(); | |||
if (!parent_graph_name.empty()) { | |||
node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name); | |||
} | |||
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, node_name, unique_subgraph_name); | |||
GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str()); | |||
ge::onnx::GraphProto *onnx_graph = attribute->mutable_g(); | |||
name_to_onnx_graph[unique_subgraph_name] = onnx_graph; | |||
@@ -91,8 +96,8 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||
domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, | |||
std::set<std::string> &all_inputs) const { | |||
std::set<std::string> graph_inputs; | |||
std::set<std::string> graph_outputs; | |||
std::unordered_set<std::string> graph_inputs; | |||
std::unordered_set<std::string> graph_outputs; | |||
for (int i = 0; i < onnx_graph.node_size(); i++) { | |||
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | |||
for (int j = 0; j < node_proto->input_size(); j++) { | |||
@@ -102,10 +107,12 @@ domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx | |||
graph_outputs.emplace(node_proto->output(j)); | |||
} | |||
} | |||
std::unordered_set<std::string> graph_initializer_tensors; | |||
for (int32_t i = 0; i < onnx_graph.initializer_size(); i++) { | |||
graph_initializer_tensors.emplace(onnx_graph.initializer(i).name()); | |||
} | |||
for (const auto &input : graph_inputs) { | |||
auto out_iter = graph_outputs.find(input); | |||
if (out_iter == graph_outputs.end()) { | |||
if (graph_outputs.count(input) == 0 && graph_initializer_tensors.count(input) == 0) { | |||
// Record input node need to be constructed | |||
all_inputs.emplace(input); | |||
} | |||
@@ -24,13 +24,15 @@ | |||
namespace ge { | |||
class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | |||
public: | |||
domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, | |||
domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_node, | |||
std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) override; | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | |||
const std::string &parent_graph_name = "") override; | |||
private: | |||
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph); | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | |||
const std::string &parent_graph_name); | |||
domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; | |||
void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph) const; | |||
void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node) const; | |||
@@ -49,10 +49,12 @@ class PARSER_FUNC_VISIBILITY SubgraphAdapter { | |||
/// @return FAILED Parse failed | |||
virtual domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, | |||
std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | |||
const std::string &parent_graph_name = "") { | |||
(void)parent_op; | |||
(void)onnx_graphs; | |||
(void)name_to_onnx_graph; | |||
(void)parent_graph_name; | |||
return domi::SUCCESS; | |||
} | |||
}; | |||
@@ -26,7 +26,7 @@ SubgraphAdapterFactory* SubgraphAdapterFactory::Instance() { | |||
std::shared_ptr<SubgraphAdapter> SubgraphAdapterFactory::CreateSubgraphAdapter( | |||
const std::string &op_type) { | |||
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create SubgraphAdapter. | |||
auto iter = subgraph_adapter_creator_map_.find(op_type); | |||
std::map<std::string, CREATOR_FUN>::const_iterator iter = subgraph_adapter_creator_map_.find(op_type); | |||
if (iter != subgraph_adapter_creator_map_.end()) { | |||
return iter->second(); | |||
} | |||
@@ -161,7 +161,6 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp | |||
if (node_def->input(i).find("^") != string::npos) { | |||
// Control input | |||
const string normalized = node_names.Renormalize(node_def->input(i).substr(1)); | |||
GE_IF_BOOL_EXEC(normalized.empty(), | |||
REPORT_INNER_ERROR("E19999", "Could not remap control input %s of node %s in function %s", | |||
node_def->input(i).c_str(), node_def->name().c_str(), name.c_str()); | |||
@@ -172,7 +171,6 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp | |||
*node_def->mutable_input(i) = "^" + normalized; | |||
} else { | |||
const auto iter = tensor_renaming.find(node_def->input(i)); | |||
GE_IF_BOOL_EXEC(iter == tensor_renaming.end(), | |||
REPORT_INNER_ERROR("E19999", "Could not remap input %s of node %s in function %s", | |||
node_def->input(i).c_str(), node_def->name().c_str(), name.c_str()); | |||
@@ -188,14 +186,12 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp | |||
// Remap return values. | |||
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { | |||
const string &ret_name = fdef->signature().output_arg(r).name(); | |||
GE_IF_BOOL_EXEC(ret_name.empty(), | |||
REPORT_INNER_ERROR("E19999", "Missing output %d to function %s", r, name.c_str()); | |||
GELOGE(domi::INTERNAL_ERROR, "Missing output %d to function %s .", r, name.c_str()); | |||
return domi::INTERNAL_ERROR); | |||
const string &return_value = return_values[ret_name]; | |||
GE_IF_BOOL_EXEC(return_value.empty(), | |||
REPORT_INNER_ERROR("E19999", "Could not remap return value %d ,%s of %s in function %s", r, | |||
ret_name.c_str(), return_value.c_str(), name.c_str()); | |||
@@ -204,12 +200,11 @@ domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelp | |||
return domi::INTERNAL_ERROR); | |||
const auto iter = tensor_renaming.find(return_value); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iter == tensor_renaming.end(), | |||
REPORT_INNER_ERROR("E19999", "can not find value[%s] in tensor_renaming map", | |||
return_value.c_str()); | |||
return domi::INTERNAL_ERROR, | |||
"can not find value[%s] in tensor_renaming map.", return_value.c_str()); | |||
if (iter == tensor_renaming.end()) { | |||
REPORT_INNER_ERROR("E19999", "can not find value[%s] in tensor_renaming map", return_value.c_str()); | |||
GELOGE(FAILED, "can not find value[%s] in tensor_renaming map.", return_value.c_str()); | |||
return domi::INTERNAL_ERROR; | |||
} | |||
(*fdef->mutable_ret())[ret_name] = iter->second; | |||
} | |||
@@ -227,7 +222,7 @@ domi::Status GraphToFunctionDef::RecordResult(ge::ComputeGraphPtr graph, | |||
GE_CHECK_NOTNULL(anchor); | |||
GE_CHECK_NOTNULL(anchor->GetOwnerNode()->GetOpDesc()); | |||
int32_t type = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(anchor->GetIdx()).GetDataType(); | |||
auto iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type); | |||
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type); | |||
GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | |||
REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported", | |||
type, anchor->GetIdx(), anchor->GetOwnerNode()->GetName().c_str(), | |||
@@ -304,7 +299,7 @@ domi::Status GraphToFunctionDef::RecordArg(ge::ComputeGraphPtr graph, const vect | |||
GE_CHECK_NOTNULL_EXEC(tensor_desc_ptr, return domi::FAILED); | |||
int32_t type = tensor_desc_ptr->GetDataType(); | |||
auto iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type); | |||
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type); | |||
GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | |||
REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported", | |||
type, anchor->GetIdx(), anchor->GetOwnerNode()->GetName().c_str(), | |||
@@ -325,8 +320,8 @@ domi::Status GraphToFunctionDef::RecordArg(ge::ComputeGraphPtr graph, const vect | |||
return FAILED; | |||
} | |||
(void)ge::AttrUtils::SetInt(op, "T", (int32_t)dtype); | |||
(void)ge::AttrUtils::SetInt(op, "arg_index", (int32_t)index); | |||
(void)ge::AttrUtils::SetInt(op, "T", static_cast<int32_t>(dtype)); | |||
(void)ge::AttrUtils::SetInt(op, "arg_index", static_cast<int32_t>(index)); | |||
ge::NodePtr arg_node = graph->AddNode(op); | |||
GE_CHECK_NOTNULL(arg_node); | |||
bool node_exists = false; | |||
@@ -378,7 +373,6 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph | |||
GE_CHECK_NOTNULL(node); | |||
if (node->GetOpDesc()->GetType() == ge::parser::DATA) { | |||
int64_t index = 0; | |||
int64_t type = 1; | |||
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID, | |||
"Get type attr failed"); | |||
@@ -400,7 +394,6 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph | |||
if (node->GetOpDesc()->GetType() == ge::parser::NETOUTPUT) { | |||
int64_t index = 0; | |||
int64_t type = 1; | |||
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID, | |||
"Get type attr failed"); | |||
@@ -589,7 +582,10 @@ bool GraphToFunctionDef::FindAttrValue(const domi::tensorflow::NodeDef *node_def | |||
void GraphToFunctionDef::AddNodeAttr(const string &attr_name, const domi::tensorflow::AttrValue &value, | |||
domi::tensorflow::NodeDef *node_def) { | |||
GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null."); | |||
if (node_def == nullptr) { | |||
GELOGI("input parameter is null."); | |||
return; | |||
} | |||
node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); | |||
} | |||
} // namespace ge |
@@ -52,7 +52,7 @@ class GraphToFunctionDef { | |||
const string &name, FunctionDef *fdef); | |||
static domi::Status BuildFunctionDef(ge::ComputeGraphPtr &graph, | |||
const string &nme_in, | |||
const string &name_in, | |||
FunctionDefLibrary *library, | |||
NodeDef *call_node_def, | |||
vector<ge::InDataAnchorPtr> &in_anchor, | |||
@@ -15,7 +15,7 @@ | |||
*/ | |||
#include "graph_optimizer.h" | |||
#include "common/op_types.h" | |||
#include "graph/op_types.h" | |||
#include "common/types_map.h" | |||
#include "common/util.h" | |||
#include "framework/omg/parser/parser_inner_ctx.h" | |||
@@ -93,6 +93,7 @@ Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, const boo | |||
GE_CHECK_NOTNULL(graph_); | |||
for (auto node : graph_->GetDirectNode()) { | |||
GE_CHECK_NOTNULL(node); | |||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) | |||
string type; | |||
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); | |||
@@ -178,7 +179,7 @@ Status CollectNodeFuncs(vector<ge::NodePtr> &nodes, FunctionDefLibrary *library) | |||
GE_IF_BOOL_EXEC( | |||
AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib; | |||
GE_CHECK_NOTNULL(funcDefBytes.GetData()); | |||
string str(reinterpret_cast<char *>(funcDefBytes.GetData()), funcDefBytes.GetSize()); | |||
string str(PtrToPtr<uint8_t, char_t>(funcDefBytes.GetData()), funcDefBytes.GetSize()); | |||
GELOGI("FUNCDEF: Get function -> %s.", str.c_str()); GE_IF_BOOL_EXEC( | |||
funcLib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()), library->MergeFrom(funcLib))); | |||
} | |||
@@ -206,9 +207,11 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) { | |||
std::unique_ptr<FunctionDefLibrary> func_def_lib(new (std::nothrow) FunctionDefLibrary()); | |||
GE_CHECK_NOTNULL(func_def_lib); | |||
// convert graph to FunctionDef | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(nodes.size() == 0, | |||
REPORT_INNER_ERROR("E19999", "Param nodes size must greater than 0"); | |||
return PARAM_INVALID, "node size must greater than 0 ."); | |||
if (nodes.size() == 0) { | |||
REPORT_INNER_ERROR("E19999", "Param nodes size must greater than 0"); | |||
GELOGE(FAILED, "node size must greater than 0 ."); | |||
return PARAM_INVALID; | |||
} | |||
GE_CHK_STATUS_RET(CollectNodeFuncs(nodes, func_def_lib.get()), "Collect functionDef in nodes failed."); | |||
GE_CHK_STATUS_RET(GraphToFunctionDef::BuildFunctionDef(sub_graph, nodes[0]->GetName(), func_def_lib.get(), | |||
node_def.get(), input_anchors, output_anchors), | |||
@@ -226,7 +229,10 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) { | |||
GELOGE(PARAM_INVALID, "Serialize func_def to string failed."); | |||
return PARAM_INVALID); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(nodes.size() == 0, return PARAM_INVALID, "nodes is empty."); | |||
if (nodes.size() == 0) { | |||
GELOGE(FAILED, "nodes is empty."); | |||
return PARAM_INVALID; | |||
} | |||
std::string fusion_op_name; | |||
for (auto node : nodes) { | |||
@@ -250,10 +256,10 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) { | |||
(void)AttrUtils::SetZeroCopyBytes( | |||
fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, | |||
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(funcdefStr.data()), funcdefStr.length())); | |||
Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(funcdefStr.data()), funcdefStr.length())); | |||
(void)AttrUtils::SetZeroCopyBytes( | |||
fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | |||
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length())); | |||
Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(nodefStr.data()), nodefStr.length())); | |||
(void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, ge::GetParserContext().type); | |||
@@ -284,6 +290,7 @@ Status ParserGraphOptimizer::InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge | |||
GE_CHECK_NOTNULL(node); | |||
OpDescPtr op_def = node->GetOpDesc(); | |||
NodePtr new_node = sub_graph->AddNode(op_def); | |||
GE_CHECK_NOTNULL(new_node); | |||
node_map[node->GetName()] = new_node; | |||
// Input | |||
@@ -381,7 +388,8 @@ Status ParserGraphOptimizer::RebuildOutputAnchors(vector<ge::OutDataAnchorPtr> & | |||
GE_CHK_BOOL_EXEC(fusion_op_desc->AddOutputDesc(src_out_desc) == ge::GRAPH_SUCCESS, return FAILED); | |||
ge::DataType data_type = src_out_desc.GetDataType(); | |||
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find((int32_t)data_type); | |||
const std::map<int32_t, int32_t>::const_iterator iter = | |||
GE_TENSORFLOW_DATA_TYPE_MAP.find(static_cast<int32_t>(data_type)); | |||
GE_IF_BOOL_EXEC( | |||
iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | |||
REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported", | |||
@@ -390,7 +398,7 @@ Status ParserGraphOptimizer::RebuildOutputAnchors(vector<ge::OutDataAnchorPtr> & | |||
return PARAM_INVALID); | |||
int32_t dtype = iter->second; | |||
output_list.push_back((int64_t)dtype); | |||
output_list.push_back(static_cast<int64_t>(dtype)); | |||
GELOGI("FUNCDEF: output_list push_back %d.", dtype); | |||
} | |||
GE_IF_BOOL_EXEC(!output_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_OUT_DATATYPE, output_list)); | |||
@@ -410,14 +418,15 @@ Status ParserGraphOptimizer::RebuildInputAnchors(vector<ge::InDataAnchorPtr> &in | |||
auto tensorDescPtr = dst_node->GetOpDesc()->GetInputDescPtr(in_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL_EXEC(tensorDescPtr, return domi::FAILED); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((fusion_op_desc->AddInputDesc(*tensorDescPtr)) != GRAPH_SUCCESS, | |||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||
fusion_op_desc->GetName().c_str(), | |||
fusion_op_desc->GetType().c_str()); | |||
return FAILED, | |||
"Add fusion_op_desc AddInputDesc failed"); | |||
if (fusion_op_desc->AddInputDesc(*tensorDescPtr) != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||
fusion_op_desc->GetName().c_str(), fusion_op_desc->GetType().c_str()); | |||
GELOGE(FAILED, "Add fusion_op_desc AddInputDesc failed"); | |||
return FAILED; | |||
} | |||
ge::DataType data_type = tensorDescPtr->GetDataType(); | |||
std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find((int32_t)data_type); | |||
const std::map<int32_t, int32_t>::const_iterator iter = | |||
GE_TENSORFLOW_DATA_TYPE_MAP.find(static_cast<int32_t>(data_type)); | |||
GE_IF_BOOL_EXEC( | |||
iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), | |||
REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported", | |||
@@ -426,7 +435,7 @@ Status ParserGraphOptimizer::RebuildInputAnchors(vector<ge::InDataAnchorPtr> &in | |||
return PARAM_INVALID); | |||
int32_t dtype = iter->second; | |||
input_list.push_back((int64_t)dtype); | |||
input_list.push_back(static_cast<int64_t>(dtype)); | |||
GELOGI("FUNCDEF: input_list push_back %d.", dtype); | |||
} | |||
GE_IF_BOOL_EXEC(!input_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_IN_DATATYPE, input_list)); | |||
@@ -440,6 +449,7 @@ Status ParserGraphOptimizer::RebuildFusionNode(vector<ge::InDataAnchorPtr> &inpu | |||
vector<ge::InControlAnchorPtr> &input_control_anchors, | |||
vector<ge::OutControlAnchorPtr> &output_control_anchors, | |||
ge::NodePtr fusion_node) { | |||
GE_CHECK_NOTNULL(fusion_node); | |||
int32_t src_index = 0; | |||
for (auto out_anchor : output_anchors) { | |||
@@ -32,7 +32,7 @@ const char *const kSerializeFormat = "serialize_format"; | |||
Status ParseParams(const Message *op_src, ArgOpOperator *const op) { | |||
GE_CHECK_NOTNULL(op_src); | |||
GE_CHECK_NOTNULL(op); | |||
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||
const domi::tensorflow::NodeDef *node = reinterpret_cast<const domi::tensorflow::NodeDef *>(op_src); | |||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | |||
domi::tensorflow::AttrValue output_attr_value; | |||
if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { | |||
@@ -44,7 +44,7 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge | |||
GELOGE(PARAM_INVALID, "Op src is null"); | |||
return PARAM_INVALID; | |||
} | |||
const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||
const domi::tensorflow::NodeDef *node = PtrToPtr<const Message, const domi::tensorflow::NodeDef>(op_src); | |||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | |||
if (op_dest == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "Param op_dest is nullptr, check invalid"); | |||
@@ -109,8 +109,11 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge | |||
return FAILED; | |||
} | |||
} | |||
const auto out_desc = op_dest->MutableOutputDesc(0); | |||
GE_CHECK_NOTNULL(out_desc); | |||
out_desc->SetDataType(out_type); | |||
std::shared_ptr<NodeDef> pkg_node = ge::parser::MakeShared<NodeDef>(); | |||
std::shared_ptr<domi::tensorflow::NodeDef> pkg_node = ge::parser::MakeShared<domi::tensorflow::NodeDef>(); | |||
GE_CHECK_NOTNULL(pkg_node); | |||
pkg_node->CopyFrom(*node); | |||
@@ -130,7 +133,7 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge | |||
(void)AttrUtils::SetZeroCopyBytes( | |||
op_dest, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | |||
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(serialized_node.data()), serialized_node.length())); | |||
Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(serialized_node.data()), serialized_node.length())); | |||
GELOGI("node_def of %s is %s.", op_dest->GetName().c_str(), serialized_node.c_str()); | |||
} | |||
@@ -27,7 +27,7 @@ using domi::ParseParamByOpFunc; | |||
namespace ge { | |||
Status TensorFlowCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | |||
GE_CHECK_NOTNULL(op_src); | |||
const NodeDef *node_src = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src); | |||
const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||
GE_CHECK_NOTNULL(node_src); | |||
GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); | |||
GE_CHECK_NOTNULL(op_dest); | |||
@@ -93,7 +93,7 @@ Status TensorFlowDataParser::ParseInputFromModel(const Message *op_src, const ge | |||
GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_SHAPE), | |||
"check Attr %s failed", TENSORFLOW_ATTR_SHAPE.c_str()); | |||
const TensorShapeProto &data_shape = attr_value.shape(); | |||
const domi::tensorflow::TensorShapeProto &data_shape = attr_value.shape(); | |||
for (auto i = 0; i < data_shape.dim_size(); i++) { | |||
model_input_dims_v.push_back(data_shape.dim(i).size()); | |||
} | |||
@@ -110,7 +110,7 @@ Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge: | |||
std::string name = op_def->GetName(); | |||
if (input_dims.count(name) == 0) { | |||
GELOGI("input shape of node %s is not designated ,need parse from model", name.c_str()); | |||
for (uint32_t i = 0; i < model_input_dims_v.size(); i++) { | |||
for (size_t i = 0; i < model_input_dims_v.size(); ++i) { | |||
user_input_dims_v.push_back(model_input_dims_v[i]); | |||
} | |||
@@ -138,7 +138,7 @@ Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge: | |||
} | |||
Status TensorFlowDataParser::CheckInputShape(const std::string &name) { | |||
for (uint32_t i = 0; i < user_input_dims_v.size(); i++) { | |||
for (size_t i = 0; i < user_input_dims_v.size(); ++i) { | |||
// if input_shape has some placeholders, user should designate them. | |||
// dim i = 0, means empty tensor. | |||
// dim i = -1 or -2, means unknown shape. | |||
@@ -31,7 +31,7 @@ Status TensorFlowEnterParser::ParseParams(const Message *op_src, ge::OpDescPtr & | |||
GE_CHECK_NOTNULL(op_desc); | |||
const std::string name = op_desc->GetName(); | |||
const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||
const domi::tensorflow::NodeDef *node = PtrToPtr<const Message, const domi::tensorflow::NodeDef>(op_src); | |||
domi::tensorflow::AttrValue str_attr; | |||
if (!TensorFlowUtil::FindAttrValue(node, ENTER_ATTR_FRAME_NAME, str_attr)) { | |||
REPORT_CALL_ERROR("E19999", "In NodeDef:%s attr:%s not exist, check invalid", | |||
@@ -38,7 +38,7 @@ node { | |||
} | |||
} | |||
*/ | |||
domi::Status ParseParams(const NodeDef *node, FillOperator *op) { | |||
domi::Status ParseParams(const domi::tensorflow::NodeDef *node, FillOperator *op) { | |||
GE_CHECK_NOTNULL(node); | |||
GE_CHECK_NOTNULL(op); | |||
op->Name(node->name()); | |||
@@ -31,7 +31,7 @@ namespace ge { | |||
Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { | |||
GE_CHECK_NOTNULL(op_src); | |||
GE_CHECK_NOTNULL(op); | |||
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||
const domi::tensorflow::NodeDef *node = reinterpret_cast<const domi::tensorflow::NodeDef *>(op_src); | |||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | |||
string type = node->op(); | |||
@@ -64,7 +64,7 @@ Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { | |||
GE_IF_BOOL_EXEC(((type == "_Retval") && (TensorFlowUtil::FindAttrValue(node, ATTR_NAME_INDEX, index_attr_value))), | |||
op->Index(index_attr_value.i())); | |||
NodeDef *pkg_node = new (std::nothrow) NodeDef(); | |||
domi::tensorflow::NodeDef *pkg_node = new (std::nothrow) domi::tensorflow::NodeDef(); | |||
GE_CHECK_NOTNULL(pkg_node); | |||
pkg_node->CopyFrom(*node); | |||
@@ -44,7 +44,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowFusionOpParser : public TensorFlowOpParse | |||
* @return SUCCESS Parsing success | |||
* @return FAILED Parsing failed | |||
*/ | |||
virtual Status ParseParams(const std::vector<const NodeDef *> &v_input_const, ge::NodePtr &node) const; | |||
virtual Status ParseParams(const std::vector<const NodeDef *> &v_input_const, ge::NodePtr &op_dest) const; | |||
/** | |||
* @ingroup domi_omg | |||
@@ -31,7 +31,7 @@ Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr & | |||
GE_CHECK_NOTNULL(op_src); | |||
GE_CHECK_NOTNULL(op_desc); | |||
const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||
const domi::tensorflow::NodeDef *node = PtrToPtr<const Message, const domi::tensorflow::NodeDef>(op_src); | |||
domi::tensorflow::AttrValue attr_num; | |||
if (!(TensorFlowUtil::FindAttrValue(node, ATTR_NAME_N, attr_num))) { | |||
GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_desc->GetName().c_str(), ATTR_NAME_N.c_str()); | |||
@@ -26,7 +26,7 @@ using namespace ge::parser; | |||
namespace ge { | |||
Status TensorFlowNoOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | |||
const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src); | |||
const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||
GE_CHECK_NOTNULL(node); | |||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | |||
NoOpOperator op; | |||
@@ -42,25 +42,6 @@ | |||
#include "proto/tensorflow/graph.pb.h" | |||
#include "proto/tensorflow/node_def.pb.h" | |||
using domi::tensorflow::NodeDef; | |||
using domi::tensorflow::TensorProto; | |||
using google::protobuf::int32; | |||
using google::protobuf::int64; | |||
using google::protobuf::Message; | |||
using std::string; | |||
using std::vector; | |||
using Status = domi::Status; | |||
using domi::tensorflow::AttrValue; | |||
using domi::tensorflow::DataType; | |||
using domi::tensorflow::DT_BOOL; | |||
using domi::tensorflow::DT_FLOAT; | |||
using domi::tensorflow::DT_INT32; | |||
using domi::tensorflow::DT_INT64; | |||
using domi::tensorflow::DT_INVALID; | |||
using domi::tensorflow::TensorShapeProto; | |||
using domi::tensorflow::TensorShapeProto_Dim; | |||
namespace ge { | |||
/** | |||
* @ingroup domi_omg | |||
@@ -40,7 +40,6 @@ | |||
#include "parser/common/op_parser_factory.h" | |||
#include "parser/common/parser_fp16_t.h" | |||
#include "parser/common/pass_manager.h" | |||
#include "parser/common/pre_checker.h" | |||
#include "parser/common/prototype_pass_manager.h" | |||
#include "parser/common/thread_pool.h" | |||
#include "parser/common/parser_utils.h" | |||
@@ -54,6 +53,7 @@ | |||
#include "register/register_utils.h" | |||
#include "register/scope/scope_pass_registry_impl.h" | |||
#include "parser/common/auto_mapping_subgraph_io_index_func.h" | |||
#include "graph/def_types.h" | |||
using ge::OpParserFactory; | |||
using ge::Pb2Json; | |||
@@ -507,8 +507,11 @@ Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def, | |||
ge::NodePtr node = nullptr; | |||
node = graph->AddNode(op_desc); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node == nullptr), DeleteFuisonNodeDef(); return INTERNAL_ERROR, "add node failed."); | |||
if (node == nullptr) { | |||
DeleteFuisonNodeDef(); | |||
GELOGE(FAILED, "add node failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
node_map_[node_name] = node; | |||
@@ -545,7 +548,11 @@ Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def, | |||
// checkout op input number with IR | |||
GE_RETURN_IF_ERROR(CheckoutInputNum(op, node_def)); | |||
ge::NodePtr node = graph->AddNode(op); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node == nullptr), DeleteFuisonNodeDef(); return INTERNAL_ERROR, "add node failed."); | |||
if (node == nullptr) { | |||
DeleteFuisonNodeDef(); | |||
GELOGE(FAILED, "add node failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
node_map_[node_name] = node; | |||
@@ -794,22 +801,24 @@ Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) { | |||
GE_CHECK_NOTNULL(out_archor_ptr); | |||
ge::InDataAnchorPtr in_archor_ptr = dest->GetInDataAnchor(outputpair.second); | |||
GE_CHECK_NOTNULL(in_archor_ptr); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed", | |||
src->GetName().c_str(), dest->GetName().c_str()); | |||
return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", | |||
src->GetName().c_str(), dest->GetName().c_str()); | |||
if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) { | |||
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed", | |||
src->GetName().c_str(), dest->GetName().c_str()); | |||
GELOGE(FAILED, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), dest->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
} else { | |||
GELOGD("Start add contorl edge: from %s to %s.", src->GetName().c_str(), dest->GetName().c_str()); | |||
ge::InControlAnchorPtr in_archor_ptr = dest->GetInControlAnchor(); | |||
GE_CHECK_NOTNULL(in_archor_ptr); | |||
ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor(); | |||
GE_CHECK_NOTNULL(out_archor_ptr); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed", | |||
src->GetName().c_str(), dest->GetName().c_str()); | |||
return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", | |||
src->GetName().c_str(), dest->GetName().c_str()); | |||
if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) { | |||
REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed", | |||
src->GetName().c_str(), dest->GetName().c_str()); | |||
GELOGE(FAILED, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), dest->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
} | |||
} | |||
dest_input_map.erase(input_iter); | |||
@@ -921,10 +930,11 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co | |||
} | |||
std::map<std::string, std::string>::const_iterator iterator = parser->adaptedOpTypeMap_.find(node_name); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
iterator == parser->adaptedOpTypeMap_.end(), | |||
REPORT_INNER_ERROR("E19999", "get adapted op type failed, node name = %s", node_name.c_str()); | |||
return FAILED, "get adapted op type failed, node name = %s", node_name.c_str()); | |||
if (iterator == parser->adaptedOpTypeMap_.cend()) { | |||
REPORT_INNER_ERROR("E19999", "get adapted op type failed, node name = %s", node_name.c_str()); | |||
GELOGE(FAILED, "get adapted op type failed, node name = %s", node_name.c_str()); | |||
return FAILED; | |||
} | |||
string op_type = iterator->second; | |||
// Log printing for determining operator type | |||
@@ -1017,10 +1027,12 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co | |||
node = graph->AddNode(op); | |||
} | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
(node == nullptr), REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op->GetName().c_str(), | |||
op->GetType().c_str(), graph->GetName().c_str()); | |||
return INTERNAL_ERROR, "add node failed."); | |||
if (node == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op->GetName().c_str(), | |||
op->GetType().c_str(), graph->GetName().c_str()); | |||
GELOGE(FAILED, "add node failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
if (needFusion) { | |||
shared_ptr<OpParser> fusion_op_parser = factory->CreateFusionOpParser(op_type); | |||
@@ -1116,10 +1128,11 @@ Status TensorFlowModelParser::AddNodeToGraphAndMarkFormat(ge::ComputeGraphPtr &g | |||
for (size_t j = 0; j < op_node_list_size; j++) { | |||
const string op_node_name = op_node_name_list[j]; | |||
auto iterator = node_map_.find(op_node_name); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
(iterator == node_map_.end()), | |||
REPORT_INNER_ERROR("E19999", "node:%s can't find in node_map_, check invalid", op_node_name.c_str()); | |||
return INTERNAL_ERROR, "add node failed."); | |||
if (iterator == node_map_.end()) { | |||
REPORT_INNER_ERROR("E19999", "node:%s can't find in node_map_, check invalid", op_node_name.c_str()); | |||
GELOGE(FAILED, "add node failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
GE_CHECK_NOTNULL(iterator->second); | |||
GE_CHK_STATUS_RET(iterator->second->SetOwnerComputeGraph(graph), "set owner compute graph failed"); | |||
graph->AddNode(iterator->second); | |||
@@ -1178,15 +1191,22 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g | |||
domi::tensorflow::GraphDef OriDef; | |||
bool read = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &OriDef); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, REPORT_INNER_ERROR("E19999", "read graph proto from binary failed"); | |||
return INTERNAL_ERROR, "read_proto_from_binary failed."); | |||
if (!read) { | |||
REPORT_INNER_ERROR("E19999", "read graph proto from binary failed"); | |||
GELOGE(FAILED, "read_proto_from_binary failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
domi::tensorflow::GraphDef graph_def; | |||
if (ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) { | |||
const bool is_empty_input = GetParserContext().input_dims.empty() && GetParserContext().out_nodes_map.empty(); | |||
if (is_empty_input) { | |||
graph_def = OriDef; | |||
} else { | |||
GELOGI("Before Trim, the Graph Node size is:%d", OriDef.node_size()); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(TrimGraph(OriDef, &graph_def), return INTERNAL_ERROR, "Trim Graph fail."); | |||
if (static_cast<bool>(TrimGraph(OriDef, &graph_def))) { | |||
GELOGE(FAILED, "Trim Graph fail."); | |||
return INTERNAL_ERROR; | |||
} | |||
GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size()); | |||
} | |||
@@ -1236,9 +1256,6 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g | |||
// This function call affects the return value of prechecker::instance().Haserror() | |||
GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list)); | |||
// Check the input validity of the node, the input attribute must have a corresponding node | |||
GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def)); | |||
// Building input and input relationships for all OP nodes | |||
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def)); | |||
GELOGD("[TF ParseFromMemory] get op nodes context from graph success"); | |||
@@ -1269,9 +1286,12 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g | |||
return INTERNAL_ERROR; | |||
} | |||
const string &node_op = node_def->op(); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((tensorflow_op_map.find(node_op) == tensorflow_op_map.end()), DeleteFuisonNodeDef(); | |||
REPORT_INNER_ERROR("E19999", "Op type %s unsupport", node_op.c_str()); | |||
return INTERNAL_ERROR, "Unsupport op type %s", node_op.c_str()); | |||
if (tensorflow_op_map.find(node_op) == tensorflow_op_map.cend()) { | |||
DeleteFuisonNodeDef(); | |||
REPORT_INNER_ERROR("E19999", "Op type %s unsupport", node_op.c_str()); | |||
GELOGE(FAILED, "Unsupport op type %s", node_op.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
ret = AddNode(node_def, graph, scope_graph); | |||
if (ret != SUCCESS) { | |||
@@ -1339,7 +1359,10 @@ Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr | |||
// Store objects parsed from pb files | |||
domi::tensorflow::GraphDef ori_def; | |||
bool read = ge::parser::ReadProtoFromBinaryFile(model_path, &ori_def); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, return INTERNAL_ERROR, "read_proto_from_binary failed."); | |||
if (!read) { | |||
GELOGE(FAILED, "read tensorflow file failed when the inupt param value of --framework is 3."); | |||
return INTERNAL_ERROR; | |||
} | |||
// Trim graph by user input and output. | |||
domi::tensorflow::GraphDef graph_def; | |||
@@ -1347,7 +1370,10 @@ Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr | |||
graph_def = ori_def; | |||
} else { | |||
GELOGI("Before Trim, the Graph Node size is:%d", ori_def.node_size()); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(TrimGraph(ori_def, &graph_def), return INTERNAL_ERROR, "Trim Graph fail."); | |||
if (static_cast<bool>(TrimGraph(ori_def, &graph_def))) { | |||
GELOGE(FAILED, "Trim Graph fail."); | |||
return INTERNAL_ERROR; | |||
} | |||
GELOGI("After Trim, The graph_def.node size is:%d", graph_def.node_size()); | |||
} | |||
@@ -1375,7 +1401,7 @@ Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr | |||
} | |||
} | |||
std::map<std::string, domi::tensorflow::GraphDef>::const_iterator | |||
const std::map<std::string, domi::tensorflow::GraphDef>::const_iterator | |||
iter = function_name_to_graphdef.find(arg.function_name); | |||
if (iter == function_name_to_graphdef.end()) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E12013", {"functionname"}, {arg.function_name}); | |||
@@ -1415,7 +1441,8 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||
GE_CHECK_NOTNULL(proto); | |||
GE_CHECK_NOTNULL(graph); | |||
const domi::tensorflow::GraphDef *ori_graph = reinterpret_cast<const domi::tensorflow::GraphDef *>(proto); | |||
const domi::tensorflow::GraphDef *ori_graph = | |||
ge::PtrToPtr<google::protobuf::Message, domi::tensorflow::GraphDef>(proto); | |||
// Make a copy for operation without modifying the original graph def. | |||
domi::tensorflow::GraphDef graph_def = *ori_graph; | |||
@@ -1441,12 +1468,16 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||
GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()), | |||
"Add node_def to PreChecker failed, node name: %s.", node.name().c_str()); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckName(&node) != SUCCESS, return FAILED, | |||
"Check op[%s] failed, name repeat in tensorflow pb file.", node.name().c_str()); | |||
GE_CHK_BOOL_EXEC_NOLOG( | |||
node.op() == TENSORFLOWF_NODE_OP_IDENTITY, | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckType(&node, true) != SUCCESS, return FAILED, | |||
"Check op[%s]'s optype failed, type is not supported.", node.name().c_str());) | |||
if (PreChecker::Instance().CheckName(&node) != SUCCESS) { | |||
GELOGE(FAILED, "Check op[%s] failed, name repeat in tensorflow pb file.", node.name().c_str()); | |||
return FAILED; | |||
} | |||
if (node.op() != TENSORFLOWF_NODE_OP_IDENTITY) { | |||
if (PreChecker::Instance().CheckType(&node, true) != SUCCESS) { | |||
GELOGE(FAILED, "Check op[%s]'s optype failed, type is not supported.", node.name().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
} | |||
bool has_error = false; | |||
@@ -1471,10 +1502,6 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||
// This function call affects the return value of prechecker::instance().Haserror() | |||
GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list)); | |||
// Check the input validity of the node, the input attribute must have a corresponding node | |||
GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def)); | |||
GELOGD("[TF Parse] check graph success"); | |||
// Building input and input relationships for all OP nodes | |||
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def)); | |||
GELOGD("[TF Parse] get op nodes context from graph success"); | |||
@@ -1547,37 +1574,6 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||
return SUCCESS; | |||
} | |||
Status TensorFlowModelParser::CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const { | |||
// Number of data nodes | |||
uint32_t data_node_count = 0; | |||
for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) { | |||
// Check that all input is valid | |||
for (const string &node_name : node_def.input()) { | |||
string tmp_node_name; | |||
GE_RETURN_IF_ERROR(CheckInputNodeName(node_name, &tmp_node_name, nullptr, nullptr)); | |||
if (nodedef_map_.find(tmp_node_name) == nodedef_map_.end()) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E12009", {"opname", "inputopname"}, | |||
{node_def.name(), node_name}); | |||
GELOGE(INTERNAL_ERROR, "Op[%s]'s input op[%s] is not exist in the graph_def.", node_def.name().c_str(), | |||
node_name.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
} | |||
if (node_def.op() == TENSORFLOWF_NODE_OP_PLACEHOLDER || node_def.op() == ge::parser::ARG) { | |||
data_node_count++; | |||
} | |||
} | |||
if (data_node_count == 0) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E12010"); | |||
GELOGE(INTERNAL_ERROR, "Model has no Placeholder node."); | |||
return INTERNAL_ERROR; | |||
} | |||
return SUCCESS; | |||
} | |||
Status TensorFlowModelParser::GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def) { | |||
// Build the input relationship first | |||
for (auto &iter : op_node_context_map_) { | |||
@@ -1868,7 +1864,7 @@ Status TensorFlowModelParser::UpdateAllNodeOpContext(shared_ptr<ge::ScopeGraph> | |||
ge::ScopeFusionOpInfo info; | |||
if (IsFusionOpChild(op_node_name, &info) && nodedef_map_[op_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) { | |||
// This node is a fusion operator | |||
std::map<std::string, OpNodeContext>::const_iterator | |||
const std::map<std::string, OpNodeContext>::const_iterator | |||
fusion_iter = tmp_fusion_op_node_context_map.find(info.fusion_node_name); | |||
if (fusion_iter == tmp_fusion_op_node_context_map.end()) { | |||
OpNodeContext op_node_context; | |||
@@ -2108,10 +2104,10 @@ Status TensorFlowModelParser::NormalizeInputOrOutputMap( | |||
std::set<std::string> compare_set; | |||
for (auto &pair : pairs) { | |||
bool is_fusion_child = (fusion_op_children_.find(node_name) != fusion_op_children_.end()) || | |||
(fusion_op_children_.find(iter->first) != fusion_op_children_.end()); | |||
bool is_fusion_op = (fusion_op_type_map_.find(node_name) != fusion_op_type_map_.end()) || | |||
(fusion_op_type_map_.find(iter->first) != fusion_op_type_map_.end()); | |||
bool is_fusion_child = (fusion_op_children_.find(node_name) != fusion_op_children_.cend()) || | |||
(fusion_op_children_.find(iter->first) != fusion_op_children_.cend()); | |||
bool is_fusion_op = (fusion_op_type_map_.find(node_name) != fusion_op_type_map_.cend()) || | |||
(fusion_op_type_map_.find(iter->first) != fusion_op_type_map_.cend()); | |||
if (((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) && | |||
(is_fusion_child || is_fusion_op)) { | |||
// The edge will be cut off at the back, ignoring | |||
@@ -2119,7 +2115,7 @@ Status TensorFlowModelParser::NormalizeInputOrOutputMap( | |||
} | |||
string name = to_string(pair.first) + ":" + to_string(pair.second); | |||
std::set<std::string>::const_iterator compare_iter = compare_set.find(name); | |||
const std::set<std::string>::const_iterator compare_iter = compare_set.find(name); | |||
if (compare_iter != compare_set.end()) { | |||
// pair<from,to> repeat, ignore | |||
continue; | |||
@@ -2158,7 +2154,7 @@ void TensorFlowModelParser::SaveEdgesControlInfo(const string &node_name, const | |||
} | |||
void TensorFlowModelParser::UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo &info) { | |||
std::map<std::string, std::vector<int32_t>>::const_iterator iter = edges_control_map.find(info.node_name); | |||
const std::map<std::string, std::vector<int32_t>>::const_iterator iter = edges_control_map.find(info.node_name); | |||
if (iter != edges_control_map.end()) { | |||
// Delete the original fusion operator node information and add the fusion operator control edge information | |||
edges_control_map.erase(iter); | |||
@@ -2228,7 +2224,8 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||
GE_CHECK_NOTNULL(graph); | |||
ge::GetParserContext().train_flag = true; | |||
const domi::tensorflow::GraphDef *graph_def_in = reinterpret_cast<const domi::tensorflow::GraphDef *>(proto); | |||
const domi::tensorflow::GraphDef *graph_def_in = | |||
ge::PtrToPtr<google::protobuf::Message, domi::tensorflow::GraphDef>(proto); | |||
// Make a copy for operation without modifying the original graph def. | |||
domi::tensorflow::GraphDef graph_def_operation = *graph_def_in; | |||
domi::tensorflow::GraphDef *graph_def = &graph_def_operation; | |||
@@ -2415,7 +2412,7 @@ Status TensorFlowModelParser::ParseProto(const std::string &serialized_proto, ge | |||
GELOGE(FAILED, "Proto object GraphDef parse serialized proto failed"); | |||
return FAILED; | |||
} | |||
return ParseProto(reinterpret_cast<const google::protobuf::Message *>(&graph_def), graph); | |||
return ParseProto(ge::PtrToPtr<domi::tensorflow::GraphDef, const google::protobuf::Message>(&graph_def), graph); | |||
} | |||
Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback, | |||
@@ -2472,95 +2469,18 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_pro | |||
return SUCCESS; | |||
} | |||
// For the identity operator whose output is "_retval", optimize it. | |||
Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, | |||
const string &curr_node_name, bool &clear_input_flag) { | |||
auto context_iter = op_node_context_map_.find(curr_node_name); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
(context_iter == op_node_context_map_.end()), | |||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str()); | |||
return INTERNAL_ERROR, "Can't find op node context."); | |||
OpNodeContext op_node_context = context_iter->second; | |||
std::map<std::string, NodeDef *>::const_iterator node_def_iter = nodedef_map.find(curr_node_name); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
(node_def_iter == nodedef_map.end()), | |||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in nodedef_map, check invalid", curr_node_name.c_str()); | |||
return INTERNAL_ERROR, "Can't find nodedef"); | |||
domi::tensorflow::NodeDef *curr_node_def = node_def_iter->second; | |||
GE_CHECK_NOTNULL(curr_node_def); | |||
bool has_out_retval = false; | |||
// For the identity operator whose output is "_retval", optimize it | |||
std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map; | |||
for (auto output_iter = output_map.begin(); output_iter != output_map.end(); ++output_iter) { | |||
const string &output_node_name = output_iter->first; | |||
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; | |||
GE_CHECK_NOTNULL(output_node_def); | |||
if (output_node_def->op() == "_Retval") { | |||
GELOGD("_Retval Identity need optimize."); | |||
output_node_def->set_input(0, curr_node_def->input(0).c_str()); | |||
has_out_retval = true; | |||
GELOGD("op %s set input(0):%s.", output_node_def->name().c_str(), curr_node_def->input(0).c_str()); | |||
} | |||
} | |||
// Deal with non _Retval output operator of Identity. | |||
if (has_out_retval) { | |||
for (auto output_iter = output_map.begin(); output_iter != output_map.end(); ++output_iter) { | |||
const string &output_node_name = output_iter->first; | |||
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; | |||
GE_CHECK_NOTNULL(output_node_def); | |||
GE_IF_BOOL_EXEC(output_node_def->op() == "_Retval", continue); | |||
for (int k = 0; k < output_node_def->input_size(); ++k) { | |||
GE_IF_BOOL_EXEC( | |||
output_node_def->input(k) == curr_node_name, output_node_def->set_input(k, curr_node_def->input(0).c_str()); | |||
GELOGD("%s op set input(%d):%s.", output_node_def->name().c_str(), k, curr_node_def->input(0).c_str());) | |||
} | |||
} | |||
clear_input_flag = true; | |||
} | |||
return SUCCESS; | |||
} | |||
Status TensorFlowModelParser::GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def, | |||
map<string, NodeDef *> &nodedef_map, | |||
const vector<NodeDef *> &nodedef_to_optimize) { | |||
GE_CHECK_NOTNULL(graph_def); | |||
if (!nodedef_to_optimize.empty()) { | |||
// Building input and input relationships for all OP nodes | |||
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def)); | |||
} else { | |||
return SUCCESS; | |||
} | |||
for (auto &curr_node_def : nodedef_to_optimize) { | |||
GE_CHECK_NOTNULL(curr_node_def); | |||
bool clear_input_flag = false; | |||
const string &curr_node_name = curr_node_def->name(); | |||
GE_RETURN_IF_ERROR(OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag)); | |||
if (clear_input_flag) { | |||
curr_node_def->clear_input(); | |||
} | |||
} | |||
GELOGI("GraphDefOptimizeIdentity success."); | |||
return SUCCESS; | |||
} | |||
Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, | |||
map<string, NodeDef *> &nodedef_map, | |||
const std::pair<string, int> &input_data, | |||
const std::vector<string> &control_list) { | |||
GE_CHECK_NOTNULL(curr_mode_def); | |||
if (curr_mode_def == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "Param curr_mode_def is nullptr, check invalid"); | |||
GELOGE(FAILED, "input param is nullptr."); | |||
return PARAM_INVALID; | |||
} | |||
string curr_node_name = curr_mode_def->name(); | |||
auto context_iter = op_node_context_map_.find(curr_node_name); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
(context_iter == op_node_context_map_.end()), | |||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str()); | |||
return INTERNAL_ERROR, "Can't find op node context."); | |||
if (context_iter == op_node_context_map_.end()) { | |||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str()); | |||
GELOGE(FAILED, "Can't find op node context."); | |||
return INTERNAL_ERROR; | |||
} | |||
OpNodeContext op_node_context = context_iter->second; | |||
std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map; | |||
@@ -2773,7 +2693,7 @@ struct DelTransposeInfo { | |||
int inputIdx; | |||
}; | |||
Status GetTransposeInfo(GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo, | |||
Status GetTransposeInfo(domi::tensorflow::GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo, | |||
std::map<std::string, DelTransposeInfo> &transposeInfo) { | |||
GE_CHECK_NOTNULL(graph_def); | |||
for (int i = 0; i < graph_def->node_size(); ++i) { | |||
@@ -2826,7 +2746,7 @@ Status EraseTransposeNode(std::map<std::string, std::string> &softmaxInfo, | |||
itTranspose->second.node_def->input(0).c_str()); | |||
itTranspose = transposeInfo.erase(itTranspose); | |||
} else { | |||
itTranspose++; | |||
++itTranspose; | |||
} | |||
} | |||
@@ -2847,7 +2767,7 @@ void TensorFlowModelParser::OptimizeTranspose(std::map<std::string, DelTranspose | |||
} | |||
} | |||
void TensorFlowModelParser::SoftmaxAddAttr(GraphDef *const graph_def) { | |||
void TensorFlowModelParser::SoftmaxAddAttr(domi::tensorflow::GraphDef *const graph_def) { | |||
// The caller guarantees that the pointer is not null | |||
for (int i = 0; i < graph_def->node_size(); ++i) { | |||
auto node_def = graph_def->mutable_node(i); | |||
@@ -2864,8 +2784,6 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph | |||
GE_CHECK_NOTNULL(graph_def); | |||
map<string, NodeDef *> nodedef_map; | |||
vector<string> op_node_name_list; | |||
// Save Identity and ReadVariableOp | |||
vector<NodeDef *> identity_to_optimize; | |||
// Save Snapshot | |||
vector<NodeDef *> snapshot_to_optimize; | |||
@@ -2875,16 +2793,12 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph | |||
const string &node_name = node_def->name(); | |||
Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list); | |||
GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); | |||
if (node_def->op() == ge::parser::IDENTITY || node_def->op() == ge::parser::READVARIABLEOP) { | |||
identity_to_optimize.push_back(node_def); | |||
} else if (node_def->op() == ge::parser::SNAPSHOT) { | |||
if (node_def->op() == ge::parser::SNAPSHOT) { | |||
snapshot_to_optimize.push_back(node_def); | |||
} | |||
nodedef_map[node_name] = node_def; | |||
} | |||
// Optimize for Identity/ReadVariableOp | |||
GE_RETURN_IF_ERROR(GraphDefOptimizeIdentity(graph_def, nodedef_map, identity_to_optimize)); | |||
// Optimize for Snapshot | |||
GE_RETURN_IF_ERROR(GraphDefOptimizeSnapShot(graph_def, nodedef_map, snapshot_to_optimize)); | |||
@@ -3055,7 +2969,7 @@ Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node, | |||
GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TENSOR) != SUCCESS, | |||
return FAILED); | |||
const TensorProto &tensor = attr_value.tensor(); | |||
const TensorShapeProto &tensor_shape = tensor.tensor_shape(); | |||
const domi::tensorflow::TensorShapeProto &tensor_shape = tensor.tensor_shape(); | |||
GE_IF_BOOL_EXEC(tensor_shape.dim_size() != 1 || tensor_shape.dim(0).size() != parser::DIM_DEFAULT_SIZE, | |||
return SUCCESS); | |||
GE_IF_BOOL_EXEC(tensor.tensor_content().empty(), return SUCCESS); | |||
@@ -3110,10 +3024,10 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef | |||
std::set<string> next_inputs; | |||
for (const string ¤t_input : current_inputs) { | |||
delete_nodes.insert(current_input); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input), | |||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||
{"input_shape", current_input}); | |||
return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); | |||
GE_CHK_BOOL_EXEC(node_lookup.count(current_input) > 0U, | |||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||
{"input_shape", current_input}); | |||
return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); | |||
const NodeDef *current_node = node_lookup[current_input]; | |||
GE_CHECK_NOTNULL(current_node); | |||
for (const string &input_name : current_node->input()) { | |||
@@ -3128,7 +3042,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef | |||
domi::tensorflow::GraphDef filtered_graph_def; | |||
filtered_graph_def.mutable_node()->Clear(); | |||
for (const NodeDef &node : input_graph_def.node()) { | |||
if (input_nodes.count(node.name())) { | |||
if (static_cast<bool>(input_nodes.count(node.name()))) { | |||
*(filtered_graph_def.mutable_node()->Add()) = node; | |||
} | |||
if (!delete_nodes.count(node.name())) { | |||
@@ -3137,12 +3051,12 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef | |||
} | |||
output_graph_def->Clear(); | |||
for (const NodeDef &node : filtered_graph_def.node()) { | |||
if (input_nodes.count(node.name())) { | |||
if (static_cast<bool>(input_nodes.count(node.name()))) { | |||
NodeDef placeholder_node = node; | |||
placeholder_node.clear_input(); | |||
GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); | |||
domi::tensorflow::AttrValue attr_value; | |||
TensorShapeProto *data_shape = attr_value.mutable_shape(); | |||
domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape(); | |||
GE_CHECK_NOTNULL(data_shape); | |||
const ge::ParserContext &ctx = ge::GetParserContext(); | |||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||
@@ -3185,11 +3099,11 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef | |||
std::set<string> next_inputs; | |||
for (const string ¤t_input : current_inputs) { | |||
required_nodes.insert(current_input); | |||
GE_IF_BOOL_EXEC(input_nodes.count(current_input), continue); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input), | |||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||
{"out_nodes", current_input}); | |||
return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); | |||
GE_IF_BOOL_EXEC(static_cast<bool>(input_nodes.count(current_input)), continue); | |||
GE_CHK_BOOL_EXEC(node_lookup.count(current_input) > 0U, | |||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||
{"out_nodes", current_input}); | |||
return FAILED, "op[%s] not found in graph.", current_input.c_str()); | |||
const NodeDef *current_node = node_lookup[current_input]; | |||
GE_CHECK_NOTNULL(current_node); | |||
for (const string &input_name : current_node->input()) { | |||
@@ -3204,18 +3118,18 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef | |||
domi::tensorflow::GraphDef filtered_graph_def; | |||
filtered_graph_def.mutable_node()->Clear(); | |||
for (const NodeDef &node : input_graph_def.node()) { | |||
if (required_nodes.count(node.name())) { | |||
if (static_cast<bool>(required_nodes.count(node.name()))) { | |||
*(filtered_graph_def.mutable_node()->Add()) = node; | |||
} | |||
} | |||
output_graph_def->Clear(); | |||
for (const NodeDef &node : filtered_graph_def.node()) { | |||
if (input_nodes.count(node.name())) { | |||
if (static_cast<bool>(input_nodes.count(node.name()))) { | |||
NodeDef placeholder_node = node; | |||
placeholder_node.clear_input(); | |||
GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); | |||
domi::tensorflow::AttrValue attr_value; | |||
TensorShapeProto *data_shape = attr_value.mutable_shape(); | |||
domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape(); | |||
GE_CHECK_NOTNULL(data_shape); | |||
const ge::ParserContext &ctx = ge::GetParserContext(); | |||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||
@@ -3265,11 +3179,12 @@ Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_par | |||
// Find all children of the fusion operator | |||
auto iter = fusion_op_nodedef_map_.find(node_def->name()); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
iter == fusion_op_nodedef_map_.end(), | |||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in fusion_op_nodedef_map_, check invalid", | |||
node_def->name().c_str()); | |||
return INTERNAL_ERROR, "FusionOp node %s has no children node!", node_def->name().c_str()); | |||
if (iter == fusion_op_nodedef_map_.end()) { | |||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in fusion_op_nodedef_map_, check invalid", | |||
node_def->name().c_str()); | |||
GELOGE(FAILED, "FusionOp node %s has no children node!", node_def->name().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
(void)ge::AttrUtils::SetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, node_def->op()); | |||
vector<const domi::tensorflow::NodeDef *> node_def_v = iter->second; | |||
@@ -3325,7 +3240,7 @@ Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_par | |||
* @return false optimize failed | |||
* | |||
*/ | |||
Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) { | |||
Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) const { | |||
GE_CHECK_NOTNULL(graph_def); | |||
// 1. find all the nodes in the graph and save them to all_nodedef_map | |||
map<string, NodeDef *> all_nodedef_map; | |||
@@ -3336,7 +3251,7 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||
string node_name = current_node->name(); | |||
all_nodedef_map[node_name] = current_node; | |||
} | |||
GE_CHK_BOOL_EXEC_INFO(!all_nodedef_map.empty(), return SUCCESS, "all_nodedef_map is empty"); | |||
GELOGD("node size is: %zu", all_nodedef_map.size()); | |||
// 2. move input to attr. | |||
for (auto &it_node_map : all_nodedef_map) { | |||
@@ -3347,14 +3262,14 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||
// 2.1. check whether the current op is register for move to attr. | |||
const std::vector<domi::RemoveInputConfigure> &move_input_vec = | |||
domi::OpRegistry::Instance()->GetRemoveInputConfigure(current_op_name); | |||
GE_CHK_BOOL_EXEC_NOLOG(!move_input_vec.empty(), continue); | |||
GELOGD("Current op %s is registered for remove input.", current_op_name.c_str()); | |||
// 2.2 check whether the current op is a TVM op. | |||
GE_CHK_BOOL_EXEC_INFO( | |||
domi::OpRegistry::Instance()->GetImplyTypeByOriOpType(current_op_name) == domi::ImplyType::TVM, continue, | |||
"op %s is not TVM op", current_op_name.c_str()); | |||
GELOGD("handle tvm op %s", current_op_name.c_str()); | |||
const bool is_unknown_custom_op = move_input_vec.empty() || | |||
(domi::OpRegistry::Instance()->GetImplyTypeByOriOpType(current_op_name) != domi::ImplyType::TVM); | |||
if (is_unknown_custom_op) { | |||
GELOGI("op %s is not TVM op, move input size: %zu", current_op_name.c_str(), move_input_vec.size()); | |||
continue; | |||
} | |||
GELOGD("Current op %s is registered for remove input and tvm op", current_op_name.c_str()); | |||
// 2.3 copy input to attr | |||
set<uint32_t> unused_inputs; | |||
@@ -3401,7 +3316,8 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||
} | |||
for (size_t i = 0; i < it.input_order.size(); ++i) { | |||
int new_index = it.input_order[i]; | |||
if (new_index < 0 || new_index >= inputs.size()) { | |||
const bool is_input_invalid = (new_index < 0) || (new_index >= inputs.size()); | |||
if (is_input_invalid) { | |||
REPORT_INNER_ERROR("E19999", "New order of %s has invalid index %d, out of range(0, %d)", | |||
it_node_map.first.c_str(), new_index, inputs.size()); | |||
GELOGE(INTERNAL_ERROR, "New order of %s has invalid index %d.", it_node_map.first.c_str(), new_index); | |||
@@ -3443,7 +3359,10 @@ Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow:: | |||
if (input_node_def->op() == parser::SWITCH || input_node_def->op() == parser::REFSWITCH) { | |||
NodeDef *identity_node_def = graph_def->add_node(); | |||
GE_CHECK_NOTNULL(identity_node_def); | |||
input_node_name = input_node_name + "identity"; | |||
std::string remove_input_name = remove_input; | |||
remove_input_name = remove_input_name.find(":") == std::string::npos ? | |||
input_node_name : (remove_input_name.replace(remove_input_name.find(":"), 1, "_")); | |||
input_node_name = remove_input_name + "_identity"; | |||
identity_node_def->set_name(input_node_name); | |||
identity_node_def->set_op(parser::IDENTITY); | |||
identity_node_def->add_input(remove_input); | |||
@@ -3465,7 +3384,7 @@ Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow:: | |||
*/ | |||
Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *node_def, | |||
const set<uint32_t> &remove_index_set, | |||
const map<string, NodeDef *> &all_node_map) { | |||
const map<string, NodeDef *> &all_node_map) const { | |||
GE_CHECK_NOTNULL(node_def); | |||
if (remove_index_set.empty()) { | |||
GELOGI("The size of remove_index_set is zero."); | |||
@@ -3662,7 +3581,7 @@ Status TensorFlowModelParser::RecordFusionResult(const std::shared_ptr<ge::Scope | |||
return SUCCESS; | |||
} | |||
Status TensorFlowModelParser::SetOriginNodeContext(NodeDef *node_def, OpNodeContext &op_node_context, | |||
Status TensorFlowModelParser::SetOriginNodeContext(const NodeDef *node_def, OpNodeContext &op_node_context, | |||
const std::vector<std::pair<std::string, int32_t>> &inputs, | |||
const std::vector<std::pair<std::string, int32_t>> &outputs) { | |||
int32_t in_index = 0; | |||
@@ -3752,7 +3671,7 @@ void TensorFlowModelParser::UpdateInnerInputMap(const string &fusion_op_name, Op | |||
++iter; | |||
} | |||
} | |||
op_node_context.input_map.insert(tmp_input_map.begin(), tmp_input_map.end()); | |||
op_node_context.input_map.insert(tmp_input_map.cbegin(), tmp_input_map.cend()); | |||
// update output map of pre node | |||
for (const auto &in_iter : op_node_context.input_map) { | |||
auto src_iter = op_node_context_map_.find(in_iter.first); | |||
@@ -3801,7 +3720,7 @@ void TensorFlowModelParser::UpdateInnerOutputMap(const string &fusion_op_name, O | |||
++iter; | |||
} | |||
} | |||
op_node_context.output_map.insert(tmp_output_map.begin(), tmp_output_map.end()); | |||
op_node_context.output_map.insert(tmp_output_map.cbegin(), tmp_output_map.cend()); | |||
// update input map of pre node | |||
for (const auto &out_iter : op_node_context.output_map) { | |||
auto dst_iter = op_node_context_map_.find(out_iter.first); | |||
@@ -3902,7 +3821,7 @@ Status TensorFlowModelParser::AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope | |||
DumpAllNodeContext("BeforeAddFusionNodeDef"); | |||
for (size_t i = 0; i < op_node_list_size; ++i) { | |||
const string op_node_name = node_name_list[i]; | |||
auto iter = fusion_op_nodedef_map_.find(op_node_name); | |||
std::map<string, vector<const NodeDef *>>::const_iterator iter = fusion_op_nodedef_map_.find(op_node_name); | |||
if (iter != fusion_op_nodedef_map_.end()) { | |||
vector<string> fusion_op_info = fusion_op_type_map_[op_node_name]; | |||
if (fusion_op_info[0] != ge::kScopeToMultiNodes) { | |||
@@ -3943,7 +3862,8 @@ Status TensorFlowModelParser::AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope | |||
} | |||
Status TensorFlowModelParser::AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, | |||
std::mutex *graph_mutex, const domi::tensorflow::NodeDef *node_def) { | |||
std::mutex *const graph_mutex, | |||
const domi::tensorflow::NodeDef *node_def) { | |||
// This is an internal function. The pointer input parameter is not empty when this function is invoked. | |||
string node_name = node_def->name(); | |||
string node_op = node_def->op(); | |||
@@ -4059,7 +3979,7 @@ Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph | |||
std::string model_data; | |||
if (AttrUtils::GetStr(node->GetOpDesc(), kExternalModel, model_data) && !model_data.empty()) { | |||
ge::Model model; | |||
auto load_ret = ge::Model::Load(reinterpret_cast<const uint8_t *>(model_data.data()), model_data.size(), model); | |||
auto load_ret = ge::Model::Load(ge::PtrToPtr<char_t, const uint8_t>(model_data.data()), model_data.size(), model); | |||
if (load_ret != GRAPH_SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "[Parse][ExternalModel]Node:%s.", node->GetName().c_str()); | |||
REPORT_CALL_ERROR("E19999", "Failed to parse external model, node:%s.", node->GetName().c_str()); | |||
@@ -35,6 +35,7 @@ | |||
#include "omg/parser/model_parser.h" | |||
#include "omg/parser/op_parser.h" | |||
#include "omg/parser/weights_parser.h" | |||
#include "common/pre_checker.h" | |||
#include "parser/tensorflow/tensorflow_fusion_op_parser.h" | |||
#include "parser/tensorflow/tensorflow_util.h" | |||
#include "proto/om.pb.h" | |||
@@ -46,15 +47,6 @@ | |||
#include "scope/scope_pass_manager.h" | |||
#include "common/parser_utils.h" | |||
using ge::ScopePassManager; | |||
using domi::tensorflow::GraphDef; | |||
using domi::tensorflow::DT_HALF; | |||
using domi::tensorflow::NodeDef; | |||
using domi::tensorflow::GraphDef; | |||
using domi::tensorflow::AttrValue; | |||
using domi::tensorflow::DataType; | |||
using ge::OpParser; | |||
namespace ge { | |||
using std::string; | |||
using std::vector; | |||
@@ -130,7 +122,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||
Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, | |||
domi::GetGraphCallback callback, | |||
ge::ComputeGraphPtr &graph) override; | |||
ge::ComputeGraphPtr &root_graph) override; | |||
/* | |||
* @ingroup domi_omg | |||
@@ -163,6 +155,18 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||
*/ | |||
Status ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback, | |||
ge::ComputeGraphPtr &root_graph) override; | |||
bool HasError() override { | |||
return PreChecker::Instance().HasError(); | |||
} | |||
Status Save(const string &file) override { | |||
return PreChecker::Instance().Save(file); | |||
} | |||
void Clear() override { | |||
PreChecker::Instance().Clear(); | |||
} | |||
private: | |||
Status Parse(const char *model_path, ge::ComputeGraphPtr &root_graph); | |||
@@ -241,15 +245,6 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||
/** | |||
* @ingroup domi_omg | |||
* @brief Verifying the validity of graphdef object parsed by pb | |||
* @param [in] graph_def Parsed tensorflow:: graphdef object | |||
* @return SUCCESS check successfully | |||
* @return FAILED check failed | |||
*/ | |||
Status CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const; | |||
/** | |||
* @ingroup domi_omg | |||
* @brief whether const OP need to update context | |||
* @param const op name | |||
* @return true or false | |||
@@ -433,28 +428,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||
* @brief Delete the connection relationship of the identity operator connecting the Arg node in graphdef | |||
*/ | |||
Status GraphDefOptimize(domi::tensorflow::GraphDef *graph_def); | |||
/** | |||
* @ingroup domi_omg | |||
* @brief Optimize for Identity/ReadVariableOp operator | |||
* @param [in] graph_def GraphDef to be optimized | |||
* @param [in] nodedef_map Map of all nodes in graph | |||
* @param [in] nodedef_to_optimize vector of NodeDef to be optimized | |||
* @return SUCCESS optimize successfully | |||
* @return others failed | |||
*/ | |||
Status GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map, | |||
const vector<NodeDef *> &nodedef_to_optimize); | |||
/** | |||
* @ingroup domi_omg | |||
* @brief For the identity operator whose output is "_retval", optimize it. | |||
* @param [in] nodedef_map Map of all nodes in graph | |||
* @param [in] curr_node_name Name of node to be optimized | |||
* @param [in] clear_input_flag Flag of whether to clear the input of the current node | |||
* @return SUCCESS optimize successfully | |||
* @return others failed | |||
*/ | |||
Status OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, const string &curr_node_name, | |||
bool &clear_input_flag); | |||
Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map, | |||
const vector<NodeDef *> &nodedef_to_optimize); | |||
Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, | |||
@@ -469,7 +443,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||
void OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *const graph_def, | |||
domi::tensorflow::NodeDef *const nodeCurrent, bool &clearInputFlag) const; | |||
static void OptimizeTranspose(std::map<std::string, DelTransposeInfo> &transposeInfo); | |||
static void SoftmaxAddAttr(GraphDef *const graph_def); | |||
static void SoftmaxAddAttr(domi::tensorflow::GraphDef *const graph_def); | |||
/** | |||
* @ingroup domi_omg | |||
@@ -551,7 +525,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||
* @return false optimize failed | |||
* | |||
*/ | |||
Status OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def); | |||
Status OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) const; | |||
/** | |||
* @ingroup domi_omg | |||
@@ -565,7 +539,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||
Status RemoveInputs(domi::tensorflow::GraphDef *graph_def, | |||
domi::tensorflow::NodeDef *node_def, | |||
const set<uint32_t> &remove_index_set, | |||
const map<string, NodeDef *> &all_node_map); | |||
const map<string, NodeDef *> &all_node_map) const; | |||
Status AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def, | |||
domi::tensorflow::NodeDef *node_def, | |||
@@ -611,7 +585,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||
static Status GetFunctionProto(const string &file, domi::tensorflow::GraphDefLibrary &graph_def_library); | |||
Status SetOriginNodeContext(NodeDef *node_def, OpNodeContext &op_node_context, | |||
Status SetOriginNodeContext(const NodeDef *node_def, OpNodeContext &op_node_context, | |||
const std::vector<std::pair<std::string, int32_t>> &inputs, | |||
const std::vector<std::pair<std::string, int32_t>> &outputs); | |||
@@ -642,7 +616,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||
Status AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph, vector<string> &node_name_list); | |||
static Status AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, | |||
std::mutex *graph_mutex, const domi::tensorflow::NodeDef *node_def); | |||
std::mutex *const graph_mutex, const domi::tensorflow::NodeDef *node_def); | |||
static void DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase); | |||
void DumpAllNodeContext(const string &phase) const; | |||
@@ -725,6 +699,18 @@ class PARSER_FUNC_VISIBILITY TensorFlowWeightsParser : public domi::WeightsParse | |||
Status Parse(const char *file, ge::Graph &graph) override; | |||
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | |||
bool HasError() override { | |||
return PreChecker::Instance().HasError(); | |||
} | |||
Status Save(const string &file) override { | |||
return PreChecker::Instance().Save(file); | |||
} | |||
void Clear() override { | |||
PreChecker::Instance().Clear(); | |||
} | |||
}; | |||
} // namespace domi | |||
#endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_ |
@@ -30,8 +30,6 @@ | |||
#include "parser/tensorflow/tensorflow_op_parser.h" | |||
#include "proto/tensorflow/node_def.pb.h" | |||
using domi::tensorflow::NodeDef; | |||
namespace ge { | |||
class PARSER_FUNC_VISIBILITY TensorflowFinalizeable { | |||
public: | |||
@@ -20,8 +20,6 @@ | |||
#include "common/op_def/ref_switch_op.h" | |||
#include "parser/tensorflow/tensorflow_op_parser.h" | |||
using domi::tensorflow::NodeDef; | |||
namespace ge { | |||
class PARSER_FUNC_VISIBILITY TensorFlowRefSwitchParser : public TensorFlowOpParser { | |||
// AUTO GEN PLEASE DO NOT MODIFY IT | |||
@@ -61,7 +61,7 @@ Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||
GE_CHECK_NOTNULL(op_src); | |||
GE_CHECK_NOTNULL(op); | |||
const NodeDef *node_src = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src); | |||
const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||
GE_CHECK_NOTNULL(node_src); | |||
GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); | |||
domi::tensorflow::AttrValue input_attr_value; | |||
@@ -34,7 +34,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowReshapeParser : public TensorFlowOpParser | |||
* @return FAILED parse failed | |||
* @author | |||
*/ | |||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | |||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; | |||
}; | |||
} // namespace ge | |||
@@ -94,7 +94,7 @@ Status TensorFlowShapeNParser::ParseN(const domi::tensorflow::NodeDef *node, Sha | |||
Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | |||
GE_CHECK_NOTNULL(op_dest); | |||
const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src); | |||
const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||
GE_CHECK_NOTNULL(node); | |||
ShapeNOperator op; | |||
op.Name(node->name()); | |||
@@ -154,13 +154,13 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||
} | |||
// AUTO GEN PLEASE DO NOT MODIFY IT | |||
Status TensorFlowShapeNParser::PreParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) { | |||
Status TensorFlowShapeNParser::PreParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op) { | |||
(void)node; | |||
(void)op; | |||
return SUCCESS; | |||
} | |||
Status TensorFlowShapeNParser::PostParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) { | |||
Status TensorFlowShapeNParser::PostParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op) { | |||
(void)node; | |||
(void)op; | |||
return SUCCESS; | |||
@@ -20,8 +20,6 @@ | |||
#include "common/op_def/shape_n_op.h" | |||
#include "parser/tensorflow/tensorflow_op_parser.h" | |||
using domi::tensorflow::NodeDef; | |||
namespace ge { | |||
class PARSER_FUNC_VISIBILITY TensorFlowShapeNParser : public TensorFlowOpParser { | |||
// AUTO GEN PLEASE DO NOT MODIFY IT | |||
@@ -29,8 +27,8 @@ class PARSER_FUNC_VISIBILITY TensorFlowShapeNParser : public TensorFlowOpParser | |||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | |||
protected: | |||
Status PreParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||
Status PostParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||
Status PreParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op); | |||
Status PostParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op); | |||
static Status ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||
static Status ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||
@@ -66,7 +66,7 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||
GE_CHECK_NOTNULL(op_src); | |||
GE_CHECK_NOTNULL(op); | |||
const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src); | |||
const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); | |||
GE_CHECK_NOTNULL(node); | |||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | |||
bool has_axis = true; | |||
@@ -22,7 +22,7 @@ | |||
namespace ge { | |||
class PARSER_FUNC_VISIBILITY TensorFlowSqueezeParser : public TensorFlowOpParser { | |||
public: | |||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | |||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; | |||
private: | |||
static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); | |||
@@ -207,7 +207,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Ch | |||
} | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType( | |||
const NodeDef *node_src, const std::string &attr_src, domi::tensorflow::DataType &data_type) { | |||
const domi::tensorflow::NodeDef *node_src, const std::string &attr_src, domi::tensorflow::DataType &data_type) { | |||
GE_CHECK_NOTNULL(node_src); | |||
std::string node_name = node_src->name(); | |||
@@ -301,7 +301,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr | |||
} | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr( | |||
const std::string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *const node_def) { | |||
GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null."); | |||
if (node_def == nullptr) { | |||
GELOGI("input parameter is null."); | |||
return; | |||
} | |||
node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); | |||
} | |||
} // namespace ge |
@@ -18,14 +18,11 @@ | |||
#define OMG_PARSER_TENSORFLOW_TENSORFLOW_UTIL_H_ | |||
#include <map> | |||
#include <set> | |||
#include <string> | |||
#include <unordered_map> | |||
#include <vector> | |||
#include "parser/common/op_def/operator.h" | |||
#include "external/graph/attr_value.h" | |||
#include "external/graph/graph.h" | |||
#include "external/graph/operator.h" | |||
#include "framework/omg/parser/parser_types.h" | |||
#include "framework/omg/omg_inner_types.h" | |||
#include "graph/compute_graph.h" | |||
@@ -37,11 +34,6 @@ | |||
#include "graph/utils/tensor_utils.h" | |||
#include "proto/tensorflow/graph.pb.h" | |||
using domi::tensorflow::NodeDef; | |||
using domi::tensorflow::FunctionDef; | |||
using domi::tensorflow::AttrValue_ListValue; | |||
using domi::tensorflow::FunctionDefLibrary; | |||
namespace ge { | |||
/***************************TensorFlow attribute type, constant definition*******************************************/ | |||
extern const std::string TENSORFLOW_ATTR_TYPE_STRING; | |||
@@ -167,7 +159,7 @@ class TensorFlowUtil { | |||
* @return FAILED parsing failed | |||
* | |||
*/ | |||
static domi::Status ParseDataType(const NodeDef *node_src, | |||
static domi::Status ParseDataType(const domi::tensorflow::NodeDef *node_src, | |||
const std::string &attr_src, | |||
domi::tensorflow::DataType &data_type); | |||
@@ -25,7 +25,7 @@ using namespace ge::parser; | |||
namespace ge { | |||
Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *const op) { | |||
GE_CHECK_NOTNULL(op_src); | |||
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||
const domi::tensorflow::NodeDef *node = ge::PtrToPtr<Message, domi::tensorflow::NodeDef>(op_src); | |||
GE_CHECK_NOTNULL(node); | |||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | |||
op->Name(node->name()); | |||
@@ -19,7 +19,6 @@ | |||
#include "graph/ge_attr_value.h" | |||
#include "graph/ge_tensor.h" | |||
#include "graph/op_desc.h" | |||
#include "graph/operator.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "parser/common/op_def/variable_op.h" | |||
@@ -253,7 +252,7 @@ static void ParseMemType(const domi::tensorflow::NodeDef *node, VariableOperator | |||
Status ParseParams(const Message *op_src, VariableOperator *op) { | |||
GE_CHECK_NOTNULL(op_src); | |||
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||
const NodeDef *node = ge::PtrToPtr<Message, NodeDef>(op_src); | |||
GE_CHECK_NOTNULL(node); | |||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | |||
string node_op = node->op(); | |||
@@ -121,7 +121,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc( | |||
GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), | |||
REPORT_CALL_ERROR("E19999", "UnserializeOpDesc failed"); | |||
return op_desc, "[Call][UnserializeOpDesc] op_desc unserialize failed"); | |||
op_desc->extAttrs_ = org_op_desc->extAttrs_; | |||
op_desc->ext_attrs_ = org_op_desc->ext_attrs_; | |||
// This function may be called by some passes of fusion engine, in this condition, do not need these attribute | |||
if (op_desc->impl_ == nullptr) { | |||
@@ -164,7 +164,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c | |||
return nullptr; | |||
} | |||
op_desc->extAttrs_ = org_op_desc->extAttrs_; | |||
op_desc->ext_attrs_ = org_op_desc->ext_attrs_; | |||
if (op_desc->impl_ == nullptr) { | |||
REPORT_INNER_ERROR("E19999", "op desc impl is nullptr, check invalid"); | |||
@@ -158,6 +158,17 @@ mmTimespec mmGetTickCount() { | |||
return rts; | |||
} | |||
INT32 mmGetSystemTime(mmSystemTime_t *sysTime) { | |||
// Beijing olympics | |||
sysTime->wYear = 2008; | |||
sysTime->wMonth = 8; | |||
sysTime->wDay = 8; | |||
sysTime->wHour = 20; | |||
sysTime->wMinute = 8; | |||
sysTime->wSecond = 0; | |||
return 0; | |||
} | |||
INT32 mmGetTid() { | |||
INT32 ret = (INT32)syscall(SYS_gettid); | |||
@@ -61,7 +61,7 @@ target_link_libraries(st_parser_proto PRIVATE | |||
################################################################################ | |||
set(DUPLICATE_PROTO_LIST | |||
"${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto" | |||
"${PARSER_DIR}/metadef/proto/onnx/ge_onnx.proto" | |||
) | |||
protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST}) | |||
@@ -118,7 +118,10 @@ set(MATEDEF_SRC_FILES | |||
"${PARSER_DIR}/metadef/graph/resource_context_mgr.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/constant_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/file_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/connection_matrix.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/cycle_detector.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/graph_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/node_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||
@@ -126,6 +129,8 @@ set(MATEDEF_SRC_FILES | |||
"${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/type_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/trace/trace_manager.cc" | |||
"${PARSER_DIR}/metadef/graph/common/large_bm.cc" | |||
"${PARSER_DIR}/metadef/ops/op_imp.cpp" | |||
"${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc" | |||
"${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" | |||
@@ -21,7 +21,9 @@ | |||
#include <google/protobuf/io/zero_copy_stream_impl.h> | |||
#include <google/protobuf/text_format.h> | |||
#include <fstream> | |||
#include <sys/types.h> | |||
#include <sys/stat.h> | |||
#include <fcntl.h> | |||
namespace ge { | |||
void ParerSTestsUtils::ClearParserInnerCtx() { | |||
@@ -131,4 +133,14 @@ void ParerSTestsUtils::WriteProtoToBinaryFile(const google::protobuf::Message &p | |||
out.close(); | |||
delete[] buf; | |||
} | |||
void ParerSTestsUtils::WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename) { | |||
const int32_t fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 384U); | |||
if (fd >= 0) { | |||
google::protobuf::io::FileOutputStream output(fd); | |||
google::protobuf::TextFormat::Print(proto, &output); | |||
output.Close(); | |||
close(fd); | |||
} | |||
} | |||
} // namespace ge |
@@ -31,6 +31,7 @@ class ParerSTestsUtils { | |||
static MemBuffer* MemBufferFromFile(const char *path); | |||
static bool ReadProtoFromText(const char *file, google::protobuf::Message *message); | |||
static void WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename); | |||
static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename); | |||
}; | |||
} // namespace ge | |||
@@ -36,6 +36,7 @@ | |||
#include "parser/caffe/caffe_op_parser.h" | |||
#include "graph/operator_reg.h" | |||
#include "parser/common/acl_graph_parser_util.h" | |||
#include "common/op_map.h" | |||
#undef protected | |||
#undef private | |||
@@ -173,7 +174,7 @@ void STestCaffeParser::RegisterCustomOp() { | |||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||
for (auto reg_data : reg_datas) { | |||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||
domi::OpRegistry::Instance()->Register(reg_data); | |||
} | |||
domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
@@ -223,6 +224,29 @@ TEST_F(STestCaffeParser, acl_caffe_parser) { | |||
EXPECT_EQ(ret, GRAPH_FAILED); | |||
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), graph); | |||
EXPECT_EQ(ret, GRAPH_FAILED); | |||
caffe_op_map.clear(); | |||
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), parser_params, graph); | |||
EXPECT_EQ(ret, GRAPH_FAILED); | |||
{ | |||
proto.set_name("empty_layer"); | |||
auto &layers = *proto.add_layers(); | |||
layers.set_name("layers"); | |||
proto.clear_layer(); | |||
const std::string empty_layer = case_dir + "/origin_models/empty_layer.pbtxt"; | |||
ParerSTestsUtils::WriteProtoToTextFile(proto, empty_layer.c_str()); | |||
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layer.c_str(), weight_file.c_str(), parser_params, graph), FAILED); | |||
proto.clear_layers(); | |||
const std::string empty_layers = case_dir + "/origin_models/empty_layers.pbtxt"; | |||
ParerSTestsUtils::WriteProtoToTextFile(proto, empty_layers.c_str()); | |||
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layers.c_str(), weight_file.c_str(), parser_params, graph), FAILED); | |||
unlink(empty_layer.c_str()); | |||
unlink(empty_layers.c_str()); | |||
} | |||
} | |||
TEST_F(STestCaffeParser, modelparser_parsefrommemory_success) | |||
@@ -24,6 +24,7 @@ | |||
#include "st/parser_st_utils.h" | |||
#include "external/ge/ge_api_types.h" | |||
#include "tests/depends/ops_stub/ops_stub.h" | |||
#include "framework/omg/parser/parser_factory.h" | |||
#include "parser/onnx/onnx_parser.h" | |||
namespace ge { | |||
@@ -96,7 +97,7 @@ void STestOnnxParser::RegisterCustomOp() { | |||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||
for (auto reg_data : reg_datas) { | |||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||
domi::OpRegistry::Instance()->Register(reg_data); | |||
} | |||
domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
@@ -64,6 +64,7 @@ | |||
#include "parser/common/data_op_parser.h" | |||
#include "parser/common/model_saver.h" | |||
#include "framework/omg/parser/parser_api.h" | |||
#include "framework/omg/parser/parser_factory.h" | |||
#include "parser/common/parser_fp16_t.h" | |||
#include "parser/common/op_parser_factory.h" | |||
#include "parser/common/prototype_pass_manager.h" | |||
@@ -151,7 +152,7 @@ void STestTensorflowParser::RegisterCustomOp() { | |||
.ParseParamsFn(ParseParams); | |||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||
for (auto reg_data : reg_datas) { | |||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||
domi::OpRegistry::Instance()->Register(reg_data); | |||
} | |||
domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
@@ -584,7 +585,7 @@ namespace { | |||
void register_tbe_op() { | |||
std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas; | |||
for (OpRegistrationData reg_data : registrationDatas) { | |||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||
OpRegistry::Instance()->Register(reg_data); | |||
} | |||
OpRegistry::Instance()->registrationDatas.clear(); | |||
@@ -1124,7 +1125,7 @@ TEST_F(STestTensorflowParser, tensorflow_parserfrommemory_failed) | |||
ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); | |||
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
ret = modelParser.ParseFromMemory(data, size, compute_graph); | |||
EXPECT_EQ(ret, INTERNAL_ERROR); | |||
EXPECT_NE(ret, SUCCESS); | |||
} | |||
TEST_F(STestTensorflowParser, modelparser_parsefrommemory_success) | |||
@@ -1259,7 +1260,7 @@ TEST_F(STestTensorflowParser, tensorflow_parserAllGraph_failed) | |||
ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
TensorFlowModelParser tensorflow_parser; | |||
ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); | |||
EXPECT_EQ(INTERNAL_ERROR, ret); | |||
ASSERT_NE(ret, SUCCESS); | |||
} | |||
TEST_F(STestTensorflowParser, test_parse_acl_output_nodes) | |||
@@ -1913,6 +1914,7 @@ TEST_F(STestTensorflowParser, tensorflow_auto_mapping_parser_adapter_test) | |||
EXPECT_EQ(ret, SUCCESS); | |||
op_dest->SetType(ge::parser::SHAPE); | |||
op_dest->AddOutputDesc(GeTensorDesc()); | |||
ret = autoMappingParser.ParseParams(node_def, op_dest); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
@@ -2648,29 +2650,6 @@ TEST_F(STestTensorflowParser, tensorflow_UpdateEdgesControlInfo_test) | |||
model_parser.UpdateEdgesControlInfo(info); | |||
} | |||
TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test) | |||
{ | |||
TensorFlowModelParser model_parser; | |||
NodeDef *node_def = new NodeDef(); | |||
node_def->set_name("Placeholder"); | |||
node_def->set_op("Placeholder_0"); | |||
std::map<string, NodeDef *> nodedef_map; | |||
nodedef_map.emplace("Placeholder", node_def); | |||
std::string curr_node_name = "Placeholder"; | |||
bool clear_input_flag = true; | |||
Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||
EXPECT_EQ(ret, INTERNAL_ERROR); | |||
GraphDef graph; | |||
curr_node_name = "pre_node_a"; | |||
nodedef_map.emplace("pre_node_a", node_def); | |||
node_def->set_op("pre_node_a"); | |||
GenOriginContext(&model_parser, curr_node_name); | |||
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||
EXPECT_EQ(ret, SUCCESS); | |||
delete node_def; | |||
} | |||
TEST_F(STestTensorflowParser, tensorflow_OptimizeSnapShot_test) | |||
{ | |||
TensorFlowModelParser model_parser; | |||
@@ -2831,27 +2810,17 @@ TEST_F(STestTensorflowParser, tensorflow_AddControlEdgeAfterRemoveInputs_test) | |||
removed_inputs_vec.emplace_back("Add0"); | |||
Status ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_def, all_node_map, removed_inputs_vec); | |||
EXPECT_EQ(ret, SUCCESS); | |||
tensorflow::NodeDef *node_swith = initNodeDef(); | |||
node_swith->set_name("switch_op"); | |||
node_swith->set_op(parser::SWITCH); | |||
all_node_map.emplace("switch_op", node_swith); | |||
removed_inputs_vec.clear(); | |||
removed_inputs_vec.emplace_back("switch_op"); | |||
ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_swith, all_node_map, removed_inputs_vec); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test) | |||
{ | |||
tensorflow::GraphDef graph_def; | |||
TensorFlowModelParser tensorflow_parser; | |||
tensorflow::NodeDef *node_def = initNodeDef(); | |||
node_def->set_name("post_node_d"); | |||
std::map<string, NodeDef *> nodedef_map; | |||
nodedef_map.emplace("post_node_d", node_def); | |||
nodedef_map.emplace("post_node_a", node_def); | |||
nodedef_map.emplace("post_node_b", node_def); | |||
std::vector<NodeDef *> nodedef_to_optimize; | |||
nodedef_to_optimize.emplace_back(node_def); | |||
std::string curr_node_name = "post_node_b"; | |||
GenOriginContext(&tensorflow_parser, curr_node_name); | |||
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); | |||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||
} | |||
TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | |||
std::string caseDir = __FILE__; | |||
std::size_t idx = caseDir.find_last_of("/"); | |||
@@ -3534,7 +3503,8 @@ TEST_F(STestTensorflowParser, tensorflow_Pb2Json_OneField2Json_test) | |||
ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); | |||
field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM); | |||
mess2Op.ParseField(reflection, node_def, field, depth, ops); | |||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str); | |||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 1); | |||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 5); | |||
delete field; | |||
} | |||
@@ -62,7 +62,7 @@ target_link_libraries(ut_parser_proto PRIVATE | |||
################################################################################ | |||
set(DUPLICATE_PROTO_LIST | |||
"${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto" | |||
"${PARSER_DIR}/metadef/proto/onnx/ge_onnx.proto" | |||
) | |||
protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST}) | |||
@@ -119,14 +119,19 @@ set(MATEDEF_SRC_FILES | |||
"${PARSER_DIR}/metadef/graph/tensor.cc" | |||
"${PARSER_DIR}/metadef/graph/types.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/file_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/graph_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/connection_matrix.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/cycle_detector.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/node_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/tensor_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/type_utils.cc" | |||
"${PARSER_DIR}/metadef/graph/utils/trace/trace_manager.cc" | |||
"${PARSER_DIR}/metadef/graph/common/large_bm.cc" | |||
"${PARSER_DIR}/metadef/ops/op_imp.cpp" | |||
"${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc" | |||
"${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" | |||
@@ -21,6 +21,9 @@ | |||
#include <google/protobuf/io/zero_copy_stream_impl.h> | |||
#include <google/protobuf/text_format.h> | |||
#include <limits.h> | |||
#include <sys/types.h> | |||
#include <sys/stat.h> | |||
#include <fcntl.h> | |||
namespace ge { | |||
void ParerUTestsUtils::ClearParserInnerCtx() { | |||
@@ -131,6 +134,16 @@ void ParerUTestsUtils::WriteProtoToBinaryFile(const google::protobuf::Message &p | |||
delete[] buf; | |||
} | |||
void ParerUTestsUtils::WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename) { | |||
const int32_t fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 384U); | |||
if (fd >= 0) { | |||
google::protobuf::io::FileOutputStream output(fd); | |||
google::protobuf::TextFormat::Print(proto, &output); | |||
output.Close(); | |||
close(fd); | |||
} | |||
} | |||
namespace ut { | |||
NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, | |||
DataType data_type, std::vector<int64_t> shape) { | |||
@@ -32,6 +32,7 @@ class ParerUTestsUtils { | |||
static MemBuffer* MemBufferFromFile(const char *path); | |||
static bool ReadProtoFromText(const char *file, google::protobuf::Message *message); | |||
static void WriteProtoToBinaryFile(const google::protobuf::Message &proto, const char *filename); | |||
static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *filename); | |||
}; | |||
namespace ut { | |||
@@ -39,6 +39,7 @@ | |||
#include "graph/operator_reg.h" | |||
#include "parser/common/acl_graph_parser_util.h" | |||
#include "parser/caffe/caffe_reshape_parser.h" | |||
#include "common/op_map.h" | |||
#undef protected | |||
#undef private | |||
@@ -162,7 +163,7 @@ static ge::NodePtr GenNodeFromOpDesc(ge::OpDescPtr opDesc){ | |||
void UtestCaffeParser::RegisterCustomOp() { | |||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||
for (auto reg_data : reg_datas) { | |||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||
domi::OpRegistry::Instance()->Register(reg_data); | |||
} | |||
domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
@@ -266,6 +267,29 @@ TEST_F(UtestCaffeParser, acl_caffe_parser) { | |||
EXPECT_EQ(ret, GRAPH_FAILED); | |||
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), graph); | |||
EXPECT_EQ(ret, GRAPH_FAILED); | |||
caffe_op_map.clear(); | |||
ret = ge::aclgrphParseCaffe(model_file.c_str(), weight_file.c_str(), parser_params, graph); | |||
EXPECT_EQ(ret, GRAPH_FAILED); | |||
{ | |||
proto.set_name("empty_layer"); | |||
auto &layers = *proto.add_layers(); | |||
layers.set_name("layers"); | |||
proto.clear_layer(); | |||
const std::string empty_layer = case_dir + "/caffe_model/empty_layer.pbtxt"; | |||
ParerUTestsUtils::WriteProtoToTextFile(proto, empty_layer.c_str()); | |||
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layer.c_str(), weight_file.c_str(), parser_params, graph), FAILED); | |||
proto.clear_layers(); | |||
const std::string empty_layers = case_dir + "/caffe_model/empty_layers.pbtxt"; | |||
ParerUTestsUtils::WriteProtoToTextFile(proto, empty_layers.c_str()); | |||
EXPECT_EQ(ge::aclgrphParseCaffe(empty_layers.c_str(), weight_file.c_str(), parser_params, graph), FAILED); | |||
unlink(empty_layer.c_str()); | |||
unlink(empty_layers.c_str()); | |||
} | |||
} | |||
TEST_F(UtestCaffeParser, ParseFromMemory_success) | |||
@@ -34,6 +34,7 @@ | |||
#include "parser/common/pass_manager.h" | |||
#include "parser/common/tbe_plugin_loader.h" | |||
#include "parser/common/parser_fp16_t.h" | |||
#include "parser/common/pre_checker.h" | |||
#undef protected | |||
#undef private | |||
@@ -342,4 +343,15 @@ TEST_F(UtestAclGraphParser, test_operatoreq) | |||
int8 = fp16; | |||
} | |||
TEST_F(UtestAclGraphParser, test_pre_checker) { | |||
PreChecker::Instance().fmk_op_types_ = nullptr; | |||
const char* str = "iiii"; | |||
PreChecker::OpId id = str; | |||
std::string type("ddd"); | |||
std::string name("lll"); | |||
Status ret = PreChecker::Instance().CheckTypeSupported(id, type, name, false); | |||
EXPECT_EQ(ret, FAILED); | |||
ret = PreChecker::Instance().CheckTypeSupported(id, type, name, true); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
} // namespace ge |
@@ -24,6 +24,7 @@ | |||
#include "external/parser/onnx_parser.h" | |||
#include "ut/parser/parser_ut_utils.h" | |||
#include "external/ge/ge_api_types.h" | |||
#include "framework/omg/parser/parser_factory.h" | |||
#include "tests/depends/ops_stub/ops_stub.h" | |||
#define protected public | |||
@@ -103,7 +104,7 @@ void UtestOnnxParser::RegisterCustomOp() { | |||
.ParseParamsFn(ParseParams); | |||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||
for (auto reg_data : reg_datas) { | |||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||
domi::OpRegistry::Instance()->Register(reg_data); | |||
} | |||
domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
@@ -176,7 +176,7 @@ void UtestTensorflowParser::RegisterCustomOp() { | |||
.ParseParamsFn(ParseParams); | |||
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||
for (auto reg_data : reg_datas) { | |||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||
domi::OpRegistry::Instance()->Register(reg_data); | |||
} | |||
domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
@@ -599,7 +599,7 @@ namespace { | |||
void register_tbe_op() { | |||
std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas; | |||
for (OpRegistrationData reg_data : registrationDatas) { | |||
OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data); | |||
OpRegistry::Instance()->Register(reg_data); | |||
} | |||
OpRegistry::Instance()->registrationDatas.clear(); | |||
@@ -1288,7 +1288,7 @@ TEST_F(UtestTensorflowParser, tensorflow_parserfrommemory_failed) | |||
ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph); | |||
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
ret = modelParser.ParseFromMemory(data, size, compute_graph); | |||
EXPECT_EQ(ret, INTERNAL_ERROR); | |||
EXPECT_NE(ret, SUCCESS); | |||
} | |||
TEST_F(UtestTensorflowParser, modelparser_parsefrommemory_success) | |||
@@ -1419,7 +1419,7 @@ TEST_F(UtestTensorflowParser, tensorflow_parserAllGraph_failed) | |||
ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
TensorFlowModelParser tensorflow_parser; | |||
ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); | |||
EXPECT_EQ(INTERNAL_ERROR, ret); | |||
ASSERT_NE(ret, SUCCESS); | |||
} | |||
TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes) | |||
@@ -2082,6 +2082,7 @@ TEST_F(UtestTensorflowParser, tensorflow_auto_mapping_parser_adapter_test) | |||
EXPECT_EQ(ret, SUCCESS); | |||
op_dest->SetType(ge::parser::SHAPE); | |||
op_dest->AddOutputDesc(GeTensorDesc()); | |||
ret = autoMappingParser.ParseParams(node_def, op_dest); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
@@ -2824,29 +2825,6 @@ TEST_F(UtestTensorflowParser, tensorflow_UpdateEdgesControlInfo_test) | |||
model_parser.UpdateEdgesControlInfo(info); | |||
} | |||
TEST_F(UtestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test) | |||
{ | |||
TensorFlowModelParser model_parser; | |||
NodeDef *node_def = new NodeDef(); | |||
node_def->set_name("Placeholder"); | |||
node_def->set_op("Placeholder_0"); | |||
std::map<string, NodeDef *> nodedef_map; | |||
nodedef_map.emplace("Placeholder", node_def); | |||
std::string curr_node_name = "Placeholder"; | |||
bool clear_input_flag = true; | |||
Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||
EXPECT_EQ(ret, INTERNAL_ERROR); | |||
GraphDef graph; | |||
curr_node_name = "pre_node_a"; | |||
nodedef_map.emplace("pre_node_a", node_def); | |||
node_def->set_op("pre_node_a"); | |||
GenOriginContext(&model_parser, curr_node_name); | |||
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||
EXPECT_EQ(ret, SUCCESS); | |||
delete node_def; | |||
} | |||
TEST_F(UtestTensorflowParser, tensorflow_OptimizeSnapShot_test) | |||
{ | |||
TensorFlowModelParser model_parser; | |||
@@ -3007,27 +2985,18 @@ TEST_F(UtestTensorflowParser, tensorflow_AddControlEdgeAfterRemoveInputs_test) | |||
removed_inputs_vec.emplace_back("Add0"); | |||
Status ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_def, all_node_map, removed_inputs_vec); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(UtestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test) | |||
{ | |||
tensorflow::GraphDef graph_def; | |||
TensorFlowModelParser tensorflow_parser; | |||
tensorflow::NodeDef *node_def = initNodeDef(); | |||
node_def->set_name("post_node_d"); | |||
tensorflow::NodeDef *node_swith = initNodeDef(); | |||
node_swith->set_name("switch_op"); | |||
node_swith->set_op(parser::SWITCH); | |||
all_node_map.emplace("switch_op", node_swith); | |||
removed_inputs_vec.clear(); | |||
removed_inputs_vec.emplace_back("switch_op"); | |||
ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_swith, all_node_map, removed_inputs_vec); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
std::map<string, NodeDef *> nodedef_map; | |||
nodedef_map.emplace("post_node_d", node_def); | |||
nodedef_map.emplace("post_node_a", node_def); | |||
nodedef_map.emplace("post_node_b", node_def); | |||
std::vector<NodeDef *> nodedef_to_optimize; | |||
nodedef_to_optimize.emplace_back(node_def); | |||
std::string curr_node_name = "post_node_b"; | |||
GenOriginContext(&tensorflow_parser, curr_node_name); | |||
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); | |||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||
} | |||
TEST_F(UtestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | |||
std::string caseDir = __FILE__; | |||
std::size_t idx = caseDir.find_last_of("/"); | |||
@@ -3696,7 +3665,8 @@ TEST_F(UtestTensorflowParser, tensorflow_Pb2Json_OneField2Json_test) | |||
ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); | |||
field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM); | |||
mess2Op.ParseField(reflection, node_def, field, depth, ops); | |||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str); | |||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 1); | |||
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 5); | |||
delete field; | |||
} | |||