diff --git a/ge/graph/manager/graph_mem_manager.cc b/ge/graph/manager/graph_mem_manager.cc index 21eaf302..185dce17 100644 --- a/ge/graph/manager/graph_mem_manager.cc +++ b/ge/graph/manager/graph_mem_manager.cc @@ -73,7 +73,7 @@ template void FinalizeAllocatorMap(std::map &allocate_map) { for (auto &allocator : allocate_map) { if (allocator.second != nullptr) { - allocator.second->Finalize(); + allocator.second->Finalize(ge::GetContext().DeviceId()); delete allocator.second; allocator.second = nullptr; } diff --git a/ge/graph/passes/same_transdata_breadth_fusion_pass.cc b/ge/graph/passes/same_transdata_breadth_fusion_pass.cc index afd78a4d..fad87238 100644 --- a/ge/graph/passes/same_transdata_breadth_fusion_pass.cc +++ b/ge/graph/passes/same_transdata_breadth_fusion_pass.cc @@ -26,9 +26,9 @@ #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "init/gelib.h" - namespace { const int kNoTransOp = 1; +const uint32_t kIndexZero = 0; } // namespace namespace ge { @@ -895,6 +895,63 @@ graphStatus SameTransdataBreadthFusionPass::AddCastNode(const ComputeGraphPtr &g return GRAPH_SUCCESS; } +graphStatus SameTransdataBreadthFusionPass::CheckAccuracySupported(const OpDescPtr &op_desc, bool &is_supported) { + // is_supported is set to be false as default value. + auto instance = GELib::GetInstance(); + if ((instance == nullptr) || (!instance->InitFlag())) { + REPORT_INNER_ERROR("E19999", "GELib is not initialized!"); + GELOGE(GRAPH_FAILED, "GELib is not initialized!"); + return GRAPH_FAILED; + } + GE_CHECK_NOTNULL(op_desc); + + OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj(); + vector op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); + if (op_infos.empty()) { + GELOGI("Can not get op info by op type:%s", op_desc->GetType().c_str()); + return GRAPH_FAILED; + } + + std::string unsupported_reason; + for (const auto &it : op_infos) { + auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores(); + auto &kernel_name = it.opKernelLib; + auto kernel_info_store = kernel_map.find(kernel_name); + if (kernel_info_store != kernel_map.end()) { + if (kernel_info_store->second != nullptr && + kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) { + GELOGI("OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), it.engine.c_str(), + op_desc->GetName().c_str()); + is_supported = true; + return GRAPH_SUCCESS; + } + } + } + GELOGI("op:%s CheckAccuracySupported failed!reason:%s", op_desc->GetName().c_str(), unsupported_reason.c_str()); + return GRAPH_SUCCESS; +} +// avoid scene: A->Cast->TransData while A's DataType is not supported by TransData +graphStatus SameTransdataBreadthFusionPass::CheckTransDataSupported(const NodePtr &node, bool &is_supported) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto input_desc = op_desc->GetInputDescPtr(kIndexZero); + GE_CHECK_NOTNULL(input_desc); + auto in_nodes = node->GetInDataNodes(); + for (const auto &in_node : in_nodes) { + if (in_node->GetType() != TRANSDATA) { + continue; + } + auto transdata_op_desc = std::make_shared(TRANSDATA, TRANSDATA); + GE_CHECK_NOTNULL(transdata_op_desc); + transdata_op_desc->AddInputDesc(*input_desc); + if (CheckAccuracySupported(transdata_op_desc, is_supported) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "[Check][AccuracySupported] failed."); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + graphStatus SameTransdataBreadthFusionPass::GetSubGraphsBetweenNormalAndTransdataNode( OutDataAnchorPtr &out_anchor, std::vector>> &sub_graphs_out, @@ -925,6 +982,18 @@ graphStatus SameTransdataBreadthFusionPass::GetSubGraphsBetweenNormalAndTransdat continue; } } + // avoid transdata receiving unsupported datatype input after deleting cast node. + // peer_in_node is cast op. + bool is_supported = false; + if (CheckTransDataSupported(peer_in_node, is_supported) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "[Check][Param] CheckTransDataSupported failed!"); + return GRAPH_FAILED; + } + if (!is_supported) { + GELOGD("CheckAccuracySupported return unsupported for transdata constructed from node [%s]'s output, skip it.", + peer_in_node->GetName().c_str()); + return GRAPH_SUCCESS; + } for (auto &peer_out_anchor : peer_in_node->GetAllOutDataAnchors()) { ret = GetSubGraphsBetweenNormalAndTransdataNode(peer_out_anchor, sub_graphs_out, nodes_list); if (ret != GRAPH_SUCCESS) { diff --git a/ge/graph/passes/same_transdata_breadth_fusion_pass.h b/ge/graph/passes/same_transdata_breadth_fusion_pass.h index 92e559a0..a7cc07ae 100755 --- a/ge/graph/passes/same_transdata_breadth_fusion_pass.h +++ b/ge/graph/passes/same_transdata_breadth_fusion_pass.h @@ -107,6 +107,10 @@ class SameTransdataBreadthFusionPass : public GraphPass { static bool IsHandleOp(const NodePtr &node); + static graphStatus CheckTransDataSupported(const NodePtr &node, bool &is_supported); + + static graphStatus CheckAccuracySupported(const OpDescPtr &op_desc, bool &is_supported); + vector>> sub_graph_anchors_; vector> before_transdata_nodes_; vector> all_transdata_nodes_; diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 49c9161d..7ee4cc15 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -422,6 +422,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/buffer_pool_memory_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/mark_node_unknown_shape_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/same_transdata_breadth_fusion_pass.cc" ) set(KERNEL_SRC_FILES @@ -603,6 +604,7 @@ set(PASS_TEST_FILES "graph/passes/memcpy_addr_async_unittest.cc" "graph/passes/hccl_continuous_pass_unittest.cc" "graph/passes/hccl_memcpy_pass_unittest.cc" + "graph/passes/same_transdata_breadth_fusion_pass_unittest.cc" ) set(KERNEL_TEST_FILES diff --git a/tests/ut/ge/graph/passes/same_transdata_breadth_fusion_pass_unittest.cc b/tests/ut/ge/graph/passes/same_transdata_breadth_fusion_pass_unittest.cc new file mode 100644 index 00000000..9f6982fb --- /dev/null +++ b/tests/ut/ge/graph/passes/same_transdata_breadth_fusion_pass_unittest.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/same_transdata_breadth_fusion_pass.cc" +#include +#include + +using namespace ge; + +class UtestGraphPassesSameTransdataBreadthFusionPass : public testing::Test { + protected: + void SetUp() {} + + void TearDown() {} +}; + +class NodeBuilder { + public: + NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared(name, type); } + + NodeBuilder &AddInputDesc(std::initializer_list shape, ge::Format format = FORMAT_NCHW, + ge::DataType data_type = DT_FLOAT) { + op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); + return *this; + } + + NodeBuilder &AddOutputDesc(std::initializer_list shape, ge::Format format = FORMAT_NCHW, + ge::DataType data_type = DT_FLOAT) { + op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); + return *this; + } + + ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); } + + private: + ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list shape, ge::Format format = FORMAT_NCHW, + ge::DataType data_type = DT_FLOAT) { + GeShape ge_shape{std::vector(shape)}; + ge::GeTensorDescPtr tensor_desc = std::make_shared(); + tensor_desc->SetShape(ge_shape); + tensor_desc->SetFormat(format); + tensor_desc->SetDataType(data_type); + return tensor_desc; + } + + ge::OpDescPtr op_desc_; +}; + +TEST_F(UtestGraphPassesSameTransdataBreadthFusionPass, test_unsupported_transdata_succ) { + // Node4D(NCHW)->cast1(DT_BOOL->FP16)->transdata1(NCHW->NC1HWC0)->sinh1 + // / + // --->cast2(DT_BOOL->FP16)->transdata2(NCHW->NC1HWC0)->sinh2 + ge::ComputeGraphPtr graph = std::make_shared("test"); + + // Node4D + ge::NodePtr node_data = NodeBuilder("Data4D", DATA).AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL).Build(graph); + // cast1 + ge::NodePtr node_cast_1 = NodeBuilder("node_cast_1", CAST) + .AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL) + .AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) + .Build(graph); + auto src_name = node_data->GetName(); + node_cast_1->GetOpDesc()->SetSrcName({src_name}); + node_cast_1->GetOpDesc()->SetInputName({src_name}); + AttrUtils::SetInt(node_cast_1->GetOpDesc(), CAST_ATTR_SRCT, DT_FLOAT); + + // trandata1 + ge::NodePtr node_transdata_1 = NodeBuilder("node_transdata_1", TRANSDATA) + .AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) + .AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) + .Build(graph); + // sinh1 + ge::NodePtr node_sinh_1 = NodeBuilder("node_sinh_1", SINH) + .AddInputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) + .AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) + .Build(graph); + // cast2 + ge::NodePtr node_cast_2 = NodeBuilder("node_cast_2", CAST) + .AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL) + .AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) + .Build(graph); + node_cast_2->GetOpDesc()->SetSrcName({src_name}); + node_cast_2->GetOpDesc()->SetInputName({src_name}); + // transdata2 + ge::NodePtr node_transdata_2 = NodeBuilder("node_transdata_2", TRANSDATA) + .AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) + .AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) + .Build(graph); + + // sinh2 + ge::NodePtr node_sinh_2 = NodeBuilder("node_sinh_2", SINH) + .AddInputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) + .AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) + .Build(graph); + + // add edge + ge::GraphUtils::AddEdge(node_data->GetOutDataAnchor(0), node_cast_1->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(node_cast_1->GetOutDataAnchor(0), node_transdata_1->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(node_transdata_1->GetOutDataAnchor(0), node_sinh_1->GetInDataAnchor(0)); + + ge::GraphUtils::AddEdge(node_data->GetOutDataAnchor(0), node_cast_2->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(node_cast_2->GetOutDataAnchor(0), node_transdata_2->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(node_transdata_2->GetOutDataAnchor(0), node_sinh_2->GetInDataAnchor(0)); + + ge::SameTransdataBreadthFusionPass pass; + ge::graphStatus status = pass.Run(graph); + EXPECT_EQ(ge::GRAPH_SUCCESS, status); +} \ No newline at end of file