@@ -73,7 +73,7 @@ template <typename T> | |||
void FinalizeAllocatorMap(std::map<rtMemType_t, T *> &allocate_map) { | |||
for (auto &allocator : allocate_map) { | |||
if (allocator.second != nullptr) { | |||
allocator.second->Finalize(); | |||
allocator.second->Finalize(ge::GetContext().DeviceId()); | |||
delete allocator.second; | |||
allocator.second = nullptr; | |||
} | |||
@@ -26,9 +26,9 @@ | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "init/gelib.h" | |||
namespace { | |||
const int kNoTransOp = 1; | |||
const uint32_t kIndexZero = 0; | |||
} // namespace | |||
namespace ge { | |||
@@ -895,6 +895,63 @@ graphStatus SameTransdataBreadthFusionPass::AddCastNode(const ComputeGraphPtr &g | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus SameTransdataBreadthFusionPass::CheckAccuracySupported(const OpDescPtr &op_desc, bool &is_supported) { | |||
// is_supported is set to be false as default value. | |||
auto instance = GELib::GetInstance(); | |||
if ((instance == nullptr) || (!instance->InitFlag())) { | |||
REPORT_INNER_ERROR("E19999", "GELib is not initialized!"); | |||
GELOGE(GRAPH_FAILED, "GELib is not initialized!"); | |||
return GRAPH_FAILED; | |||
} | |||
GE_CHECK_NOTNULL(op_desc); | |||
OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj(); | |||
vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); | |||
if (op_infos.empty()) { | |||
GELOGI("Can not get op info by op type:%s", op_desc->GetType().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
std::string unsupported_reason; | |||
for (const auto &it : op_infos) { | |||
auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores(); | |||
auto &kernel_name = it.opKernelLib; | |||
auto kernel_info_store = kernel_map.find(kernel_name); | |||
if (kernel_info_store != kernel_map.end()) { | |||
if (kernel_info_store->second != nullptr && | |||
kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) { | |||
GELOGI("OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), it.engine.c_str(), | |||
op_desc->GetName().c_str()); | |||
is_supported = true; | |||
return GRAPH_SUCCESS; | |||
} | |||
} | |||
} | |||
GELOGI("op:%s CheckAccuracySupported failed!reason:%s", op_desc->GetName().c_str(), unsupported_reason.c_str()); | |||
return GRAPH_SUCCESS; | |||
} | |||
// avoid scene: A->Cast->TransData while A's DataType is not supported by TransData | |||
graphStatus SameTransdataBreadthFusionPass::CheckTransDataSupported(const NodePtr &node, bool &is_supported) { | |||
auto op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
auto input_desc = op_desc->GetInputDescPtr(kIndexZero); | |||
GE_CHECK_NOTNULL(input_desc); | |||
auto in_nodes = node->GetInDataNodes(); | |||
for (const auto &in_node : in_nodes) { | |||
if (in_node->GetType() != TRANSDATA) { | |||
continue; | |||
} | |||
auto transdata_op_desc = std::make_shared<ge::OpDesc>(TRANSDATA, TRANSDATA); | |||
GE_CHECK_NOTNULL(transdata_op_desc); | |||
transdata_op_desc->AddInputDesc(*input_desc); | |||
if (CheckAccuracySupported(transdata_op_desc, is_supported) != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "[Check][AccuracySupported] failed."); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus SameTransdataBreadthFusionPass::GetSubGraphsBetweenNormalAndTransdataNode( | |||
OutDataAnchorPtr &out_anchor, | |||
std::vector<std::vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out, | |||
@@ -925,6 +982,18 @@ graphStatus SameTransdataBreadthFusionPass::GetSubGraphsBetweenNormalAndTransdat | |||
continue; | |||
} | |||
} | |||
// avoid transdata receiving unsupported datatype input after deleting cast node. | |||
// peer_in_node is cast op. | |||
bool is_supported = false; | |||
if (CheckTransDataSupported(peer_in_node, is_supported) != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "[Check][Param] CheckTransDataSupported failed!"); | |||
return GRAPH_FAILED; | |||
} | |||
if (!is_supported) { | |||
GELOGD("CheckAccuracySupported return unsupported for transdata constructed from node [%s]'s output, skip it.", | |||
peer_in_node->GetName().c_str()); | |||
return GRAPH_SUCCESS; | |||
} | |||
for (auto &peer_out_anchor : peer_in_node->GetAllOutDataAnchors()) { | |||
ret = GetSubGraphsBetweenNormalAndTransdataNode(peer_out_anchor, sub_graphs_out, nodes_list); | |||
if (ret != GRAPH_SUCCESS) { | |||
@@ -107,6 +107,10 @@ class SameTransdataBreadthFusionPass : public GraphPass { | |||
static bool IsHandleOp(const NodePtr &node); | |||
static graphStatus CheckTransDataSupported(const NodePtr &node, bool &is_supported); | |||
static graphStatus CheckAccuracySupported(const OpDescPtr &op_desc, bool &is_supported); | |||
vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors_; | |||
vector<vector<NodePtr>> before_transdata_nodes_; | |||
vector<pair<int, InDataAnchorPtr>> all_transdata_nodes_; | |||
@@ -422,6 +422,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||
"${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/buffer_pool_memory_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/mark_node_unknown_shape_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/same_transdata_breadth_fusion_pass.cc" | |||
) | |||
set(KERNEL_SRC_FILES | |||
@@ -603,6 +604,7 @@ set(PASS_TEST_FILES | |||
"graph/passes/memcpy_addr_async_unittest.cc" | |||
"graph/passes/hccl_continuous_pass_unittest.cc" | |||
"graph/passes/hccl_memcpy_pass_unittest.cc" | |||
"graph/passes/same_transdata_breadth_fusion_pass_unittest.cc" | |||
) | |||
set(KERNEL_TEST_FILES | |||
@@ -0,0 +1,121 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/passes/same_transdata_breadth_fusion_pass.cc" | |||
#include <gtest/gtest.h> | |||
#include <string> | |||
using namespace ge; | |||
class UtestGraphPassesSameTransdataBreadthFusionPass : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
class NodeBuilder { | |||
public: | |||
NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); } | |||
NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW, | |||
ge::DataType data_type = DT_FLOAT) { | |||
op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); | |||
return *this; | |||
} | |||
NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW, | |||
ge::DataType data_type = DT_FLOAT) { | |||
op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); | |||
return *this; | |||
} | |||
ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); } | |||
private: | |||
ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW, | |||
ge::DataType data_type = DT_FLOAT) { | |||
GeShape ge_shape{std::vector<int64_t>(shape)}; | |||
ge::GeTensorDescPtr tensor_desc = std::make_shared<ge::GeTensorDesc>(); | |||
tensor_desc->SetShape(ge_shape); | |||
tensor_desc->SetFormat(format); | |||
tensor_desc->SetDataType(data_type); | |||
return tensor_desc; | |||
} | |||
ge::OpDescPtr op_desc_; | |||
}; | |||
TEST_F(UtestGraphPassesSameTransdataBreadthFusionPass, test_unsupported_transdata_succ) { | |||
// Node4D(NCHW)->cast1(DT_BOOL->FP16)->transdata1(NCHW->NC1HWC0)->sinh1 | |||
// / | |||
// --->cast2(DT_BOOL->FP16)->transdata2(NCHW->NC1HWC0)->sinh2 | |||
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
// Node4D | |||
ge::NodePtr node_data = NodeBuilder("Data4D", DATA).AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL).Build(graph); | |||
// cast1 | |||
ge::NodePtr node_cast_1 = NodeBuilder("node_cast_1", CAST) | |||
.AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL) | |||
.AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) | |||
.Build(graph); | |||
auto src_name = node_data->GetName(); | |||
node_cast_1->GetOpDesc()->SetSrcName({src_name}); | |||
node_cast_1->GetOpDesc()->SetInputName({src_name}); | |||
AttrUtils::SetInt(node_cast_1->GetOpDesc(), CAST_ATTR_SRCT, DT_FLOAT); | |||
// trandata1 | |||
ge::NodePtr node_transdata_1 = NodeBuilder("node_transdata_1", TRANSDATA) | |||
.AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) | |||
.AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
.Build(graph); | |||
// sinh1 | |||
ge::NodePtr node_sinh_1 = NodeBuilder("node_sinh_1", SINH) | |||
.AddInputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
.AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
.Build(graph); | |||
// cast2 | |||
ge::NodePtr node_cast_2 = NodeBuilder("node_cast_2", CAST) | |||
.AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL) | |||
.AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) | |||
.Build(graph); | |||
node_cast_2->GetOpDesc()->SetSrcName({src_name}); | |||
node_cast_2->GetOpDesc()->SetInputName({src_name}); | |||
// transdata2 | |||
ge::NodePtr node_transdata_2 = NodeBuilder("node_transdata_2", TRANSDATA) | |||
.AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16) | |||
.AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
.Build(graph); | |||
// sinh2 | |||
ge::NodePtr node_sinh_2 = NodeBuilder("node_sinh_2", SINH) | |||
.AddInputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
.AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16) | |||
.Build(graph); | |||
// add edge | |||
ge::GraphUtils::AddEdge(node_data->GetOutDataAnchor(0), node_cast_1->GetInDataAnchor(0)); | |||
ge::GraphUtils::AddEdge(node_cast_1->GetOutDataAnchor(0), node_transdata_1->GetInDataAnchor(0)); | |||
ge::GraphUtils::AddEdge(node_transdata_1->GetOutDataAnchor(0), node_sinh_1->GetInDataAnchor(0)); | |||
ge::GraphUtils::AddEdge(node_data->GetOutDataAnchor(0), node_cast_2->GetInDataAnchor(0)); | |||
ge::GraphUtils::AddEdge(node_cast_2->GetOutDataAnchor(0), node_transdata_2->GetInDataAnchor(0)); | |||
ge::GraphUtils::AddEdge(node_transdata_2->GetOutDataAnchor(0), node_sinh_2->GetInDataAnchor(0)); | |||
ge::SameTransdataBreadthFusionPass pass; | |||
ge::graphStatus status = pass.Run(graph); | |||
EXPECT_EQ(ge::GRAPH_SUCCESS, status); | |||
} |