/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/dimension_compute_pass.h" #include #include #include #include "common/types.h" #include "graph/passes/base_pass.h" #include "graph_builder_utils.h" #include "inc/kernel.h" #include "inc/kernel_factory.h" namespace ge { namespace { const char *AddNYes = "AddNYes"; const char *AddNNo = "AddNNo"; const char *HuberLossYes = "HuberLossYes"; const char *ShapeNo = "ShapeNo"; const char *ShapeYes = "ShapeYes"; const char *DataNo = "dataNo"; } // namespace class UtestShapeYesKernel : public Kernel { public: Status Compute(const NodePtr &node, std::vector &v_output) override { auto output = std::make_shared(); std::vector data{1, 2, 3}; std::vector shape{3}; output->MutableTensorDesc().SetShape(GeShape(shape)); output->SetData(data); output->MutableTensorDesc().SetDataType(DT_UINT8); v_output.push_back(output); return SUCCESS; } }; REGISTER_KERNEL(ShapeYes, UtestShapeYesKernel); class UtestGraphPassesDimensionAdjustPass : public testing::Test { protected: UtestGraphPassesDimensionAdjustPass() = default; }; namespace { /// netoutput1 /// | /// shapeNo1 /// | /// addnNo1 /// / \. /// / \. /// const1 const2 ComputeGraphPtr BuildGraph8() { auto builder = ut::GraphBuilder("test"); auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); auto const2 = builder.AddNode("const2", CONSTANT, 0, 1); auto addn1 = builder.AddNode("addn1", AddNNo, 2, 1); auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1); auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); builder.AddDataEdge(const1, 0, addn1, 0); builder.AddDataEdge(const2, 0, addn1, 1); builder.AddDataEdge(addn1, 0, shape1, 0); builder.AddDataEdge(shape1, 0, netoutput1, 0); return builder.GetGraph(); } /// netoutput1 /// | /// shapeNo1 /// | /// addnYes1 /// / \. /// / \. ///const1 data1 ComputeGraphPtr BuildGraph9() { auto builder = ut::GraphBuilder("test"); auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); auto data1 = builder.AddNode("data1", DataNo, 0, 1); auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1); auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1); auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); builder.AddDataEdge(const1, 0, addn1, 0); builder.AddDataEdge(data1, 0, addn1, 1); builder.AddDataEdge(addn1, 0, shape1, 0); builder.AddDataEdge(shape1, 0, netoutput1, 0); return builder.GetGraph(); } /// netoutput1 /// | /// shapeYes1 /// | /// addnNo1 ComputeGraphPtr BuildGraph1() { auto builder = ut::GraphBuilder("test"); auto addnNo1 = builder.AddNode("addnNo1", AddNNo, 2, 1); auto shapeYes1 = builder.AddNode("shapeYes1", ShapeYes, 1, 1); auto netoutput1 = builder.AddNode("netoutput1", NETOUTPUT, 1, 0); builder.AddDataEdge(addnNo1, 0, shapeYes1, 0); builder.AddDataEdge(shapeYes1, 0, netoutput1, 0); return builder.GetGraph(); } } // namespace TEST_F(UtestGraphPassesDimensionAdjustPass, not_changed_no_kernel) { auto graph = BuildGraph8(); NamesToPass names_to_pass; names_to_pass.push_back({"Test", new DimensionComputePass}); GEPass pass(graph); EXPECT_EQ(pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(graph->GetAllNodes().size(), 5); for (auto &name_to_pass : names_to_pass) { delete name_to_pass.second; } } TEST_F(UtestGraphPassesDimensionAdjustPass, not_changed_no_compute_kernel) { auto graph = BuildGraph9(); NamesToPass names_to_pass; names_to_pass.push_back({"Test", new DimensionComputePass}); GEPass pass(graph); EXPECT_EQ(pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(graph->GetAllNodes().size(), 5); for (auto &name_to_pass : names_to_pass) { delete name_to_pass.second; } } TEST_F(UtestGraphPassesDimensionAdjustPass, success) { auto graph = BuildGraph1(); NamesToPass names_to_pass; names_to_pass.push_back({"Test", new DimensionComputePass}); GEPass pass(graph); EXPECT_EQ(pass.Run(names_to_pass), SUCCESS); EXPECT_EQ(graph->GetAllNodes().size(), 2); for (auto &name_to_pass : names_to_pass) { delete name_to_pass.second; } } } // namespace ge