Merge pull request !622 from lipeiyang/ge_devpull/625/MERGE
@@ -21,6 +21,7 @@ | |||
#include "graph/op_desc.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/debug/ge_op_types.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "register/register_fmk_types.h" | |||
@@ -143,6 +144,49 @@ Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( | |||
return SUCCESS; | |||
} | |||
domi::Status AutoMappingSubgraphDataFormat(const NodePtr &parent_node, const ge::Graph &graph) { | |||
GE_CHECK_NOTNULL(parent_node); | |||
const auto &parent_op_desc = parent_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(parent_op_desc); | |||
const auto &compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
GE_CHECK_NOTNULL(compute_graph); | |||
const auto data_nodes = FindNodesByType(compute_graph, DATA); | |||
for (size_t i = 0U; i < data_nodes.size(); ++i) { | |||
const auto &data_op_desc = data_nodes[i]->GetOpDesc(); | |||
GE_CHECK_NOTNULL(data_op_desc); | |||
int32_t index = -1; | |||
// when this function has been called, PARENT_INDEX has not been set | |||
if (!ge::AttrUtils::GetInt(data_op_desc, ge::ATTR_NAME_INDEX, index)) { | |||
REPORT_INNER_ERROR("E19999", "Get attr:index failed, op_name:%s", data_nodes[i]->GetName().c_str()); | |||
GELOGE(FAILED, "[Get][Attr] Get attr:index failed, op_name:%s", data_nodes[i]->GetName().c_str()); | |||
return FAILED; | |||
} | |||
GE_CHK_BOOL_RET_STATUS(static_cast<size_t>(index) < parent_op_desc->GetAllInputsSize(), PARAM_INVALID, | |||
"[Check][Index] failed, index=%d should less than %zu.", index, | |||
parent_op_desc->GetAllInputsSize()); | |||
// set data format by node input desc | |||
const auto input_format = parent_op_desc->GetInputDesc(index).GetFormat(); | |||
const auto input_original_format = parent_op_desc->GetInputDesc(index).GetOriginFormat(); | |||
const auto input_desc = data_op_desc->MutableInputDesc(0U); | |||
const auto output_desc = data_op_desc->MutableOutputDesc(0U); | |||
GE_CHECK_NOTNULL(input_desc); | |||
GE_CHECK_NOTNULL(output_desc); | |||
input_desc->SetFormat(input_format); | |||
input_desc->SetOriginFormat(input_original_format); | |||
output_desc->SetFormat(input_format); | |||
output_desc->SetOriginFormat(input_original_format); | |||
GELOGD("Set index %d of data[%zu], node:%s, format:%d->%d, original " | |||
"format:%d->%d, from parent node:%s, node_type:%s", | |||
index, i, data_nodes[i]->GetName().c_str(), | |||
static_cast<int32_t>(output_desc->GetFormat()), | |||
static_cast<int32_t>(input_format), | |||
static_cast<int32_t>(output_desc->GetOriginFormat()), | |||
static_cast<int32_t>(input_original_format), | |||
parent_node->GetName().c_str(), parent_node->GetType().c_str()); | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace ge | |||
namespace domi { | |||
@@ -20,11 +20,15 @@ | |||
#include <functional> | |||
#include "external/graph/graph.h" | |||
#include "external/register/register_error_codes.h" | |||
#include "graph/node.h" | |||
namespace ge { | |||
domi::Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( | |||
const ge::Graph &graph, | |||
const std::function<domi::Status(int data_index, int &parent_input_index)> &input, | |||
const std::function<domi::Status(int netoutput_index, int &parent_output_index)> &output); | |||
// only data node may set default NHWC/NCHW format by parser when call NpuOnnxGraphOp | |||
// this function should be called before AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo | |||
domi::Status AutoMappingSubgraphDataFormat(const NodePtr &parent_node, const ge::Graph &graph); | |||
} // namespace ge | |||
#endif // PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_ |
@@ -318,6 +318,12 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) { | |||
} | |||
Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, const ComputeGraphPtr &root_graph) { | |||
// Inner function, data format need be set by parant node | |||
GE_CHK_STATUS_RET(AutoMappingSubgraphDataFormat(node, graph), | |||
"[Call][AutoMappingSubgraphDataFormat] failed, node:%s, " | |||
"root graph:%s, graph:%s", | |||
node->GetName().c_str(), root_graph->GetName().c_str(), | |||
ParserUtils::GetGraphName(graph).c_str()); | |||
// Inner function, input params have been checked by caller | |||
Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( | |||
graph, | |||
@@ -733,6 +733,17 @@ namespace { | |||
ge::GraphUtils::AddEdge(node_h->GetOutControlAnchor(), node_j->GetInControlAnchor()); | |||
} | |||
void MakeGraph(const ComputeGraphPtr &root_graph, const string &name) { | |||
root_graph->SetName(name); | |||
ge::NodePtr data1 = AddNode(root_graph, name + "_input1", parser::DATA, 1, 1); | |||
ge::NodePtr data2 = AddNode(root_graph, name + "_input2", parser::DATA, 1, 1); | |||
ge::NodePtr add = AddNode(root_graph, name + "_add", parser::ADD, 2, 1); | |||
ge::NodePtr net_output = AddNode(root_graph, name + "_net_output", parser::NETOUTPUT, 1, 1); | |||
ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), add->GetInDataAnchor(0)); | |||
ge::GraphUtils::AddEdge(data2->GetOutDataAnchor(0), add->GetInDataAnchor(1)); | |||
ge::GraphUtils::AddEdge(add->GetOutDataAnchor(0), net_output->GetInDataAnchor(0)); | |||
} | |||
void ChangeDataType(tensorflow::NodeDef* node_tf, int32_t data_type) | |||
{ | |||
domi::tensorflow::AttrValue input_attr_value; | |||
@@ -1211,6 +1222,69 @@ TEST_F(STestTensorflowParser, parser_ConvertToGeDataType) | |||
ASSERT_EQ(dataType, ge::DataType::DT_UNDEFINED); | |||
} | |||
TEST_F(STestTensorflowParser, tensorflow_parser_with_external_normal_graph) { | |||
// 1. Create root graph | |||
ComputeGraphPtr root_graph = ge::parser::MakeShared<ge::ComputeGraph>("root_graph"); | |||
MakeGraph(root_graph, "root_graph"); | |||
// 2. Create ONNX sub graph | |||
// 2.1 Sub graph of onnx graph | |||
ge::ComputeGraphPtr sub_sub_graph = ge::parser::MakeShared<ge::ComputeGraph>("sub_sub"); | |||
// 2.2 ONNX graph | |||
ComputeGraphPtr sub_graph = ge::parser::MakeShared<ge::ComputeGraph>("sub_sub"); | |||
MakeGraph(sub_graph, "sub_graph"); | |||
auto add = sub_graph->FindNode("sub_graph_add"); | |||
ASSERT_NE(add, nullptr); | |||
add->GetOpDesc()->AddSubgraphName("sub_sub_graph"); | |||
add->GetOpDesc()->SetSubgraphInstanceName(0, sub_sub_graph->GetName()); | |||
sub_graph->AddSubGraph(sub_sub_graph); | |||
auto input1 = sub_graph->FindNode("sub_graph_input1"); | |||
ASSERT_NE(input1, nullptr); | |||
AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||
auto input2 = sub_graph->FindNode("sub_graph_input2"); | |||
ASSERT_NE(input2, nullptr); | |||
AttrUtils::SetInt(input2->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||
// 3. Serialize ONNX graph to string | |||
// 3.1 normal | |||
ge::Model model("model", ""); | |||
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph)); | |||
Buffer buffer; | |||
graphStatus save_ret = model.Save(buffer, false); | |||
ASSERT_EQ(save_ret, GRAPH_SUCCESS); | |||
std::string external_graph(reinterpret_cast<const char *>(buffer.GetData()), | |||
buffer.GetSize()); | |||
// model will failed | |||
input1->GetOpDesc()->DelAttr(ATTR_NAME_INDEX); | |||
ge::Model model_will_fail("model_will_fail", ""); | |||
model_will_fail.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph)); | |||
Buffer buffer_fail; | |||
save_ret = model_will_fail.Save(buffer_fail, false); | |||
ASSERT_EQ(save_ret, GRAPH_SUCCESS); | |||
std::string external_graph_fail( | |||
reinterpret_cast<const char *>(buffer_fail.GetData()), | |||
buffer_fail.GetSize()); | |||
// 4. Set string to function node | |||
auto root_add = root_graph->FindNode("root_graph_add"); | |||
ASSERT_NE(root_add, nullptr); | |||
AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph); | |||
auto root_input1 = root_graph->FindNode("root_graph_input1"); | |||
ASSERT_NE(root_input1, nullptr); | |||
AttrUtils::SetInt(root_input1->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||
auto root_input2 = root_graph->FindNode("root_graph_input2"); | |||
ASSERT_NE(root_input2, nullptr); | |||
AttrUtils::SetInt(root_input2->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||
// 5. Run test (normal) | |||
auto ret = TensorFlowModelParser::AddExternalGraph(root_graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
EXPECT_EQ(root_graph->GetAllSubgraphs().size(), 2); | |||
EXPECT_EQ(sub_graph->GetAllSubgraphs().size(), 1); | |||
EXPECT_NE(root_graph->GetSubgraph(sub_graph->GetName()), nullptr); | |||
EXPECT_EQ(root_graph->GetSubgraph(sub_graph->GetName())->GetAllSubgraphs().size(), 0); | |||
} | |||
TEST_F(STestTensorflowParser, tensorflow_ParserProto_failed) | |||
{ | |||
std::string caseDir = __FILE__; | |||