Browse Source

!622 sync ge_dev to master 20220816

Merge pull request !622 from lipeiyang/ge_dev
pull/625/MERGE
lipeiyang Gitee 2 years ago
parent
commit
dac35179a8
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 128 additions and 0 deletions
  1. +44
    -0
      parser/common/auto_mapping_subgraph_io_index_func.cc
  2. +4
    -0
      parser/common/auto_mapping_subgraph_io_index_func.h
  3. +6
    -0
      parser/tensorflow/tensorflow_parser.cc
  4. +74
    -0
      tests/st/testcase/test_tensorflow_parser.cc

+ 44
- 0
parser/common/auto_mapping_subgraph_io_index_func.cc View File

@@ -21,6 +21,7 @@
#include "graph/op_desc.h" #include "graph/op_desc.h"
#include "graph/utils/attr_utils.h" #include "graph/utils/attr_utils.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/debug/ge_op_types.h"
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h" #include "graph/utils/node_utils.h"
#include "register/register_fmk_types.h" #include "register/register_fmk_types.h"
@@ -143,6 +144,49 @@ Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(


return SUCCESS; return SUCCESS;
} }

domi::Status AutoMappingSubgraphDataFormat(const NodePtr &parent_node, const ge::Graph &graph) {
GE_CHECK_NOTNULL(parent_node);
const auto &parent_op_desc = parent_node->GetOpDesc();
GE_CHECK_NOTNULL(parent_op_desc);
const auto &compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
const auto data_nodes = FindNodesByType(compute_graph, DATA);
for (size_t i = 0U; i < data_nodes.size(); ++i) {
const auto &data_op_desc = data_nodes[i]->GetOpDesc();
GE_CHECK_NOTNULL(data_op_desc);
int32_t index = -1;
// when this function has been called, PARENT_INDEX has not been set
if (!ge::AttrUtils::GetInt(data_op_desc, ge::ATTR_NAME_INDEX, index)) {
REPORT_INNER_ERROR("E19999", "Get attr:index failed, op_name:%s", data_nodes[i]->GetName().c_str());
GELOGE(FAILED, "[Get][Attr] Get attr:index failed, op_name:%s", data_nodes[i]->GetName().c_str());
return FAILED;
}
GE_CHK_BOOL_RET_STATUS(static_cast<size_t>(index) < parent_op_desc->GetAllInputsSize(), PARAM_INVALID,
"[Check][Index] failed, index=%d should less than %zu.", index,
parent_op_desc->GetAllInputsSize());
// set data format by node input desc
const auto input_format = parent_op_desc->GetInputDesc(index).GetFormat();
const auto input_original_format = parent_op_desc->GetInputDesc(index).GetOriginFormat();
const auto input_desc = data_op_desc->MutableInputDesc(0U);
const auto output_desc = data_op_desc->MutableOutputDesc(0U);
GE_CHECK_NOTNULL(input_desc);
GE_CHECK_NOTNULL(output_desc);
input_desc->SetFormat(input_format);
input_desc->SetOriginFormat(input_original_format);
output_desc->SetFormat(input_format);
output_desc->SetOriginFormat(input_original_format);
GELOGD("Set index %d of data[%zu], node:%s, format:%d->%d, original "
"format:%d->%d, from parent node:%s, node_type:%s",
index, i, data_nodes[i]->GetName().c_str(),
static_cast<int32_t>(output_desc->GetFormat()),
static_cast<int32_t>(input_format),
static_cast<int32_t>(output_desc->GetOriginFormat()),
static_cast<int32_t>(input_original_format),
parent_node->GetName().c_str(), parent_node->GetType().c_str());
}
return SUCCESS;
}
} // namespace ge } // namespace ge


namespace domi { namespace domi {


+ 4
- 0
parser/common/auto_mapping_subgraph_io_index_func.h View File

@@ -20,11 +20,15 @@
#include <functional> #include <functional>
#include "external/graph/graph.h" #include "external/graph/graph.h"
#include "external/register/register_error_codes.h" #include "external/register/register_error_codes.h"
#include "graph/node.h"


namespace ge { namespace ge {
domi::Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( domi::Status AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(
const ge::Graph &graph, const ge::Graph &graph,
const std::function<domi::Status(int data_index, int &parent_input_index)> &input, const std::function<domi::Status(int data_index, int &parent_input_index)> &input,
const std::function<domi::Status(int netoutput_index, int &parent_output_index)> &output); const std::function<domi::Status(int netoutput_index, int &parent_output_index)> &output);
// only data node may set default NHWC/NCHW format by parser when call NpuOnnxGraphOp
// this function should be called before AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo
domi::Status AutoMappingSubgraphDataFormat(const NodePtr &parent_node, const ge::Graph &graph);
} // namespace ge } // namespace ge
#endif // PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_ #endif // PARSER_COMMON_AUTO_MAPPING_SUBGRAPH_IO_INDEX_FUNC_H_

+ 6
- 0
parser/tensorflow/tensorflow_parser.cc View File

@@ -318,6 +318,12 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) {
} }


Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, const ComputeGraphPtr &root_graph) { Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, const ComputeGraphPtr &root_graph) {
// Inner function, data format need be set by parant node
GE_CHK_STATUS_RET(AutoMappingSubgraphDataFormat(node, graph),
"[Call][AutoMappingSubgraphDataFormat] failed, node:%s, "
"root graph:%s, graph:%s",
node->GetName().c_str(), root_graph->GetName().c_str(),
ParserUtils::GetGraphName(graph).c_str());
// Inner function, input params have been checked by caller // Inner function, input params have been checked by caller
Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(
graph, graph,


+ 74
- 0
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -733,6 +733,17 @@ namespace {
ge::GraphUtils::AddEdge(node_h->GetOutControlAnchor(), node_j->GetInControlAnchor()); ge::GraphUtils::AddEdge(node_h->GetOutControlAnchor(), node_j->GetInControlAnchor());
} }


void MakeGraph(const ComputeGraphPtr &root_graph, const string &name) {
root_graph->SetName(name);
ge::NodePtr data1 = AddNode(root_graph, name + "_input1", parser::DATA, 1, 1);
ge::NodePtr data2 = AddNode(root_graph, name + "_input2", parser::DATA, 1, 1);
ge::NodePtr add = AddNode(root_graph, name + "_add", parser::ADD, 2, 1);
ge::NodePtr net_output = AddNode(root_graph, name + "_net_output", parser::NETOUTPUT, 1, 1);
ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), add->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(data2->GetOutDataAnchor(0), add->GetInDataAnchor(1));
ge::GraphUtils::AddEdge(add->GetOutDataAnchor(0), net_output->GetInDataAnchor(0));
}

