|
|
@@ -1,125 +1,125 @@ |
|
|
|
/**
|
|
|
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
|
|
|
#include <set>
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
#include "framework/omg/omg_inner_types.h"
|
|
|
|
#include "graph/common/local_context.h"
|
|
|
|
#include "graph/passes/subgraph_const_migration_pass.h"
|
|
|
|
#include "inc/pass_manager.h"
|
|
|
|
#include "register/op_registry.h"
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
class UtestSubgraphConstMigrationPass : public testing::Test {
|
|
|
|
protected:
|
|
|
|
void SetUp() {}
|
|
|
|
void TearDown() {}
|
|
|
|
|
|
|
|
public:
|
|
|
|
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) {
|
|
|
|
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT);
|
|
|
|
auto op_desc = std::make_shared<OpDesc>(name, type);
|
|
|
|
for (auto i = 0; i < in_num; ++i) {
|
|
|
|
op_desc->AddInputDesc(test_desc);
|
|
|
|
}
|
|
|
|
for (auto i = 0; i < out_num; ++i) {
|
|
|
|
op_desc->AddOutputDesc(test_desc);
|
|
|
|
}
|
|
|
|
if (type == "Const") {
|
|
|
|
uint64_t const_value = 101;
|
|
|
|
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t));
|
|
|
|
AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight);
|
|
|
|
}
|
|
|
|
return graph->AddNode(op_desc);
|
|
|
|
}
|
|
|
|
|
|
|
|
void make_original_graph(const ComputeGraphPtr &graph) {
|
|
|
|
auto data = MakeNode(graph, 1, 1, "data", "Data");
|
|
|
|
{
|
|
|
|
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0);
|
|
|
|
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1);
|
|
|
|
}
|
|
|
|
auto const1 = MakeNode(graph, 0, 1, "const1", "Const");
|
|
|
|
{
|
|
|
|
auto data1 = MakeNode(graph, 1, 1, "data1", "Data");
|
|
|
|
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1);
|
|
|
|
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2);
|
|
|
|
GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor());
|
|
|
|
}
|
|
|
|
|
|
|
|
auto const2 = MakeNode(graph, 0, 1, "const2", "Const");
|
|
|
|
{
|
|
|
|
auto data2 = MakeNode(graph, 1, 1, "data2", "Data");
|
|
|
|
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2);
|
|
|
|
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3);
|
|
|
|
GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor());
|
|
|
|
}
|
|
|
|
|
|
|
|
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D");
|
|
|
|
GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0));
|
|
|
|
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1));
|
|
|
|
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2));
|
|
|
|
}
|
|
|
|
|
|
|
|
void make_multibatch_graph(const ComputeGraphPtr &graph) {
|
|
|
|
auto index = MakeNode(graph, 1, 1, "index", "Data");
|
|
|
|
auto data = MakeNode(graph, 1, 1, "data", "Data");
|
|
|
|
auto data1 = MakeNode(graph, 1, 1, "data1", "Data");
|
|
|
|
auto data2 = MakeNode(graph, 1, 1, "data2", "Data");
|
|
|
|
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0);
|
|
|
|
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1);
|
|
|
|
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2);
|
|
|
|
|
|
|
|
auto case1 = MakeNode(graph, 4, 1, "case", "Case");
|
|
|
|
GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0));
|
|
|
|
GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1));
|
|
|
|
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2));
|
|
|
|
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3));
|
|
|
|
auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput");
|
|
|
|
GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
|
|
|
|
|
|
|
|
AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2);
|
|
|
|
case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic);
|
|
|
|
ComputeGraphPtr branch = std::make_shared<ComputeGraph>("test_branch");
|
|
|
|
make_original_graph(branch);
|
|
|
|
for (int i = 0; i < 2; ++i) {
|
|
|
|
std::string name("_ascend_mbatch_batch_" + std::to_string(i));
|
|
|
|
std::vector<NodePtr> input_nodes;
|
|
|
|
std::vector<NodePtr> output_nodes;
|
|
|
|
ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes);
|
|
|
|
|
|
|
|
subgraph->SetName(name);
|
|
|
|
subgraph->SetParentNode(case1);
|
|
|
|
subgraph->SetParentGraph(graph);
|
|
|
|
graph->AddSubgraph(subgraph->GetName(), subgraph);
|
|
|
|
|
|
|
|
case1->GetOpDesc()->AddSubgraphName(name);
|
|
|
|
case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
TEST_F(UtestSubgraphConstMigrationPass, graph_nullptr) {
|
|
|
|
PassManager pass_manager;
|
|
|
|
pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass);
|
|
|
|
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph");
|
|
|
|
make_multibatch_graph(graph);
|
|
|
|
pass_manager.Run(graph);
|
|
|
|
}
|
|
|
|
/** |
|
|
|
* Copyright 2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
* You may obtain a copy of the License at |
|
|
|
* |
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, software |
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <gtest/gtest.h> |
|
|
|
|
|
|
|
#include <set> |
|
|
|
#include <string> |
|
|
|
|
|
|
|
#include "framework/omg/omg_inner_types.h" |
|
|
|
#include "graph/common/local_context.h" |
|
|
|
#include "graph/passes/subgraph_const_migration_pass.h" |
|
|
|
#include "inc/pass_manager.h" |
|
|
|
#include "register/op_registry.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
class UtestSubgraphConstMigrationPass : public testing::Test { |
|
|
|
protected: |
|
|
|
void SetUp() {} |
|
|
|
void TearDown() {} |
|
|
|
|
|
|
|
public: |
|
|
|
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { |
|
|
|
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); |
|
|
|
auto op_desc = std::make_shared<OpDesc>(name, type); |
|
|
|
for (auto i = 0; i < in_num; ++i) { |
|
|
|
op_desc->AddInputDesc(test_desc); |
|
|
|
} |
|
|
|
for (auto i = 0; i < out_num; ++i) { |
|
|
|
op_desc->AddOutputDesc(test_desc); |
|
|
|
} |
|
|
|
if (type == "Const") { |
|
|
|
uint64_t const_value = 101; |
|
|
|
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); |
|
|
|
AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight); |
|
|
|
} |
|
|
|
return graph->AddNode(op_desc); |
|
|
|
} |
|
|
|
|
|
|
|
void make_original_graph(const ComputeGraphPtr &graph) { |
|
|
|
auto data = MakeNode(graph, 1, 1, "data", "Data"); |
|
|
|
{ |
|
|
|
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); |
|
|
|
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); |
|
|
|
} |
|
|
|
auto const1 = MakeNode(graph, 0, 1, "const1", "Const"); |
|
|
|
{ |
|
|
|
auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); |
|
|
|
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); |
|
|
|
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); |
|
|
|
GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor()); |
|
|
|
} |
|
|
|
|
|
|
|
auto const2 = MakeNode(graph, 0, 1, "const2", "Const"); |
|
|
|
{ |
|
|
|
auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); |
|
|
|
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); |
|
|
|
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3); |
|
|
|
GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor()); |
|
|
|
} |
|
|
|
|
|
|
|
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); |
|
|
|
GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); |
|
|
|
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); |
|
|
|
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); |
|
|
|
} |
|
|
|
|
|
|
|
void make_multibatch_graph(const ComputeGraphPtr &graph) { |
|
|
|
auto index = MakeNode(graph, 1, 1, "index", "Data"); |
|
|
|
auto data = MakeNode(graph, 1, 1, "data", "Data"); |
|
|
|
auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); |
|
|
|
auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); |
|
|
|
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); |
|
|
|
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); |
|
|
|
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); |
|
|
|
|
|
|
|
auto case1 = MakeNode(graph, 4, 1, "case", "Case"); |
|
|
|
GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0)); |
|
|
|
GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1)); |
|
|
|
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2)); |
|
|
|
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3)); |
|
|
|
auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput"); |
|
|
|
GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); |
|
|
|
|
|
|
|
AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); |
|
|
|
case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic); |
|
|
|
ComputeGraphPtr branch = std::make_shared<ComputeGraph>("test_branch"); |
|
|
|
make_original_graph(branch); |
|
|
|
for (int i = 0; i < 2; ++i) { |
|
|
|
std::string name("_ascend_mbatch_batch_" + std::to_string(i)); |
|
|
|
std::vector<NodePtr> input_nodes; |
|
|
|
std::vector<NodePtr> output_nodes; |
|
|
|
ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes); |
|
|
|
|
|
|
|
subgraph->SetName(name); |
|
|
|
subgraph->SetParentNode(case1); |
|
|
|
subgraph->SetParentGraph(graph); |
|
|
|
graph->AddSubgraph(subgraph->GetName(), subgraph); |
|
|
|
|
|
|
|
case1->GetOpDesc()->AddSubgraphName(name); |
|
|
|
case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName()); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
TEST_F(UtestSubgraphConstMigrationPass, subgraph_const_migration) { |
|
|
|
PassManager pass_manager; |
|
|
|
pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass); |
|
|
|
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); |
|
|
|
make_multibatch_graph(graph); |
|
|
|
EXPECT_EQ(pass_manager.Run(graph), SUCCESS); |
|
|
|
} |
|
|
|
} // namespace ge |