Browse Source

!716 大于2G的onnx模型,权重和图分离,非initializer的const算子,也会有权重文件。

Merge pull request !716 from 韩健/hanjian
pull/713/MERGE
韩健 i-robot 2 years ago
parent
commit
ca35f8b50d
5 changed files with 98 additions and 5 deletions
  1. +1
    -0
      parser/onnx/onnx_file_constant_parser.cc
  2. +22
    -4
      parser/onnx/onnx_parser.cc
  3. +3
    -1
      parser/onnx/onnx_parser.h
  4. +39
    -0
      tests/st/testcase/test_onnx_parser.cc
  5. +33
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc

+ 1
- 0
parser/onnx/onnx_file_constant_parser.cc View File

@@ -117,6 +117,7 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro
return FAILED;
}
op_def.SetAttr(kFileConstantPath, attrs);
GELOGD("The weight file of Op[%s] is: [%s].", tensor_proto.name().c_str(), attrs.GetName().c_str());
return SUCCESS;
}



+ 22
- 4
parser/onnx/onnx_parser.cc View File

@@ -366,6 +366,7 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph,
*attribute_t = it.second;
if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) {
const_node->set_op_type(kFileConstant);
GELOGD("Initializer const node [%s], the weight was stored in the file.", const_node->name().c_str());
} else {
const_node->set_op_type(ge::kOpTypeConstant);
}
@@ -374,7 +375,21 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph,
return SUCCESS;
}

void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const {
void OnnxModelParser::UpdateConstantOpType(ge::onnx::NodeProto *node) const {
// If weight in file, Marker Constant(not Initializer) as file constant
for (auto it : node->attribute()) {
if (it.name() == ge::kAttrNameValue) {
const ::ge::onnx::TensorProto tensor_proto = it.t();
if (tensor_proto.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) {
node->set_op_type(kFileConstant);
GELOGD("Const node [%s], the weight was stored in the file.", node->name().c_str());
}
break;
}
}
}

void OnnxModelParser::UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const {
int index = 0;
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i);
@@ -382,6 +397,9 @@ void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const
std::string node_name = node->op_type() + "_" + to_string(index++);
node->set_name(node_name);
}
if (node->op_type() == kOpTypeConstant) {
UpdateConstantOpType(node);
}
}
}

@@ -966,7 +984,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP
}
GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size());

// 3. Parse Constant from graph.
// 3. Parse Constant(initializer) from graph.
ret = ParseInitializer(onnx_graph, initializer_name_tensor);
if (ret != SUCCESS) {
GELOGE(ret, "[Parse][Initializer] for onnx failed.");
@@ -980,8 +998,8 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP
return ret;
}

// 5. Update node name for node do not has name.
UpdateAllNodeName(onnx_graph);
// 5. Update node name for node do not has name, update const op type
UpdateNodeNameAndOpType(onnx_graph);

// 6 Precheck.
ret = Prechecker(onnx_graph);


+ 3
- 1
parser/onnx/onnx_parser.h View File

@@ -105,7 +105,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {
Status ParseInitializer(ge::onnx::GraphProto &onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) const;

void UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const;
void UpdateConstantOpType(ge::onnx::NodeProto *node) const;

void UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const;

Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type);



+ 39
- 0
tests/st/testcase/test_onnx_parser.cc View File

@@ -25,7 +25,10 @@
#include "external/ge/ge_api_types.h"
#include "tests/depends/ops_stub/ops_stub.h"
#include "framework/omg/parser/parser_factory.h"
#include "parser/onnx/onnx_util.h"
#define private public
#include "parser/onnx/onnx_parser.h"
#undef private

namespace ge {
class STestOnnxParser : public testing::Test {
@@ -103,6 +106,31 @@ void STestOnnxParser::RegisterCustomOp() {
domi::OpRegistry::Instance()->registrationDatas.clear();
}

ge::onnx::GraphProto CreateOnnxGraph() {
ge::onnx::GraphProto onnx_graph;
(void)onnx_graph.add_input();
(void)onnx_graph.add_output();
::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node();
::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node();
::ge::onnx::NodeProto* node_add = onnx_graph.add_node();
node_const1->set_op_type(kOpTypeConstant);
node_const2->set_op_type(kOpTypeConstant);
node_add->set_op_type("Add");

::ge::onnx::AttributeProto* attr = node_const1->add_attribute();
attr->set_name(ge::kAttrNameValue);
::ge::onnx::TensorProto* tensor_proto = attr->mutable_t();
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL);
attr = node_const1->add_attribute();

attr = node_const2->add_attribute();
attr->set_name(ge::kAttrNameValue);
tensor_proto = attr->mutable_t();
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT);

return onnx_graph;
}

TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
@@ -184,4 +212,15 @@ TEST_F(STestOnnxParser, onnx_parser_if_node_with_const_input) {
EXPECT_EQ(ret, GRAPH_SUCCESS);
}

TEST_F(STestOnnxParser, onnx_test_ModelParseToGraph)
{
OnnxModelParser modelParser;
ge::onnx::ModelProto model_proto;
auto onnx_graph = model_proto.mutable_graph();
*onnx_graph = CreateOnnxGraph();
ge::Graph graph;

Status ret = modelParser.ModelParseToGraph(model_proto, graph);
EXPECT_EQ(ret, FAILED);
}
} // namespace ge

+ 33
- 0
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -111,6 +111,29 @@ void UtestOnnxParser::RegisterCustomOp() {
domi::OpRegistry::Instance()->registrationDatas.clear();
}

ge::onnx::GraphProto CreateOnnxGraph() {
ge::onnx::GraphProto onnx_graph;
::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node();
::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node();
::ge::onnx::NodeProto* node_add = onnx_graph.add_node();
node_const1->set_op_type(kOpTypeConstant);
node_const2->set_op_type(kOpTypeConstant);
node_add->set_op_type("Add");

::ge::onnx::AttributeProto* attr = node_const1->add_attribute();
attr->set_name(ge::kAttrNameValue);
::ge::onnx::TensorProto* tensor_proto = attr->mutable_t();
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL);
attr = node_const1->add_attribute();

attr = node_const2->add_attribute();
attr->set_name(ge::kAttrNameValue);
tensor_proto = attr->mutable_t();
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT);

return onnx_graph;
}

TEST_F(UtestOnnxParser, onnx_parser_if_node) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
@@ -575,6 +598,16 @@ TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test)
EXPECT_EQ(ret, domi::FAILED);
}

TEST_F(UtestOnnxParser, OnnxModelParser_ParseConstant_test)
{
OnnxModelParser model_parser;
ge::onnx::GraphProto onnx_graph = CreateOnnxGraph();

model_parser.UpdateNodeNameAndOpType(onnx_graph);
std::string type = onnx_graph.mutable_node(0)->op_type();
EXPECT_EQ(type, kFileConstant);
}

TEST_F(UtestOnnxParser, onnx_test_ConstructOriType)
{
ge::onnx::ModelProto model_proto;


Loading…
Cancel
Save