void ChangeDataType(tensorflow::NodeDef* node_tf, int32_t data_type) void ChangeDataType(tensorflow::NodeDef* node_tf, int32_t data_type)
{ {
domi::tensorflow::AttrValue input_attr_value; domi::tensorflow::AttrValue input_attr_value;
@@ -1211,6 +1222,69 @@ TEST_F(STestTensorflowParser, parser_ConvertToGeDataType)
ASSERT_EQ(dataType, ge::DataType::DT_UNDEFINED); ASSERT_EQ(dataType, ge::DataType::DT_UNDEFINED);
} }


TEST_F(STestTensorflowParser, tensorflow_parser_with_external_normal_graph) {
// 1. Create root graph
ComputeGraphPtr root_graph = ge::parser::MakeShared<ge::ComputeGraph>("root_graph");
MakeGraph(root_graph, "root_graph");

// 2. Create ONNX sub graph
// 2.1 Sub graph of onnx graph
ge::ComputeGraphPtr sub_sub_graph = ge::parser::MakeShared<ge::ComputeGraph>("sub_sub");
// 2.2 ONNX graph
ComputeGraphPtr sub_graph = ge::parser::MakeShared<ge::ComputeGraph>("sub_sub");
MakeGraph(sub_graph, "sub_graph");
auto add = sub_graph->FindNode("sub_graph_add");
ASSERT_NE(add, nullptr);
add->GetOpDesc()->AddSubgraphName("sub_sub_graph");
add->GetOpDesc()->SetSubgraphInstanceName(0, sub_sub_graph->GetName());
sub_graph->AddSubGraph(sub_sub_graph);
auto input1 = sub_graph->FindNode("sub_graph_input1");
ASSERT_NE(input1, nullptr);
AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_INDEX, 0);
auto input2 = sub_graph->FindNode("sub_graph_input2");
ASSERT_NE(input2, nullptr);
AttrUtils::SetInt(input2->GetOpDesc(), ATTR_NAME_INDEX, 1);

// 3. Serialize ONNX graph to string
// 3.1 normal
ge::Model model("model", "");
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph));
Buffer buffer;
graphStatus save_ret = model.Save(buffer, false);
ASSERT_EQ(save_ret, GRAPH_SUCCESS);
std::string external_graph(reinterpret_cast<const char *>(buffer.GetData()),
buffer.GetSize());
// model will failed
input1->GetOpDesc()->DelAttr(ATTR_NAME_INDEX);
ge::Model model_will_fail("model_will_fail", "");
model_will_fail.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph));
Buffer buffer_fail;
save_ret = model_will_fail.Save(buffer_fail, false);
ASSERT_EQ(save_ret, GRAPH_SUCCESS);
std::string external_graph_fail(
reinterpret_cast<const char *>(buffer_fail.GetData()),
buffer_fail.GetSize());

// 4. Set string to function node
auto root_add = root_graph->FindNode("root_graph_add");
ASSERT_NE(root_add, nullptr);
AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph);
auto root_input1 = root_graph->FindNode("root_graph_input1");
ASSERT_NE(root_input1, nullptr);
AttrUtils::SetInt(root_input1->GetOpDesc(), ATTR_NAME_INDEX, 0);
auto root_input2 = root_graph->FindNode("root_graph_input2");
ASSERT_NE(root_input2, nullptr);
AttrUtils::SetInt(root_input2->GetOpDesc(), ATTR_NAME_INDEX, 1);

// 5. Run test (normal)
auto ret = TensorFlowModelParser::AddExternalGraph(root_graph);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(root_graph->GetAllSubgraphs().size(), 2);
EXPECT_EQ(sub_graph->GetAllSubgraphs().size(), 1);
EXPECT_NE(root_graph->GetSubgraph(sub_graph->GetName()), nullptr);
EXPECT_EQ(root_graph->GetSubgraph(sub_graph->GetName())->GetAllSubgraphs().size(), 0);
}

TEST_F(STestTensorflowParser, tensorflow_ParserProto_failed) TEST_F(STestTensorflowParser, tensorflow_ParserProto_failed)
{ {
std::string caseDir = __FILE__; std::string caseDir = __FILE__;


Loading…
Cancel
Save