@@ -236,14 +236,8 @@ const char *const kFieldInnerPro = "inner_product_param"; | |||||
const char *const kFieldDim = "dim"; | const char *const kFieldDim = "dim"; | ||||
const char *const kFieldBiasTerm = "bias_term"; | const char *const kFieldBiasTerm = "bias_term"; | ||||
const char *const kDevNull = "/dev/null"; | const char *const kDevNull = "/dev/null"; | ||||
const std::string kMessage = "message"; | |||||
const std::string kLayerParameter = "LayerParameter"; | |||||
const std::string kCloseBrace = "}"; | |||||
const std::string kOptional = "optional"; | |||||
const std::string kRepeated = "repeated"; | |||||
const std::string kRequired = "required"; | |||||
const std::string kCustom = "custom"; | |||||
const std::string kBuiltin = "built-in"; | |||||
const char *const kCustom = "custom"; | |||||
const char *const kBuiltin = "built-in"; | |||||
std::vector<std::string> kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT, | std::vector<std::string> kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT, | ||||
ge::parser::NETOUTPUT}; | ge::parser::NETOUTPUT}; | ||||
const std::set<std::string> kCustomProtoLayerCommonField = {"name", "type"}; | const std::set<std::string> kCustomProtoLayerCommonField = {"name", "type"}; | ||||
@@ -284,104 +278,104 @@ const set<string> CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxW | |||||
"Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"}; | "Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"}; | ||||
Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const { | Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const { | ||||
if (proto_message.input_size() > 0) { | |||||
GELOGI("This net exsit input."); | |||||
if (proto_message.input_size() <= 0) { | |||||
return SUCCESS; | |||||
} | |||||
GELOGI("This net exsit input."); | |||||
if (proto_message.input_dim_size() > 0) { | |||||
if (proto_message.input_shape_size() > 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11001"); | |||||
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); | |||||
return FAILED; | |||||
} | |||||
if (proto_message.input_dim_size() > 0) { | |||||
if (proto_message.input_shape_size() > 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11001"); | |||||
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); | |||||
return FAILED; | |||||
} | |||||
const int32_t input_dim_size = proto_message.input_dim_size(); | |||||
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || | |||||
((input_dim_size % proto_message.input_size()) != 0)); | |||||
if (is_input_invalid) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, | |||||
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); | |||||
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", | |||||
input_dim_size, proto_message.input_size()); | |||||
return FAILED; | |||||
} | |||||
const int32_t input_dim_size = proto_message.input_dim_size(); | |||||
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || | |||||
((input_dim_size % proto_message.input_size()) != 0)); | |||||
if (is_input_invalid) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, | |||||
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); | |||||
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", | |||||
input_dim_size, proto_message.input_size()); | |||||
return FAILED; | |||||
} | |||||
for (int i = 0; i < proto_message.input_size(); i++) { | |||||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||||
GE_CHECK_NOTNULL(layer); | |||||
layer->set_name(proto_message.input(i)); | |||||
layer->set_type(ge::parser::INPUT_TYPE); | |||||
layer->add_top(proto_message.input(i)); | |||||
for (int i = 0; i < proto_message.input_size(); i++) { | |||||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||||
GE_CHECK_NOTNULL(layer); | |||||
layer->set_name(proto_message.input(i)); | |||||
layer->set_type(ge::parser::INPUT_TYPE); | |||||
layer->add_top(proto_message.input(i)); | |||||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||||
GE_CHECK_NOTNULL(input_param); | |||||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||||
GE_CHECK_NOTNULL(shape); | |||||
for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { | |||||
// Can guarantee that it will not cross the border | |||||
shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); | |||||
} | |||||
input_data_flag = true; | |||||
} | |||||
} else if (proto_message.input_shape_size() > 0) { | |||||
if (proto_message.input_shape_size() != proto_message.input_size()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, | |||||
{std::to_string(proto_message.input_shape_size()), | |||||
std::to_string(proto_message.input_size())}); | |||||
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", | |||||
proto_message.input_shape_size(), proto_message.input_size()); | |||||
return FAILED; | |||||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||||
GE_CHECK_NOTNULL(input_param); | |||||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||||
GE_CHECK_NOTNULL(shape); | |||||
for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { | |||||
// Can guarantee that it will not cross the border | |||||
shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); | |||||
} | } | ||||
input_data_flag = true; | |||||
} | |||||
} else if (proto_message.input_shape_size() > 0) { | |||||
if (proto_message.input_shape_size() != proto_message.input_size()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, | |||||
{std::to_string(proto_message.input_shape_size()), | |||||
std::to_string(proto_message.input_size())}); | |||||
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", | |||||
proto_message.input_shape_size(), proto_message.input_size()); | |||||
return FAILED; | |||||
} | |||||
for (int i = 0; i < proto_message.input_size(); i++) { | |||||
int dim_size = proto_message.input_shape(i).dim_size(); | |||||
for (int i = 0; i < proto_message.input_size(); i++) { | |||||
int dim_size = proto_message.input_shape(i).dim_size(); | |||||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||||
GE_CHECK_NOTNULL(layer); | |||||
layer->set_name(proto_message.input(i)); | |||||
layer->set_type(ge::parser::INPUT_TYPE); | |||||
layer->add_top(proto_message.input(i)); | |||||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||||
GE_CHECK_NOTNULL(layer); | |||||
layer->set_name(proto_message.input(i)); | |||||
layer->set_type(ge::parser::INPUT_TYPE); | |||||
layer->add_top(proto_message.input(i)); | |||||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||||
GE_CHECK_NOTNULL(input_param); | |||||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||||
GE_CHECK_NOTNULL(shape); | |||||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||||
GE_CHECK_NOTNULL(input_param); | |||||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||||
GE_CHECK_NOTNULL(shape); | |||||
for (int j = 0; j < dim_size; j++) { | |||||
// Can guarantee that it will not cross the border | |||||
shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j))); | |||||
} | |||||
input_data_flag = true; | |||||
for (int j = 0; j < dim_size; j++) { | |||||
// Can guarantee that it will not cross the border | |||||
shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j))); | |||||
} | } | ||||
} else { | |||||
const ge::ParserContext &ctx = ge::GetParserContext(); | |||||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||||
for (int i = 0; i < proto_message.input_size(); i++) { | |||||
string name = proto_message.input(i); | |||||
if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input | |||||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name})); | |||||
GELOGE(FAILED, "[Find][Dim]Model has no input shape."); | |||||
return FAILED; | |||||
} | |||||
std::vector<int64_t> dims = input_dims.at(name); | |||||
size_t dim_size = dims.size(); | |||||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||||
GE_CHECK_NOTNULL(layer); | |||||
layer->set_name(name); | |||||
layer->set_type(ge::parser::INPUT_TYPE); | |||||
layer->add_top(proto_message.input(i)); | |||||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||||
GE_CHECK_NOTNULL(input_param); | |||||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||||
GE_CHECK_NOTNULL(shape); | |||||
for (size_t j = 0; j < dim_size; j++) { | |||||
shape->add_dim(dims.at(j)); | |||||
} | |||||
input_data_flag = true; | |||||
input_data_flag = true; | |||||
} | |||||
} else { | |||||
const ge::ParserContext &ctx = ge::GetParserContext(); | |||||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||||
for (int i = 0; i < proto_message.input_size(); i++) { | |||||
string name = proto_message.input(i); | |||||
if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input | |||||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name})); | |||||
GELOGE(FAILED, "[Find][Dim]Model has no input shape."); | |||||
return FAILED; | |||||
} | |||||
std::vector<int64_t> dims = input_dims.at(name); | |||||
size_t dim_size = dims.size(); | |||||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||||
GE_CHECK_NOTNULL(layer); | |||||
layer->set_name(name); | |||||
layer->set_type(ge::parser::INPUT_TYPE); | |||||
layer->add_top(proto_message.input(i)); | |||||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||||
GE_CHECK_NOTNULL(input_param); | |||||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||||
GE_CHECK_NOTNULL(shape); | |||||
for (size_t j = 0; j < dim_size; j++) { | |||||
shape->add_dim(dims.at(j)); | |||||
} | } | ||||
input_data_flag = true; | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -423,7 +417,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (ParseLayerParameter(layer_descriptor, message, operators) != SUCCESS) { | |||||
if (ParseLayerParameter(*layer_descriptor, *message, operators) != SUCCESS) { | |||||
delete message; | delete message; | ||||
GELOGE(FAILED, "[Parse][LayerParameter] failed, model path:%s.", model_path); | GELOGE(FAILED, "[Parse][LayerParameter] failed, model path:%s.", model_path); | ||||
return FAILED; | return FAILED; | ||||
@@ -536,18 +530,18 @@ Status CaffeModelParser::ReadCaffeModelFromText(const char *model_path, google:: | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||||
const google::protobuf::Message *message, | |||||
Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||||
const google::protobuf::Message &message, | |||||
vector<ge::Operator> &operators) const { | vector<ge::Operator> &operators) const { | ||||
auto field_name = layer_descriptor->FindFieldByName(kFieldName); | |||||
auto field_name = layer_descriptor.FindFieldByName(kFieldName); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); | ||||
auto field_type = layer_descriptor->FindFieldByName(kFieldType); | |||||
auto field_type = layer_descriptor.FindFieldByName(kFieldType); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); | ||||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
reflection->ListFields(*message, &field_desc); | |||||
reflection->ListFields(message, &field_desc); | |||||
for (auto &field : field_desc) { | for (auto &field : field_desc) { | ||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message"); | ||||
// Only care about layers | // Only care about layers | ||||
@@ -561,10 +555,10 @@ Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
int field_size = reflection->FieldSize(*message, field); | |||||
int field_size = reflection->FieldSize(message, field); | |||||
GELOGI("Total Layer num of model file is %d", field_size); | GELOGI("Total Layer num of model file is %d", field_size); | ||||
for (int i = 0; i < field_size; ++i) { | for (int i = 0; i < field_size; ++i) { | ||||
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); | |||||
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); | |||||
const google::protobuf::Reflection *layer_reflection = layer_message.GetReflection(); | const google::protobuf::Reflection *layer_reflection = layer_message.GetReflection(); | ||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
GE_CHECK_NOTNULL(layer_reflection); | GE_CHECK_NOTNULL(layer_reflection); | ||||
@@ -1316,7 +1310,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||||
layer_name_map[layer.name()]++; | layer_name_map[layer.name()]++; | ||||
// Set the name in proto and layer | // Set the name in proto and layer | ||||
domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); | domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); | ||||
duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) | |||||
duplicate_name_layer->set_name(new_name); | |||||
layer.set_name(new_name);) | |||||
// Insert the new operator name, the number of times of duplicate name is recorded as 1 | // Insert the new operator name, the number of times of duplicate name is recorded as 1 | ||||
layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); | layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); | ||||
@@ -1539,7 +1534,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||||
layer_name_map[layer.name()]++; | layer_name_map[layer.name()]++; | ||||
// Set the name in proto and layer | // Set the name in proto and layer | ||||
domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); | domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); | ||||
duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) | |||||
duplicate_name_layer->set_name(new_name); | |||||
layer.set_name(new_name);) | |||||
// Insert the new operator name, the number of times of duplicate name is recorded as 1 | // Insert the new operator name, the number of times of duplicate name is recorded as 1 | ||||
layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); | layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); | ||||
@@ -1832,13 +1828,13 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (CheckLayersSize(message) != SUCCESS) { | |||||
if (CheckLayersSize(*message) != SUCCESS) { | |||||
delete message; | delete message; | ||||
message = nullptr; | message = nullptr; | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (ParseLayerParameter(layer_descriptor, message, graph) != SUCCESS) { | |||||
if (ParseLayerParameter(*layer_descriptor, *message, graph) != SUCCESS) { | |||||
delete message; | delete message; | ||||
message = nullptr; | message = nullptr; | ||||
REPORT_CALL_ERROR("E19999", "ParseLayerParameter failed failed from weight file:%s.", weight_path); | REPORT_CALL_ERROR("E19999", "ParseLayerParameter failed failed from weight file:%s.", weight_path); | ||||
@@ -1852,18 +1848,18 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||||
const google::protobuf::Message *message, | |||||
Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||||
const google::protobuf::Message &message, | |||||
ge::ComputeGraphPtr &graph) { | ge::ComputeGraphPtr &graph) { | ||||
auto field_name = layer_descriptor->FindFieldByName(kFieldName); | |||||
auto field_name = layer_descriptor.FindFieldByName(kFieldName); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); | ||||
auto field_type = layer_descriptor->FindFieldByName(kFieldType); | |||||
auto field_type = layer_descriptor.FindFieldByName(kFieldType); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); | ||||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
reflection->ListFields(*message, &field_desc); | |||||
reflection->ListFields(message, &field_desc); | |||||
NetParameter tmp_net; | NetParameter tmp_net; | ||||
for (auto &field : field_desc) { | for (auto &field : field_desc) { | ||||
@@ -1880,13 +1876,13 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
int field_size = reflection->FieldSize(*message, field); | |||||
int field_size = reflection->FieldSize(message, field); | |||||
GELOGI("Total Layer num of model file is %d", field_size); | GELOGI("Total Layer num of model file is %d", field_size); | ||||
for (int i = 0; i < field_size; ++i) { | for (int i = 0; i < field_size; ++i) { | ||||
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); | |||||
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); | |||||
LayerParameter *layer = tmp_net.add_layer(); | LayerParameter *layer = tmp_net.add_layer(); | ||||
if (ConvertLayerProto(&layer_message, layer) != SUCCESS) { | |||||
if (ConvertLayerProto(layer_message, layer) != SUCCESS) { | |||||
GELOGE(FAILED, "[Invoke][ConvertLayerProto] Convert message to layer proto failed."); | GELOGE(FAILED, "[Invoke][ConvertLayerProto] Convert message to layer proto failed."); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -1907,16 +1903,16 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *message, | |||||
Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message &message, | |||||
google::protobuf::Message *layer) { | google::protobuf::Message *layer) { | ||||
const google::protobuf::Reflection *layer_reflection = message->GetReflection(); | |||||
const google::protobuf::Reflection *layer_reflection = message.GetReflection(); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
layer_reflection->ListFields(*message, &field_desc); | |||||
layer_reflection->ListFields(message, &field_desc); | |||||
for (auto &field : field_desc) { | for (auto &field : field_desc) { | ||||
GE_CHECK_NOTNULL(field); | GE_CHECK_NOTNULL(field); | ||||
if (ParseLayerField(layer_reflection, message, field, layer) != SUCCESS) { | |||||
if (ParseLayerField(*layer_reflection, message, *field, layer) != SUCCESS) { | |||||
GELOGE(FAILED, "[Invoke][ParseLayerField] Parse field %s failed.", field->name().c_str()); | GELOGE(FAILED, "[Invoke][ParseLayerField] Parse field %s failed.", field->name().c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -1924,114 +1920,114 @@ Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *me | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *reflection, | |||||
const google::protobuf::Message *message, | |||||
const google::protobuf::FieldDescriptor *field, | |||||
Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection &reflection, | |||||
const google::protobuf::Message &message, | |||||
const google::protobuf::FieldDescriptor &field, | |||||
google::protobuf::Message *layer) const { | google::protobuf::Message *layer) const { | ||||
GELOGD("Start to parse field: %s.", field->name().c_str()); | |||||
GELOGD("Start to parse field: %s.", field.name().c_str()); | |||||
domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer); | domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer); | ||||
string filed_name = field->name(); | |||||
#define CASE_FIELD_NAME(kName, method) \ | |||||
string filed_name = field.name(); | |||||
#define CASE_FIELD_NAME(kName, method, inner_message, field_ptr) \ | |||||
if (filed_name == kField##kName) { \ | if (filed_name == kField##kName) { \ | ||||
string value = reflection->GetString(*message, field); \ | |||||
string value = reflection.GetString(inner_message, field_ptr); \ | |||||
GELOGD("Parse res: (%s : %s)", filed_name.c_str(), value.c_str()); \ | GELOGD("Parse res: (%s : %s)", filed_name.c_str(), value.c_str()); \ | ||||
layer_proto->set_##method(value); \ | layer_proto->set_##method(value); \ | ||||
return SUCCESS; \ | return SUCCESS; \ | ||||
} | } | ||||
CASE_FIELD_NAME(Name, name); | |||||
CASE_FIELD_NAME(Type, type); | |||||
CASE_FIELD_NAME(Name, name, message, &field); | |||||
CASE_FIELD_NAME(Type, type, message, &field); | |||||
#undef CASE_FIELD_NAME | #undef CASE_FIELD_NAME | ||||
#define CASE_FIELD_NAME_REPEATED(kName, method) \ | |||||
if (filed_name == kField##kName) { \ | |||||
int field_size = reflection->FieldSize(*message, field); \ | |||||
for (int i = 0; i < field_size; ++i) { \ | |||||
auto value = reflection->GetRepeatedString(*message, field, i); \ | |||||
layer_proto->add_##method(value); \ | |||||
} \ | |||||
return SUCCESS; \ | |||||
} | |||||
CASE_FIELD_NAME_REPEATED(Bottom, bottom); | |||||
CASE_FIELD_NAME_REPEATED(Top, top); | |||||
#define CASE_FIELD_NAME_REPEATED(kName, method, inner_message, field_ptr) \ | |||||
if (filed_name == kField##kName) { \ | |||||
int field_size = reflection.FieldSize(inner_message, field_ptr); \ | |||||
for (int i = 0; i < field_size; ++i) { \ | |||||
auto value = reflection.GetRepeatedString(inner_message, field_ptr, i); \ | |||||
layer_proto->add_##method(value); \ | |||||
} \ | |||||
return SUCCESS; \ | |||||
} | |||||
CASE_FIELD_NAME_REPEATED(Bottom, bottom, message, &field); | |||||
CASE_FIELD_NAME_REPEATED(Top, top, message, &field); | |||||
#undef CASE_FIELD_NAME_REPEATED | #undef CASE_FIELD_NAME_REPEATED | ||||
if (filed_name == kFieldBlobs) { | if (filed_name == kFieldBlobs) { | ||||
int field_size = reflection->FieldSize(*message, field); | |||||
int field_size = reflection.FieldSize(message, &field); | |||||
for (int i = 0; i < field_size; ++i) { | for (int i = 0; i < field_size; ++i) { | ||||
domi::caffe::BlobProto *item_message = layer_proto->add_blobs(); | domi::caffe::BlobProto *item_message = layer_proto->add_blobs(); | ||||
const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); | |||||
if (ConvertBlobsProto(&sub_message, item_message) != SUCCESS) { | |||||
GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field->name().c_str()); | |||||
const google::protobuf::Message &sub_message = reflection.GetRepeatedMessage(message, &field, i); | |||||
if (ConvertBlobsProto(sub_message, item_message) != SUCCESS) { | |||||
GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field.name().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if (filed_name == kFieldConvParam) { | if (filed_name == kFieldConvParam) { | ||||
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); | |||||
const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); | |||||
ConvolutionParameter *conv_param = layer_proto->mutable_convolution_param(); | ConvolutionParameter *conv_param = layer_proto->mutable_convolution_param(); | ||||
ConvertConvParamProto(&sub_message, conv_param); | |||||
ConvertConvParamProto(sub_message, conv_param); | |||||
} | } | ||||
if (filed_name == kFieldInnerPro) { | if (filed_name == kFieldInnerPro) { | ||||
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); | |||||
const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); | |||||
InnerProductParameter *inner_product = layer_proto->mutable_inner_product_param(); | InnerProductParameter *inner_product = layer_proto->mutable_inner_product_param(); | ||||
ConvertInnerProdcutProto(&sub_message, inner_product); | |||||
ConvertInnerProdcutProto(sub_message, inner_product); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message, | |||||
Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message &message, | |||||
google::protobuf::Message *blobs) const { | google::protobuf::Message *blobs) const { | ||||
const google::protobuf::Reflection *blobs_reflection = message->GetReflection(); | |||||
const google::protobuf::Reflection *blobs_reflection = message.GetReflection(); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
blobs_reflection->ListFields(*message, &field_desc); | |||||
blobs_reflection->ListFields(message, &field_desc); | |||||
domi::caffe::BlobProto *blobs_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobProto>(blobs); | domi::caffe::BlobProto *blobs_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobProto>(blobs); | ||||
for (auto &field : field_desc) { | for (auto &field : field_desc) { | ||||
GE_CHECK_NOTNULL(field); | GE_CHECK_NOTNULL(field); | ||||
string feild_name = field->name(); | string feild_name = field->name(); | ||||
#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name) \ | |||||
if (feild_name == #kName) { \ | |||||
int field_size = blobs_reflection->FieldSize(*message, field); \ | |||||
for (int i = 0; i < field_size; ++i) { \ | |||||
valuetype value = blobs_reflection->GetRepeated##method(*message, field, i); \ | |||||
blobs_proto->add_##name(value); \ | |||||
} \ | |||||
continue; \ | |||||
} | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data); | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff); | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data); | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff); | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data); | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data); | |||||
#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name, inner_message, inner_field) \ | |||||
if (feild_name == #kName) { \ | |||||
int field_size = blobs_reflection->FieldSize(inner_message, inner_field); \ | |||||
for (int i = 0; i < field_size; ++i) { \ | |||||
valuetype value = blobs_reflection->GetRepeated##method(inner_message, inner_field, i); \ | |||||
blobs_proto->add_##name(value); \ | |||||
} \ | |||||
continue; \ | |||||
} | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data, message, field); | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff, message, field); | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data, message, field); | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff, message, field); | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data, message, field); | |||||
CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data, message, field); | |||||
#undef CASE_BLOBS_FIELD_NAME_REPEATED | #undef CASE_BLOBS_FIELD_NAME_REPEATED | ||||
#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name) \ | |||||
if (feild_name == #kName) { \ | |||||
valuetype value = blobs_reflection->Get##method(*message, field); \ | |||||
blobs_proto->set_##name(value); \ | |||||
continue; \ | |||||
} | |||||
CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data); | |||||
CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num); | |||||
CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels); | |||||
CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height); | |||||
CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width); | |||||
#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name, inner_message, inner_field) \ | |||||
if (feild_name == #kName) { \ | |||||
valuetype value = blobs_reflection->Get##method(inner_message, inner_field); \ | |||||
blobs_proto->set_##name(value); \ | |||||
continue; \ | |||||
} | |||||
CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data, message, field); | |||||
CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num, message, field); | |||||
CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels, message, field); | |||||
CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height, message, field); | |||||
CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width, message, field); | |||||
#undef CASE_BLOBS_FIELD_NAME | #undef CASE_BLOBS_FIELD_NAME | ||||
if (feild_name == kFieldShape) { | if (feild_name == kFieldShape) { | ||||
const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(*message, field); | |||||
const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(message, field); | |||||
domi::caffe::BlobShape *blob_shape = blobs_proto->mutable_shape(); | domi::caffe::BlobShape *blob_shape = blobs_proto->mutable_shape(); | ||||
ConvertBlobShapeProto(&sub_message, blob_shape); | |||||
ConvertBlobShapeProto(sub_message, blob_shape); | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message *message, | |||||
Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message &message, | |||||
google::protobuf::Message *dest_message) const { | google::protobuf::Message *dest_message) const { | ||||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
reflection->ListFields(*message, &field_desc); | |||||
reflection->ListFields(message, &field_desc); | |||||
domi::caffe::BlobShape *shape_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobShape>(dest_message); | domi::caffe::BlobShape *shape_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobShape>(dest_message); | ||||
@@ -2039,21 +2035,21 @@ Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message | |||||
if (field->name() != kFieldDim) { | if (field->name() != kFieldDim) { | ||||
continue; | continue; | ||||
} | } | ||||
int field_size = reflection->FieldSize(*message, field); | |||||
int field_size = reflection->FieldSize(message, field); | |||||
for (int i = 0; i < field_size; ++i) { | for (int i = 0; i < field_size; ++i) { | ||||
int64_t value = reflection->GetRepeatedInt64(*message, field, i); | |||||
int64_t value = reflection->GetRepeatedInt64(message, field, i); | |||||
shape_proto->add_dim(value); | shape_proto->add_dim(value); | ||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message *message, | |||||
Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message &message, | |||||
google::protobuf::Message *dest_message) const { | google::protobuf::Message *dest_message) const { | ||||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
reflection->ListFields(*message, &field_desc); | |||||
reflection->ListFields(message, &field_desc); | |||||
domi::caffe::ConvolutionParameter *conv_param_proto = | domi::caffe::ConvolutionParameter *conv_param_proto = | ||||
PtrToPtr<google::protobuf::Message, domi::caffe::ConvolutionParameter>(dest_message); | PtrToPtr<google::protobuf::Message, domi::caffe::ConvolutionParameter>(dest_message); | ||||
@@ -2062,18 +2058,18 @@ Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message | |||||
if (field->name() != kFieldBiasTerm) { | if (field->name() != kFieldBiasTerm) { | ||||
continue; | continue; | ||||
} | } | ||||
bool value = reflection->GetBool(*message, field); | |||||
bool value = reflection->GetBool(message, field); | |||||
conv_param_proto->set_bias_term(value); | conv_param_proto->set_bias_term(value); | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message *message, | |||||
Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message &message, | |||||
google::protobuf::Message *dest_message) const { | google::protobuf::Message *dest_message) const { | ||||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
reflection->ListFields(*message, &field_desc); | |||||
reflection->ListFields(message, &field_desc); | |||||
domi::caffe::InnerProductParameter *inner_product_proto = | domi::caffe::InnerProductParameter *inner_product_proto = | ||||
PtrToPtr<google::protobuf::Message, domi::caffe::InnerProductParameter>(dest_message); | PtrToPtr<google::protobuf::Message, domi::caffe::InnerProductParameter>(dest_message); | ||||
@@ -2082,17 +2078,17 @@ Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Mess | |||||
if (field->name() != kFieldBiasTerm) { | if (field->name() != kFieldBiasTerm) { | ||||
continue; | continue; | ||||
} | } | ||||
bool value = reflection->GetBool(*message, field); | |||||
bool value = reflection->GetBool(message, field); | |||||
inner_product_proto->set_bias_term(value); | inner_product_proto->set_bias_term(value); | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *message) const { | |||||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||||
Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message &message) const { | |||||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
reflection->ListFields(*message, &field_desc); | |||||
reflection->ListFields(message, &field_desc); | |||||
int num_layer = 0; | int num_layer = 0; | ||||
int num_layers = 0; | int num_layers = 0; | ||||
@@ -2110,7 +2106,7 @@ Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *mess | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
int field_size = reflection->FieldSize(*message, field); | |||||
int field_size = reflection->FieldSize(message, field); | |||||
if (field->name() == kLayerName) { | if (field->name() == kLayerName) { | ||||
num_layer = field_size; | num_layer = field_size; | ||||
} else { | } else { | ||||
@@ -212,8 +212,8 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
* @return SUCCESS parse layer successfully | * @return SUCCESS parse layer successfully | ||||
* @return FAILED parse layer failed | * @return FAILED parse layer failed | ||||
*/ | */ | ||||
Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||||
const google::protobuf::Message *message, std::vector<ge::Operator> &operators) const; | |||||
Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||||
const google::protobuf::Message &message, std::vector<ge::Operator> &operators) const; | |||||
/* | /* | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -386,33 +386,33 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser { | |||||
Status ParseWeightByFusionProto(const char *weight_path, const string &fusion_proto_path, | Status ParseWeightByFusionProto(const char *weight_path, const string &fusion_proto_path, | ||||
const string &fusion_proto_name, ge::ComputeGraphPtr &graph); | const string &fusion_proto_name, ge::ComputeGraphPtr &graph); | ||||
Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||||
const google::protobuf::Message *message, | |||||
Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||||
const google::protobuf::Message &message, | |||||
ge::ComputeGraphPtr &graph); | ge::ComputeGraphPtr &graph); | ||||
Status ConvertLayerParameter(const google::protobuf::Message *layer_message, | Status ConvertLayerParameter(const google::protobuf::Message *layer_message, | ||||
ge::ComputeGraphPtr &graph); | ge::ComputeGraphPtr &graph); | ||||
Status CheckLayersSize(const google::protobuf::Message *message) const; | |||||
Status CheckLayersSize(const google::protobuf::Message &message) const; | |||||
Status ConvertLayerProto(const google::protobuf::Message *message, | |||||
Status ConvertLayerProto(const google::protobuf::Message &message, | |||||
google::protobuf::Message *layer); | google::protobuf::Message *layer); | ||||
Status ParseLayerField(const google::protobuf::Reflection *reflection, | |||||
const google::protobuf::Message *message, | |||||
const google::protobuf::FieldDescriptor *field, | |||||
Status ParseLayerField(const google::protobuf::Reflection &reflection, | |||||
const google::protobuf::Message &message, | |||||
const google::protobuf::FieldDescriptor &field, | |||||
google::protobuf::Message *layer) const; | google::protobuf::Message *layer) const; | ||||
Status ConvertBlobsProto(const google::protobuf::Message *message, | |||||
Status ConvertBlobsProto(const google::protobuf::Message &message, | |||||
google::protobuf::Message *blobs) const; | google::protobuf::Message *blobs) const; | ||||
Status ConvertBlobShapeProto(const google::protobuf::Message *message, | |||||
Status ConvertBlobShapeProto(const google::protobuf::Message &message, | |||||
google::protobuf::Message *dest_message) const; | google::protobuf::Message *dest_message) const; | ||||
Status ConvertInnerProdcutProto(const google::protobuf::Message *message, | |||||
Status ConvertInnerProdcutProto(const google::protobuf::Message &message, | |||||
google::protobuf::Message *dest_message) const; | google::protobuf::Message *dest_message) const; | ||||
Status ConvertConvParamProto(const google::protobuf::Message *message, | |||||
Status ConvertConvParamProto(const google::protobuf::Message &message, | |||||
google::protobuf::Message *dest_message) const; | google::protobuf::Message *dest_message) const; | ||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -69,16 +69,14 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { | |||||
break; | break; | ||||
} | } | ||||
#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ | #define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ | ||||
case dt_type: \ | |||||
{ \ | |||||
case dt_type: { \ | |||||
unique_ptr<value_type> addr_trans(new(std::nothrow) value_type[count]()); \ | unique_ptr<value_type> addr_trans(new(std::nothrow) value_type[count]()); \ | ||||
GE_CHECK_NOTNULL(addr_trans); \ | GE_CHECK_NOTNULL(addr_trans); \ | ||||
for (int32_t i = 0; i < (count); i++) { \ | for (int32_t i = 0; i < (count); i++) { \ | ||||
*(addr_trans.get() + i) = static_cast<value_type>(*((addr).get() + i)); \ | *(addr_trans.get() + i) = static_cast<value_type>(*((addr).get() + i)); \ | ||||
} \ | } \ | ||||
(tensor).SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), (count) * sizeof(value_type)); \ | (tensor).SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), (count) * sizeof(value_type)); \ | ||||
break; \ | |||||
} \ | |||||
break; } \ | |||||
CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor) | CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor) | ||||
CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor) | CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor) | ||||
@@ -89,7 +87,7 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { | |||||
#undef CASE_SET_DATA | #undef CASE_SET_DATA | ||||
default: | default: | ||||
{ | { | ||||
tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(T)); | |||||
tensor.SetData(PtrToPtr<T, uint8_t>(addr.get()), count * sizeof(T)); | |||||
break; | break; | ||||
} | } | ||||
} | } | ||||
@@ -31,11 +31,11 @@ using GeTensorDesc = ge::GeTensorDesc; | |||||
using namespace ge::parser; | using namespace ge::parser; | ||||
namespace { | namespace { | ||||
const std::string kAttrShape = "shape"; | |||||
const std::string kAttrDataType = "dtype"; | |||||
const std::string kFileConstantPath = "file_constant_path"; | |||||
const std::string kLocation = "location"; | |||||
const std::string kOffset = "offset"; | |||||
const char *const kAttrShape = "shape"; | |||||
const char *const kAttrDataType = "dtype"; | |||||
const char *const kFileConstantPath = "file_constant_path"; | |||||
const char *const kLocation = "location"; | |||||
const char *const kOffset = "offset"; | |||||
const int64_t kOffsetCoefficient = 4096; | const int64_t kOffsetCoefficient = 4096; | ||||
const char *const kFileConstant = "FileConstant"; | const char *const kFileConstant = "FileConstant"; | ||||
} | } | ||||
@@ -46,7 +46,7 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator & | |||||
GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); | GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); | ||||
ge::onnx::TensorProto tensor_proto; | ge::onnx::TensorProto tensor_proto; | ||||
if (GetTensorProto(node, tensor_proto) != SUCCESS) { | |||||
if (GetTensorProto(*node, tensor_proto) != SUCCESS) { | |||||
REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str()); | REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str()); | ||||
GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str()); | GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str()); | ||||
return FAILED; | return FAILED; | ||||
@@ -65,29 +65,29 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator & | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto *node_proto, | |||||
ge::onnx::TensorProto &tensor_proto) { | |||||
for (const auto &it : node_proto->attribute()) { | |||||
Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto &node_proto, | |||||
ge::onnx::TensorProto &tensor_proto) const { | |||||
for (const auto &it : node_proto.attribute()) { | |||||
if (it.name() != ge::kAttrNameValue) { | if (it.name() != ge::kAttrNameValue) { | ||||
continue; | continue; | ||||
} | } | ||||
tensor_proto = it.t(); | tensor_proto = it.t(); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto->name().c_str()); | |||||
GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto->name().c_str()); | |||||
REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto.name().c_str()); | |||||
GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto.name().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||||
void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||||
std::vector<int64_t> tmp_shape; | std::vector<int64_t> tmp_shape; | ||||
for (int i = 0; i < tensor_proto.dims_size(); i++) { | for (int i = 0; i < tensor_proto.dims_size(); i++) { | ||||
tmp_shape.push_back(tensor_proto.dims(i)); | tmp_shape.push_back(tensor_proto.dims(i)); | ||||
} | } | ||||
op_def.SetAttr(kAttrShape.c_str(), tmp_shape); | |||||
op_def.SetAttr(kAttrShape, tmp_shape); | |||||
} | } | ||||
Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||||
Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||||
int64_t data_type = tensor_proto.data_type(); | int64_t data_type = tensor_proto.data_type(); | ||||
ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type); | ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type); | ||||
if (type >= ge::DataType::DT_UNDEFINED) { | if (type >= ge::DataType::DT_UNDEFINED) { | ||||
@@ -96,11 +96,11 @@ Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
op_def.SetAttr(kAttrDataType.c_str(), type); | |||||
op_def.SetAttr(kAttrDataType, type); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||||
Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||||
ge::NamedAttrs attrs; | ge::NamedAttrs attrs; | ||||
for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) { | for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) { | ||||
const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i); | const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i); | ||||
@@ -116,12 +116,12 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro | |||||
GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
op_def.SetAttr(kFileConstantPath.c_str(), attrs); | |||||
op_def.SetAttr(kFileConstantPath, attrs); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, | Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, | ||||
ge::NamedAttrs &attrs) { | |||||
ge::NamedAttrs &attrs) const { | |||||
if (string_proto.key() == kLocation) { | if (string_proto.key() == kLocation) { | ||||
AttrUtils::SetStr(attrs, kLocation, string_proto.value()); | AttrUtils::SetStr(attrs, kLocation, string_proto.value()); | ||||
} else { | } else { | ||||
@@ -134,7 +134,7 @@ Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProt | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (string_proto.key() == kOffset) { | if (string_proto.key() == kOffset) { | ||||
if (std::numeric_limits<int64_t>::max() / kOffsetCoefficient < value) { | |||||
if (value > (std::numeric_limits<int64_t>::max() / kOffsetCoefficient)) { | |||||
REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | ||||
GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | ||||
return FAILED; | return FAILED; | ||||
@@ -26,11 +26,11 @@ class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser { | |||||
Status ParseParams(const Message *op_src, ge::Operator &op_def) override; | Status ParseParams(const Message *op_src, ge::Operator &op_def) override; | ||||
private: | private: | ||||
Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||||
Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||||
void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||||
Status GetTensorProto(const ge::onnx::NodeProto *node_proto, ge::onnx::TensorProto &tensor_proto); | |||||
Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs); | |||||
Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||||
Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||||
void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||||
Status GetTensorProto(const ge::onnx::NodeProto &node_proto, ge::onnx::TensorProto &tensor_proto) const; | |||||
Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs) const; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -232,7 +232,8 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra | |||||
domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type); | domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type); | ||||
if (post_func == nullptr) { | if (post_func == nullptr) { | ||||
GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str()); | GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str()); | ||||
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) { | |||||
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || | |||||
parse_func_v2 == nullptr) { | |||||
GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str()); | GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str()); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -522,9 +523,9 @@ Status OnnxModelParser::SetOperatorInputs() { | |||||
auto src_op = output_op_iter->second; | auto src_op = output_op_iter->second; | ||||
int dst_index = input_node_index.second; | int dst_index = input_node_index.second; | ||||
int src_index = out_node_index.second; | int src_index = out_node_index.second; | ||||
GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", src_index, | |||||
ParserUtils::GetOperatorName(src_op).c_str(), dst_index, | |||||
ParserUtils::GetOperatorName(dst_op).c_str()); | |||||
GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", | |||||
src_index, ParserUtils::GetOperatorName(src_op).c_str(), | |||||
dst_index, ParserUtils::GetOperatorName(dst_op).c_str()); | |||||
auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op); | auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op); | ||||
GE_CHECK_NOTNULL(dst_op_desc); | GE_CHECK_NOTNULL(dst_op_desc); | ||||
auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op); | auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op); | ||||
@@ -689,7 +690,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
input_ops.emplace_back(in_op->second); | input_ops.emplace_back(in_op->second); | ||||
GELOGI("Model assigned input node name: %s", ParserUtils::GetOperatorName(in_op->second).c_str()); | |||||
GELOGI("Model assigned input node name: %s", | |||||
ParserUtils::GetOperatorName(in_op->second).c_str()); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -717,7 +719,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vec | |||||
int index = node_name_index.second; | int index = node_name_index.second; | ||||
output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)}); | output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)}); | ||||
out_tensor_to_nodes[output_name] = std::make_pair(node_name, index); | out_tensor_to_nodes[output_name] = std::make_pair(node_name, index); | ||||
GELOGI("out node index %d, node:%s", index, node_name.c_str()); | |||||
GELOGI("Out node index %d, node:%s", index, node_name.c_str()); | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -934,16 +936,13 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||||
cur_compute_graph->GetName().c_str()); | cur_compute_graph->GetName().c_str()); | ||||
return ret; | return ret; | ||||
} | } | ||||
} | } | ||||
UpdateDataFormat(root_graph); | UpdateDataFormat(root_graph); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { | Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { | ||||
ClearMembers(); | ClearMembers(); | ||||
GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), | GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), | ||||
"Run ProtoType Pass Failed"); | "Run ProtoType Pass Failed"); | ||||
// 1. Get all inializer. | // 1. Get all inializer. | ||||
@@ -1174,7 +1173,8 @@ Status OnnxModelParser::SetOutputsInfo(const ParserUtils::OutputMapping &final_o | |||||
default_out_nodes.emplace_back(output_node_info); | default_out_nodes.emplace_back(output_node_info); | ||||
output_tensor_names.emplace_back(tensor_name); | output_tensor_names.emplace_back(tensor_name); | ||||
GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].", | GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].", | ||||
output_node_info.first.c_str(), output_node_info.second, tensor_name.c_str()); | |||||
output_node_info.first.c_str(), | |||||
output_node_info.second, tensor_name.c_str()); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -17,6 +17,8 @@ | |||||
#ifndef PARSER_ONNX_ONNX_UTIL_PARSER_H_ | #ifndef PARSER_ONNX_ONNX_UTIL_PARSER_H_ | ||||
#define PARSER_ONNX_ONNX_UTIL_PARSER_H_ | #define PARSER_ONNX_ONNX_UTIL_PARSER_H_ | ||||
#include <string> | |||||
#include <cstdint> | |||||
#include "external/graph/types.h" | #include "external/graph/types.h" | ||||
namespace OnnxDataType { | namespace OnnxDataType { | ||||
@@ -59,4 +61,4 @@ class OnnxUtil { | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif //PARSER_ONNX_ONNX_UTIL_PARSER_H_ | |||||
#endif // PARSER_ONNX_ONNX_UTIL_PARSER_H_ |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -25,6 +25,7 @@ using parser::IF; | |||||
namespace { | namespace { | ||||
const std::map<std::string, int> kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}}; | const std::map<std::string, int> kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}}; | ||||
const int kIfNodeAttrSize = 2; | const int kIfNodeAttrSize = 2; | ||||
const char *kIf = "If"; | |||||
} // namespace | } // namespace | ||||
domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | ||||
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ||||
@@ -33,7 +34,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||||
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), | GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), | ||||
parent_node->op_type().c_str()); | parent_node->op_type().c_str()); | ||||
auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); | |||||
auto ret = ParseIfNodeSubgraphs(*parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "[Parse][Node] Parse if node failed."); | GELOGE(ret, "[Parse][Node] Parse if node failed."); | ||||
REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); | REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); | ||||
@@ -44,19 +45,19 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||||
} | } | ||||
domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | ||||
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
ge::onnx::NodeProto &parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) const { | std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) const { | ||||
if (parent_node->attribute_size() != kIfNodeAttrSize) { | |||||
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | |||||
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | |||||
if (parent_node.attribute_size() != kIfNodeAttrSize) { | |||||
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size()); | |||||
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
GELOGD("node attribute size:%d.", parent_node->attribute_size()); | |||||
GELOGD("node attribute size:%d.", parent_node.attribute_size()); | |||||
std::set<std::string> all_inputs; | std::set<std::string> all_inputs; | ||||
// for onnx graph, the first attribute may be else branch and the second attribute may be then branch | // for onnx graph, the first attribute may be else branch and the second attribute may be then branch | ||||
for (int i = 0; i < parent_node->attribute_size(); i++) { | |||||
ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i); | |||||
for (int i = 0; i < parent_node.attribute_size(); i++) { | |||||
ge::onnx::AttributeProto *attribute = parent_node.mutable_attribute(i); | |||||
GE_CHECK_NOTNULL(attribute); | GE_CHECK_NOTNULL(attribute); | ||||
std::string attr_name = attribute->name(); | std::string attr_name = attribute->name(); | ||||
auto itr = kAttrNameToIndex.find(attr_name); | auto itr = kAttrNameToIndex.find(attr_name); | ||||
@@ -68,7 +69,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
std::string unique_subgraph_name; | std::string unique_subgraph_name; | ||||
std::string node_name = parent_node->name(); | |||||
std::string node_name = parent_node.name(); | |||||
if (!parent_graph_name.empty()) { | if (!parent_graph_name.empty()) { | ||||
node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name); | node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name); | ||||
} | } | ||||
@@ -90,7 +91,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||||
AddInputNodeForGraph(all_inputs, *onnx_graph); | AddInputNodeForGraph(all_inputs, *onnx_graph); | ||||
} | } | ||||
AddInputForParentNode(all_inputs, *parent_node); | |||||
AddInputForParentNode(all_inputs, parent_node); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -135,5 +136,5 @@ void IfSubgraphAdapter::AddInputForParentNode(const std::set<std::string> &all_i | |||||
parent_node.add_input(input_name); | parent_node.add_input(input_name); | ||||
} | } | ||||
} | } | ||||
REGISTER_SUBGRAPH_ADAPTER_CREATOR(IF, IfSubgraphAdapter); | |||||
REGISTER_SUBGRAPH_ADAPTER_CREATOR(kIf, IfSubgraphAdapter); | |||||
} // namespace ge | } // namespace ge |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -20,6 +20,7 @@ | |||||
#include <set> | #include <set> | ||||
#include <string> | #include <string> | ||||
#include "subgraph_adapter.h" | #include "subgraph_adapter.h" | ||||
#include "parser/onnx/onnx_util.h" | |||||
namespace ge { | namespace ge { | ||||
class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | ||||
@@ -30,7 +31,7 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | |||||
const std::string &parent_graph_name = "") override; | const std::string &parent_graph_name = "") override; | ||||
private: | private: | ||||
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto &parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | ||||
const std::string &parent_graph_name) const; | const std::string &parent_graph_name) const; | ||||
domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; | domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -36,8 +36,6 @@ | |||||
#include "proto/onnx/ge_onnx.pb.h" | #include "proto/onnx/ge_onnx.pb.h" | ||||
#include "external/register/register_error_codes.h" | #include "external/register/register_error_codes.h" | ||||
#include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
#include "parser/onnx/onnx_util.h" | |||||
namespace ge { | namespace ge { | ||||
class PARSER_FUNC_VISIBILITY SubgraphAdapter { | class PARSER_FUNC_VISIBILITY SubgraphAdapter { | ||||
public: | public: | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -62,7 +62,6 @@ protected: | |||||
* @brief SubgraphAdapter creation function | * @brief SubgraphAdapter creation function | ||||
* @return Created SubgraphAdapter | * @return Created SubgraphAdapter | ||||
*/ | */ | ||||
// typedef shared_ptr<SubgraphAdapter> (*CREATOR_FUN)(void); | |||||
using CREATOR_FUN = std::function<std::shared_ptr<SubgraphAdapter>(void)>; | using CREATOR_FUN = std::function<std::shared_ptr<SubgraphAdapter>(void)>; | ||||
/** | /** | ||||
@@ -105,7 +104,7 @@ public: | |||||
* @param [in] op_type Op type | * @param [in] op_type Op type | ||||
* @param [in] clazz SubgraphAdapter implementation class | * @param [in] clazz SubgraphAdapter implementation class | ||||
*/ | */ | ||||
#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ | |||||
#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ | |||||
std::shared_ptr<SubgraphAdapter> Creator_##op_type##_Subgraph_Adapter() { \ | std::shared_ptr<SubgraphAdapter> Creator_##op_type##_Subgraph_Adapter() { \ | ||||
std::shared_ptr<clazz> ptr(new (std::nothrow) clazz()); \ | std::shared_ptr<clazz> ptr(new (std::nothrow) clazz()); \ | ||||
if (ptr == nullptr) { \ | if (ptr == nullptr) { \ | ||||
@@ -167,6 +167,33 @@ class H2CC(object): | |||||
del self.stack_template | del self.stack_template | ||||
del self.func_list_exist | del self.func_list_exist | ||||
@staticmethod | |||||
def implement_function(func): | |||||
function_def = '' | |||||
function_def += '{\n' | |||||
all_items = func.split() | |||||
start = 0 | |||||
return_type = all_items[start] | |||||
if return_type == "const": | |||||
start += 1 | |||||
return_type = all_items[start] | |||||
if return_type.startswith(('std::map', 'std::set', 'std::vector')): | |||||
return_type = "std::map" | |||||
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): | |||||
return_type = "Ptr" | |||||
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): | |||||
return_type += "&" | |||||
if RETURN_STATEMENTS.__contains__(return_type): | |||||
function_def += RETURN_STATEMENTS[return_type] | |||||
else: | |||||
logging.warning("Unhandled return type[%s]", return_type) | |||||
function_def += '\n' | |||||
function_def += '}\n' | |||||
function_def += '\n' | |||||
return function_def | |||||
def just_skip(self): | def just_skip(self): | ||||
# skip blank line or comment | # skip blank line or comment | ||||
if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( | if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( | ||||
@@ -263,6 +290,7 @@ class H2CC(object): | |||||
logging.info('Added %s functions', len(self.func_list_exist)) | logging.info('Added %s functions', len(self.func_list_exist)) | ||||
logging.info('Successfully converted,please see ' + self.output_file) | logging.info('Successfully converted,please see ' + self.output_file) | ||||
def handle_func1(self, line): | def handle_func1(self, line): | ||||
""" | """ | ||||
:param line: | :param line: | ||||
@@ -461,12 +489,6 @@ class H2CC(object): | |||||
logging.info("func_name[%s]", func_name) | logging.info("func_name[%s]", func_name) | ||||
return line, func_name | return line, func_name | ||||
def write_func_content(self, content, func_name, need_generate): | |||||
if not (func_name in self.func_list_exist) and need_generate: | |||||
self.output_fd.write(content) | |||||
self.func_list_exist.append(func_name) | |||||
logging.info('add func:[%s]', func_name) | |||||
def gen_comment(self, start_i): | def gen_comment(self, start_i): | ||||
comment_line = '' | comment_line = '' | ||||
# Function comments are on top of function declarations, copy them over | # Function comments are on top of function declarations, copy them over | ||||
@@ -488,32 +510,11 @@ class H2CC(object): | |||||
break | break | ||||
return comment_line | return comment_line | ||||
@staticmethod | |||||
def implement_function(func): | |||||
function_def = '' | |||||
function_def += '{\n' | |||||
all_items = func.split() | |||||
start = 0 | |||||
return_type = all_items[start] | |||||
if return_type == "const": | |||||
start += 1 | |||||
return_type = all_items[start] | |||||
if return_type.startswith(('std::map', 'std::set', 'std::vector')): | |||||
return_type = "std::map" | |||||
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): | |||||
return_type = "Ptr" | |||||
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): | |||||
return_type += "&" | |||||
if RETURN_STATEMENTS.__contains__(return_type): | |||||
function_def += RETURN_STATEMENTS[return_type] | |||||
else: | |||||
logging.warning("Unhandled return type[%s]", return_type) | |||||
function_def += '\n' | |||||
function_def += '}\n' | |||||
function_def += '\n' | |||||
return function_def | |||||
def write_func_content(self, content, func_name, need_generate): | |||||
if not (func_name in self.func_list_exist) and need_generate: | |||||
self.output_fd.write(content) | |||||
self.func_list_exist.append(func_name) | |||||
logging.info('add func:[%s]', func_name) | |||||
def collect_header_files(path): | def collect_header_files(path): | ||||
@@ -765,7 +765,7 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_CheckLayersSize_test) | |||||
layer->set_name("Abs"); | layer->set_name("Abs"); | ||||
layer->set_type("AbsVal"); | layer->set_type("AbsVal"); | ||||
Status ret = weightParser.CheckLayersSize(layer); | |||||
Status ret = weightParser.CheckLayersSize(*layer); | |||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
} | } | ||||
@@ -809,8 +809,54 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test) | |||||
const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | ||||
const google::protobuf::Message *message = proto->New(); | const google::protobuf::Message *message = proto->New(); | ||||
Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph); | |||||
Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph); | |||||
delete message; | delete message; | ||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | } | ||||
TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_DimSize_0) | |||||
{ | |||||
CaffeModelParser modelParser; | |||||
domi::caffe::NetParameter net; | |||||
net.add_input("111"); | |||||
net.add_input_shape(); | |||||
bool input_data_flag = true; | |||||
Status ret = modelParser.ParseInput(net, input_data_flag); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_Err1) | |||||
{ | |||||
CaffeModelParser modelParser; | |||||
domi::caffe::NetParameter net; | |||||
net.add_input("111"); | |||||
net.add_input("222"); | |||||
net.add_input_shape(); | |||||
bool input_data_flag = true; | |||||
Status ret = modelParser.ParseInput(net, input_data_flag); | |||||
EXPECT_EQ(ret, FAILED); | |||||
} | |||||
TEST_F(STestCaffeParser, CaffeModelParser_ParserLayerParameter_Succ) | |||||
{ | |||||
CaffeModelParser modelParser; | |||||
std::string case_dir = __FILE__; | |||||
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
std::string model_file = case_dir + "/origin_models/"; | |||||
const char *model_path = model_file.c_str(); | |||||
std::string custom_proto = model_file; | |||||
std::string caffe_proto = model_file; | |||||
std::vector<ge::Operator> operators; | |||||
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Data", "Input"); | |||||
ge::Operator op_src = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
operators.emplace_back(op_src); | |||||
model_file = case_dir + "/origin_models/caffe_add.pbtxt"; | |||||
custom_proto = case_dir + "/../../../metadef/proto/caffe/caffe.proto"; | |||||
model_path = model_file.c_str(); | |||||
std::string caffe_proto_path = case_dir + "/../../../metadef/proto/caffe/caffe.proto"; | |||||
auto ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto_path, operators); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -739,6 +739,9 @@ TEST_F(UtestCaffeParser, CaffeModelParser_CustomProtoParse_test) | |||||
Status ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto, operators); | Status ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto, operators); | ||||
EXPECT_EQ(ret, PARAM_INVALID); | EXPECT_EQ(ret, PARAM_INVALID); | ||||
ret = modelParser.CustomProtoParse("", custom_proto, caffe_proto, operators); | |||||
EXPECT_EQ(ret, FAILED); | |||||
model_file = case_dir + "/caffe_model/caffe_add.pbtxt"; | model_file = case_dir + "/caffe_model/caffe_add.pbtxt"; | ||||
custom_proto = case_dir + "/../../../../../metadef/proto/caffe/caffe.proto"; | custom_proto = case_dir + "/../../../../../metadef/proto/caffe/caffe.proto"; | ||||
model_path = model_file.c_str(); | model_path = model_file.c_str(); | ||||
@@ -890,7 +893,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_CheckLayersSize_test) | |||||
layer->set_name("Abs"); | layer->set_name("Abs"); | ||||
layer->set_type("AbsVal"); | layer->set_type("AbsVal"); | ||||
Status ret = weightParser.CheckLayersSize(layer); | |||||
Status ret = weightParser.CheckLayersSize(*layer); | |||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
} | } | ||||
@@ -902,7 +905,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test) | |||||
layer->set_name("Abs"); | layer->set_name("Abs"); | ||||
layer->set_type("AbsVal"); | layer->set_type("AbsVal"); | ||||
Status ret = weightParser.ConvertLayerProto(&net, &net); | |||||
Status ret = weightParser.ConvertLayerProto(net, &net); | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
BlobProto* blob = layer->add_blobs(); | BlobProto* blob = layer->add_blobs(); | ||||
@@ -911,16 +914,16 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test) | |||||
BlobShape* shap = blob->mutable_shape(); | BlobShape* shap = blob->mutable_shape(); | ||||
shap->add_dim(1); | shap->add_dim(1); | ||||
shap->add_dim(2); | shap->add_dim(2); | ||||
ret = weightParser.ConvertBlobsProto(&net, &net); | |||||
ret = weightParser.ConvertBlobsProto(net, &net); | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
ret = weightParser.ConvertBlobShapeProto(&net, &net); | |||||
ret = weightParser.ConvertBlobShapeProto(net, &net); | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
ret = weightParser.ConvertConvParamProto(&net, &net); | |||||
ret = weightParser.ConvertConvParamProto(net, &net); | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
ret = weightParser.ConvertInnerProdcutProto(&net, &net); | |||||
ret = weightParser.ConvertInnerProdcutProto(net, &net); | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | } | ||||
@@ -1133,7 +1136,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test) | |||||
const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | ||||
const google::protobuf::Message *message = proto->New(); | const google::protobuf::Message *message = proto->New(); | ||||
Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph); | |||||
Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph); | |||||
delete message; | delete message; | ||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | } | ||||
@@ -1163,7 +1166,7 @@ TEST_F(UtestCaffeParser, CaffeModelParser_ParseLayerParameter_test) | |||||
google::protobuf::DynamicMessageFactory factory; | google::protobuf::DynamicMessageFactory factory; | ||||
const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | ||||
google::protobuf::Message *message = proto->New(); | google::protobuf::Message *message = proto->New(); | ||||
Status ret = modelParser.ParseLayerParameter(descriptor, message, operators); | |||||
Status ret = modelParser.ParseLayerParameter(*descriptor, *message, operators); | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
const domi::FrameworkType fmk_type = domi::TENSORFLOW; | const domi::FrameworkType fmk_type = domi::TENSORFLOW; | ||||
@@ -381,7 +381,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto) | |||||
OnnxFileConstantParser parser; | OnnxFileConstantParser parser; | ||||
ge::onnx::NodeProto input_node; | ge::onnx::NodeProto input_node; | ||||
ge::onnx::TensorProto tensor_proto; | ge::onnx::TensorProto tensor_proto; | ||||
Status ret = parser.GetTensorProto(&input_node, tensor_proto); | |||||
Status ret = parser.GetTensorProto(input_node, tensor_proto); | |||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | ||||
@@ -391,7 +391,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto) | |||||
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | ||||
*attribute_tensor = tensor_proto; | *attribute_tensor = tensor_proto; | ||||
ret = parser.GetTensorProto(&input_node, tensor_proto); | |||||
ret = parser.GetTensorProto(input_node, tensor_proto); | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | } | ||||