Browse Source

modify SameTransdataBreadthFusionPass

pull/2079/head
wuweikang 3 years ago
parent
commit
269bd5450d
5 changed files with 198 additions and 2 deletions
  1. +1
    -1
      ge/graph/manager/graph_mem_manager.cc
  2. +70
    -1
      ge/graph/passes/same_transdata_breadth_fusion_pass.cc
  3. +4
    -0
      ge/graph/passes/same_transdata_breadth_fusion_pass.h
  4. +2
    -0
      tests/ut/ge/CMakeLists.txt
  5. +121
    -0
      tests/ut/ge/graph/passes/same_transdata_breadth_fusion_pass_unittest.cc

+ 1
- 1
ge/graph/manager/graph_mem_manager.cc View File

@@ -73,7 +73,7 @@ template <typename T>
void FinalizeAllocatorMap(std::map<rtMemType_t, T *> &allocate_map) {
for (auto &allocator : allocate_map) {
if (allocator.second != nullptr) {
allocator.second->Finalize();
allocator.second->Finalize(ge::GetContext().DeviceId());
delete allocator.second;
allocator.second = nullptr;
}


+ 70
- 1
ge/graph/passes/same_transdata_breadth_fusion_pass.cc View File

@@ -26,9 +26,9 @@
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "init/gelib.h"

namespace {
const int kNoTransOp = 1;
const uint32_t kIndexZero = 0;
} // namespace

namespace ge {
@@ -895,6 +895,63 @@ graphStatus SameTransdataBreadthFusionPass::AddCastNode(const ComputeGraphPtr &g
return GRAPH_SUCCESS;
}

graphStatus SameTransdataBreadthFusionPass::CheckAccuracySupported(const OpDescPtr &op_desc, bool &is_supported) {
// is_supported is set to be false as default value.
auto instance = GELib::GetInstance();
if ((instance == nullptr) || (!instance->InitFlag())) {
REPORT_INNER_ERROR("E19999", "GELib is not initialized!");
GELOGE(GRAPH_FAILED, "GELib is not initialized!");
return GRAPH_FAILED;
}
GE_CHECK_NOTNULL(op_desc);

OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj();
vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
if (op_infos.empty()) {
GELOGI("Can not get op info by op type:%s", op_desc->GetType().c_str());
return GRAPH_FAILED;
}

std::string unsupported_reason;
for (const auto &it : op_infos) {
auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
auto &kernel_name = it.opKernelLib;
auto kernel_info_store = kernel_map.find(kernel_name);
if (kernel_info_store != kernel_map.end()) {
if (kernel_info_store->second != nullptr &&
kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) {
GELOGI("OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), it.engine.c_str(),
op_desc->GetName().c_str());
is_supported = true;
return GRAPH_SUCCESS;
}
}
}
GELOGI("op:%s CheckAccuracySupported failed!reason:%s", op_desc->GetName().c_str(), unsupported_reason.c_str());
return GRAPH_SUCCESS;
}
// avoid scene: A->Cast->TransData while A's DataType is not supported by TransData
graphStatus SameTransdataBreadthFusionPass::CheckTransDataSupported(const NodePtr &node, bool &is_supported) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto input_desc = op_desc->GetInputDescPtr(kIndexZero);
GE_CHECK_NOTNULL(input_desc);
auto in_nodes = node->GetInDataNodes();
for (const auto &in_node : in_nodes) {
if (in_node->GetType() != TRANSDATA) {
continue;
}
auto transdata_op_desc = std::make_shared<ge::OpDesc>(TRANSDATA, TRANSDATA);
GE_CHECK_NOTNULL(transdata_op_desc);
transdata_op_desc->AddInputDesc(*input_desc);
if (CheckAccuracySupported(transdata_op_desc, is_supported) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "[Check][AccuracySupported] failed.");
return GRAPH_FAILED;
}
}
return GRAPH_SUCCESS;
}

graphStatus SameTransdataBreadthFusionPass::GetSubGraphsBetweenNormalAndTransdataNode(
OutDataAnchorPtr &out_anchor,
std::vector<std::vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out,
@@ -925,6 +982,18 @@ graphStatus SameTransdataBreadthFusionPass::GetSubGraphsBetweenNormalAndTransdat
continue;
}
}
// avoid transdata receiving unsupported datatype input after deleting cast node.
// peer_in_node is cast op.
bool is_supported = false;
if (CheckTransDataSupported(peer_in_node, is_supported) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "[Check][Param] CheckTransDataSupported failed!");
return GRAPH_FAILED;
}
if (!is_supported) {
GELOGD("CheckAccuracySupported return unsupported for transdata constructed from node [%s]'s output, skip it.",
peer_in_node->GetName().c_str());
return GRAPH_SUCCESS;
}
for (auto &peer_out_anchor : peer_in_node->GetAllOutDataAnchors()) {
ret = GetSubGraphsBetweenNormalAndTransdataNode(peer_out_anchor, sub_graphs_out, nodes_list);
if (ret != GRAPH_SUCCESS) {


+ 4
- 0
ge/graph/passes/same_transdata_breadth_fusion_pass.h View File

@@ -107,6 +107,10 @@ class SameTransdataBreadthFusionPass : public GraphPass {

static bool IsHandleOp(const NodePtr &node);

static graphStatus CheckTransDataSupported(const NodePtr &node, bool &is_supported);

static graphStatus CheckAccuracySupported(const OpDescPtr &op_desc, bool &is_supported);

vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors_;
vector<vector<NodePtr>> before_transdata_nodes_;
vector<pair<int, InDataAnchorPtr>> all_transdata_nodes_;


+ 2
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -422,6 +422,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/buffer_pool_memory_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/mark_node_unknown_shape_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/same_transdata_breadth_fusion_pass.cc"
)

set(KERNEL_SRC_FILES
@@ -603,6 +604,7 @@ set(PASS_TEST_FILES
"graph/passes/memcpy_addr_async_unittest.cc"
"graph/passes/hccl_continuous_pass_unittest.cc"
"graph/passes/hccl_memcpy_pass_unittest.cc"
"graph/passes/same_transdata_breadth_fusion_pass_unittest.cc"
)

set(KERNEL_TEST_FILES


+ 121
- 0
tests/ut/ge/graph/passes/same_transdata_breadth_fusion_pass_unittest.cc View File

@@ -0,0 +1,121 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/same_transdata_breadth_fusion_pass.cc"
#include <gtest/gtest.h>
#include <string>

using namespace ge;

class UtestGraphPassesSameTransdataBreadthFusionPass : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
};

class NodeBuilder {
public:
NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }

NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
ge::DataType data_type = DT_FLOAT) {
op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
return *this;
}

NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
ge::DataType data_type = DT_FLOAT) {
op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
return *this;
}

ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); }

