/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/pass_utils.h" #include #include #include "common/types.h" #include "graph/types.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph_builder_utils.h" #include "inc/kernel.h" #include "inc/kernel_factory.h" using namespace ge; class UtestGraphPassesPassUtils : public testing::Test { protected: void SetUp() {} void TearDown() {} }; class NodeBuilder { public: NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared(name, type); } NodeBuilder &AddInputDesc(std::initializer_list shape, ge::Format format = FORMAT_NCHW, ge::DataType data_type = DT_FLOAT) { op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); return *this; } NodeBuilder &AddOutputDesc(std::initializer_list shape, ge::Format format = FORMAT_NCHW, ge::DataType data_type = DT_FLOAT) { op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); return *this; } ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); } private: ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list shape, ge::Format format = FORMAT_NCHW, ge::DataType data_type = DT_FLOAT) { GeShape ge_shape{std::vector(shape)}; ge::GeTensorDescPtr tensor_desc = std::make_shared(); tensor_desc->SetShape(ge_shape); tensor_desc->SetFormat(format); tensor_desc->SetDataType(data_type); return tensor_desc; } ge::OpDescPtr op_desc_; }; TEST_F(UtestGraphPassesPassUtils, set_out_node_weight) { ge::ComputeGraphPtr graph = std::make_shared("test"); // data ge::NodePtr node_data = NodeBuilder("data", DATA).AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT).Build(graph); // const ge::NodePtr node_const = NodeBuilder("const", CONSTANT).AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT).Build(graph); // relu ge::NodePtr node_relu = NodeBuilder("node_relu1", RELU) .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .Build(graph); // sinh ge::NodePtr node_sinh = NodeBuilder("node_sinh", SINH) .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .Build(graph); // relu ge::NodePtr node_relu2 = NodeBuilder("node_relu2", RELU) .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .Build(graph); // sinh ge::NodePtr node_sinh2 = NodeBuilder("node_sinh2", SINH) .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .Build(graph); // add edge ge::GraphUtils::AddEdge(node_data->GetOutControlAnchor(), node_const->GetInControlAnchor()); ge::GraphUtils::AddEdge(node_const->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0)); ge::GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_sinh->GetInDataAnchor(0)); ge::GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_relu2->GetInControlAnchor()); ge::GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_sinh2->GetInDataAnchor(0)); for (auto node : graph->GetDirectNode()) { if (node->GetType() == CONSTANT) { int32_t weight[] = {1}; GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); GeTensorPtr tensor = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); vector tensor_vec = {tensor}; OpDescUtils::SetWeights(node, tensor_vec); } if (!node->GetOutDataNodes().empty()) { auto out_data_anchor = node->GetOutDataNodes().at(0)->GetOutDataAnchor(0); Status status = PassUtils::SetOutNodeWeight(out_data_anchor, node); EXPECT_EQ(SUCCESS, status); } } } // only some failure castes for coverage check TEST_F(UtestGraphPassesPassUtils, is_constant_null) { ge::NodePtr node = nullptr; bool ret = PassUtils::IsConstant(node); EXPECT_EQ(false, ret); } TEST_F(UtestGraphPassesPassUtils, get_in_data_node_fail) { ge::NodePtr node = nullptr; NodePtr in_data_node = PassUtils::GetInDataNode(node, 0); EXPECT_EQ(nullptr, in_data_node); ge::ComputeGraphPtr graph = std::make_shared("test"); // relu ge::NodePtr node_relu = NodeBuilder("relu", RELU) .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .Build(graph); NodePtr data_node = PassUtils::GetInDataNode(node_relu, 1); EXPECT_EQ(nullptr, data_node); } TEST_F(UtestGraphPassesPassUtils, get_unique_in_data_anchor_index_failed) { int invalid_index = -1; ge::NodePtr node = nullptr; int status = PassUtils::GetUniqueInDataAnchorIndex(node); EXPECT_EQ(invalid_index, status); ge::ComputeGraphPtr graph = std::make_shared("test"); // relu ge::NodePtr node_relu = NodeBuilder("relu", RELU) .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .Build(graph); int ret = PassUtils::GetUniqueInDataAnchorIndex(node_relu); EXPECT_EQ(invalid_index, ret); } TEST_F(UtestGraphPassesPassUtils, unlink_node_with_ctrl_copy_fail) { ge::ComputeGraphPtr graph = std::make_shared("test"); // relu ge::NodePtr node_relu = NodeBuilder("relu", RELU) .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT) .Build(graph); Status status = PassUtils::UnlinkNodeWithControlCopy(node_relu, 1); EXPECT_EQ(ge::SUCCESS, status); Status ret = PassUtils::UnlinkNodeWithControlCopy(node_relu, 0); EXPECT_EQ(ge::FAILED, ret); } TEST_F(UtestGraphPassesPassUtils, null_input) { std::vector deleted_nodes; std::vector end_nodes; EXPECT_NE(PassUtils::RemoveInactiveBranchToMerge(nullptr, deleted_nodes, end_nodes), 0); }