Merge pull request !622 from lipeiyang/ge_devpull/625/MERGE
@@ -21,6 +21,7 @@ | |||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/debug/ge_op_types.h" | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
@@ -143,6 +144,49 @@ Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
domi::Status AutoMappingSubgraphDataFormat(const NodePtr &parent_node, const ge::Graph &graph) { | |||||
GE_CHECK_NOTNULL(parent_node); | |||||
const auto &parent_op_desc = parent_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(parent_op_desc); | |||||
const auto &compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||||
GE_CHECK_NOTNULL(compute_graph); | |||||
const auto data_nodes = FindNodesByType(compute_graph, DATA); | |||||
for (size_t i = 0U; i < data_nodes.size(); ++i) { | |||||
const auto &data_op_desc = data_nodes[i]->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(data_op_desc); | |||||
int32_t index = -1; | |||||
// when this function has been called, PARENT_INDEX has not been set | |||||
if (!ge::AttrUtils::GetInt(data_op_desc, ge::ATTR_NAME_INDEX, index)) { | |||||
REPORT_INNER_ERROR("E19999", "Get attr:index failed, op_name:%s", data_nodes[i]->GetName().c_str()); | |||||
GELOGE(FAILED, "[Get][Attr] Get attr:index failed, op_name:%s", data_nodes[i]->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
GE_CHK_BOOL_RET_STATUS(static_cast<size_t>(index) < parent_op_desc->GetAllInputsSize(), PARAM_INVALID, | |||||
"[Check][Index] failed, index=%d should less than %zu.", index, | |||||
parent_op_desc->GetAllInputsSize()); | |||||
// set data format by node input desc | |||||
const auto input_format = parent_op_desc->GetInputDesc(index).GetFormat(); | |||||
const auto input_original_format = parent_op_desc->GetInputDesc(index).GetOriginFormat(); | |||||
const auto input_desc = data_op_desc->MutableInputDesc(0U); | |||||
const auto output_desc = data_op_desc->MutableOutputDesc(0U); | |||||
GE_CHECK_NOTNULL(input_desc); | |||||
GE_CHECK_NOTNULL(output_desc); | |||||
input_desc->SetFormat(input_format); | |||||
input_desc->SetOriginFormat(input_original_format); | |||||
output_desc->SetFormat(input_format); | |||||
output_desc->SetOriginFormat(input_original_format); | |||||
GELOGD("Set index %d of data[%zu], node:%s, format:%d->%d, original " | |||||
"format:%d->%d, from parent node:%s, node_type:%s", | |||||
index, i, data_nodes[i]->GetName().c_str(), | |||||
static_cast<int32_t>(output_desc->GetFormat()), | |||||
static_cast<int32_t>(input_format), | |||||
static_cast<int32_t>(output_desc->GetOriginFormat()), | |||||
static_cast<int32_t>(input_original_format), | |||||
parent_node->GetName().c_str(), parent_node->GetType().c_str()); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge | } // namespace ge | ||||
namespace domi { | namespace domi { | ||||
@@ -20,11 +20,15 @@ | |||||
#include <functional> | #include <functional> | ||||
#include "external/graph/graph.h" | #include "external/graph/graph.h" | ||||
#include "external/register/register_error_codes.h" | #include "external/register/register_error_codes.h" | ||||
#include "graph/node.h" | |||||
namespace ge { | namespace ge { | ||||
domi::Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( | domi::Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( | ||||
const ge::Graph &graph, | const ge::Graph &graph, | ||||
const std::function<domi::Status(int data_index, int &parent_input_index)> &input, | const std::function<domi::Status(int data_index, int &parent_input_index)> &input, | ||||
const std::function<domi::Status(int netoutput_index, int &parent_output_index)> &output); | const std::function<domi::Status(int netoutput_index, int &parent_output_index)> &output); | ||||
// only data node may set default NHWC/NCHW format by parser when call NpuOnnxGraphOp | |||||
// this function should be called before AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo | |||||
domi::Status AutoMappingSubgraphDataFormat(const NodePtr &parent_node, const ge::Graph &graph); | |||||
} // namespace ge | } // namespace ge | ||||
#endif // PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_ | #endif // PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_ |
@@ -318,6 +318,12 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) { | |||||
} | } | ||||
Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, const ComputeGraphPtr &root_graph) { | Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, const ComputeGraphPtr &root_graph) { | ||||
// Inner function, data format need be set by parant node | |||||
GE_CHK_STATUS_RET(AutoMappingSubgraphDataFormat(node, graph), | |||||
"[Call][AutoMappingSubgraphDataFormat] failed, node:%s, " | |||||
"root graph:%s, graph:%s", | |||||
node->GetName().c_str(), root_graph->GetName().c_str(), | |||||
ParserUtils::GetGraphName(graph).c_str()); | |||||
// Inner function, input params have been checked by caller | // Inner function, input params have been checked by caller | ||||
Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( | Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( | ||||
graph, | graph, | ||||
@@ -733,6 +733,17 @@ namespace { | |||||
ge::GraphUtils::AddEdge(node_h->GetOutControlAnchor(), node_j->GetInControlAnchor()); | ge::GraphUtils::AddEdge(node_h->GetOutControlAnchor(), node_j->GetInControlAnchor()); | ||||
} | } | ||||
void MakeGraph(const ComputeGraphPtr &root_graph, const string &name) { | |||||
root_graph->SetName(name); | |||||
ge::NodePtr data1 = AddNode(root_graph, name + "_input1", parser::DATA, 1, 1); | |||||
ge::NodePtr data2 = AddNode(root_graph, name + "_input2", parser::DATA, 1, 1); | |||||
ge::NodePtr add = AddNode(root_graph, name + "_add", parser::ADD, 2, 1); | |||||
ge::NodePtr net_output = AddNode(root_graph, name + "_net_output", parser::NETOUTPUT, 1, 1); | |||||
ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), add->GetInDataAnchor(0)); | |||||
ge::GraphUtils::AddEdge(data2->GetOutDataAnchor(0), add->GetInDataAnchor(1)); | |||||
ge::GraphUtils::AddEdge(add->GetOutDataAnchor(0), net_output->GetInDataAnchor(0)); | |||||
} | |||||
void ChangeDataType(tensorflow::NodeDef* node_tf, int32_t data_type) | void ChangeDataType(tensorflow::NodeDef* node_tf, int32_t data_type) | ||||
{ | { | ||||
domi::tensorflow::AttrValue input_attr_value; | domi::tensorflow::AttrValue input_attr_value; | ||||
@@ -1211,6 +1222,69 @@ TEST_F(STestTensorflowParser, parser_ConvertToGeDataType) | |||||
ASSERT_EQ(dataType, ge::DataType::DT_UNDEFINED); | ASSERT_EQ(dataType, ge::DataType::DT_UNDEFINED); | ||||
} | } | ||||
TEST_F(STestTensorflowParser, tensorflow_parser_with_external_normal_graph) { | |||||
// 1. Create root graph | |||||
ComputeGraphPtr root_graph = ge::parser::MakeShared<ge::ComputeGraph>("root_graph"); | |||||
MakeGraph(root_graph, "root_graph"); | |||||
// 2. Create ONNX sub graph | |||||
// 2.1 Sub graph of onnx graph | |||||
ge::ComputeGraphPtr sub_sub_graph = ge::parser::MakeShared<ge::ComputeGraph>("sub_sub"); | |||||
// 2.2 ONNX graph | |||||
ComputeGraphPtr sub_graph = ge::parser::MakeShared<ge::ComputeGraph>("sub_sub"); | |||||
MakeGraph(sub_graph, "sub_graph"); | |||||
auto add = sub_graph->FindNode("sub_graph_add"); | |||||
ASSERT_NE(add, nullptr); | |||||
add->GetOpDesc()->AddSubgraphName("sub_sub_graph"); | |||||
add->GetOpDesc()->SetSubgraphInstanceName(0, sub_sub_graph->GetName()); | |||||
sub_graph->AddSubGraph(sub_sub_graph); | |||||
auto input1 = sub_graph->FindNode("sub_graph_input1"); | |||||
ASSERT_NE(input1, nullptr); | |||||
AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
auto input2 = sub_graph->FindNode("sub_graph_input2"); | |||||
ASSERT_NE(input2, nullptr); | |||||
AttrUtils::SetInt(input2->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
// 3. Serialize ONNX graph to string | |||||
// 3.1 normal | |||||
ge::Model model("model", ""); | |||||
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph)); | |||||
Buffer buffer; | |||||
graphStatus save_ret = model.Save(buffer, false); | |||||
ASSERT_EQ(save_ret, GRAPH_SUCCESS); | |||||
std::string external_graph(reinterpret_cast<const char *>(buffer.GetData()), | |||||
buffer.GetSize()); | |||||
// model will failed | |||||
input1->GetOpDesc()->DelAttr(ATTR_NAME_INDEX); | |||||
ge::Model model_will_fail("model_will_fail", ""); | |||||
model_will_fail.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph)); | |||||
Buffer buffer_fail; | |||||
save_ret = model_will_fail.Save(buffer_fail, false); | |||||
ASSERT_EQ(save_ret, GRAPH_SUCCESS); | |||||
std::string external_graph_fail( | |||||
reinterpret_cast<const char *>(buffer_fail.GetData()), | |||||
buffer_fail.GetSize()); | |||||
// 4. Set string to function node | |||||
auto root_add = root_graph->FindNode("root_graph_add"); | |||||
ASSERT_NE(root_add, nullptr); | |||||
AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph); | |||||
auto root_input1 = root_graph->FindNode("root_graph_input1"); | |||||
ASSERT_NE(root_input1, nullptr); | |||||
AttrUtils::SetInt(root_input1->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
auto root_input2 = root_graph->FindNode("root_graph_input2"); | |||||
ASSERT_NE(root_input2, nullptr); | |||||
AttrUtils::SetInt(root_input2->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
// 5. Run test (normal) | |||||
auto ret = TensorFlowModelParser::AddExternalGraph(root_graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
EXPECT_EQ(root_graph->GetAllSubgraphs().size(), 2); | |||||
EXPECT_EQ(sub_graph->GetAllSubgraphs().size(), 1); | |||||
EXPECT_NE(root_graph->GetSubgraph(sub_graph->GetName()), nullptr); | |||||
EXPECT_EQ(root_graph->GetSubgraph(sub_graph->GetName())->GetAllSubgraphs().size(), 0); | |||||
} | |||||
TEST_F(STestTensorflowParser, tensorflow_ParserProto_failed) | TEST_F(STestTensorflowParser, tensorflow_ParserProto_failed) | ||||
{ | { | ||||
std::string caseDir = __FILE__; | std::string caseDir = __FILE__; | ||||