private:
ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
ge::DataType data_type = DT_FLOAT) {
GeShape ge_shape{std::vector<int64_t>(shape)};
ge::GeTensorDescPtr tensor_desc = std::make_shared<ge::GeTensorDesc>();
tensor_desc->SetShape(ge_shape);
tensor_desc->SetFormat(format);
tensor_desc->SetDataType(data_type);
return tensor_desc;
}

ge::OpDescPtr op_desc_;
};

TEST_F(UtestGraphPassesSameTransdataBreadthFusionPass, test_unsupported_transdata_succ) {
// Node4D(NCHW)->cast1(DT_BOOL->FP16)->transdata1(NCHW->NC1HWC0)->sinh1
// /
// --->cast2(DT_BOOL->FP16)->transdata2(NCHW->NC1HWC0)->sinh2
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");

// Node4D
ge::NodePtr node_data = NodeBuilder("Data4D", DATA).AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL).Build(graph);
// cast1
ge::NodePtr node_cast_1 = NodeBuilder("node_cast_1", CAST)
.AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL)
.AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16)
.Build(graph);
auto src_name = node_data->GetName();
node_cast_1->GetOpDesc()->SetSrcName({src_name});
node_cast_1->GetOpDesc()->SetInputName({src_name});
AttrUtils::SetInt(node_cast_1->GetOpDesc(), CAST_ATTR_SRCT, DT_FLOAT);

// trandata1
ge::NodePtr node_transdata_1 = NodeBuilder("node_transdata_1", TRANSDATA)
.AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16)
.AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16)
.Build(graph);
// sinh1
ge::NodePtr node_sinh_1 = NodeBuilder("node_sinh_1", SINH)
.AddInputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16)
.AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16)
.Build(graph);
// cast2
ge::NodePtr node_cast_2 = NodeBuilder("node_cast_2", CAST)
.AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_BOOL)
.AddOutputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16)
.Build(graph);
node_cast_2->GetOpDesc()->SetSrcName({src_name});
node_cast_2->GetOpDesc()->SetInputName({src_name});
// transdata2
ge::NodePtr node_transdata_2 = NodeBuilder("node_transdata_2", TRANSDATA)
.AddInputDesc({2, 16, 2, 2}, FORMAT_NCHW, DT_FLOAT16)
.AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16)
.Build(graph);

// sinh2
ge::NodePtr node_sinh_2 = NodeBuilder("node_sinh_2", SINH)
.AddInputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16)
.AddOutputDesc({2, 1, 2, 2, 16}, FORMAT_NC1HWC0, DT_FLOAT16)
.Build(graph);

// add edge
ge::GraphUtils::AddEdge(node_data->GetOutDataAnchor(0), node_cast_1->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_cast_1->GetOutDataAnchor(0), node_transdata_1->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_transdata_1->GetOutDataAnchor(0), node_sinh_1->GetInDataAnchor(0));

ge::GraphUtils::AddEdge(node_data->GetOutDataAnchor(0), node_cast_2->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_cast_2->GetOutDataAnchor(0), node_transdata_2->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_transdata_2->GetOutDataAnchor(0), node_sinh_2->GetInDataAnchor(0));

ge::SameTransdataBreadthFusionPass pass;
ge::graphStatus status = pass.Run(graph);
EXPECT_EQ(ge::GRAPH_SUCCESS, status);
}

Loading…
Cancel
Save