diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index a09d5789..5456e151 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -688,6 +688,7 @@ set(PASS_TEST_FILES "graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc" "graph/passes/multi_batch_clone_pass_unittest.cc" + "graph/passes/replace_with_empty_const_pass_unittest.cc" ) set(KERNEL_TEST_FILES diff --git a/tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc b/tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc new file mode 100644 index 00000000..078d8dbc --- /dev/null +++ b/tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/replace_with_empty_const_pass.h" + +#include +#include +#include + +#include "graph_builder_utils.h" + +namespace ge { +class UtestReplaceWithEmptyConstPass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +namespace { +/// data1 const1 +/// \ / +/// add1 +/// | +/// cast1(empty) +/// | +/// conv2d +ut::GraphBuilder Graph1Builder() { + ut::GraphBuilder builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", "Data", 0, 1); + auto const1 = builder.AddNode("const1", "Const", 0, 1); + auto add1 = builder.AddNode("add1", "Add", 2, 1); + auto cast1 = builder.AddNode("cast1", "Cast", 1, 1); + auto conv2d = builder.AddNode("conv2d", "Conv2D", 1, 0); + + add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1,1,8,8}),FORMAT_NCHW)); + add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1,1,8,8}),FORMAT_NCHW)); + add1->GetOpDesc()->AddOutputDesc(GeTensorDesc(GeShape({1,1,8,8}),FORMAT_NCHW)); + cast1->GetOpDesc()->AddOutputDesc(GeTensorDesc(GeShape({1,1,8,8}),FORMAT_NCHW)); + GeTensorDesc empty_tensor(GeShape({1,0,8,8}),FORMAT_NCHW); + cast1->GetOpDesc()->UpdateOutputDesc(0,empty_tensor); + + builder.AddDataEdge(data1, 0, add1, 0); + builder.AddDataEdge(const1, 0, add1, 1); + builder.AddDataEdge(add1, 0, cast1, 0); + builder.AddDataEdge(cast1, 0, conv2d, 0); + return builder; +} +} // namespace + + +TEST_F(UtestReplaceWithEmptyConstPass, replace_whith_empty_const_success) { + auto builder = Graph1Builder(); + auto graph = builder.GetGraph(); + graph->SetSessionID(0); + ReplaceWithEmptyConstPass replace_with_empty_const_pass; + + EXPECT_EQ(graph->GetDirectNodesSize(),5); + // run pass on add1, graph still has 5 nodes + auto add1 = graph->FindNode("add1"); + Status ret = replace_with_empty_const_pass.Run(add1); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(graph->GetDirectNodesSize(),5); + + auto cast1 = graph->FindNode("cast1"); + ret = replace_with_empty_const_pass.Run(cast1) + EXPECT_EQ(cast1->GetOutAllNodes().size(),0); + auto conv2d = graph->FindNode("conv2d"); + EXPECT_EQ(conv2d->GetInDataNodes().at(0)->GetType(),"Const"); +} +} // namespace ge