From ca35f8b50dcd755389a69f4b440716f32c092cc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E5=81=A5?= Date: Thu, 27 Oct 2022 13:20:45 +0000 Subject: [PATCH] =?UTF-8?q?!716=20=E5=A4=A7=E4=BA=8E2G=E7=9A=84onnx?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=EF=BC=8C=E6=9D=83=E9=87=8D=E5=92=8C=E5=9B=BE?= =?UTF-8?q?=E5=88=86=E7=A6=BB=EF=BC=8C=E9=9D=9Einitializer=E7=9A=84const?= =?UTF-8?q?=E7=AE=97=E5=AD=90=EF=BC=8C=E4=B9=9F=E4=BC=9A=E6=9C=89=E6=9D=83?= =?UTF-8?q?=E9=87=8D=E6=96=87=E4=BB=B6=E3=80=82=20Merge=20pull=20request?= =?UTF-8?q?=20!716=20from=20=E9=9F=A9=E5=81=A5/hanjian?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- parser/onnx/onnx_file_constant_parser.cc | 1 + parser/onnx/onnx_parser.cc | 26 ++++++++++++--- parser/onnx/onnx_parser.h | 4 ++- tests/st/testcase/test_onnx_parser.cc | 39 ++++++++++++++++++++++ .../onnx_parser_testcase/onnx_parser_unittest.cc | 33 ++++++++++++++++++ 5 files changed, 98 insertions(+), 5 deletions(-) diff --git a/parser/onnx/onnx_file_constant_parser.cc b/parser/onnx/onnx_file_constant_parser.cc index 36abbbe..0271d11 100644 --- a/parser/onnx/onnx_file_constant_parser.cc +++ b/parser/onnx/onnx_file_constant_parser.cc @@ -117,6 +117,7 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro return FAILED; } op_def.SetAttr(kFileConstantPath, attrs); + GELOGD("The weight file of Op[%s] is: [%s].", tensor_proto.name().c_str(), attrs.GetName().c_str()); return SUCCESS; } diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index f3769e8..6ceaa51 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -366,6 +366,7 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, *attribute_t = it.second; if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { const_node->set_op_type(kFileConstant); + GELOGD("Initializer const node [%s], the weight was stored in the file.", const_node->name().c_str()); } else { const_node->set_op_type(ge::kOpTypeConstant); } @@ -374,7 +375,21 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, return SUCCESS; } -void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const { +void OnnxModelParser::UpdateConstantOpType(ge::onnx::NodeProto *node) const { + // If weight in file, Marker Constant(not Initializer) as file constant + for (auto it : node->attribute()) { + if (it.name() == ge::kAttrNameValue) { + const ::ge::onnx::TensorProto tensor_proto = it.t(); + if (tensor_proto.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { + node->set_op_type(kFileConstant); + GELOGD("Const node [%s], the weight was stored in the file.", node->name().c_str()); + } + break; + } + } +} + +void OnnxModelParser::UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const { int index = 0; for (int i = 0; i < onnx_graph.node_size(); i++) { ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); @@ -382,6 +397,9 @@ void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const std::string node_name = node->op_type() + "_" + to_string(index++); node->set_name(node_name); } + if (node->op_type() == kOpTypeConstant) { + UpdateConstantOpType(node); + } } } @@ -966,7 +984,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP } GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); - // 3. Parse Constant from graph. + // 3. Parse Constant(initializer) from graph. ret = ParseInitializer(onnx_graph, initializer_name_tensor); if (ret != SUCCESS) { GELOGE(ret, "[Parse][Initializer] for onnx failed."); @@ -980,8 +998,8 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP return ret; } - // 5. Update node name for node do not has name. - UpdateAllNodeName(onnx_graph); + // 5. Update node name for node do not has name, update const op type + UpdateNodeNameAndOpType(onnx_graph); // 6 Precheck. ret = Prechecker(onnx_graph); diff --git a/parser/onnx/onnx_parser.h b/parser/onnx/onnx_parser.h index 90b7397..394ff23 100644 --- a/parser/onnx/onnx_parser.h +++ b/parser/onnx/onnx_parser.h @@ -105,7 +105,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, std::map &initializer_name_tensor) const; - void UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const; + void UpdateConstantOpType(ge::onnx::NodeProto *node) const; + + void UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const; Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); diff --git a/tests/st/testcase/test_onnx_parser.cc b/tests/st/testcase/test_onnx_parser.cc index 05b4207..69a335c 100644 --- a/tests/st/testcase/test_onnx_parser.cc +++ b/tests/st/testcase/test_onnx_parser.cc @@ -25,7 +25,10 @@ #include "external/ge/ge_api_types.h" #include "tests/depends/ops_stub/ops_stub.h" #include "framework/omg/parser/parser_factory.h" +#include "parser/onnx/onnx_util.h" +#define private public #include "parser/onnx/onnx_parser.h" +#undef private namespace ge { class STestOnnxParser : public testing::Test { @@ -103,6 +106,31 @@ void STestOnnxParser::RegisterCustomOp() { domi::OpRegistry::Instance()->registrationDatas.clear(); } +ge::onnx::GraphProto CreateOnnxGraph() { + ge::onnx::GraphProto onnx_graph; + (void)onnx_graph.add_input(); + (void)onnx_graph.add_output(); + ::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node(); + ::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node(); + ::ge::onnx::NodeProto* node_add = onnx_graph.add_node(); + node_const1->set_op_type(kOpTypeConstant); + node_const2->set_op_type(kOpTypeConstant); + node_add->set_op_type("Add"); + + ::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); + attr->set_name(ge::kAttrNameValue); + ::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); + tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); + attr = node_const1->add_attribute(); + + attr = node_const2->add_attribute(); + attr->set_name(ge::kAttrNameValue); + tensor_proto = attr->mutable_t(); + tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); + + return onnx_graph; +} + TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) { std::string case_dir = __FILE__; case_dir = case_dir.substr(0, case_dir.find_last_of("/")); @@ -184,4 +212,15 @@ TEST_F(STestOnnxParser, onnx_parser_if_node_with_const_input) { EXPECT_EQ(ret, GRAPH_SUCCESS); } +TEST_F(STestOnnxParser, onnx_test_ModelParseToGraph) +{ + OnnxModelParser modelParser; + ge::onnx::ModelProto model_proto; + auto onnx_graph = model_proto.mutable_graph(); + *onnx_graph = CreateOnnxGraph(); + ge::Graph graph; + + Status ret = modelParser.ModelParseToGraph(model_proto, graph); + EXPECT_EQ(ret, FAILED); +} } // namespace ge diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc index b0338ce..b944b55 100644 --- a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -111,6 +111,29 @@ void UtestOnnxParser::RegisterCustomOp() { domi::OpRegistry::Instance()->registrationDatas.clear(); } +ge::onnx::GraphProto CreateOnnxGraph() { + ge::onnx::GraphProto onnx_graph; + ::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node(); + ::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node(); + ::ge::onnx::NodeProto* node_add = onnx_graph.add_node(); + node_const1->set_op_type(kOpTypeConstant); + node_const2->set_op_type(kOpTypeConstant); + node_add->set_op_type("Add"); + + ::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); + attr->set_name(ge::kAttrNameValue); + ::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); + tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); + attr = node_const1->add_attribute(); + + attr = node_const2->add_attribute(); + attr->set_name(ge::kAttrNameValue); + tensor_proto = attr->mutable_t(); + tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); + + return onnx_graph; +} + TEST_F(UtestOnnxParser, onnx_parser_if_node) { std::string case_dir = __FILE__; case_dir = case_dir.substr(0, case_dir.find_last_of("/")); @@ -575,6 +598,16 @@ TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) EXPECT_EQ(ret, domi::FAILED); } +TEST_F(UtestOnnxParser, OnnxModelParser_ParseConstant_test) +{ + OnnxModelParser model_parser; + ge::onnx::GraphProto onnx_graph = CreateOnnxGraph(); + + model_parser.UpdateNodeNameAndOpType(onnx_graph); + std::string type = onnx_graph.mutable_node(0)->op_type(); + EXPECT_EQ(type, kFileConstant); +} + TEST_F(UtestOnnxParser, onnx_test_ConstructOriType) { ge::onnx::ModelProto model_proto;