Merge pull request !716 from 韩健/hanjianpull/713/MERGE
@@ -117,6 +117,7 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
op_def.SetAttr(kFileConstantPath, attrs); | op_def.SetAttr(kFileConstantPath, attrs); | ||||
GELOGD("The weight file of Op[%s] is: [%s].", tensor_proto.name().c_str(), attrs.GetName().c_str()); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -366,6 +366,7 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||||
*attribute_t = it.second; | *attribute_t = it.second; | ||||
if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | ||||
const_node->set_op_type(kFileConstant); | const_node->set_op_type(kFileConstant); | ||||
GELOGD("Initializer const node [%s], the weight was stored in the file.", const_node->name().c_str()); | |||||
} else { | } else { | ||||
const_node->set_op_type(ge::kOpTypeConstant); | const_node->set_op_type(ge::kOpTypeConstant); | ||||
} | } | ||||
@@ -374,7 +375,21 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const { | |||||
void OnnxModelParser::UpdateConstantOpType(ge::onnx::NodeProto *node) const { | |||||
// If weight in file, Marker Constant(not Initializer) as file constant | |||||
for (auto it : node->attribute()) { | |||||
if (it.name() == ge::kAttrNameValue) { | |||||
const ::ge::onnx::TensorProto tensor_proto = it.t(); | |||||
if (tensor_proto.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | |||||
node->set_op_type(kFileConstant); | |||||
GELOGD("Const node [%s], the weight was stored in the file.", node->name().c_str()); | |||||
} | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
void OnnxModelParser::UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const { | |||||
int index = 0; | int index = 0; | ||||
for (int i = 0; i < onnx_graph.node_size(); i++) { | for (int i = 0; i < onnx_graph.node_size(); i++) { | ||||
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | ||||
@@ -382,6 +397,9 @@ void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const | |||||
std::string node_name = node->op_type() + "_" + to_string(index++); | std::string node_name = node->op_type() + "_" + to_string(index++); | ||||
node->set_name(node_name); | node->set_name(node_name); | ||||
} | } | ||||
if (node->op_type() == kOpTypeConstant) { | |||||
UpdateConstantOpType(node); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -966,7 +984,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||||
} | } | ||||
GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); | GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); | ||||
// 3. Parse Constant from graph. | |||||
// 3. Parse Constant(initializer) from graph. | |||||
ret = ParseInitializer(onnx_graph, initializer_name_tensor); | ret = ParseInitializer(onnx_graph, initializer_name_tensor); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "[Parse][Initializer] for onnx failed."); | GELOGE(ret, "[Parse][Initializer] for onnx failed."); | ||||
@@ -980,8 +998,8 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||||
return ret; | return ret; | ||||
} | } | ||||
// 5. Update node name for node do not has name. | |||||
UpdateAllNodeName(onnx_graph); | |||||
// 5. Update node name for node do not has name, update const op type | |||||
UpdateNodeNameAndOpType(onnx_graph); | |||||
// 6 Precheck. | // 6 Precheck. | ||||
ret = Prechecker(onnx_graph); | ret = Prechecker(onnx_graph); | ||||
@@ -105,7 +105,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, | Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, | ||||
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) const; | std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) const; | ||||
void UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const; | |||||
void UpdateConstantOpType(ge::onnx::NodeProto *node) const; | |||||
void UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const; | |||||
Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); | Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); | ||||
@@ -25,7 +25,10 @@ | |||||
#include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
#include "tests/depends/ops_stub/ops_stub.h" | #include "tests/depends/ops_stub/ops_stub.h" | ||||
#include "framework/omg/parser/parser_factory.h" | #include "framework/omg/parser/parser_factory.h" | ||||
#include "parser/onnx/onnx_util.h" | |||||
#define private public | |||||
#include "parser/onnx/onnx_parser.h" | #include "parser/onnx/onnx_parser.h" | ||||
#undef private | |||||
namespace ge { | namespace ge { | ||||
class STestOnnxParser : public testing::Test { | class STestOnnxParser : public testing::Test { | ||||
@@ -103,6 +106,31 @@ void STestOnnxParser::RegisterCustomOp() { | |||||
domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
} | } | ||||
ge::onnx::GraphProto CreateOnnxGraph() { | |||||
ge::onnx::GraphProto onnx_graph; | |||||
(void)onnx_graph.add_input(); | |||||
(void)onnx_graph.add_output(); | |||||
::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node(); | |||||
::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node(); | |||||
::ge::onnx::NodeProto* node_add = onnx_graph.add_node(); | |||||
node_const1->set_op_type(kOpTypeConstant); | |||||
node_const2->set_op_type(kOpTypeConstant); | |||||
node_add->set_op_type("Add"); | |||||
::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); | |||||
attr->set_name(ge::kAttrNameValue); | |||||
::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); | |||||
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); | |||||
attr = node_const1->add_attribute(); | |||||
attr = node_const2->add_attribute(); | |||||
attr->set_name(ge::kAttrNameValue); | |||||
tensor_proto = attr->mutable_t(); | |||||
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); | |||||
return onnx_graph; | |||||
} | |||||
TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) { | TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) { | ||||
std::string case_dir = __FILE__; | std::string case_dir = __FILE__; | ||||
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | ||||
@@ -184,4 +212,15 @@ TEST_F(STestOnnxParser, onnx_parser_if_node_with_const_input) { | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | EXPECT_EQ(ret, GRAPH_SUCCESS); | ||||
} | } | ||||
TEST_F(STestOnnxParser, onnx_test_ModelParseToGraph) | |||||
{ | |||||
OnnxModelParser modelParser; | |||||
ge::onnx::ModelProto model_proto; | |||||
auto onnx_graph = model_proto.mutable_graph(); | |||||
*onnx_graph = CreateOnnxGraph(); | |||||
ge::Graph graph; | |||||
Status ret = modelParser.ModelParseToGraph(model_proto, graph); | |||||
EXPECT_EQ(ret, FAILED); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -111,6 +111,29 @@ void UtestOnnxParser::RegisterCustomOp() { | |||||
domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
} | } | ||||
ge::onnx::GraphProto CreateOnnxGraph() { | |||||
ge::onnx::GraphProto onnx_graph; | |||||
::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node(); | |||||
::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node(); | |||||
::ge::onnx::NodeProto* node_add = onnx_graph.add_node(); | |||||
node_const1->set_op_type(kOpTypeConstant); | |||||
node_const2->set_op_type(kOpTypeConstant); | |||||
node_add->set_op_type("Add"); | |||||
::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); | |||||
attr->set_name(ge::kAttrNameValue); | |||||
::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); | |||||
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); | |||||
attr = node_const1->add_attribute(); | |||||
attr = node_const2->add_attribute(); | |||||
attr->set_name(ge::kAttrNameValue); | |||||
tensor_proto = attr->mutable_t(); | |||||
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); | |||||
return onnx_graph; | |||||
} | |||||
TEST_F(UtestOnnxParser, onnx_parser_if_node) { | TEST_F(UtestOnnxParser, onnx_parser_if_node) { | ||||
std::string case_dir = __FILE__; | std::string case_dir = __FILE__; | ||||
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | ||||
@@ -575,6 +598,16 @@ TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) | |||||
EXPECT_EQ(ret, domi::FAILED); | EXPECT_EQ(ret, domi::FAILED); | ||||
} | } | ||||
TEST_F(UtestOnnxParser, OnnxModelParser_ParseConstant_test) | |||||
{ | |||||
OnnxModelParser model_parser; | |||||
ge::onnx::GraphProto onnx_graph = CreateOnnxGraph(); | |||||
model_parser.UpdateNodeNameAndOpType(onnx_graph); | |||||
std::string type = onnx_graph.mutable_node(0)->op_type(); | |||||
EXPECT_EQ(type, kFileConstant); | |||||
} | |||||
TEST_F(UtestOnnxParser, onnx_test_ConstructOriType) | TEST_F(UtestOnnxParser, onnx_test_ConstructOriType) | ||||
{ | { | ||||
ge::onnx::ModelProto model_proto; | ge::onnx::ModelProto model_proto; | ||||