Merge pull request !716 from 韩健/hanjianpull/713/MERGE
@@ -117,6 +117,7 @@ Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_pro | |||
return FAILED; | |||
} | |||
op_def.SetAttr(kFileConstantPath, attrs); | |||
GELOGD("The weight file of Op[%s] is: [%s].", tensor_proto.name().c_str(), attrs.GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
@@ -366,6 +366,7 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||
*attribute_t = it.second; | |||
if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | |||
const_node->set_op_type(kFileConstant); | |||
GELOGD("Initializer const node [%s], the weight was stored in the file.", const_node->name().c_str()); | |||
} else { | |||
const_node->set_op_type(ge::kOpTypeConstant); | |||
} | |||
@@ -374,7 +375,21 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||
return SUCCESS; | |||
} | |||
void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const { | |||
void OnnxModelParser::UpdateConstantOpType(ge::onnx::NodeProto *node) const { | |||
// If weight in file, Marker Constant(not Initializer) as file constant | |||
for (auto it : node->attribute()) { | |||
if (it.name() == ge::kAttrNameValue) { | |||
const ::ge::onnx::TensorProto tensor_proto = it.t(); | |||
if (tensor_proto.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | |||
node->set_op_type(kFileConstant); | |||
GELOGD("Const node [%s], the weight was stored in the file.", node->name().c_str()); | |||
} | |||
break; | |||
} | |||
} | |||
} | |||
void OnnxModelParser::UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const { | |||
int index = 0; | |||
for (int i = 0; i < onnx_graph.node_size(); i++) { | |||
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | |||
@@ -382,6 +397,9 @@ void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const | |||
std::string node_name = node->op_type() + "_" + to_string(index++); | |||
node->set_name(node_name); | |||
} | |||
if (node->op_type() == kOpTypeConstant) { | |||
UpdateConstantOpType(node); | |||
} | |||
} | |||
} | |||
@@ -966,7 +984,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||
} | |||
GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); | |||
// 3. Parse Constant from graph. | |||
// 3. Parse Constant(initializer) from graph. | |||
ret = ParseInitializer(onnx_graph, initializer_name_tensor); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "[Parse][Initializer] for onnx failed."); | |||
@@ -980,8 +998,8 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||
return ret; | |||
} | |||
// 5. Update node name for node do not has name. | |||
UpdateAllNodeName(onnx_graph); | |||
// 5. Update node name for node do not has name, update const op type | |||
UpdateNodeNameAndOpType(onnx_graph); | |||
// 6 Precheck. | |||
ret = Prechecker(onnx_graph); | |||
@@ -105,7 +105,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||
Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) const; | |||
void UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const; | |||
void UpdateConstantOpType(ge::onnx::NodeProto *node) const; | |||
void UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const; | |||
Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); | |||
@@ -25,7 +25,10 @@ | |||
#include "external/ge/ge_api_types.h" | |||
#include "tests/depends/ops_stub/ops_stub.h" | |||
#include "framework/omg/parser/parser_factory.h" | |||
#include "parser/onnx/onnx_util.h" | |||
#define private public | |||
#include "parser/onnx/onnx_parser.h" | |||
#undef private | |||
namespace ge { | |||
class STestOnnxParser : public testing::Test { | |||
@@ -103,6 +106,31 @@ void STestOnnxParser::RegisterCustomOp() { | |||
domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
} | |||
ge::onnx::GraphProto CreateOnnxGraph() { | |||
ge::onnx::GraphProto onnx_graph; | |||
(void)onnx_graph.add_input(); | |||
(void)onnx_graph.add_output(); | |||
::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node(); | |||
::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node(); | |||
::ge::onnx::NodeProto* node_add = onnx_graph.add_node(); | |||
node_const1->set_op_type(kOpTypeConstant); | |||
node_const2->set_op_type(kOpTypeConstant); | |||
node_add->set_op_type("Add"); | |||
::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); | |||
attr->set_name(ge::kAttrNameValue); | |||
::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); | |||
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); | |||
attr = node_const1->add_attribute(); | |||
attr = node_const2->add_attribute(); | |||
attr->set_name(ge::kAttrNameValue); | |||
tensor_proto = attr->mutable_t(); | |||
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); | |||
return onnx_graph; | |||
} | |||
TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) { | |||
std::string case_dir = __FILE__; | |||
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
@@ -184,4 +212,15 @@ TEST_F(STestOnnxParser, onnx_parser_if_node_with_const_input) { | |||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||
} | |||
TEST_F(STestOnnxParser, onnx_test_ModelParseToGraph) | |||
{ | |||
OnnxModelParser modelParser; | |||
ge::onnx::ModelProto model_proto; | |||
auto onnx_graph = model_proto.mutable_graph(); | |||
*onnx_graph = CreateOnnxGraph(); | |||
ge::Graph graph; | |||
Status ret = modelParser.ModelParseToGraph(model_proto, graph); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
} // namespace ge |
@@ -111,6 +111,29 @@ void UtestOnnxParser::RegisterCustomOp() { | |||
domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
} | |||
ge::onnx::GraphProto CreateOnnxGraph() { | |||
ge::onnx::GraphProto onnx_graph; | |||
::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node(); | |||
::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node(); | |||
::ge::onnx::NodeProto* node_add = onnx_graph.add_node(); | |||
node_const1->set_op_type(kOpTypeConstant); | |||
node_const2->set_op_type(kOpTypeConstant); | |||
node_add->set_op_type("Add"); | |||
::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); | |||
attr->set_name(ge::kAttrNameValue); | |||
::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); | |||
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); | |||
attr = node_const1->add_attribute(); | |||
attr = node_const2->add_attribute(); | |||
attr->set_name(ge::kAttrNameValue); | |||
tensor_proto = attr->mutable_t(); | |||
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); | |||
return onnx_graph; | |||
} | |||
TEST_F(UtestOnnxParser, onnx_parser_if_node) { | |||
std::string case_dir = __FILE__; | |||
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
@@ -575,6 +598,16 @@ TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) | |||
EXPECT_EQ(ret, domi::FAILED); | |||
} | |||
TEST_F(UtestOnnxParser, OnnxModelParser_ParseConstant_test) | |||
{ | |||
OnnxModelParser model_parser; | |||
ge::onnx::GraphProto onnx_graph = CreateOnnxGraph(); | |||
model_parser.UpdateNodeNameAndOpType(onnx_graph); | |||
std::string type = onnx_graph.mutable_node(0)->op_type(); | |||
EXPECT_EQ(type, kFileConstant); | |||
} | |||
TEST_F(UtestOnnxParser, onnx_test_ConstructOriType) | |||
{ | |||
ge::onnx::ModelProto model_proto; | |||