diff --git a/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc b/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc index 7d3a754d..00157395 100644 --- a/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc @@ -1,125 +1,125 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include - -#include "framework/omg/omg_inner_types.h" -#include "graph/common/local_context.h" -#include "graph/passes/subgraph_const_migration_pass.h" -#include "inc/pass_manager.h" -#include "register/op_registry.h" - -namespace ge { -class UtestSubgraphConstMigrationPass : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} - - public: - NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { - GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); - auto op_desc = std::make_shared(name, type); - for (auto i = 0; i < in_num; ++i) { - op_desc->AddInputDesc(test_desc); - } - for (auto i = 0; i < out_num; ++i) { - op_desc->AddOutputDesc(test_desc); - } - if (type == "Const") { - uint64_t const_value = 101; - auto weight = make_shared(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); - AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight); - } - return graph->AddNode(op_desc); - } - - void make_original_graph(const ComputeGraphPtr &graph) { - auto data = MakeNode(graph, 1, 1, "data", "Data"); - { - AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - } - auto const1 = MakeNode(graph, 0, 1, "const1", "Const"); - { - auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); - GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor()); - } - - auto const2 = MakeNode(graph, 0, 1, "const2", "Const"); - { - auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); - AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); - AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3); - GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor()); - } - - auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); - GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); - } - - void make_multibatch_graph(const ComputeGraphPtr &graph) { - auto index = MakeNode(graph, 1, 1, "index", "Data"); - auto data = MakeNode(graph, 1, 1, "data", "Data"); - auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); - auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); - AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); - - auto case1 = MakeNode(graph, 4, 1, "case", "Case"); - GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0)); - GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1)); - GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2)); - GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3)); - auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput"); - GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); - - AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); - case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic); - ComputeGraphPtr branch = std::make_shared("test_branch"); - make_original_graph(branch); - for (int i = 0; i < 2; ++i) { - std::string name("_ascend_mbatch_batch_" + std::to_string(i)); - std::vector input_nodes; - std::vector output_nodes; - ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes); - - subgraph->SetName(name); - subgraph->SetParentNode(case1); - subgraph->SetParentGraph(graph); - graph->AddSubgraph(subgraph->GetName(), subgraph); - - case1->GetOpDesc()->AddSubgraphName(name); - case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName()); - } - } -}; - -TEST_F(UtestSubgraphConstMigrationPass, graph_nullptr) { - PassManager pass_manager; - pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass); - ComputeGraphPtr graph = std::make_shared("test_graph"); - make_multibatch_graph(graph); - pass_manager.Run(graph); -} +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include "framework/omg/omg_inner_types.h" +#include "graph/common/local_context.h" +#include "graph/passes/subgraph_const_migration_pass.h" +#include "inc/pass_manager.h" +#include "register/op_registry.h" + +namespace ge { +class UtestSubgraphConstMigrationPass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} + + public: + NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { + GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); + auto op_desc = std::make_shared(name, type); + for (auto i = 0; i < in_num; ++i) { + op_desc->AddInputDesc(test_desc); + } + for (auto i = 0; i < out_num; ++i) { + op_desc->AddOutputDesc(test_desc); + } + if (type == "Const") { + uint64_t const_value = 101; + auto weight = make_shared(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); + AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight); + } + return graph->AddNode(op_desc); + } + + void make_original_graph(const ComputeGraphPtr &graph) { + auto data = MakeNode(graph, 1, 1, "data", "Data"); + { + AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); + AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); + } + auto const1 = MakeNode(graph, 0, 1, "const1", "Const"); + { + auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); + GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor()); + } + + auto const2 = MakeNode(graph, 0, 1, "const2", "Const"); + { + auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); + AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); + AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3); + GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor()); + } + + auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); + GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); + GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); + } + + void make_multibatch_graph(const ComputeGraphPtr &graph) { + auto index = MakeNode(graph, 1, 1, "index", "Data"); + auto data = MakeNode(graph, 1, 1, "data", "Data"); + auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); + auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); + AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); + AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); + + auto case1 = MakeNode(graph, 4, 1, "case", "Case"); + GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1)); + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2)); + GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3)); + auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput"); + GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); + + AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); + case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic); + ComputeGraphPtr branch = std::make_shared("test_branch"); + make_original_graph(branch); + for (int i = 0; i < 2; ++i) { + std::string name("_ascend_mbatch_batch_" + std::to_string(i)); + std::vector input_nodes; + std::vector output_nodes; + ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes); + + subgraph->SetName(name); + subgraph->SetParentNode(case1); + subgraph->SetParentGraph(graph); + graph->AddSubgraph(subgraph->GetName(), subgraph); + + case1->GetOpDesc()->AddSubgraphName(name); + case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName()); + } + } +}; + +TEST_F(UtestSubgraphConstMigrationPass, subgraph_const_migration) { + PassManager pass_manager; + pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass); + ComputeGraphPtr graph = std::make_shared("test_graph"); + make_multibatch_graph(graph); + EXPECT_EQ(pass_manager.Run(graph), SUCCESS); +} } // namespace ge \ No newline at end of file