Browse Source

fix format

tags/v1.3.0
wjm 4 years ago
parent
commit
0b64d918e5
1 changed files with 124 additions and 124 deletions
  1. +124
    -124
      tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc

+ 124
- 124
tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc View File

@@ -1,125 +1,125 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <set>
#include <string>
#include "framework/omg/omg_inner_types.h"
#include "graph/common/local_context.h"
#include "graph/passes/subgraph_const_migration_pass.h"
#include "inc/pass_manager.h"
#include "register/op_registry.h"
namespace ge {
class UtestSubgraphConstMigrationPass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
public:
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) {
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT);
auto op_desc = std::make_shared<OpDesc>(name, type);
for (auto i = 0; i < in_num; ++i) {
op_desc->AddInputDesc(test_desc);
}
for (auto i = 0; i < out_num; ++i) {
op_desc->AddOutputDesc(test_desc);
}
if (type == "Const") {
uint64_t const_value = 101;
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t));
AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight);
}
return graph->AddNode(op_desc);
}
void make_original_graph(const ComputeGraphPtr &graph) {
auto data = MakeNode(graph, 1, 1, "data", "Data");
{
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0);
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1);
}
auto const1 = MakeNode(graph, 0, 1, "const1", "Const");
{
auto data1 = MakeNode(graph, 1, 1, "data1", "Data");
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1);
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2);
GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor());
}
auto const2 = MakeNode(graph, 0, 1, "const2", "Const");
{
auto data2 = MakeNode(graph, 1, 1, "data2", "Data");
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2);
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3);
GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor());
}
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D");
GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0));
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1));
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2));
}
void make_multibatch_graph(const ComputeGraphPtr &graph) {
auto index = MakeNode(graph, 1, 1, "index", "Data");
auto data = MakeNode(graph, 1, 1, "data", "Data");
auto data1 = MakeNode(graph, 1, 1, "data1", "Data");
auto data2 = MakeNode(graph, 1, 1, "data2", "Data");
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0);
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1);
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2);
auto case1 = MakeNode(graph, 4, 1, "case", "Case");
GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0));
GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1));
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2));
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3));
auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput");
GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2);
case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic);
ComputeGraphPtr branch = std::make_shared<ComputeGraph>("test_branch");
make_original_graph(branch);
for (int i = 0; i < 2; ++i) {
std::string name("_ascend_mbatch_batch_" + std::to_string(i));
std::vector<NodePtr> input_nodes;
std::vector<NodePtr> output_nodes;
ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes);
subgraph->SetName(name);
subgraph->SetParentNode(case1);
subgraph->SetParentGraph(graph);
graph->AddSubgraph(subgraph->GetName(), subgraph);
case1->GetOpDesc()->AddSubgraphName(name);
case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName());
}
}
};
TEST_F(UtestSubgraphConstMigrationPass, graph_nullptr) {
PassManager pass_manager;
pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass);
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph");
make_multibatch_graph(graph);
pass_manager.Run(graph);
}
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <set>
#include <string>
#include "framework/omg/omg_inner_types.h"
#include "graph/common/local_context.h"
#include "graph/passes/subgraph_const_migration_pass.h"
#include "inc/pass_manager.h"
#include "register/op_registry.h"
namespace ge {
class UtestSubgraphConstMigrationPass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
public:
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) {
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT);
auto op_desc = std::make_shared<OpDesc>(name, type);
for (auto i = 0; i < in_num; ++i) {
op_desc->AddInputDesc(test_desc);
}
for (auto i = 0; i < out_num; ++i) {
op_desc->AddOutputDesc(test_desc);
}
if (type == "Const") {
uint64_t const_value = 101;
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t));
AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight);
}
return graph->AddNode(op_desc);
}
void make_original_graph(const ComputeGraphPtr &graph) {
auto data = MakeNode(graph, 1, 1, "data", "Data");
{
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0);
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1);
}
auto const1 = MakeNode(graph, 0, 1, "const1", "Const");
{
auto data1 = MakeNode(graph, 1, 1, "data1", "Data");
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1);
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2);
GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor());
}
auto const2 = MakeNode(graph, 0, 1, "const2", "Const");
{
auto data2 = MakeNode(graph, 1, 1, "data2", "Data");
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2);
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3);
GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor());
}
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D");
GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0));
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1));
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2));
}
void make_multibatch_graph(const ComputeGraphPtr &graph) {
auto index = MakeNode(graph, 1, 1, "index", "Data");
auto data = MakeNode(graph, 1, 1, "data", "Data");
auto data1 = MakeNode(graph, 1, 1, "data1", "Data");
auto data2 = MakeNode(graph, 1, 1, "data2", "Data");
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0);
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1);
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2);
auto case1 = MakeNode(graph, 4, 1, "case", "Case");
GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0));
GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1));
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2));
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3));
auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput");
GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2);
case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic);
ComputeGraphPtr branch = std::make_shared<ComputeGraph>("test_branch");
make_original_graph(branch);
for (int i = 0; i < 2; ++i) {
std::string name("_ascend_mbatch_batch_" + std::to_string(i));
std::vector<NodePtr> input_nodes;
std::vector<NodePtr> output_nodes;
ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes);
subgraph->SetName(name);
subgraph->SetParentNode(case1);
subgraph->SetParentGraph(graph);
graph->AddSubgraph(subgraph->GetName(), subgraph);
case1->GetOpDesc()->AddSubgraphName(name);
case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName());
}
}
};
TEST_F(UtestSubgraphConstMigrationPass, subgraph_const_migration) {
PassManager pass_manager;
pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass);
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph");
make_multibatch_graph(graph);
EXPECT_EQ(pass_manager.Run(graph), SUCCESS);
}
} // namespace ge

Loading…
Cancel
Save