/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #define protected public #define private public #include "graph/passes/dimension_adjust_pass.h" #include "common/debug/log.h" #include "common/debug/memory_dumper.h" #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "common/types.h" #include "graph/types.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "inc/kernel_factory.h" #undef protected #undef private using namespace std; using namespace testing; namespace ge { class UtestGraphPassesDimensionAdjustPass : public testing::Test { protected: void SetUp() {} void TearDown() {} }; TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { ge::ComputeGraphPtr graph = std::make_shared("default"); ge::OpDescPtr data_op_desc = make_shared("data", CONSTANTOP); int64_t dims_size = 1; vector data_vec = {1, 2, 3}; for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); vector data_value_vec(dims_size, 1); GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); GeTensorPtr data_tensor = make_shared(data_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t)); OpDescUtils::SetWeights(data_op_desc, data_tensor); data_op_desc->AddOutputDesc(data_tensor_desc); NodePtr data_node = graph->AddNode(data_op_desc); data_node->Init(); // add dim node ge::OpDescPtr dim_op_desc = make_shared("dim", CONSTANTOP); vector dim_value_vec = {0}; GeTensorDesc dim_tensor_desc(ge::GeShape(), FORMAT_NCHW, DT_INT32); GeTensorPtr dim_tensor = make_shared(dim_tensor_desc, (uint8_t *)dim_value_vec.data(), dim_value_vec.size() * sizeof(int32_t)); OpDescUtils::SetWeights(dim_op_desc, dim_tensor); dim_op_desc->AddOutputDesc(dim_tensor_desc); NodePtr dim_node = graph->AddNode(dim_op_desc); dim_node->Init(); // add expanddims node OpDescPtr expanddims_op_desc = std::make_shared("Expanddims", EXPANDDIMS); vector expanddims_vec = {1, 1, 2, 3}; GeTensorDesc expanddims_tensor_desc(ge::GeShape(expanddims_vec), FORMAT_NCHW, DT_INT32); GeTensorPtr expanddims_tensor = make_shared(expanddims_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t)); OpDescUtils::SetWeights(expanddims_op_desc, expanddims_tensor); expanddims_op_desc->AddInputDesc(data_tensor_desc); expanddims_op_desc->AddInputDesc(dim_tensor_desc); expanddims_op_desc->AddOutputDesc(expanddims_tensor_desc); NodePtr op_node = graph->AddNode(expanddims_op_desc); op_node->Init(); // add output node OpDescPtr netoutput_op_desc = std::make_shared("NetOutput", "NetOutput"); netoutput_op_desc->AddInputDesc(expanddims_tensor_desc); NodePtr netoutput_node = graph->AddNode(netoutput_op_desc); netoutput_node->Init(); // add edge GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), op_node->GetInDataAnchor(0)); GraphUtils::AddEdge(dim_node->GetOutDataAnchor(0), op_node->GetInDataAnchor(1)); GraphUtils::AddEdge(op_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0)); std::shared_ptr pass = make_shared(); ge::Status ret = pass->Run(op_node); EXPECT_EQ(SUCCESS, ret); } TEST_F(UtestGraphPassesDimensionAdjustPass, input_node_is_nullptr) { std::shared_ptr pass = make_shared(); ge::NodePtr node = nullptr; ge::Status ret = pass->Run(node); EXPECT_EQ(PARAM_INVALID, ret); } TEST_F(UtestGraphPassesDimensionAdjustPass, node_op_desc_is_nullptr) { NodePtr op_node = make_shared(nullptr, nullptr); std::shared_ptr pass = make_shared(); ge::Status ret = pass->Run(op_node); EXPECT_EQ(PARAM_INVALID, ret); } TEST_F(UtestGraphPassesDimensionAdjustPass, node_get_original_type_failed) { ge::ComputeGraphPtr graph = std::make_shared("default"); OpDescPtr expanddim_op_desc = std::make_shared("Expanddims", FRAMEWORKOP); NodePtr op_node = make_shared(expanddim_op_desc, graph); std::shared_ptr pass = make_shared(); ge::Status ret = pass->Run(op_node); } TEST_F(UtestGraphPassesDimensionAdjustPass, node_not_register_op) { ge::ComputeGraphPtr graph = std::make_shared("default"); OpDescPtr expanddim_op_desc = std::make_shared("Expanddims", FRAMEWORKOP); AttrUtils::SetStr(expanddim_op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "expanddims_fake"); NodePtr op_node = make_shared(expanddim_op_desc, graph); std::shared_ptr pass = make_shared(); ge::Status ret = pass->Run(op_node); EXPECT_EQ(SUCCESS, ret); } TEST_F(UtestGraphPassesDimensionAdjustPass, node_compute_failed) { ge::ComputeGraphPtr graph = std::make_shared("default"); OpDescPtr expanddim_op_desc = std::make_shared("Expanddims", EXPANDDIMS); NodePtr op_node = make_shared(expanddim_op_desc, graph); std::shared_ptr pass = make_shared(); ge::Status ret = pass->Run(op_node); EXPECT_EQ(SUCCESS, ret); } } // namespace ge