diff --git a/ge/graph/passes/cast_remove_pass.cc b/ge/graph/passes/cast_remove_pass.cc index fe2866f0..7e2bb7bb 100644 --- a/ge/graph/passes/cast_remove_pass.cc +++ b/ge/graph/passes/cast_remove_pass.cc @@ -21,7 +21,6 @@ #include "graph/common/transop_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" -#include "init/gelib.h" namespace ge { Status CastRemovePass::Run(NodePtr &node) { @@ -62,14 +61,10 @@ Status CastRemovePass::Run(NodePtr &node) { if (!HasSameDataType(op_desc, end_op_desc, type)) { return SUCCESS; } - auto instance_ptr = ge::GELib::GetInstance(); - if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "gelib is not initilized!"); + if (RemoveCast(type, nodes_to_fuse) != SUCCESS) { return FAILED; } - - OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); - return DoFuse(ops_kernel_manager, type, nodes_to_fuse); + return SUCCESS; } bool CastRemovePass::CheckPrecisionLoss(const std::vector &nodes_to_fuse) { @@ -100,14 +95,26 @@ bool CastRemovePass::HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op // op1->TransData->Cast->TransposeD->Cast->TransData->op2 // change to be // op1->TransData->TransposeD->TransData->op2 -Status CastRemovePass::DoFuse(const OpsKernelManager &ops_kernel_manager, - const DataType &type, - std::vector &nodes_to_fuse) { - std::vector to_be_deleted_cast_index; - for (size_t i = 0; i < nodes_to_fuse.size(); i++) { - NodePtr node = nodes_to_fuse[i]; +Status CastRemovePass::RemoveCast(DataType &type, std::vector &nodes_to_fuse) { + string cast_name; + for (NodePtr &node : nodes_to_fuse) { + if (node->GetType() == CAST) { + GELOGI("CastRemovePass, remove Cast %s.", node->GetName().c_str()); + cast_name = node->GetName(); + if (IsolateAndDeleteNode(node, {0}) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed", + node->GetName().c_str(), node->GetType().c_str()); + GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", node->GetName().c_str()); + return FAILED; + } + } + } + + if (cast_name.empty()) { + return SUCCESS; + } + for (auto &node : nodes_to_fuse) { if (node->GetType() == CAST) { - to_be_deleted_cast_index.emplace_back(i); continue; } OpDescPtr op_desc = node->GetOpDesc(); @@ -116,66 +123,25 @@ Status CastRemovePass::DoFuse(const OpsKernelManager &ops_kernel_manager, GELOGE(FAILED, "OpDesc must not be null."); return FAILED; } - auto in_desc = op_desc->MutableInputDesc(0); - auto out_desc = op_desc->MutableOutputDesc(0); - auto in_desc_org_dtype = in_desc->GetDataType(); - auto out_desc_org_dtype = out_desc->GetDataType(); - in_desc->SetDataType(type); - out_desc->SetDataType(type); - bool is_supported = false; - string un_supported_reasons; - for (const auto &ops_kernel_store_info : ops_kernel_manager.GetAllOpsKernelInfoStores()) { - map op_infos; - ops_kernel_store_info.second->GetAllOpsKernelInfo(op_infos); - if (op_infos.find(op_desc->GetType()) == op_infos.end()) { - continue; - } - string un_supported_reason; - is_supported = ops_kernel_store_info.second->CheckAccuracySupported(op_desc, un_supported_reason); - if (is_supported) { - break; - } - un_supported_reasons += "{op_store " + ops_kernel_store_info.first + ":" + un_supported_reason + "} "; - } - if (!is_supported) { - // if no operator_info_store supported, do nothing - in_desc->SetDataType(in_desc_org_dtype); - out_desc->SetDataType(out_desc_org_dtype); - to_be_deleted_cast_index.clear(); - GELOGI("Fused Op[%s] check supported fail! Reasons is as follows: %s", - op_desc->GetName().c_str(), - un_supported_reasons.c_str()); - return SUCCESS; - } + // change node name for recompile cache, will be abandoned in April + string new_node_name = cast_name + op_desc->GetName(); + op_desc->SetName(new_node_name); // add attr to changed TransData, then will be rebuild if (!AttrUtils::SetBool(op_desc, ATTR_NEED_COMPILE, true)) { REPORT_CALL_ERROR("E19999", "Set Attr:%s of op:%s(%s) failed", ATTR_NEED_COMPILE.c_str(), - op_desc->GetName().c_str(), - op_desc->GetType().c_str()); + op_desc->GetName().c_str(), op_desc->GetType().c_str()); GELOGE(FAILED, "Set ATTR_NEED_COMPILE Attr fail."); return FAILED; } + auto in_desc = op_desc->MutableInputDesc(0); + auto out_desc = op_desc->MutableOutputDesc(0); + in_desc->SetDataType(type); + out_desc->SetDataType(type); GELOGI("CastRemovePass, change %s %s datatype to be %s.", node->GetType().c_str(), node->GetName().c_str(), TypeUtils::DataTypeToSerialString(type).c_str()); } - return DoRemoveCast(to_be_deleted_cast_index, nodes_to_fuse); -} - -Status CastRemovePass::DoRemoveCast(const std::vector &to_be_deleted_cast_index, - std::vector &nodes_to_fuse) { - for (auto &cast_idx : to_be_deleted_cast_index) { - GELOGI("CastRemovePass, remove Cast %s.", nodes_to_fuse[cast_idx]->GetName().c_str()); - if (IsolateAndDeleteNode(nodes_to_fuse[cast_idx], {0}) != SUCCESS) { - REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed when CastRemovePass %s", - nodes_to_fuse[cast_idx]->GetName().c_str(), - nodes_to_fuse[cast_idx]->GetType().c_str(), - __FUNCTION__); - GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", nodes_to_fuse[cast_idx]->GetName().c_str()); - return FAILED; - } - } return SUCCESS; } diff --git a/ge/graph/passes/cast_remove_pass.h b/ge/graph/passes/cast_remove_pass.h index 0122fa69..0ee52998 100644 --- a/ge/graph/passes/cast_remove_pass.h +++ b/ge/graph/passes/cast_remove_pass.h @@ -19,7 +19,6 @@ #include #include "graph/passes/base_pass.h" -#include "opskernel_manager/ops_kernel_manager.h" namespace ge { class CastRemovePass : public BaseNodePass { @@ -29,9 +28,8 @@ class CastRemovePass : public BaseNodePass { private: bool CheckPrecisionLoss(const std::vector &nodes_to_fuse); bool HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op_desc, DataType &type) const; + Status RemoveCast(DataType &type, std::vector &nodes_to_fuse); NodePtr GetTheEndNode(NodePtr begin_node, std::vector &nodes_to_fuse); - Status DoRemoveCast(const std::vector &to_be_deleted_cast_index, std::vector &nodes_to_fuse); - Status DoFuse(const OpsKernelManager &ops_kernel_manager, const DataType &type, std::vector &nodes_to_fuse); }; } // namespace ge #endif // GE_GRAPH_PASSES_CAST_REMOVE_PASS_H_ diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 12b329d7..108725ad 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -709,7 +709,6 @@ set(PASS_TEST_FILES "graph/passes/buffer_pool_memory_pass_unittest.cc" "graph/passes/mark_node_unknown_shape_pass_unittest.cc" "graph/passes/reshape_recovery_pass_unittest.cc" - "graph/passes/cast_remove_pass_unittest.cc" ) set(KERNEL_TEST_FILES diff --git a/tests/ut/ge/graph/passes/cast_remove_pass_unittest.cc b/tests/ut/ge/graph/passes/cast_remove_pass_unittest.cc deleted file mode 100644 index 28bb2de2..00000000 --- a/tests/ut/ge/graph/passes/cast_remove_pass_unittest.cc +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#define protected public -#define private public -#include "graph/passes/cast_remove_pass.h" -#undef protected -#undef private - -#include "anchor.h" -#include "common/debug/log.h" -#include "common/debug/memory_dumper.h" -#include "common/op/attr_value_util.h" -#include "common/types.h" -#include "framework/common/ge_inner_error_codes.h" -#include "graph/attr_value.h" -#include "graph/debug/ge_attr_define.h" -#include "inc/pass_manager.h" -#include "graph_builder_utils.h" -#include -#include -#include -#include "opskernel_manager/ops_kernel_manager.h" -#include "omg/omg_inner_types.h" - - -using namespace testing; -using namespace ge; -using namespace std; - -class UtestGraphPassesCastRemovePass : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -// case1:no net_out_put_node -TEST_F(UtestGraphPassesCastRemovePass, DoFuseProcess) { - std::vector nodes_to_fuse; - - auto builder = ut::GraphBuilder("g1"); - auto data = builder.AddNode("data", DATA, 1, 1); - auto cast1 = builder.AddNode("cast1", CAST, 1, 1); - cast1->GetOpDesc()->MutableOutputDesc(0)->SetDataType(DT_FLOAT16); - auto trans = builder.AddNode("trans", TRANSPOSE, 1, 1, FORMAT_NCHW, DT_FLOAT16); - auto cast2 = builder.AddNode("cast2", CAST, 1, 1); - cast2->GetOpDesc()->MutableInputDesc(0)->SetDataType(DT_FLOAT16); - auto net = builder.AddNode("netout", NETOUTPUT, 1, 1); - - builder.AddDataEdge(data, 0, cast1, 0); - builder.AddDataEdge(cast1, 0, trans, 0); - builder.AddDataEdge(trans, 0, cast2, 0); - builder.AddDataEdge(cast2, 0, net, 0); - ComputeGraphPtr compute_graph = builder.GetGraph(); - - map options; - - CastRemovePass cast_remove_pass; - DataType type = DT_FLOAT; - nodes_to_fuse.emplace_back(cast1); - nodes_to_fuse.emplace_back(trans); - nodes_to_fuse.emplace_back(cast2); - OpsKernelManager ops_kernel_manager; - cast_remove_pass.DoFuse(ops_kernel_manager, type, nodes_to_fuse); - EXPECT_EQ(compute_graph->GetAllNodesSize(),5); - std::vector to_be_deleted_cast_index; - to_be_deleted_cast_index.emplace_back(0); - to_be_deleted_cast_index.emplace_back(2); - (void)cast_remove_pass.DoRemoveCast(to_be_deleted_cast_index, nodes_to_fuse); - EXPECT_EQ(compute_graph->GetAllNodesSize(),3); -}