@@ -236,14 +236,8 @@ const char *const kFieldInnerPro = "inner_product_param"; | |||
const char *const kFieldDim = "dim"; | |||
const char *const kFieldBiasTerm = "bias_term"; | |||
const char *const kDevNull = "/dev/null"; | |||
const std::string kMessage = "message"; | |||
const std::string kLayerParameter = "LayerParameter"; | |||
const std::string kCloseBrace = "}"; | |||
const std::string kOptional = "optional"; | |||
const std::string kRepeated = "repeated"; | |||
const std::string kRequired = "required"; | |||
const std::string kCustom = "custom"; | |||
const std::string kBuiltin = "built-in"; | |||
const char *const kCustom = "custom"; | |||
const char *const kBuiltin = "built-in"; | |||
std::vector<std::string> kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT, | |||
ge::parser::NETOUTPUT}; | |||
const std::set<std::string> kCustomProtoLayerCommonField = {"name", "type"}; | |||
@@ -284,104 +278,104 @@ const set<string> CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxW | |||
"Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"}; | |||
Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const { | |||
if (proto_message.input_size() > 0) { | |||
GELOGI("This net exsit input."); | |||
if (proto_message.input_size() <= 0) { | |||
return SUCCESS; | |||
} | |||
GELOGI("This net exsit input."); | |||
if (proto_message.input_dim_size() > 0) { | |||
if (proto_message.input_shape_size() > 0) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11001"); | |||
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); | |||
return FAILED; | |||
} | |||
if (proto_message.input_dim_size() > 0) { | |||
if (proto_message.input_shape_size() > 0) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11001"); | |||
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); | |||
return FAILED; | |||
} | |||
const int32_t input_dim_size = proto_message.input_dim_size(); | |||
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || | |||
((input_dim_size % proto_message.input_size()) != 0)); | |||
if (is_input_invalid) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, | |||
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); | |||
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", | |||
input_dim_size, proto_message.input_size()); | |||
return FAILED; | |||
} | |||
const int32_t input_dim_size = proto_message.input_dim_size(); | |||
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || | |||
((input_dim_size % proto_message.input_size()) != 0)); | |||
if (is_input_invalid) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, | |||
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); | |||
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", | |||
input_dim_size, proto_message.input_size()); | |||
return FAILED; | |||
} | |||
for (int i = 0; i < proto_message.input_size(); i++) { | |||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
GE_CHECK_NOTNULL(layer); | |||
layer->set_name(proto_message.input(i)); | |||
layer->set_type(ge::parser::INPUT_TYPE); | |||
layer->add_top(proto_message.input(i)); | |||
for (int i = 0; i < proto_message.input_size(); i++) { | |||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
GE_CHECK_NOTNULL(layer); | |||
layer->set_name(proto_message.input(i)); | |||
layer->set_type(ge::parser::INPUT_TYPE); | |||
layer->add_top(proto_message.input(i)); | |||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
GE_CHECK_NOTNULL(input_param); | |||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
GE_CHECK_NOTNULL(shape); | |||
for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { | |||
// Can guarantee that it will not cross the border | |||
shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); | |||
} | |||
input_data_flag = true; | |||
} | |||
} else if (proto_message.input_shape_size() > 0) { | |||
if (proto_message.input_shape_size() != proto_message.input_size()) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, | |||
{std::to_string(proto_message.input_shape_size()), | |||
std::to_string(proto_message.input_size())}); | |||
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", | |||
proto_message.input_shape_size(), proto_message.input_size()); | |||
return FAILED; | |||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
GE_CHECK_NOTNULL(input_param); | |||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
GE_CHECK_NOTNULL(shape); | |||
for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { | |||
// Can guarantee that it will not cross the border | |||
shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); | |||
} | |||
input_data_flag = true; | |||
} | |||
} else if (proto_message.input_shape_size() > 0) { | |||
if (proto_message.input_shape_size() != proto_message.input_size()) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, | |||
{std::to_string(proto_message.input_shape_size()), | |||
std::to_string(proto_message.input_size())}); | |||
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", | |||
proto_message.input_shape_size(), proto_message.input_size()); | |||
return FAILED; | |||
} | |||
for (int i = 0; i < proto_message.input_size(); i++) { | |||
int dim_size = proto_message.input_shape(i).dim_size(); | |||
for (int i = 0; i < proto_message.input_size(); i++) { | |||
int dim_size = proto_message.input_shape(i).dim_size(); | |||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
GE_CHECK_NOTNULL(layer); | |||
layer->set_name(proto_message.input(i)); | |||
layer->set_type(ge::parser::INPUT_TYPE); | |||
layer->add_top(proto_message.input(i)); | |||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
GE_CHECK_NOTNULL(layer); | |||
layer->set_name(proto_message.input(i)); | |||
layer->set_type(ge::parser::INPUT_TYPE); | |||
layer->add_top(proto_message.input(i)); | |||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
GE_CHECK_NOTNULL(input_param); | |||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
GE_CHECK_NOTNULL(shape); | |||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
GE_CHECK_NOTNULL(input_param); | |||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
GE_CHECK_NOTNULL(shape); | |||
for (int j = 0; j < dim_size; j++) { | |||
// Can guarantee that it will not cross the border | |||
shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j))); | |||
} | |||
input_data_flag = true; | |||
for (int j = 0; j < dim_size; j++) { | |||
// Can guarantee that it will not cross the border | |||
shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j))); | |||
} | |||
} else { | |||
const ge::ParserContext &ctx = ge::GetParserContext(); | |||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||
for (int i = 0; i < proto_message.input_size(); i++) { | |||
string name = proto_message.input(i); | |||
if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input | |||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name})); | |||
GELOGE(FAILED, "[Find][Dim]Model has no input shape."); | |||
return FAILED; | |||
} | |||
std::vector<int64_t> dims = input_dims.at(name); | |||
size_t dim_size = dims.size(); | |||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
GE_CHECK_NOTNULL(layer); | |||
layer->set_name(name); | |||
layer->set_type(ge::parser::INPUT_TYPE); | |||
layer->add_top(proto_message.input(i)); | |||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
GE_CHECK_NOTNULL(input_param); | |||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
GE_CHECK_NOTNULL(shape); | |||
for (size_t j = 0; j < dim_size; j++) { | |||
shape->add_dim(dims.at(j)); | |||
} | |||
input_data_flag = true; | |||
input_data_flag = true; | |||
} | |||
} else { | |||
const ge::ParserContext &ctx = ge::GetParserContext(); | |||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||
for (int i = 0; i < proto_message.input_size(); i++) { | |||
string name = proto_message.input(i); | |||
if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input | |||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name})); | |||
GELOGE(FAILED, "[Find][Dim]Model has no input shape."); | |||
return FAILED; | |||
} | |||
std::vector<int64_t> dims = input_dims.at(name); | |||
size_t dim_size = dims.size(); | |||
domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||
GE_CHECK_NOTNULL(layer); | |||
layer->set_name(name); | |||
layer->set_type(ge::parser::INPUT_TYPE); | |||
layer->add_top(proto_message.input(i)); | |||
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | |||
GE_CHECK_NOTNULL(input_param); | |||
domi::caffe::BlobShape *shape = input_param->add_shape(); | |||
GE_CHECK_NOTNULL(shape); | |||
for (size_t j = 0; j < dim_size; j++) { | |||
shape->add_dim(dims.at(j)); | |||
} | |||
input_data_flag = true; | |||
} | |||
} | |||
return SUCCESS; | |||
@@ -423,7 +417,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons | |||
return FAILED; | |||
} | |||
if (ParseLayerParameter(layer_descriptor, message, operators) != SUCCESS) { | |||
if (ParseLayerParameter(*layer_descriptor, *message, operators) != SUCCESS) { | |||
delete message; | |||
GELOGE(FAILED, "[Parse][LayerParameter] failed, model path:%s.", model_path); | |||
return FAILED; | |||
@@ -536,18 +530,18 @@ Status CaffeModelParser::ReadCaffeModelFromText(const char *model_path, google:: | |||
return SUCCESS; | |||
} | |||
Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||
const google::protobuf::Message *message, | |||
Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||
const google::protobuf::Message &message, | |||
vector<ge::Operator> &operators) const { | |||
auto field_name = layer_descriptor->FindFieldByName(kFieldName); | |||
auto field_name = layer_descriptor.FindFieldByName(kFieldName); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); | |||
auto field_type = layer_descriptor->FindFieldByName(kFieldType); | |||
auto field_type = layer_descriptor.FindFieldByName(kFieldType); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); | |||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
reflection->ListFields(*message, &field_desc); | |||
reflection->ListFields(message, &field_desc); | |||
for (auto &field : field_desc) { | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message"); | |||
// Only care about layers | |||
@@ -561,10 +555,10 @@ Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor | |||
return FAILED; | |||
} | |||
int field_size = reflection->FieldSize(*message, field); | |||
int field_size = reflection->FieldSize(message, field); | |||
GELOGI("Total Layer num of model file is %d", field_size); | |||
for (int i = 0; i < field_size; ++i) { | |||
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); | |||
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); | |||
const google::protobuf::Reflection *layer_reflection = layer_message.GetReflection(); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); | |||
GE_CHECK_NOTNULL(layer_reflection); | |||
@@ -1316,7 +1310,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||
layer_name_map[layer.name()]++; | |||
// Set the name in proto and layer | |||
domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); | |||
duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) | |||
duplicate_name_layer->set_name(new_name); | |||
layer.set_name(new_name);) | |||
// Insert the new operator name, the number of times of duplicate name is recorded as 1 | |||
layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); | |||
@@ -1539,7 +1534,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||
layer_name_map[layer.name()]++; | |||
// Set the name in proto and layer | |||
domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); | |||
duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) | |||
duplicate_name_layer->set_name(new_name); | |||
layer.set_name(new_name);) | |||
// Insert the new operator name, the number of times of duplicate name is recorded as 1 | |||
layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); | |||
@@ -1832,13 +1828,13 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con | |||
return FAILED; | |||
} | |||
if (CheckLayersSize(message) != SUCCESS) { | |||
if (CheckLayersSize(*message) != SUCCESS) { | |||
delete message; | |||
message = nullptr; | |||
return FAILED; | |||
} | |||
if (ParseLayerParameter(layer_descriptor, message, graph) != SUCCESS) { | |||
if (ParseLayerParameter(*layer_descriptor, *message, graph) != SUCCESS) { | |||
delete message; | |||
message = nullptr; | |||
REPORT_CALL_ERROR("E19999", "ParseLayerParameter failed failed from weight file:%s.", weight_path); | |||
@@ -1852,18 +1848,18 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con | |||
return SUCCESS; | |||
} | |||
Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||
const google::protobuf::Message *message, | |||
Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||
const google::protobuf::Message &message, | |||
ge::ComputeGraphPtr &graph) { | |||
auto field_name = layer_descriptor->FindFieldByName(kFieldName); | |||
auto field_name = layer_descriptor.FindFieldByName(kFieldName); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); | |||
auto field_type = layer_descriptor->FindFieldByName(kFieldType); | |||
auto field_type = layer_descriptor.FindFieldByName(kFieldType); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); | |||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
reflection->ListFields(*message, &field_desc); | |||
reflection->ListFields(message, &field_desc); | |||
NetParameter tmp_net; | |||
for (auto &field : field_desc) { | |||
@@ -1880,13 +1876,13 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto | |||
return FAILED; | |||
} | |||
int field_size = reflection->FieldSize(*message, field); | |||
int field_size = reflection->FieldSize(message, field); | |||
GELOGI("Total Layer num of model file is %d", field_size); | |||
for (int i = 0; i < field_size; ++i) { | |||
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); | |||
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); | |||
LayerParameter *layer = tmp_net.add_layer(); | |||
if (ConvertLayerProto(&layer_message, layer) != SUCCESS) { | |||
if (ConvertLayerProto(layer_message, layer) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][ConvertLayerProto] Convert message to layer proto failed."); | |||
return FAILED; | |||
} | |||
@@ -1907,16 +1903,16 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto | |||
return SUCCESS; | |||
} | |||
Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *message, | |||
Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message &message, | |||
google::protobuf::Message *layer) { | |||
const google::protobuf::Reflection *layer_reflection = message->GetReflection(); | |||
const google::protobuf::Reflection *layer_reflection = message.GetReflection(); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); | |||
vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
layer_reflection->ListFields(*message, &field_desc); | |||
layer_reflection->ListFields(message, &field_desc); | |||
for (auto &field : field_desc) { | |||
GE_CHECK_NOTNULL(field); | |||
if (ParseLayerField(layer_reflection, message, field, layer) != SUCCESS) { | |||
if (ParseLayerField(*layer_reflection, message, *field, layer) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][ParseLayerField] Parse field %s failed.", field->name().c_str()); | |||
return FAILED; | |||
} | |||
@@ -1924,114 +1920,114 @@ Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *me | |||
return SUCCESS; | |||
} | |||
Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *reflection, | |||
const google::protobuf::Message *message, | |||
const google::protobuf::FieldDescriptor *field, | |||
Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection &reflection, | |||
const google::protobuf::Message &message, | |||
const google::protobuf::FieldDescriptor &field, | |||
google::protobuf::Message *layer) const { | |||
GELOGD("Start to parse field: %s.", field->name().c_str()); | |||
GELOGD("Start to parse field: %s.", field.name().c_str()); | |||
domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer); | |||
string filed_name = field->name(); | |||
#define CASE_FIELD_NAME(kName, method) \ | |||
string filed_name = field.name(); | |||
#define CASE_FIELD_NAME(kName, method, inner_message, field_ptr) \ | |||
if (filed_name == kField##kName) { \ | |||
string value = reflection->GetString(*message, field); \ | |||
string value = reflection.GetString(inner_message, field_ptr); \ | |||
GELOGD("Parse res: (%s : %s)", filed_name.c_str(), value.c_str()); \ | |||
layer_proto->set_##method(value); \ | |||
return SUCCESS; \ | |||
} | |||
CASE_FIELD_NAME(Name, name); | |||
CASE_FIELD_NAME(Type, type); | |||
CASE_FIELD_NAME(Name, name, message, &field); | |||
CASE_FIELD_NAME(Type, type, message, &field); | |||
#undef CASE_FIELD_NAME | |||
#define CASE_FIELD_NAME_REPEATED(kName, method) \ | |||
if (filed_name == kField##kName) { \ | |||
int field_size = reflection->FieldSize(*message, field); \ | |||
for (int i = 0; i < field_size; ++i) { \ | |||
auto value = reflection->GetRepeatedString(*message, field, i); \ | |||
layer_proto->add_##method(value); \ | |||
} \ | |||
return SUCCESS; \ | |||
} | |||
CASE_FIELD_NAME_REPEATED(Bottom, bottom); | |||
CASE_FIELD_NAME_REPEATED(Top, top); | |||
#define CASE_FIELD_NAME_REPEATED(kName, method, inner_message, field_ptr) \ | |||
if (filed_name == kField##kName) { \ | |||
int field_size = reflection.FieldSize(inner_message, field_ptr); \ | |||
for (int i = 0; i < field_size; ++i) { \ | |||
auto value = reflection.GetRepeatedString(inner_message, field_ptr, i); \ | |||
layer_proto->add_##method(value); \ | |||
} \ | |||
return SUCCESS; \ | |||
} | |||
CASE_FIELD_NAME_REPEATED(Bottom, bottom, message, &field); | |||
CASE_FIELD_NAME_REPEATED(Top, top, message, &field); | |||
#undef CASE_FIELD_NAME_REPEATED | |||
if (filed_name == kFieldBlobs) { | |||
int field_size = reflection->FieldSize(*message, field); | |||
int field_size = reflection.FieldSize(message, &field); | |||
for (int i = 0; i < field_size; ++i) { | |||
domi::caffe::BlobProto *item_message = layer_proto->add_blobs(); | |||
const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); | |||
if (ConvertBlobsProto(&sub_message, item_message) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field->name().c_str()); | |||
const google::protobuf::Message &sub_message = reflection.GetRepeatedMessage(message, &field, i); | |||
if (ConvertBlobsProto(sub_message, item_message) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field.name().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
if (filed_name == kFieldConvParam) { | |||
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); | |||
const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); | |||
ConvolutionParameter *conv_param = layer_proto->mutable_convolution_param(); | |||
ConvertConvParamProto(&sub_message, conv_param); | |||
ConvertConvParamProto(sub_message, conv_param); | |||
} | |||
if (filed_name == kFieldInnerPro) { | |||
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); | |||
const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); | |||
InnerProductParameter *inner_product = layer_proto->mutable_inner_product_param(); | |||
ConvertInnerProdcutProto(&sub_message, inner_product); | |||
ConvertInnerProdcutProto(sub_message, inner_product); | |||
} | |||
return SUCCESS; | |||
} | |||
Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message, | |||
Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message &message, | |||
google::protobuf::Message *blobs) const { | |||
const google::protobuf::Reflection *blobs_reflection = message->GetReflection(); | |||
const google::protobuf::Reflection *blobs_reflection = message.GetReflection(); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); | |||
vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
blobs_reflection->ListFields(*message, &field_desc); | |||
blobs_reflection->ListFields(message, &field_desc); | |||
domi::caffe::BlobProto *blobs_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobProto>(blobs); | |||
for (auto &field : field_desc) { | |||
GE_CHECK_NOTNULL(field); | |||
string feild_name = field->name(); | |||
#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name) \ | |||
if (feild_name == #kName) { \ | |||
int field_size = blobs_reflection->FieldSize(*message, field); \ | |||
for (int i = 0; i < field_size; ++i) { \ | |||
valuetype value = blobs_reflection->GetRepeated##method(*message, field, i); \ | |||
blobs_proto->add_##name(value); \ | |||
} \ | |||
continue; \ | |||
} | |||
CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data); | |||
CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff); | |||
CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data); | |||
CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff); | |||
CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data); | |||
CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data); | |||
#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name, inner_message, inner_field) \ | |||
if (feild_name == #kName) { \ | |||
int field_size = blobs_reflection->FieldSize(inner_message, inner_field); \ | |||
for (int i = 0; i < field_size; ++i) { \ | |||
valuetype value = blobs_reflection->GetRepeated##method(inner_message, inner_field, i); \ | |||
blobs_proto->add_##name(value); \ | |||
} \ | |||
continue; \ | |||
} | |||
CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data, message, field); | |||
CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff, message, field); | |||
CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data, message, field); | |||
CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff, message, field); | |||
CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data, message, field); | |||
CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data, message, field); | |||
#undef CASE_BLOBS_FIELD_NAME_REPEATED | |||
#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name) \ | |||
if (feild_name == #kName) { \ | |||
valuetype value = blobs_reflection->Get##method(*message, field); \ | |||
blobs_proto->set_##name(value); \ | |||
continue; \ | |||
} | |||
CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data); | |||
CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num); | |||
CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels); | |||
CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height); | |||
CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width); | |||
#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name, inner_message, inner_field) \ | |||
if (feild_name == #kName) { \ | |||
valuetype value = blobs_reflection->Get##method(inner_message, inner_field); \ | |||
blobs_proto->set_##name(value); \ | |||
continue; \ | |||
} | |||
CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data, message, field); | |||
CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num, message, field); | |||
CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels, message, field); | |||
CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height, message, field); | |||
CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width, message, field); | |||
#undef CASE_BLOBS_FIELD_NAME | |||
if (feild_name == kFieldShape) { | |||
const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(*message, field); | |||
const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(message, field); | |||
domi::caffe::BlobShape *blob_shape = blobs_proto->mutable_shape(); | |||
ConvertBlobShapeProto(&sub_message, blob_shape); | |||
ConvertBlobShapeProto(sub_message, blob_shape); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message *message, | |||
Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message &message, | |||
google::protobuf::Message *dest_message) const { | |||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
reflection->ListFields(*message, &field_desc); | |||
reflection->ListFields(message, &field_desc); | |||
domi::caffe::BlobShape *shape_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobShape>(dest_message); | |||
@@ -2039,21 +2035,21 @@ Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message | |||
if (field->name() != kFieldDim) { | |||
continue; | |||
} | |||
int field_size = reflection->FieldSize(*message, field); | |||
int field_size = reflection->FieldSize(message, field); | |||
for (int i = 0; i < field_size; ++i) { | |||
int64_t value = reflection->GetRepeatedInt64(*message, field, i); | |||
int64_t value = reflection->GetRepeatedInt64(message, field, i); | |||
shape_proto->add_dim(value); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message *message, | |||
Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message &message, | |||
google::protobuf::Message *dest_message) const { | |||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
reflection->ListFields(*message, &field_desc); | |||
reflection->ListFields(message, &field_desc); | |||
domi::caffe::ConvolutionParameter *conv_param_proto = | |||
PtrToPtr<google::protobuf::Message, domi::caffe::ConvolutionParameter>(dest_message); | |||
@@ -2062,18 +2058,18 @@ Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message | |||
if (field->name() != kFieldBiasTerm) { | |||
continue; | |||
} | |||
bool value = reflection->GetBool(*message, field); | |||
bool value = reflection->GetBool(message, field); | |||
conv_param_proto->set_bias_term(value); | |||
} | |||
return SUCCESS; | |||
} | |||
Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message *message, | |||
Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message &message, | |||
google::protobuf::Message *dest_message) const { | |||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
reflection->ListFields(*message, &field_desc); | |||
reflection->ListFields(message, &field_desc); | |||
domi::caffe::InnerProductParameter *inner_product_proto = | |||
PtrToPtr<google::protobuf::Message, domi::caffe::InnerProductParameter>(dest_message); | |||
@@ -2082,17 +2078,17 @@ Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Mess | |||
if (field->name() != kFieldBiasTerm) { | |||
continue; | |||
} | |||
bool value = reflection->GetBool(*message, field); | |||
bool value = reflection->GetBool(message, field); | |||
inner_product_proto->set_bias_term(value); | |||
} | |||
return SUCCESS; | |||
} | |||
Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *message) const { | |||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message &message) const { | |||
const google::protobuf::Reflection *reflection = message.GetReflection(); | |||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); | |||
vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
reflection->ListFields(*message, &field_desc); | |||
reflection->ListFields(message, &field_desc); | |||
int num_layer = 0; | |||
int num_layers = 0; | |||
@@ -2110,7 +2106,7 @@ Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *mess | |||
return FAILED; | |||
} | |||
int field_size = reflection->FieldSize(*message, field); | |||
int field_size = reflection->FieldSize(message, field); | |||
if (field->name() == kLayerName) { | |||
num_layer = field_size; | |||
} else { | |||
@@ -212,8 +212,8 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||
* @return SUCCESS parse layer successfully | |||
* @return FAILED parse layer failed | |||
*/ | |||
Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||
const google::protobuf::Message *message, std::vector<ge::Operator> &operators) const; | |||
Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||
const google::protobuf::Message &message, std::vector<ge::Operator> &operators) const; | |||
/* | |||
* @ingroup domi_omg | |||
@@ -386,33 +386,33 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser { | |||
Status ParseWeightByFusionProto(const char *weight_path, const string &fusion_proto_path, | |||
const string &fusion_proto_name, ge::ComputeGraphPtr &graph); | |||
Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||
const google::protobuf::Message *message, | |||
Status ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, | |||
const google::protobuf::Message &message, | |||
ge::ComputeGraphPtr &graph); | |||
Status ConvertLayerParameter(const google::protobuf::Message *layer_message, | |||
ge::ComputeGraphPtr &graph); | |||
Status CheckLayersSize(const google::protobuf::Message *message) const; | |||
Status CheckLayersSize(const google::protobuf::Message &message) const; | |||
Status ConvertLayerProto(const google::protobuf::Message *message, | |||
Status ConvertLayerProto(const google::protobuf::Message &message, | |||
google::protobuf::Message *layer); | |||
Status ParseLayerField(const google::protobuf::Reflection *reflection, | |||
const google::protobuf::Message *message, | |||
const google::protobuf::FieldDescriptor *field, | |||
Status ParseLayerField(const google::protobuf::Reflection &reflection, | |||
const google::protobuf::Message &message, | |||
const google::protobuf::FieldDescriptor &field, | |||
google::protobuf::Message *layer) const; | |||
Status ConvertBlobsProto(const google::protobuf::Message *message, | |||
Status ConvertBlobsProto(const google::protobuf::Message &message, | |||
google::protobuf::Message *blobs) const; | |||
Status ConvertBlobShapeProto(const google::protobuf::Message *message, | |||
Status ConvertBlobShapeProto(const google::protobuf::Message &message, | |||
google::protobuf::Message *dest_message) const; | |||
Status ConvertInnerProdcutProto(const google::protobuf::Message *message, | |||
Status ConvertInnerProdcutProto(const google::protobuf::Message &message, | |||
google::protobuf::Message *dest_message) const; | |||
Status ConvertConvParamProto(const google::protobuf::Message *message, | |||
Status ConvertConvParamProto(const google::protobuf::Message &message, | |||
google::protobuf::Message *dest_message) const; | |||
/** | |||
* @ingroup domi_omg | |||
@@ -69,16 +69,14 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { | |||
break; | |||
} | |||
#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ | |||
case dt_type: \ | |||
{ \ | |||
case dt_type: { \ | |||
unique_ptr<value_type> addr_trans(new(std::nothrow) value_type[count]()); \ | |||
GE_CHECK_NOTNULL(addr_trans); \ | |||
for (int32_t i = 0; i < (count); i++) { \ | |||
*(addr_trans.get() + i) = static_cast<value_type>(*((addr).get() + i)); \ | |||
} \ | |||
(tensor).SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), (count) * sizeof(value_type)); \ | |||
break; \ | |||
} \ | |||
break; } \ | |||
CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor) | |||
CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor) | |||
@@ -89,7 +87,7 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { | |||
#undef CASE_SET_DATA | |||
default: | |||
{ | |||
tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(T)); | |||
tensor.SetData(PtrToPtr<T, uint8_t>(addr.get()), count * sizeof(T)); | |||
break; | |||
} | |||
} | |||
@@ -31,11 +31,11 @@ using GeTensorDesc = ge::GeTensorDesc; | |||
using namespace ge::parser; | |||
namespace { | |||
const std::string kAttrShape = "shape"; | |||
const std::string kAttrDataType = "dtype"; | |||
const std::string kFileConstantPath = "file_constant_path"; | |||
const std::string kLocation = "location"; | |||
const std::string kOffset = "offset"; | |||
const char *const kAttrShape = "shape"; | |||
const char *const kAttrDataType = "dtype"; | |||
const char *const kFileConstantPath = "file_constant_path"; | |||
const char *const kLocation = "location"; | |||
const char *const kOffset = "offset"; | |||
const int64_t kOffsetCoefficient = 4096; | |||
const char *const kFileConstant = "FileConstant"; | |||
} | |||
@@ -46,7 +46,7 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator & | |||
GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); | |||
ge::onnx::TensorProto tensor_proto; | |||
if (GetTensorProto(node, tensor_proto) != SUCCESS) { | |||
if (GetTensorProto(*node, tensor_proto) != SUCCESS) { | |||
REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str()); | |||
GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str()); | |||
return FAILED; | |||
@@ -65,29 +65,29 @@ Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator & | |||
return SUCCESS; | |||
} | |||
Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto *node_proto, | |||
ge::onnx::TensorProto &tensor_proto) { | |||
for (const auto &it : node_proto->attribute()) { | |||
Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto &node_proto, | |||
ge::onnx::TensorProto &tensor_proto) const { | |||
for (const auto &it : node_proto.attribute()) { | |||
if (it.name() != ge::kAttrNameValue) { | |||
continue; | |||
} | |||
tensor_proto = it.t(); | |||
return SUCCESS; | |||
} | |||
REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto->name().c_str()); | |||
GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto->name().c_str()); | |||
REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto.name().c_str()); | |||
GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto.name().c_str()); | |||
return FAILED; | |||
} | |||
void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||
void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||
std::vector<int64_t> tmp_shape; | |||
for (int i = 0; i < tensor_proto.dims_size(); i++) { | |||
tmp_shape.push_back(tensor_proto.dims(i)); | |||
} | |||
op_def.SetAttr(kAttrShape.c_str(), tmp_shape); | |||
op_def.SetAttr(kAttrShape, tmp_shape); | |||
} | |||
Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||
Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||
int64_t data_type = tensor_proto.data_type(); | |||
ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type); | |||
if (type >= ge::DataType::DT_UNDEFINED) { | |||
@@ -96,11 +96,11 @@ Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor | |||
return FAILED; | |||
} | |||
op_def.SetAttr(kAttrDataType.c_str(), type); | |||
op_def.SetAttr(kAttrDataType, type); | |||
return SUCCESS; | |||
} | |||
Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||
Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||
ge::NamedAttrs attrs; | |||
for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) { | |||
const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i); | |||
@@ -116,12 +116,12 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro | |||
GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | |||
return FAILED; | |||
} | |||
op_def.SetAttr(kFileConstantPath.c_str(), attrs); | |||
op_def.SetAttr(kFileConstantPath, attrs); | |||
return SUCCESS; | |||
} | |||
Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, | |||
ge::NamedAttrs &attrs) { | |||
ge::NamedAttrs &attrs) const { | |||
if (string_proto.key() == kLocation) { | |||
AttrUtils::SetStr(attrs, kLocation, string_proto.value()); | |||
} else { | |||
@@ -134,7 +134,7 @@ Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProt | |||
return FAILED; | |||
} | |||
if (string_proto.key() == kOffset) { | |||
if (std::numeric_limits<int64_t>::max() / kOffsetCoefficient < value) { | |||
if (value > (std::numeric_limits<int64_t>::max() / kOffsetCoefficient)) { | |||
REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | |||
GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | |||
return FAILED; | |||
@@ -26,11 +26,11 @@ class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser { | |||
Status ParseParams(const Message *op_src, ge::Operator &op_def) override; | |||
private: | |||
Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||
Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||
void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||
Status GetTensorProto(const ge::onnx::NodeProto *node_proto, ge::onnx::TensorProto &tensor_proto); | |||
Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs); | |||
Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||
Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||
void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||
Status GetTensorProto(const ge::onnx::NodeProto &node_proto, ge::onnx::TensorProto &tensor_proto) const; | |||
Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs) const; | |||
}; | |||
} // namespace ge | |||
@@ -232,7 +232,8 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra | |||
domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type); | |||
if (post_func == nullptr) { | |||
GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str()); | |||
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) { | |||
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || | |||
parse_func_v2 == nullptr) { | |||
GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str()); | |||
return SUCCESS; | |||
} | |||
@@ -522,9 +523,9 @@ Status OnnxModelParser::SetOperatorInputs() { | |||
auto src_op = output_op_iter->second; | |||
int dst_index = input_node_index.second; | |||
int src_index = out_node_index.second; | |||
GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", src_index, | |||
ParserUtils::GetOperatorName(src_op).c_str(), dst_index, | |||
ParserUtils::GetOperatorName(dst_op).c_str()); | |||
GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", | |||
src_index, ParserUtils::GetOperatorName(src_op).c_str(), | |||
dst_index, ParserUtils::GetOperatorName(dst_op).c_str()); | |||
auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op); | |||
GE_CHECK_NOTNULL(dst_op_desc); | |||
auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op); | |||
@@ -689,7 +690,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve | |||
return PARAM_INVALID; | |||
} | |||
input_ops.emplace_back(in_op->second); | |||
GELOGI("Model assigned input node name: %s", ParserUtils::GetOperatorName(in_op->second).c_str()); | |||
GELOGI("Model assigned input node name: %s", | |||
ParserUtils::GetOperatorName(in_op->second).c_str()); | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -717,7 +719,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vec | |||
int index = node_name_index.second; | |||
output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)}); | |||
out_tensor_to_nodes[output_name] = std::make_pair(node_name, index); | |||
GELOGI("out node index %d, node:%s", index, node_name.c_str()); | |||
GELOGI("Out node index %d, node:%s", index, node_name.c_str()); | |||
} | |||
} | |||
return SUCCESS; | |||
@@ -934,16 +936,13 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||
cur_compute_graph->GetName().c_str()); | |||
return ret; | |||
} | |||
} | |||
UpdateDataFormat(root_graph); | |||
return SUCCESS; | |||
} | |||
Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { | |||
ClearMembers(); | |||
GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), | |||
"Run ProtoType Pass Failed"); | |||
// 1. Get all inializer. | |||
@@ -1174,7 +1173,8 @@ Status OnnxModelParser::SetOutputsInfo(const ParserUtils::OutputMapping &final_o | |||
default_out_nodes.emplace_back(output_node_info); | |||
output_tensor_names.emplace_back(tensor_name); | |||
GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].", | |||
output_node_info.first.c_str(), output_node_info.second, tensor_name.c_str()); | |||
output_node_info.first.c_str(), | |||
output_node_info.second, tensor_name.c_str()); | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -17,6 +17,8 @@ | |||
#ifndef PARSER_ONNX_ONNX_UTIL_PARSER_H_ | |||
#define PARSER_ONNX_ONNX_UTIL_PARSER_H_ | |||
#include <string> | |||
#include <cstdint> | |||
#include "external/graph/types.h" | |||
namespace OnnxDataType { | |||
@@ -59,4 +61,4 @@ class OnnxUtil { | |||
}; | |||
} // namespace ge | |||
#endif //PARSER_ONNX_ONNX_UTIL_PARSER_H_ | |||
#endif // PARSER_ONNX_ONNX_UTIL_PARSER_H_ |
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -25,6 +25,7 @@ using parser::IF; | |||
namespace { | |||
const std::map<std::string, int> kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}}; | |||
const int kIfNodeAttrSize = 2; | |||
const char *kIf = "If"; | |||
} // namespace | |||
domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
@@ -33,7 +34,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), | |||
parent_node->op_type().c_str()); | |||
auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); | |||
auto ret = ParseIfNodeSubgraphs(*parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Parse][Node] Parse if node failed."); | |||
REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); | |||
@@ -44,19 +45,19 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||
} | |||
domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
ge::onnx::NodeProto &parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) const { | |||
if (parent_node->attribute_size() != kIfNodeAttrSize) { | |||
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | |||
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | |||
if (parent_node.attribute_size() != kIfNodeAttrSize) { | |||
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size()); | |||
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node.attribute_size()); | |||
return FAILED; | |||
} | |||
GELOGD("node attribute size:%d.", parent_node->attribute_size()); | |||
GELOGD("node attribute size:%d.", parent_node.attribute_size()); | |||
std::set<std::string> all_inputs; | |||
// for onnx graph, the first attribute may be else branch and the second attribute may be then branch | |||
for (int i = 0; i < parent_node->attribute_size(); i++) { | |||
ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i); | |||
for (int i = 0; i < parent_node.attribute_size(); i++) { | |||
ge::onnx::AttributeProto *attribute = parent_node.mutable_attribute(i); | |||
GE_CHECK_NOTNULL(attribute); | |||
std::string attr_name = attribute->name(); | |||
auto itr = kAttrNameToIndex.find(attr_name); | |||
@@ -68,7 +69,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||
return FAILED; | |||
} | |||
std::string unique_subgraph_name; | |||
std::string node_name = parent_node->name(); | |||
std::string node_name = parent_node.name(); | |||
if (!parent_graph_name.empty()) { | |||
node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name); | |||
} | |||
@@ -90,7 +91,7 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||
AddInputNodeForGraph(all_inputs, *onnx_graph); | |||
} | |||
AddInputForParentNode(all_inputs, *parent_node); | |||
AddInputForParentNode(all_inputs, parent_node); | |||
return SUCCESS; | |||
} | |||
@@ -135,5 +136,5 @@ void IfSubgraphAdapter::AddInputForParentNode(const std::set<std::string> &all_i | |||
parent_node.add_input(input_name); | |||
} | |||
} | |||
REGISTER_SUBGRAPH_ADAPTER_CREATOR(IF, IfSubgraphAdapter); | |||
REGISTER_SUBGRAPH_ADAPTER_CREATOR(kIf, IfSubgraphAdapter); | |||
} // namespace ge |
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -20,6 +20,7 @@ | |||
#include <set> | |||
#include <string> | |||
#include "subgraph_adapter.h" | |||
#include "parser/onnx/onnx_util.h" | |||
namespace ge { | |||
class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | |||
@@ -30,7 +31,7 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | |||
const std::string &parent_graph_name = "") override; | |||
private: | |||
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto &parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | |||
const std::string &parent_graph_name) const; | |||
domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -36,8 +36,6 @@ | |||
#include "proto/onnx/ge_onnx.pb.h" | |||
#include "external/register/register_error_codes.h" | |||
#include "framework/omg/parser/parser_types.h" | |||
#include "parser/onnx/onnx_util.h" | |||
namespace ge { | |||
class PARSER_FUNC_VISIBILITY SubgraphAdapter { | |||
public: | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -62,7 +62,6 @@ protected: | |||
* @brief SubgraphAdapter creation function | |||
* @return Created SubgraphAdapter | |||
*/ | |||
// typedef shared_ptr<SubgraphAdapter> (*CREATOR_FUN)(void); | |||
using CREATOR_FUN = std::function<std::shared_ptr<SubgraphAdapter>(void)>; | |||
/** | |||
@@ -105,7 +104,7 @@ public: | |||
* @param [in] op_type Op type | |||
* @param [in] clazz SubgraphAdapter implementation class | |||
*/ | |||
#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ | |||
#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ | |||
std::shared_ptr<SubgraphAdapter> Creator_##op_type##_Subgraph_Adapter() { \ | |||
std::shared_ptr<clazz> ptr(new (std::nothrow) clazz()); \ | |||
if (ptr == nullptr) { \ | |||
@@ -167,6 +167,33 @@ class H2CC(object): | |||
del self.stack_template | |||
del self.func_list_exist | |||
@staticmethod | |||
def implement_function(func): | |||
function_def = '' | |||
function_def += '{\n' | |||
all_items = func.split() | |||
start = 0 | |||
return_type = all_items[start] | |||
if return_type == "const": | |||
start += 1 | |||
return_type = all_items[start] | |||
if return_type.startswith(('std::map', 'std::set', 'std::vector')): | |||
return_type = "std::map" | |||
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): | |||
return_type = "Ptr" | |||
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): | |||
return_type += "&" | |||
if RETURN_STATEMENTS.__contains__(return_type): | |||
function_def += RETURN_STATEMENTS[return_type] | |||
else: | |||
logging.warning("Unhandled return type[%s]", return_type) | |||
function_def += '\n' | |||
function_def += '}\n' | |||
function_def += '\n' | |||
return function_def | |||
def just_skip(self): | |||
# skip blank line or comment | |||
if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( | |||
@@ -263,6 +290,7 @@ class H2CC(object): | |||
logging.info('Added %s functions', len(self.func_list_exist)) | |||
logging.info('Successfully converted,please see ' + self.output_file) | |||
def handle_func1(self, line): | |||
""" | |||
:param line: | |||
@@ -461,12 +489,6 @@ class H2CC(object): | |||
logging.info("func_name[%s]", func_name) | |||
return line, func_name | |||
def write_func_content(self, content, func_name, need_generate): | |||
if not (func_name in self.func_list_exist) and need_generate: | |||
self.output_fd.write(content) | |||
self.func_list_exist.append(func_name) | |||
logging.info('add func:[%s]', func_name) | |||
def gen_comment(self, start_i): | |||
comment_line = '' | |||
# Function comments are on top of function declarations, copy them over | |||
@@ -488,32 +510,11 @@ class H2CC(object): | |||
break | |||
return comment_line | |||
@staticmethod | |||
def implement_function(func): | |||
function_def = '' | |||
function_def += '{\n' | |||
all_items = func.split() | |||
start = 0 | |||
return_type = all_items[start] | |||
if return_type == "const": | |||
start += 1 | |||
return_type = all_items[start] | |||
if return_type.startswith(('std::map', 'std::set', 'std::vector')): | |||
return_type = "std::map" | |||
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): | |||
return_type = "Ptr" | |||
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): | |||
return_type += "&" | |||
if RETURN_STATEMENTS.__contains__(return_type): | |||
function_def += RETURN_STATEMENTS[return_type] | |||
else: | |||
logging.warning("Unhandled return type[%s]", return_type) | |||
function_def += '\n' | |||
function_def += '}\n' | |||
function_def += '\n' | |||
return function_def | |||
def write_func_content(self, content, func_name, need_generate): | |||
if not (func_name in self.func_list_exist) and need_generate: | |||
self.output_fd.write(content) | |||
self.func_list_exist.append(func_name) | |||
logging.info('add func:[%s]', func_name) | |||
def collect_header_files(path): | |||
@@ -765,7 +765,7 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_CheckLayersSize_test) | |||
layer->set_name("Abs"); | |||
layer->set_type("AbsVal"); | |||
Status ret = weightParser.CheckLayersSize(layer); | |||
Status ret = weightParser.CheckLayersSize(*layer); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
@@ -809,8 +809,54 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test) | |||
const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | |||
const google::protobuf::Message *message = proto->New(); | |||
Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph); | |||
Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph); | |||
delete message; | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_DimSize_0) | |||
{ | |||
CaffeModelParser modelParser; | |||
domi::caffe::NetParameter net; | |||
net.add_input("111"); | |||
net.add_input_shape(); | |||
bool input_data_flag = true; | |||
Status ret = modelParser.ParseInput(net, input_data_flag); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(STestCaffeParser, CaffeModelParser_ParseInput_test_Err1) | |||
{ | |||
CaffeModelParser modelParser; | |||
domi::caffe::NetParameter net; | |||
net.add_input("111"); | |||
net.add_input("222"); | |||
net.add_input_shape(); | |||
bool input_data_flag = true; | |||
Status ret = modelParser.ParseInput(net, input_data_flag); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
TEST_F(STestCaffeParser, CaffeModelParser_ParserLayerParameter_Succ) | |||
{ | |||
CaffeModelParser modelParser; | |||
std::string case_dir = __FILE__; | |||
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
std::string model_file = case_dir + "/origin_models/"; | |||
const char *model_path = model_file.c_str(); | |||
std::string custom_proto = model_file; | |||
std::string caffe_proto = model_file; | |||
std::vector<ge::Operator> operators; | |||
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Data", "Input"); | |||
ge::Operator op_src = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||
operators.emplace_back(op_src); | |||
model_file = case_dir + "/origin_models/caffe_add.pbtxt"; | |||
custom_proto = case_dir + "/../../../metadef/proto/caffe/caffe.proto"; | |||
model_path = model_file.c_str(); | |||
std::string caffe_proto_path = case_dir + "/../../../metadef/proto/caffe/caffe.proto"; | |||
auto ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto_path, operators); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
} // namespace ge |
@@ -739,6 +739,9 @@ TEST_F(UtestCaffeParser, CaffeModelParser_CustomProtoParse_test) | |||
Status ret = modelParser.CustomProtoParse(model_path, custom_proto, caffe_proto, operators); | |||
EXPECT_EQ(ret, PARAM_INVALID); | |||
ret = modelParser.CustomProtoParse("", custom_proto, caffe_proto, operators); | |||
EXPECT_EQ(ret, FAILED); | |||
model_file = case_dir + "/caffe_model/caffe_add.pbtxt"; | |||
custom_proto = case_dir + "/../../../../../metadef/proto/caffe/caffe.proto"; | |||
model_path = model_file.c_str(); | |||
@@ -890,7 +893,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_CheckLayersSize_test) | |||
layer->set_name("Abs"); | |||
layer->set_type("AbsVal"); | |||
Status ret = weightParser.CheckLayersSize(layer); | |||
Status ret = weightParser.CheckLayersSize(*layer); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
@@ -902,7 +905,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test) | |||
layer->set_name("Abs"); | |||
layer->set_type("AbsVal"); | |||
Status ret = weightParser.ConvertLayerProto(&net, &net); | |||
Status ret = weightParser.ConvertLayerProto(net, &net); | |||
EXPECT_EQ(ret, SUCCESS); | |||
BlobProto* blob = layer->add_blobs(); | |||
@@ -911,16 +914,16 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ConvertLayerProto_test) | |||
BlobShape* shap = blob->mutable_shape(); | |||
shap->add_dim(1); | |||
shap->add_dim(2); | |||
ret = weightParser.ConvertBlobsProto(&net, &net); | |||
ret = weightParser.ConvertBlobsProto(net, &net); | |||
EXPECT_EQ(ret, SUCCESS); | |||
ret = weightParser.ConvertBlobShapeProto(&net, &net); | |||
ret = weightParser.ConvertBlobShapeProto(net, &net); | |||
EXPECT_EQ(ret, SUCCESS); | |||
ret = weightParser.ConvertConvParamProto(&net, &net); | |||
ret = weightParser.ConvertConvParamProto(net, &net); | |||
EXPECT_EQ(ret, SUCCESS); | |||
ret = weightParser.ConvertInnerProdcutProto(&net, &net); | |||
ret = weightParser.ConvertInnerProdcutProto(net, &net); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
@@ -1133,7 +1136,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_ParseLayerParameter_test) | |||
const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | |||
const google::protobuf::Message *message = proto->New(); | |||
Status ret = weightParser.ParseLayerParameter(descriptor, message, compute_graph); | |||
Status ret = weightParser.ParseLayerParameter(*descriptor, *message, compute_graph); | |||
delete message; | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
@@ -1163,7 +1166,7 @@ TEST_F(UtestCaffeParser, CaffeModelParser_ParseLayerParameter_test) | |||
google::protobuf::DynamicMessageFactory factory; | |||
const google::protobuf::Message *proto = factory.GetPrototype(descriptor); | |||
google::protobuf::Message *message = proto->New(); | |||
Status ret = modelParser.ParseLayerParameter(descriptor, message, operators); | |||
Status ret = modelParser.ParseLayerParameter(*descriptor, *message, operators); | |||
EXPECT_EQ(ret, SUCCESS); | |||
const domi::FrameworkType fmk_type = domi::TENSORFLOW; | |||
@@ -381,7 +381,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto) | |||
OnnxFileConstantParser parser; | |||
ge::onnx::NodeProto input_node; | |||
ge::onnx::TensorProto tensor_proto; | |||
Status ret = parser.GetTensorProto(&input_node, tensor_proto); | |||
Status ret = parser.GetTensorProto(input_node, tensor_proto); | |||
EXPECT_EQ(ret, FAILED); | |||
ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | |||
@@ -391,7 +391,7 @@ TEST_F(UtestOnnxParser, FileConstantGetTensorProto) | |||
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | |||
*attribute_tensor = tensor_proto; | |||
ret = parser.GetTensorProto(&input_node, tensor_proto); | |||
ret = parser.GetTensorProto(input_node, tensor_proto); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||