From eee4d7b4922fb1231502245a0d5ec2c6816f7c32 Mon Sep 17 00:00:00 2001 From: wjm Date: Thu, 10 Jun 2021 16:50:41 +0800 Subject: [PATCH 01/24] fix safe --- ge/common/dump/exception_dumper.cc | 1 + ge/graph/build/model_builder.cc | 2 +- ge/graph/load/model_manager/task_info/kernel_task_info.cc | 2 ++ ge/graph/preprocess/insert_op/util_insert_aipp_op.cc | 1 + ge/hybrid/model/hybrid_model_builder.cc | 1 + 5 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ge/common/dump/exception_dumper.cc b/ge/common/dump/exception_dumper.cc index c8ec3d35..c41da551 100644 --- a/ge/common/dump/exception_dumper.cc +++ b/ge/common/dump/exception_dumper.cc @@ -161,6 +161,7 @@ Status ExceptionDumper::DumpExceptionInfo(const std::vector &ex uint64_t proto_size = dump_data.ByteSizeLong(); std::unique_ptr proto_msg(new (std::nothrow) char[proto_size]); + GE_CHECK_NOTNULL(proto_msg); bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); if (!ret || proto_size == 0) { REPORT_INNER_ERROR("E19999", "Serialize proto to string fail"); diff --git a/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc index 431e4882..5dddd22a 100755 --- a/ge/graph/build/model_builder.cc +++ b/ge/graph/build/model_builder.cc @@ -707,7 +707,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) { GE_CHECK_NOTNULL(kernel_buffer.GetData()); std::vector data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); - tbe_kernel = std::make_shared(kernel_name, std::move(data)); + tbe_kernel = MakeShared(kernel_name, std::move(data)); GE_CHECK_NOTNULL(tbe_kernel); GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); diff --git a/ge/graph/load/model_manager/task_info/kernel_task_info.cc b/ge/graph/load/model_manager/task_info/kernel_task_info.cc index 919a56cd..8cae6090 100755 --- a/ge/graph/load/model_manager/task_info/kernel_task_info.cc +++ b/ge/graph/load/model_manager/task_info/kernel_task_info.cc @@ -645,6 +645,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne GE_CHECK_NOTNULL(op_desc); args_addr = std::unique_ptr(new (std::nothrow) uint8_t[args_size_]); + GE_CHECK_NOTNULL(args_addr); errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); if (sec_ret != EOK) { REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret); @@ -1000,6 +1001,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k // copy args to new host memory args_addr = std::unique_ptr(new (std::nothrow) uint8_t[args_size_]); + GE_CHECK_NOTNULL(args_addr); GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); if (sec_ret != EOK) { diff --git a/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc index d76b79b9..9bfc71ba 100755 --- a/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc +++ b/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc @@ -568,6 +568,7 @@ Status InsertNewOpUtil::GetDataRelatedNode(NodePtr &node, std::map aipp_params(new (std::nothrow) domi::AippOpParams()); + GE_CHECK_NOTNULL(aipp_params); ge::GeAttrValue::NAMED_ATTRS aipp_attr; GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST, "[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str()); diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index e865d922..6a966622 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -1048,6 +1048,7 @@ Status HybridModelBuilder::InitConstantOps() { } else { var_tensor.reset(new(std::nothrow)TensorValue(nullptr, 0)); } + GE_CHECK_NOTNULL(var_tensor); } else { GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); From 085413913478de5f5f8c6b43125eaaf1d2819bcb Mon Sep 17 00:00:00 2001 From: wqtshg Date: Sun, 20 Jun 2021 10:27:49 +0800 Subject: [PATCH 02/24] update submodule --- metadef | 2 +- parser | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/metadef b/metadef index e189fc7f..8fdd3e81 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit e189fc7f4da9f7714f009d70da4db627de17955d +Subproject commit 8fdd3e8195f602e4b8919a513be8c0ec6a415a11 diff --git a/parser b/parser index c074dfa5..b79ef8ad 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit c074dfa5960d67f2910122d46d4d264dd6554aad +Subproject commit b79ef8ad19c8ab4335a97b2c3668d2776b62ce0a From be6b3a176f52054f95cca592a39ca9e68a3306e2 Mon Sep 17 00:00:00 2001 From: zhupuxu Date: Wed, 9 Jun 2021 14:13:41 +0800 Subject: [PATCH 03/24] step info Signed-off-by: zhupuxu --- ge/common/profiling/ge_profiling.cc | 39 ++++++++++++++++++++-- inc/framework/common/profiling/ge_profiling.h | 5 +++ tests/depends/profiler/src/profiler_stub.cc | 8 +++++ tests/ut/ge/CMakeLists.txt | 1 + .../ge/profiling/ge_profiling_manager_unittest.cc | 19 ++++++++++- 5 files changed, 69 insertions(+), 3 deletions(-) diff --git a/ge/common/profiling/ge_profiling.cc b/ge/common/profiling/ge_profiling.cc index d0343326..48d12609 100644 --- a/ge/common/profiling/ge_profiling.cc +++ b/ge/common/profiling/ge_profiling.cc @@ -22,6 +22,7 @@ #include "graph/load/graph_loader.h" #include "init/gelib.h" #include "framework/common/ge_inner_error_codes.h" +#include "model/ge_model.h" namespace { const uint32_t kDeviceListIndex = 3; @@ -42,6 +43,10 @@ const std::map kProfCommandTypeMap = { {kProfCommandhandleFinalize, kProfilingFinalize}, {kProfCommandhandleModelSubscribe, kProfModelSubscribe}, {kProfCommandhandleModelUnsubscribe, kProfModelUnsubscribe}}; + +const uint64_t kModelId = ge::INVALID_MODEL_ID; +const uint16_t kStepStart = 0; +const uint16_t kStepEnd = 1; } // namespace bool TransProfConfigToParam(const ProfCommandHandleData &profCommand, vector &prof_config_params) { @@ -216,6 +221,36 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le return ge::SUCCESS; } -GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) { - return ge::SUCCESS; +ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) { + static bool is_first_run = true; + int32_t device_id = 0; + rtError_t rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(rt_ret, "[Get][LogicDeviceId]Failed, ret 0x%X", rt_ret); + REPORT_CALL_ERROR("E19999", "Get logic device id failed, ret 0x%X", rt_ret); + return ge::FAILED; + } + if (is_first_run && tag_id == kStepStart) { + GE_CHK_STATUS_RET_NOLOG(ge::ProfilingManager::Instance().ProfileStepInfo(index_id, + kModelId, + tag_id, + stream, + device_id)); + is_first_run = false; + return ge::SUCCESS; + } + if (!is_first_run && tag_id == kStepEnd) { + GE_CHK_STATUS_RET_NOLOG(ge::ProfilingManager::Instance().ProfileStepInfo(index_id, + kModelId, + tag_id, + stream, + device_id)); + is_first_run = true; + return ge::SUCCESS; + } + GELOGE(ge::FAILED, "Param tag_id:%u invalid when is_first_run is %d", tag_id, is_first_run); + REPORT_INPUT_ERROR("E10001", std::vector({"value", "parameter", "reason"}), + std::vector({std::to_string(tag_id), "tag_id", + "tag id must be 0 when first run, must be 1 when second run"})); + return ge::FAILED; } diff --git a/inc/framework/common/profiling/ge_profiling.h b/inc/framework/common/profiling/ge_profiling.h index a8de56a8..7a238b2f 100644 --- a/inc/framework/common/profiling/ge_profiling.h +++ b/inc/framework/common/profiling/ge_profiling.h @@ -43,6 +43,11 @@ GE_FUNC_VISIBILITY ge::Status RegProfCtrlCallback(MsprofCtrlCallback func); GE_FUNC_VISIBILITY ge::Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func); GE_FUNC_VISIBILITY ge::Status RegProfReporterCallback(MsprofReporterCallback func); GE_FUNC_VISIBILITY ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len); + +/// +/// @brief Output the profiling data of single operator in Pytorch, and does not support multithreading +/// @return Status result +/// GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream); #endif // INC_FRAMEWORK_COMMON_GE_PROFILING_H_ diff --git a/tests/depends/profiler/src/profiler_stub.cc b/tests/depends/profiler/src/profiler_stub.cc index 1ed49fd8..0b8eaa88 100644 --- a/tests/depends/profiler/src/profiler_stub.cc +++ b/tests/depends/profiler/src/profiler_stub.cc @@ -16,6 +16,7 @@ #include "toolchain/prof_engine.h" #include "toolchain/prof_mgr_core.h" +#include "runtime/base.h" void * ProfMgrStartUp(const ProfMgrCfg *cfg) { @@ -32,3 +33,10 @@ int Msprof::Engine::RegisterEngine(const std::string& module, const Msprof::Engi return 0; } +rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) { + return 0; +} + +rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback) { + return 0; +} diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 1743a0f0..c5790d59 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -158,6 +158,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/opskernel_manager/ops_kernel_builder_manager.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/model_manager.cc" "${GE_CODE_DIR}/ge/common/profiling/profiling_manager.cc" + "${GE_CODE_DIR}/ge/common/profiling/ge_profiling.cc" "${GE_CODE_DIR}/ge/graph/manager/host_mem_manager.cc" "${GE_CODE_DIR}/ge/graph/manager/memory_api.cc" "${GE_CODE_DIR}/ge/session/inner_session.cc" diff --git a/tests/ut/ge/profiling/ge_profiling_manager_unittest.cc b/tests/ut/ge/profiling/ge_profiling_manager_unittest.cc index 9c615317..aae3f535 100644 --- a/tests/ut/ge/profiling/ge_profiling_manager_unittest.cc +++ b/tests/ut/ge/profiling/ge_profiling_manager_unittest.cc @@ -25,6 +25,7 @@ #define private public #include "common/profiling/profiling_manager.h" #include "graph/ge_local_context.h" +#include "inc/framework/common/profiling/ge_profiling.h" #undef protected #undef private @@ -115,4 +116,20 @@ TEST_F(UtestGeProfilinganager, get_fp_bp_point_empty) { ProfilingManager::Instance().GetFpBpPoint(fp_point, bp_point); EXPECT_EQ(fp_point, ""); EXPECT_EQ(bp_point, ""); -} \ No newline at end of file +} + +TEST_F(UtestGeProfilinganager, set_step_info_success) { + uint64_t index_id = 0; + auto stream = (rtStream_t)0x1; + Status ret = ProfSetStepInfo(index_id, 0, stream); + EXPECT_EQ(ret, ge::SUCCESS); + ret = ProfSetStepInfo(index_id, 1, stream); + EXPECT_EQ(ret, ge::SUCCESS); +} + +TEST_F(UtestGeProfilinganager, set_step_info_failed) { + uint64_t index_id = 0; + auto stream = (rtStream_t)0x1; + Status ret = ProfSetStepInfo(index_id, 1, stream); + EXPECT_EQ(ret, ge::FAILED); +} From e51ffe2f54995c457ce3e4eb797becd027634ea8 Mon Sep 17 00:00:00 2001 From: wuweikang Date: Mon, 28 Jun 2021 10:26:17 +0800 Subject: [PATCH 04/24] fix mem leak --- ge/graph/load/model_manager/model_manager.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/ge/graph/load/model_manager/model_manager.cc b/ge/graph/load/model_manager/model_manager.cc index 6a563d2f..cae828d6 100755 --- a/ge/graph/load/model_manager/model_manager.cc +++ b/ge/graph/load/model_manager/model_manager.cc @@ -1397,6 +1397,14 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { vector allocated_mem; rtError_t status; rtStream_t stream = nullptr; + std::function callback = [&]() { + for (auto mem : allocated_mem) { + GE_CHK_RT(rtFree(mem)); + } + GE_CHK_RT(rtStreamDestroy(stream)); + }; + GE_MAKE_GUARD(release, callback); + vector v_cust_so; void *args = nullptr; @@ -1471,13 +1479,6 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { GELOGE(RT_FAILED, "[Call][RtStreamSynchronize] fail, ret = 0x%X", status); return RT_ERROR_TO_GE_STATUS(status); } - std::function callback = [&]() { - for (auto mem : allocated_mem) { - GE_CHK_RT(rtFree(mem)); - } - GE_CHK_RT(rtStreamDestroy(stream)); - }; - GE_MAKE_GUARD(release, callback); GELOGI("Cpu kernel launch task success."); return SUCCESS; } From bb2c55fac84ac35558f7958f580372581e617d12 Mon Sep 17 00:00:00 2001 From: wuweikang Date: Thu, 13 May 2021 16:14:41 +0800 Subject: [PATCH 05/24] add copy graph --- ge/graph/manager/graph_manager.cc | 2 +- ge/hybrid/model/hybrid_model.h | 1 + ge/hybrid/model/hybrid_model_builder.cc | 45 +++++++++++++++++----- ge/hybrid/model/hybrid_model_builder.h | 1 + .../hybrid/executor/subgraph_executor_unittest.cc | 3 ++ .../hybrid/model/hybrid_model_builder_unittest.cc | 26 ++++++++++--- 6 files changed, 63 insertions(+), 15 deletions(-) diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 0a4633ad..0b27fdf3 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -3139,10 +3139,10 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { } // Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency if (count > 1 && graph_node->GetBuildFlag()) { - graph_node->Lock(); GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); // In online inference concurrency senario, graph_node is allowed to be locked for 'count' times graph_node->SetSemSize(count); + graph_node->Lock(); graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); diff --git a/ge/hybrid/model/hybrid_model.h b/ge/hybrid/model/hybrid_model.h index 9821242a..77246e20 100644 --- a/ge/hybrid/model/hybrid_model.h +++ b/ge/hybrid/model/hybrid_model.h @@ -147,6 +147,7 @@ class HybridModel { GeRootModelPtr ge_root_model_; std::map input_nodes_; ComputeGraphPtr root_graph_; + ComputeGraphPtr orig_root_graph_; std::map device_variable_nodes_; //lint !e148 std::map host_variable_nodes_; //lint !e148 std::map> variable_tensors_; diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index b9536dba..e865d922 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -147,6 +147,7 @@ Status HybridModelBuilder::Build() { GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); hybrid_model_.model_name_ = ge_root_model_->GetModelName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); + GE_CHK_STATUS_RET(CopyGraph(), "[Invoke][CopyGraph] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[Invoke][RecoverGraphUnknownFlag] failed, model_name_:[%s]", GetGraphName()); @@ -171,11 +172,12 @@ Status HybridModelBuilder::Build() { Status HybridModelBuilder::BuildForSingleOp() { GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); + hybrid_model_.root_graph_ = ge_root_model_->GetRootGraph(); hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); - const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()]; - GE_CHK_STATUS_RET(IndexTaskDefs(ge_root_model_->GetRootGraph(), ge_model), + const GeModelPtr ge_model = ret[hybrid_model_.root_graph_->GetName()]; + GE_CHK_STATUS_RET(IndexTaskDefs(hybrid_model_.root_graph_, ge_model), "[Invoke][IndexTaskDefs] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(LoadGraph(), "[Invoke][LoadGraph] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(InitWeights(), "[Invoke][InitWeights] failed, model_name_:[%s]", GetGraphName()); @@ -190,6 +192,27 @@ Status HybridModelBuilder::ValidateParams() { return SUCCESS; } +Status HybridModelBuilder::CopyGraph() { + GELOGD("Copy compute graph begin."); + auto root_graph = ge_root_model_->GetRootGraph(); + + std::string new_graph_name = ge_root_model_->GetRootGraph()->GetName(); + ComputeGraphPtr new_root_graph = MakeShared(new_graph_name); + GE_CHECK_NOTNULL(new_root_graph); + int32_t depth = 0; + std::map node_old_2_new; + std::map op_desc_old_2_new; + graphStatus ret = GraphUtils::CopyComputeGraph(root_graph, new_root_graph, node_old_2_new, op_desc_old_2_new, depth); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Copy compute graph failed."); + return GRAPH_FAILED; + } + hybrid_model_.root_graph_ = new_root_graph; + + GELOGD("Copy compute graph[%s] success.", new_graph_name.c_str()); + return SUCCESS; +} + Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { auto op_desc = node->GetOpDesc(); GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), @@ -814,12 +837,13 @@ Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, } Status HybridModelBuilder::LoadGraph() { - auto root_graph = ge_root_model_->GetRootGraph(); + auto root_graph = hybrid_model_.root_graph_; if (!GetContext().GetHostExecFlag()) { std::shared_ptr merged_graph; GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), root_graph->GetAllNodesSize()); + hybrid_model_.orig_root_graph_ = root_graph; GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "[Invoke][UnfoldSubgraphs]Failed to unfold subgraphs, model_name_:%s.", GetGraphName()); root_graph = std::move(merged_graph); @@ -877,6 +901,7 @@ Status HybridModelBuilder::LoadGraph() { } for (auto &it : hybrid_model_.known_shape_sub_models_) { auto node_item = MutableNodeItem(it.first); + GE_CHECK_NOTNULL(node_item); AscendString graph_name; GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); @@ -1125,7 +1150,9 @@ Status HybridModelBuilder::InitWeights() { sub_weight_buffer->GetSize()); auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); if (subgraph != ge_root_model_->GetRootGraph()) { - subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); + subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_model.first); + } else { + subgraph = hybrid_model_.root_graph_; } GE_CHECK_NOTNULL(subgraph); hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); @@ -1304,7 +1331,7 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const } Status HybridModelBuilder::IndexTaskDefs() { - const auto root_graph = ge_root_model_->GetRootGraph(); + const auto &root_graph = hybrid_model_.root_graph_; const auto &root_graph_name = root_graph->GetName(); if (SetOutputNameAttr(*root_graph) != SUCCESS) { GELOGW("Set output name attr failed."); @@ -1338,7 +1365,7 @@ Status HybridModelBuilder::IndexTaskDefs() { Status HybridModelBuilder::IndexSpecialNodes() { GELOGD("Start to index special nodes"); - const auto &root_graph = ge_root_model_->GetRootGraph(); + const auto &root_graph = hybrid_model_.root_graph_; for (auto &node : root_graph->GetAllNodes()) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); @@ -1493,7 +1520,7 @@ Status HybridModelBuilder::InitRuntimeParams() { runtime_param_.session_id = ret ? static_cast(value) : 0; ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value); runtime_param_.logic_var_base = ret ? static_cast(value) : 0; - runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID(); + runtime_param_.graph_id = hybrid_model_.root_graph_->GetGraphID(); value = 0; for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { (void) ge::AttrUtils::GetInt(it.second, ATTR_MODEL_VAR_SIZE, value); @@ -1630,7 +1657,7 @@ Status HybridModelBuilder::TransAllVarData() { } Status HybridModelBuilder::CopyVarData() { - GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(ge_root_model_->GetRootGraph(), + GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(hybrid_model_.root_graph_, runtime_param_.session_id, hybrid_model_.device_id_), "[Invoke][CopyVarData] failed."); @@ -1713,7 +1740,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem } Status HybridModelBuilder::RecoverGraphUnknownFlag() { - const auto &root_graph = ge_root_model_->GetRootGraph(); + const auto &root_graph = hybrid_model_.root_graph_; for (auto &sub_graph : root_graph->GetAllSubgraphs()) { GE_CHECK_NOTNULL(sub_graph); for (const auto &node : sub_graph->GetDirectNode()) { diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 9c1eb187..05830e82 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -56,6 +56,7 @@ class HybridModelBuilder { Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph); Status ValidateParams(); Status LoadGraph(); + Status CopyGraph(); Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); static Status InitHcclExecutorOnDemand(const GeModelPtr &ge_model); Status LoadTask(NodeItem &node_item); diff --git a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc index 2dc3b639..827705ae 100644 --- a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc @@ -249,6 +249,9 @@ TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { graph_context.callback_manager = std::unique_ptr(new CallbackManager()); ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); + auto root_graph = hybrid_model.root_graph_; + switch_t = root_graph->FindNode("switch_t"); + switch_f = root_graph->FindNode("switch_f"); const auto node_it_t = hybrid_model.node_items_.find(switch_t); const auto node_it_f = hybrid_model.node_items_.find(switch_f); ASSERT_NE(hybrid_model.node_items_.end(), node_it_t); diff --git a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc index 5567aca2..10f7c0fe 100644 --- a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc +++ b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc @@ -214,11 +214,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { ASSERT_EQ(it->second->frame_index_, index); ASSERT_EQ(it->second->parent_frame_, -1); }; - TestFrameGroup(enter1, control_group_index); - TestFrameGroup(active1, control_group_index); - TestFrameGroup(active2, control_group_index); - TestFrameGroup(active3, control_group_index); - TestFrameGroup(output1, -1); + auto root_graph = hybrid_model.root_graph_; + auto enter1_node = root_graph->FindNode("enter"); + auto active1_node = root_graph->FindNode("active1"); + auto active2_node = root_graph->FindNode("active2"); + auto active3_node = root_graph->FindNode("active3"); + auto output1_node = root_graph->FindNode("net_output"); + TestFrameGroup(enter1_node, control_group_index); + TestFrameGroup(active1_node, control_group_index); + TestFrameGroup(active2_node, control_group_index); + TestFrameGroup(active3_node, control_group_index); + TestFrameGroup(output1_node, -1); engine_mapping.clear(); task_executor.clear(); @@ -373,4 +379,14 @@ TEST_F(UtestHybridModelBuilder, TestInitHcclExecutorOnDemand) { NodeExecutorManager::GetInstance().builders_.erase(NodeExecutorManager::ExecutorType::HCCL); ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), SUCCESS); } + +TEST_F(UtestHybridModelBuilder, copy_graph_success) { +ComputeGraphPtr graph = std::make_shared("test"); +GeRootModelPtr ge_root_model = make_shared(graph); +HybridModel hybrid_model(ge_root_model); +HybridModelBuilder hybrid_model_builder(hybrid_model); + +Status st = hybrid_model_builder.CopyGraph(); +EXPECT_EQ(st, SUCCESS); +} } // namespace ge From 1e0a3c0bca8cc07fab32cdacc1957e61307bc347 Mon Sep 17 00:00:00 2001 From: lianghao Date: Mon, 28 Jun 2021 16:29:03 +0800 Subject: [PATCH 06/24] FillKernel c78 --- ge/host_kernels/fill_kernel.cc | 8 +++++ .../passes/folding_kernel/fill_kernel_unittest.cc | 36 ++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/ge/host_kernels/fill_kernel.cc b/ge/host_kernels/fill_kernel.cc index 0022791c..9f62c8be 100644 --- a/ge/host_kernels/fill_kernel.cc +++ b/ge/host_kernels/fill_kernel.cc @@ -45,6 +45,7 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vectorGetName().c_str()); GE_CHECK_NOTNULL(input.at(kFillDimsInputIndex)); GE_CHECK_NOTNULL(input.at(kFillDataInputIndex)); @@ -57,6 +58,13 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vectorGetOutputDescPtr(0); + GE_CHECK_NOTNULL(output_desc); + if (output_desc->GetShape().IsUnknownShape()) { + GELOGD("Output is unknown shape, [%s] skip FillKernel.", op_desc_ptr->GetName().c_str()); + return NOT_CHANGED; + } + GeTensorPtr output_ptr; output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(0)); if (output_ptr == nullptr) { diff --git a/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc b/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc index f58d6d9b..c0cce260 100644 --- a/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc +++ b/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc @@ -64,6 +64,7 @@ class UtestGraphPassesFoldingKernelFillKernel : public testing::Test { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -124,6 +125,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillBoolShape2And3) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -230,6 +232,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsHaveNegativeNumber) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -284,6 +287,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsTypeNotSupport) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -310,6 +314,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsOverflow) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -336,6 +341,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsMulDataTypeOverflow) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -343,3 +349,33 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsMulDataTypeOverflow) { EXPECT_EQ(PARAM_INVALID, status); } + +TEST_F(UtestGraphPassesFoldingKernelFillKernel, OutputdescUnknown) { + ge::OpDescPtr op_dims = std::make_shared(); + vector dims_vec = {2}; + vector dims_value_vec = {2, 3}; + GeTensorDesc dims_tensor_desc(GeShape(dims_vec), FORMAT_NCHW, DT_INT32); + GeTensorPtr dim_tensor = std::make_shared(dims_tensor_desc, (uint8_t *) dims_value_vec.data(), + dims_value_vec.size() * sizeof(int32_t)); + OpDescUtils::SetWeights(op_dims, dim_tensor); + + ge::OpDescPtr op_value = std::make_shared(); + vector data_vec = {1}; + GeTensorDesc value_tensor_desc(GeShape(), FORMAT_NCHW, DT_BOOL); + GeTensorPtr value_tensor = + std::make_shared(value_tensor_desc, (uint8_t *) data_vec.data(), data_vec.size() * sizeof(bool)); + OpDescUtils::SetWeights(op_value, value_tensor); + + op_desc_ptr->AddInputDesc(dims_tensor_desc); + op_desc_ptr->AddInputDesc(value_tensor_desc); + + vector out_vec = {-1, -1}; + GeTensorDesc out_tensor_desc(GeShape(out_vec), FORMAT_NCHW, DT_INT32); + op_desc_ptr->AddOutputDesc(out_tensor_desc); + + std::vector input = {dim_tensor, value_tensor}; + std::vector outputs; + Status status = kernel->Compute(op_desc_ptr, input, outputs); + + EXPECT_EQ(NOT_CHANGED, status); +} \ No newline at end of file From 062757756c1f78d312bc587f4e620325bb9f2473 Mon Sep 17 00:00:00 2001 From: wangxiaotian22 Date: Wed, 30 Jun 2021 09:10:52 +0800 Subject: [PATCH 07/24] fix sc --- ge/graph/preprocess/multi_batch_copy_graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index 1634c8ce..fd3a4e91 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -1206,7 +1206,7 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector &start_ auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); if (!IsAllDimsPositive(dims)) { REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s", - node->GetName().c_str(), formats::ShapeToString(dims).c_str()); + node->GetName().c_str(), formats::ShapeToString(dims).c_str()); GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s", node->GetName().c_str(), formats::ShapeToString(dims).c_str()); return INTERNAL_ERROR; From e5457d5949780495a6f23ef300d5e83602702bbb Mon Sep 17 00:00:00 2001 From: wangxiaotian22 Date: Wed, 30 Jun 2021 15:54:19 +0800 Subject: [PATCH 08/24] fix sc + --- ge/ge_runtime/task/label_goto_task.cc | 2 +- ge/single_op/task/op_task.h | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ge/ge_runtime/task/label_goto_task.cc b/ge/ge_runtime/task/label_goto_task.cc index 4302bff3..7cb6d556 100644 --- a/ge/ge_runtime/task/label_goto_task.cc +++ b/ge/ge_runtime/task/label_goto_task.cc @@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() { return false; } - rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); + rt_ret = rtLabelListCpy(reinterpret_cast(label_list.data()), label_list.size(), label_info_, label_info_size); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); return false; diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index ce569ce0..1b59ebbb 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -33,6 +33,10 @@ #include "register/op_tiling.h" namespace ge { +namespace { +const int kAddressNum = 2; +} // namespace + class StreamResource; struct SingleOpModelParam; class OpTask { @@ -256,7 +260,7 @@ class MemcpyAsyncTask : public OpTask { friend class SingleOpModel; friend class RtsKernelTaskBuilder; - uintptr_t addresses_[2]; + uintptr_t addresses_[kAddressNum]; size_t dst_max_; size_t count_; rtMemcpyKind_t kind_; From 8c55572e127381271fed94eab34cfbfc30cf0c2a Mon Sep 17 00:00:00 2001 From: lianghao Date: Thu, 1 Jul 2021 20:46:51 +0800 Subject: [PATCH 09/24] FindLastBpFromBpNode c78 --- ge/graph/build/task_generator.cc | 44 ++++++++++------------ ge/graph/build/task_generator.h | 2 +- tests/ut/ge/graph/build/task_generator_unittest.cc | 4 +- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc index f9456aab..356e500f 100755 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -793,7 +793,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP GELOGI("Start AutoFindBpOpIndex"); NodePtr bp_node = nullptr; uint32_t current_idx = 0; - uint32_t netoutput_idx = 0; for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -811,7 +810,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) { if (bp_node == nullptr) { bp_node = node; - netoutput_idx = current_idx - 1; } } if (graph->GetNeedIteration()) { @@ -836,34 +834,30 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP if (bp_node == nullptr) { GELOGW("not find bp_node."); return SUCCESS; - } else if (bp_node->GetName() == NODE_NAME_NET_OUTPUT) { - profiling_point.bp_index = netoutput_idx; - GELOGI("First bp name %s, idx %u", bp_node->GetName().c_str(), netoutput_idx); - } else { - profiling_point.bp_index = FindLastBpFromBpNode(graph, bp_node); } - return SUCCESS; + return FindLastBpFromBpNode(graph, bp_node, profiling_point.bp_index); } -uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const { - uint32_t last_bp = 0; +Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &target_node, + uint32_t &bp_index) const { + bp_index = 0; + auto target_desc = target_node->GetOpDesc(); + GE_CHECK_NOTNULL(target_desc); OpDescPtr bp_op_desc = nullptr; - for (auto &in_anchor : bp_node->GetAllInDataAnchors()) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) { - continue; - } - auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL(out_node_desc); - if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) { - bp_op_desc = out_node_desc; + for (auto &in_node : target_node->GetInAllNodes()) { + GE_CHECK_NOTNULL(in_node); + auto in_node_desc = in_node->GetOpDesc(); + GE_CHECK_NOTNULL(in_node_desc); + if ((bp_op_desc == nullptr || (in_node_desc->GetId() > bp_op_desc->GetId())) && + (in_node_desc->GetStreamId() == target_desc->GetStreamId())){ + bp_op_desc = in_node_desc; } - GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId()); } if (bp_op_desc == nullptr) { - return last_bp; + GELOGI("Did not find bp node."); + return SUCCESS; } uint32_t current_idx = 0; for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { @@ -871,12 +865,14 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const GE_CHECK_NOTNULL(op_desc); current_idx++; if (op_desc->GetName() == bp_op_desc->GetName()) { - last_bp = current_idx; - GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp); + bp_index = current_idx; + GELOGI("Find bp name %s, idx %u", op_desc->GetName().c_str(), bp_index); break; } } - return last_bp; + GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(), + bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId()); + return SUCCESS; } Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, diff --git a/ge/graph/build/task_generator.h b/ge/graph/build/task_generator.h index 40cef3ba..0720a1b1 100755 --- a/ge/graph/build/task_generator.h +++ b/ge/graph/build/task_generator.h @@ -116,7 +116,7 @@ class TaskGenerator { Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const; Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, vector &all_reduce_nodes) const; - uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const; + Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const; Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, ProfilingPoint &profiling_point) const; diff --git a/tests/ut/ge/graph/build/task_generator_unittest.cc b/tests/ut/ge/graph/build/task_generator_unittest.cc index f869f1e0..1e865050 100644 --- a/tests/ut/ge/graph/build/task_generator_unittest.cc +++ b/tests/ut/ge/graph/build/task_generator_unittest.cc @@ -116,7 +116,9 @@ TEST_F(UtestTaskGeneratorTest, FindLastBpFromBpNode) { TaskGenerator task_generator(nullptr, 0); auto net_output = graph->FindNode("Node_Output"); // netoutput has no data input, return default value 0 - EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output), 0); + uint32_t bp_index = 0; + EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output, bp_index), 0); + EXPECT_EQ(bp_index, 2); } TEST_F(UtestTaskGeneratorTest, UpdateOpIsVarAttr) { From 5da987eb3af97459c694c21c2e7c17ac53829917 Mon Sep 17 00:00:00 2001 From: wangzhengjun Date: Wed, 30 Jun 2021 17:08:09 +0800 Subject: [PATCH 10/24] set size for dynamic input --- ge/hybrid/executor/hybrid_model_async_executor.cc | 8 ++++--- .../hybrid_model_async_executor_unittest.cc | 28 ++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/ge/hybrid/executor/hybrid_model_async_executor.cc b/ge/hybrid/executor/hybrid_model_async_executor.cc index 930412e3..f05d7c13 100644 --- a/ge/hybrid/executor/hybrid_model_async_executor.cc +++ b/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -297,13 +297,15 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData ¤t_data, Hy } } tensor_desc->SetShape(shape); - args.input_desc[input_index] = tensor_desc; - GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str()); + GELOGD("Update shape[%s] of input[%zu] to [%s]", + shape.ToString().c_str(), input_index, tensor_desc->MutableShape().ToString().c_str()); GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), "[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size," "index = %zu, shape = [%s], model_id = %u.", input_index, tensor_desc->GetShape().ToString().c_str(), model_id_); - GELOGD("Input tensor[%zu] size = %zu", input_index, tensor_size); + GELOGD("Input tensor[%zu] size = %ld", input_index, tensor_size); + TensorUtils::SetSize(*tensor_desc, tensor_size); + args.input_desc[input_index] = tensor_desc; } GE_CHECK_GE(tensor_size, 0); diff --git a/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc b/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc index 98bb78f2..a6a0bc6a 100644 --- a/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc @@ -103,4 +103,32 @@ TEST_F(UtestHybridModelAsyncExecutor, Test_execute) { context.callback_manager->callback_queue_.Push(eof_entry); ASSERT_EQ(executor.Execute(args), SUCCESS); } + +TEST_F(UtestHybridModelAsyncExecutor, test_PrepareInputs) { + ComputeGraphPtr graph = std::make_shared("test"); + GeRootModelPtr ge_root_model = make_shared(graph); + ge_root_model->SetModelName("test_name"); + GeModelPtr ge_sub_model = make_shared(); + HybridModel hybrid_model(ge_root_model); + HybridModelAsyncExecutor executor(&hybrid_model); + GeTensorDescPtr tensor_desc = make_shared(GeShape({-1, 16, 16, 3})); + tensor_desc->SetShapeRange({{1, 256}, {16, 16}, {16, 16}, {3, 3}}); + executor.input_tensor_desc_.insert({0, tensor_desc}); + executor.device_id_ = 0; + executor.input_sizes_.insert({0, -1}); + executor.is_input_dynamic_.push_back(true); + + unique_ptr data_buf(new (std::nothrow)uint8_t[3072]); + InputData input_data; + input_data.blobs.push_back(DataBuffer(data_buf.get(), 3072, false)); + input_data.shapes.push_back({1, 16, 16, 3}); + HybridModelExecutor::ExecuteArgs args; + + auto ret = executor.PrepareInputs(input_data, args); + ASSERT_EQ(ret, SUCCESS); + ASSERT_EQ(args.input_desc[0]->GetShape().ToString(), GeShape({1, 16, 16, 3}).ToString()); + int64_t tensor_size = 0; + TensorUtils::GetSize(*(args.input_desc[0]), tensor_size); + ASSERT_EQ(tensor_size, 3104); +} } // namespace ge \ No newline at end of file From c440980918063a04c35a38aff6a66f6e781d322c Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Tue, 22 Jun 2021 16:20:29 +0800 Subject: [PATCH 11/24] Fix BuildPartitionFrame failed --- ge/graph/partition/dynamic_shape_partition.cc | 25 +++++++++++++------------ ge/graph/partition/dynamic_shape_partition.h | 4 +++- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/ge/graph/partition/dynamic_shape_partition.cc b/ge/graph/partition/dynamic_shape_partition.cc index 1db47498..8fc19ff2 100755 --- a/ge/graph/partition/dynamic_shape_partition.cc +++ b/ge/graph/partition/dynamic_shape_partition.cc @@ -284,9 +284,6 @@ Status DynamicShapePartitioner::InitClusters() { auto cluster = MakeShared(rank++, type, node, this); REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed."); node_2_cluster_[node] = cluster; - if (cluster->IsUnknownShape()) { - ordered_cluster_.push_back(cluster); - } int64_t group_index = -1; if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { @@ -306,7 +303,7 @@ Status DynamicShapePartitioner::InitClusters() { return SUCCESS; } -Status DynamicShapePartitioner::TopologicalSortClusters() { +Status DynamicShapePartitioner::TopologicalSortClusters(const OrderedFilter &ordered_filter) { ordered_cluster_.clear(); // BFS topological sort clusters for known shape cluster std::queue ready_clusters; @@ -331,7 +328,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { auto cluster = ready_clusters.front(); ready_clusters.pop(); cluster->UpdateRank(rank++); - if (cluster->IsKnownShape() || cluster->IsInputNode()) { + if (ordered_filter == nullptr || ordered_filter(cluster)) { ordered_cluster_.push_back(cluster); } for (const auto &out_cluster : cluster->Outputs()) { @@ -378,7 +375,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { continue; } - bool is_unknown_cluster = cluster->IsUnknownShape(); for (++rit; rit != control_cluster.rend(); ++rit) { const auto &cluster_from = *rit; if (all_merged_clusters.count(cluster_from) > 0) { @@ -395,11 +391,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { } } } - - if (!is_unknown_cluster && cluster->IsUnknownShape()) { - GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str()); - ordered_cluster_.push_back(cluster); - } } } @@ -475,9 +466,19 @@ void DynamicShapePartitioner::MergeClustersInputData() { } Status DynamicShapePartitioner::MergeClusters() { + const auto filter_known = [](const ClusterPtr &cluster) { + return cluster->IsKnownShape() || cluster->IsInputNode(); + }; + const auto filter_unknown = [](const ClusterPtr &cluster) { + return cluster->IsUnknownShape(); + }; + MergeClustersControlFlow(); + REQUIRE_SUCCESS(TopologicalSortClusters(filter_unknown), + "[TopologicalSort][Clusters] after merge control flow clusters failed."); MergeClustersUnknownShape(); - REQUIRE_SUCCESS(TopologicalSortClusters(), "[TopologicalSort][Clusters] after merge unknown shape clusters failed."); + REQUIRE_SUCCESS(TopologicalSortClusters(filter_known), + "[TopologicalSort][Clusters] after merge unknown shape clusters failed."); MergeClustersKnownShape(); MergeClustersInputData(); return SUCCESS; diff --git a/ge/graph/partition/dynamic_shape_partition.h b/ge/graph/partition/dynamic_shape_partition.h index bd3b128f..7e82d131 100644 --- a/ge/graph/partition/dynamic_shape_partition.h +++ b/ge/graph/partition/dynamic_shape_partition.h @@ -111,6 +111,8 @@ class DynamicShapePartitioner { Status Partition(); + using OrderedFilter = std::function &cluster)>; + private: Status PartitionImpl(); // Collect nodes that satisfy the unknowshape rules: @@ -138,7 +140,7 @@ class DynamicShapePartitioner { // Merge clusters step3 void MergeClustersInputData(); // Topological sort clusters after merge unknown shape clusters. - Status TopologicalSortClusters(); + Status TopologicalSortClusters(const OrderedFilter &ordered_filter); // Deduplicate merged clusters void PruneUniqueClusters(); // Establish the input-output anchors for each partition of the cluster and record links to other clusters From 7b1331770a413b3b997cb00ed910e0bead2811b2 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Wed, 23 Jun 2021 20:31:18 +0800 Subject: [PATCH 12/24] Fix multi control from one node --- ge/ge_local_engine/engine/host_cpu_engine.cc | 11 ++++------- ge/hybrid/model/hybrid_model_builder.cc | 4 ---- ge/hybrid/model/node_item.cc | 17 +++++++---------- ge/hybrid/model/node_item.h | 2 +- ge/hybrid/node_executor/hccl/hccl_node_executor.cc | 1 + ge/hybrid/node_executor/rts/rts_node_executor.cc | 7 +------ 6 files changed, 14 insertions(+), 28 deletions(-) diff --git a/ge/ge_local_engine/engine/host_cpu_engine.cc b/ge/ge_local_engine/engine/host_cpu_engine.cc index cd68ae15..d9b67736 100755 --- a/ge/ge_local_engine/engine/host_cpu_engine.cc +++ b/ge/ge_local_engine/engine/host_cpu_engine.cc @@ -13,15 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "host_cpu_engine.h" -#include "graph/common/omg_util.h" +#include "ge_local_engine/engine/host_cpu_engine.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_adapter.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/type_utils.h" #include "register/op_kernel_registry.h" #include "register/host_cpu_context.h" #include "common/ge/ge_util.h" #include "common/ge/plugin_manager.h" -#include "graph/utils/type_utils.h" #include "common/fp16_t.h" #include "common/math/math_util.h" @@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) { } Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr &op_kernel) { - std::string op_type; - auto status = GetOriginalType(node, op_type); - GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status); - + const std::string op_type = NodeUtils::GetNodeType(node); auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); if (kernel == nullptr) { GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index e865d922..8ccc2ac6 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -288,10 +288,6 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n return SUCCESS; } - if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity. - node->GetOpDesc()->SetType(IDENTITY); - } - std::unique_ptr new_node; GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index cef06fc6..250562ce 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -14,10 +14,8 @@ * limitations under the License. */ -#include "node_item.h" -#include -#include "common/debug/log.h" -#include "graph/common/omg_util.h" +#include "hybrid/model/node_item.h" + #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "hybrid/executor/worker/shape_inference_engine.h" @@ -98,8 +96,7 @@ Status ParseFusedSubgraph(NodeItem &node_item) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type)); + const std::string node_type = NodeUtils::GetNodeType(node); if (node_type == DATA) { GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); } else if (node_type == kNodeTypeRetVal) { @@ -409,8 +406,8 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { if (switch_index < switch_groups_.size()) { - std::vector &switch_group = switch_groups_[switch_index]; - switch_group.emplace_back(node_item); + auto &switch_group = switch_groups_[switch_index]; + switch_group.emplace(node_item); } else { ctrl_send_.insert(node_item); } @@ -433,8 +430,8 @@ void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { } // this is StreamMerge node, node_item is StreamActive node. - std::vector &switch_group = switch_groups_[merge_index]; - switch_group.emplace_back(node_item); + auto &switch_group = switch_groups_[merge_index]; + switch_group.emplace(node_item); node_item->ctrl_send_.emplace(this); GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index ec66f094..12775b00 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -155,7 +155,7 @@ struct NodeItem { std::map data_recv_; // Recv data notify from std::set ctrl_send_; // Send ctrl notify to std::set ctrl_recv_; // Recv ctrl notify from - std::vector> switch_groups_; // Send ctrl notify to + std::vector> switch_groups_; // Send ctrl notify to std::shared_ptr kernel_task; std::unique_ptr fused_subgraph; diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc index d942695e..0f98674c 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc @@ -342,6 +342,7 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function do GE_CHK_RT_RET(rtEventDestroy(evt)); } GELOGI("rdma callback success."); + return SUCCESS; }; HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.cc b/ge/hybrid/node_executor/rts/rts_node_executor.cc index 3ad791b6..d52f56b9 100644 --- a/ge/hybrid/node_executor/rts/rts_node_executor.cc +++ b/ge/hybrid/node_executor/rts/rts_node_executor.cc @@ -17,13 +17,9 @@ #include "hybrid/node_executor/rts/rts_node_executor.h" #include "hybrid/node_executor/rts/rts_task_factory.h" -#include "common/debug/log.h" #include "common/ge/ge_util.h" -#include "common/types.h" -#include "graph/common/omg_util.h" #include "graph/utils/tensor_utils.h" #include "hybrid/model/hybrid_model.h" -#include "runtime/rt.h" namespace ge { namespace hybrid { @@ -133,8 +129,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { GE_CHECK_NOTNULL(node); GELOGD("[%s] Load for local task.", node->GetName().c_str()); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); + const std::string node_type = NodeUtils::GetNodeType(node); RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); if (rts_task == nullptr) { GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); From ded54e73afe6ba9f0baa381cba2b91b947c065e2 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Wed, 23 Jun 2021 20:53:47 +0800 Subject: [PATCH 13/24] Fix Guard for variable release --- ge/graph/load/model_manager/model_manager.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ge/graph/load/model_manager/model_manager.cc b/ge/graph/load/model_manager/model_manager.cc index cae828d6..99793252 100755 --- a/ge/graph/load/model_manager/model_manager.cc +++ b/ge/graph/load/model_manager/model_manager.cc @@ -1394,17 +1394,19 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { return SUCCESS; } - vector allocated_mem; - rtError_t status; rtStream_t stream = nullptr; + vector allocated_mem; std::function callback = [&]() { for (auto mem : allocated_mem) { GE_CHK_RT(rtFree(mem)); } - GE_CHK_RT(rtStreamDestroy(stream)); + if (stream != nullptr) { + GE_CHK_RT(rtStreamDestroy(stream)); + } }; GE_MAKE_GUARD(release, callback); + rtError_t status; vector v_cust_so; void *args = nullptr; From 64d312ab12ae2f6f3b0b22a74fe340784f931f33 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Thu, 24 Jun 2021 09:32:14 +0800 Subject: [PATCH 14/24] UT for LaunchKernelCustAicpuSo --- ge/graph/load/model_manager/model_manager.cc | 4 +++- tests/ut/ge/graph/load/model_manager_unittest.cc | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/ge/graph/load/model_manager/model_manager.cc b/ge/graph/load/model_manager/model_manager.cc index 99793252..6fd2e273 100755 --- a/ge/graph/load/model_manager/model_manager.cc +++ b/ge/graph/load/model_manager/model_manager.cc @@ -1378,7 +1378,9 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); std::lock_guard lock(cust_aicpu_mutex_); - if (cust_aicpu_so_.size() == 0) return SUCCESS; + if (cust_aicpu_so_.empty()) { + return SUCCESS; + } // get current context rtContext_t rt_cur_ctx = nullptr; auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); diff --git a/tests/ut/ge/graph/load/model_manager_unittest.cc b/tests/ut/ge/graph/load/model_manager_unittest.cc index a3545b33..d9e4eabd 100644 --- a/tests/ut/ge/graph/load/model_manager_unittest.cc +++ b/tests/ut/ge/graph/load/model_manager_unittest.cc @@ -438,4 +438,22 @@ TEST_F(UtestModelManagerModelManager, test_data_input_tensor) { auto ret = mm.DataInputTensor(model_id,inputs); EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. } + +TEST_F(UtestModelManagerModelManager, test_launch_kernel_cust_aicpu) { + ModelManager mm; + + // cust_aicpu_so_ is empty. + EXPECT_EQ(mm.LaunchKernelCustAicpuSo("empty_cust_aicpu"), SUCCESS); + + // deleteCustOp after Launch will deleted. + uintptr_t resource_id = 1; // for rtCtxGetCurrent stub + std::vector kernel_bin(256); + auto &cust_resource_001 = mm.cust_aicpu_so_[resource_id]; + auto tbe_kernel = std::shared_ptr(new OpKernelBin("deleteCustOp", std::move(kernel_bin))); + auto &cust_opkernel_001 = cust_resource_001["deleteCustOp"] = tbe_kernel; + + EXPECT_FALSE(mm.cust_aicpu_so_.empty()); + EXPECT_EQ(mm.LaunchKernelCustAicpuSo("deleteCustOp"), SUCCESS); + EXPECT_TRUE(mm.cust_aicpu_so_.empty()); +} } // namespace ge From 65cafdd034f3c169e60a5e739348e3a34490ba4b Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Thu, 24 Jun 2021 15:49:11 +0800 Subject: [PATCH 15/24] Replace MemcpyAsyncNodeTask --- ge/hybrid/node_executor/rts/rts_node_executor.cc | 1 + ge/hybrid/node_executor/rts/rts_node_task.cc | 29 ------------------------ ge/hybrid/node_executor/rts/rts_node_task.h | 5 ---- 3 files changed, 1 insertion(+), 34 deletions(-) diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.cc b/ge/hybrid/node_executor/rts/rts_node_executor.cc index d52f56b9..e3058ee3 100644 --- a/ge/hybrid/node_executor/rts/rts_node_executor.cc +++ b/ge/hybrid/node_executor/rts/rts_node_executor.cc @@ -29,6 +29,7 @@ REGISTER_RTS_TASK_CREATOR(IDENTITY, IdentityNodeTask); REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask); REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask); REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask); +REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, IdentityNodeTask); Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { auto input_desc = context.MutableInputDesc(index); diff --git a/ge/hybrid/node_executor/rts/rts_node_task.cc b/ge/hybrid/node_executor/rts/rts_node_task.cc index 104196ee..aec6804d 100644 --- a/ge/hybrid/node_executor/rts/rts_node_task.cc +++ b/ge/hybrid/node_executor/rts/rts_node_task.cc @@ -43,7 +43,6 @@ namespace hybrid { REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); -REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask); REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); @@ -168,34 +167,6 @@ Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::functio return SUCCESS; } -Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { - GELOGD("[%s] Start to execute.", task_context.GetNodeName()); - auto input_desc = task_context.MutableInputDesc(0); - GE_CHECK_NOTNULL(input_desc); - int64_t copy_size = 0; - GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size)); - // copy_size would not be negative since GetTensorSizeInBytes returned successfully. - if (copy_size > 0) { - const auto in_v = task_context.MutableInput(0); - const auto out_v = task_context.MutableOutput(0); - GE_CHECK_NOTNULL(in_v); - GE_CHECK_NOTNULL(out_v); - GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(), - in_v->GetSize(), out_v->GetSize(), copy_size); - GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size, - RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream())); - } else { - GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size); - } - - if (done_callback) { - GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); - } - - GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); - return SUCCESS; -} - Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { GELOGD("[%s] Start to execute.", task_context.GetNodeName()); const auto in_x = task_context.GetInput(0); // x diff --git a/ge/hybrid/node_executor/rts/rts_node_task.h b/ge/hybrid/node_executor/rts/rts_node_task.h index d7d63eb5..e18f9a8f 100644 --- a/ge/hybrid/node_executor/rts/rts_node_task.h +++ b/ge/hybrid/node_executor/rts/rts_node_task.h @@ -60,11 +60,6 @@ class StreamMergeNodeTask : public RtsNodeTask { Status ExecuteAsync(TaskContext &task_context, std::function done_callback) override; }; -class MemcpyAsyncNodeTask : public RtsNodeTask { - public: - Status ExecuteAsync(TaskContext &task_context, std::function done_callback) override; -}; - class PassThroughNodeTask : public RtsNodeTask { public: Status ExecuteAsync(TaskContext &task_context, std::function done_callback) override; From 002583e4efb78816ea7493dd36ffe5b458107754 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Fri, 25 Jun 2021 20:53:08 +0800 Subject: [PATCH 16/24] Fix Set Control flow group for -1 --- .../passes/mark_force_unknown_for_cond_pass.cc | 97 ++++++++++++---------- ge/graph/passes/mark_force_unknown_for_cond_pass.h | 11 +++ ge/graph/passes/switch_to_stream_switch_pass.cc | 5 +- ge/hybrid/executor/node_state.cc | 35 +++++--- ge/hybrid/executor/node_state.h | 1 + ge/hybrid/model/node_item.cc | 6 +- ge/hybrid/model/node_item.h | 4 +- ge/hybrid/node_executor/node_executor.cc | 1 + ge/hybrid/node_executor/task_context.cc | 6 ++ ge/hybrid/node_executor/task_context.h | 1 + 10 files changed, 108 insertions(+), 59 deletions(-) diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc index 74babadc..a4095c1b 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -16,8 +16,6 @@ #include "mark_force_unknown_for_cond_pass.h" -#include - #include "graph/utils/node_utils.h" #include "graph/common/omg_util.h" @@ -26,17 +24,7 @@ namespace { inline bool IsMergeInLoop(const NodePtr &node) { const static std::set kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; - std::string node_type; - (void)GetOriginalType(node, node_type); - return kLoopMergeInputs.count(node_type) > 0; -} - -inline bool IsSwitchInLoop(const NodePtr &node) { - const static std::set kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; - - std::string node_type; - (void)GetOriginalType(node, node_type); - return kLoopSwitchInputs.count(node_type) > 0; + return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0; } } @@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { GELOGD("MarkForceUnknownForCondPass Enter"); std::map> switch_groups; for (const auto &node : graph->GetDirectNode()) { - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type), - "[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str()); - if (kMergeOpTypes.count(node_type) == 0) { + if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) { continue; } @@ -65,6 +50,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { } /// +/// @brief Deal with Switch node for LoopCond +/// @param [in] Switch node +/// @param [in] dest span +/// @param [out] Search queue +/// @return true: Switch In while loop / false: Not in while Loop. +/// +bool MarkForceUnknownForCondPass::DealWithLoopSwitch(const NodePtr &node, uint32_t dst_span, + std::queue> search_queue) { + /// LoopCond --->\. + /// \. + /// Enter-----------+ \. + /// +--> Merge --> Switch --> Exit + /// NextIteration---+ + const auto is_loop_op = [](const NodePtr &n) { + return NodeUtils::GetNodeType(n) == LOOPCOND; + }; + const auto is_exit_op = [](const NodePtr &n) { + return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0; + }; + + const auto src_nodes = node->GetInAllNodes(); + const auto dst_nodes = node->GetOutAllNodes(); + if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) && + std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) { + return false; + } + + for (const auto &m : src_nodes) { + if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) { + for (const auto &n : m->GetInAllNodes()) { + if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) { + continue; + } + + search_queue.push({n, dst_span}); + GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(), + n->GetName().c_str(), dst_span); + } + } + } + + return true; +} + +/// /// @brief Mark force unknown shape for Switch node /// @param [in] merge node /// @param [out] switch group @@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { /// void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector &switch_group) { // Switch --> {Switch --> Merge} --> Merge + GELOGD("Search Switch node for Merge: %s", node->GetName().c_str()); std::unordered_set nodes_seen; std::queue> search_queue({{node, 0}}); while (!search_queue.empty()) { @@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: const auto dst_span = search_queue.front().second; search_queue.pop(); - // Switch --> Identity --> Constant - for (const auto &in_node : dst_node->GetInControlNodes()) { - if (nodes_seen.count(in_node) > 0) { - GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); - continue; - } - nodes_seen.insert(in_node); - - if (in_node->GetType() == IDENTITY) { - GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(), - in_node->GetName().c_str(), dst_span); - search_queue.push({in_node, dst_span}); - } - } - - for (const auto &in_node : dst_node->GetInDataNodes()) { + for (const auto &in_node : dst_node->GetInAllNodes()) { if (nodes_seen.count(in_node) > 0) { GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); continue; } nodes_seen.insert(in_node); - std::string node_type; - (void)GetOriginalType(in_node, node_type); + const std::string node_type = NodeUtils::GetNodeType(in_node); GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), in_node->GetName().c_str(), dst_span); if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. + if (DealWithLoopSwitch(in_node, dst_span, search_queue)) { + continue; + } + if (dst_span > 0) { search_queue.push({in_node, dst_span - 1}); } else { - const auto &all_in_nodes = in_node->GetInDataNodes(); - if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { - GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), - in_node->GetName().c_str()); - } else { - switch_group.emplace_back(in_node); - } + switch_group.emplace_back(in_node); } } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. search_queue.push({in_node, dst_span + 1}); diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.h b/ge/graph/passes/mark_force_unknown_for_cond_pass.h index 528a8fdc..d2be9a9e 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.h +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.h @@ -19,6 +19,8 @@ #include "inc/graph_pass.h" +#include + namespace ge { class MarkForceUnknownForCondPass : public GraphPass { public: @@ -26,6 +28,15 @@ class MarkForceUnknownForCondPass : public GraphPass { private: /// + /// @brief Deal with Switch node for LoopCond + /// @param [in] Switch node + /// @param [in] dest span + /// @param [out] Search queue + /// @return true: Switch In while loop / false: Not in while Loop. + /// + bool DealWithLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue> search_queue); + + /// /// @brief Mark force unknown shape for Switch node /// @param [in] merge node /// @param [out] switch group diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index e4ab0111..1a47c14b 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); int64_t group_index = -1; - (void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); - SetControlFlowGroup(stream_switch, group_index); + if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { + SetControlFlowGroup(stream_switch, group_index); + } return stream_switch; } diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 468c84e6..4b0d0c44 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -326,17 +326,37 @@ std::shared_ptr NodeState::GetTaskContext() { } void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { - if (node_item_->root_data_.count(input_idx) > 0) { + const auto is_persist_tensor = [](const std::map> &items, int idx) { + const auto is_exist = [&idx](const std::pair> &items) { + return items.second.count(idx) > 0; + }; + return std::any_of(items.begin(), items.end(), is_exist); + }; + + if (is_persist_tensor(node_item_->root_data_, input_idx)) { GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); root_tensor_values_[input_idx] = tensor; - } - - if (node_item_->enter_data_.count(input_idx) > 0) { + } else if (is_persist_tensor(node_item_->enter_data_, input_idx)) { GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); root_tensor_values_[input_idx] = tensor; } } +void NodeState::UpdatePersistTensor() { + const auto update_tensor = [&](const std::map> &items) { + for (const auto &item : items) { + for (const auto idx : item.second) { + UpdatePersistTensor(idx); + } + } + }; + + update_tensor(node_item_->root_data_); + if (iteration_count_ > 0) { + update_tensor(node_item_->enter_data_); + } +} + void NodeState::UpdatePersistTensor(int input_idx) { const auto it = root_tensor_values_.find(input_idx); if (it == root_tensor_values_.end()) { @@ -363,16 +383,9 @@ void NodeState::ResetContext(uint64_t iteration) { data_scheduled_ = static_cast(node_item_->root_data_.size()); ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size()); - for (auto item : node_item_->root_data_) { - UpdatePersistTensor(item.first); - } - if (iteration > 0) { data_scheduled_ += static_cast(node_item_->enter_data_.size()); ctrl_scheduled_ += static_cast(node_item_->enter_ctrl_.size()); - for (auto item : node_item_->enter_data_) { - UpdatePersistTensor(item.first); - } } iteration_count_ = iteration; diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index b80b60b0..1ec8517e 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -132,6 +132,7 @@ struct NodeState { void RunNextIteration(); void SavePersistTensor(int input_idx, const TensorValue &tensor); + void UpdatePersistTensor(); Status NodeScheduled(const std::function &ready) const; diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 250562ce..8e87c6e2 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -395,11 +395,13 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { data_send_.emplace(node_item); node_item->data_recv_[this] = anchor_index; if (is_root_node_) { - node_item->root_data_[anchor_index] = this; + auto &data_anchors = node_item->root_data_[this]; + data_anchors.emplace(anchor_index); } // If Enter feed Not Merge, take as root Node. if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { - node_item->enter_data_[anchor_index] = this; + auto &data_anchors = node_item->enter_data_[this]; + data_anchors.emplace(anchor_index); } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); } diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 12775b00..f6dcdcf6 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -148,9 +148,9 @@ struct NodeItem { int64_t frame_index_ = -1; int64_t parent_frame_ = -1; std::set root_ctrl_; // Recv ctrl from root node - std::map root_data_; // Recv data from root node + std::map> root_data_; // Recv data from root node std::set enter_ctrl_; // Recv ctrl from Enter node - std::map enter_data_; // Recv data from Enter node + std::map> enter_data_; // Recv data from Enter node std::set data_send_; // Send data notify to std::map data_recv_; // Recv data notify from std::set ctrl_send_; // Send ctrl notify to diff --git a/ge/hybrid/node_executor/node_executor.cc b/ge/hybrid/node_executor/node_executor.cc index 9e9354d9..eeb5ba20 100755 --- a/ge/hybrid/node_executor/node_executor.cc +++ b/ge/hybrid/node_executor/node_executor.cc @@ -39,6 +39,7 @@ const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); + GE_CHK_STATUS_RET_NOLOG(context.UpdatePersistTensor()); GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); return SUCCESS; } diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index fe580c1e..7ff83ce0 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -468,6 +468,12 @@ Status TaskContext::PropagateOutputs() { return SUCCESS; } +Status TaskContext::UpdatePersistTensor() { + GE_CHECK_NOTNULL(node_state_); + node_state_->UpdatePersistTensor(); + return SUCCESS; +} + const void *TaskContext::GetVarBaseAddr() { return execution_context_->model->GetVarMemBase(); } diff --git a/ge/hybrid/node_executor/task_context.h b/ge/hybrid/node_executor/task_context.h index c96e194e..cff5d680 100644 --- a/ge/hybrid/node_executor/task_context.h +++ b/ge/hybrid/node_executor/task_context.h @@ -78,6 +78,7 @@ class TaskContext { Status AllocateOutputs(AllocationAttr *attr = nullptr); Status AllocateWorkspaces(); Status AllocateWorkspace(size_t size, void **buffer, void *ori_addr = nullptr); + Status UpdatePersistTensor(); bool IsTraceEnabled() const; From 796513222a45e7f1b4dddbaff22c0f65a6dace9a Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Sat, 26 Jun 2021 15:00:53 +0800 Subject: [PATCH 17/24] UpdatePersistTensor from ExecutionEngine --- ge/hybrid/executor/node_state.cc | 4 ++++ ge/hybrid/executor/worker/execution_engine.cc | 1 + ge/hybrid/node_executor/node_executor.cc | 1 - ge/hybrid/node_executor/task_context.cc | 11 +---------- ge/hybrid/node_executor/task_context.h | 1 - 5 files changed, 6 insertions(+), 12 deletions(-) diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 4b0d0c44..7ab7b536 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -333,6 +333,10 @@ void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { return std::any_of(items.begin(), items.end(), is_exist); }; + if (root_tensor_values_.count(input_idx) > 0) { + return; + } + if (is_persist_tensor(node_item_->root_data_, input_idx)) { GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); root_tensor_values_[input_idx] = tensor; diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc index 8eecbc80..d4c73f58 100755 --- a/ge/hybrid/executor/worker/execution_engine.cc +++ b/ge/hybrid/executor/worker/execution_engine.cc @@ -375,6 +375,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", node_state.GetName().c_str()); + node_state.UpdatePersistTensor(); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); GELOGD("[%s] Done task preparation successfully.", node_state.GetName().c_str()); diff --git a/ge/hybrid/node_executor/node_executor.cc b/ge/hybrid/node_executor/node_executor.cc index eeb5ba20..9e9354d9 100755 --- a/ge/hybrid/node_executor/node_executor.cc +++ b/ge/hybrid/node_executor/node_executor.cc @@ -39,7 +39,6 @@ const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); - GE_CHK_STATUS_RET_NOLOG(context.UpdatePersistTensor()); GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); return SUCCESS; } diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index 7ff83ce0..bb4340b7 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -458,22 +458,12 @@ Status TaskContext::PropagateOutputs() { subgraph_context_->all_inputs_[input_offset].SetName( node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); } - - auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); - GE_CHECK_NOTNULL(dst_node_state); - dst_node_state->SavePersistTensor(dst_input_idx, *tensor); } } (void)guard; return SUCCESS; } -Status TaskContext::UpdatePersistTensor() { - GE_CHECK_NOTNULL(node_state_); - node_state_->UpdatePersistTensor(); - return SUCCESS; -} - const void *TaskContext::GetVarBaseAddr() { return execution_context_->model->GetVarMemBase(); } @@ -499,6 +489,7 @@ void TaskContext::ReleaseInputsAndOutputs() { void TaskContext::ReleaseInput(int index) { auto input_tensor = MutableInput(index); if (input_tensor != nullptr) { + node_state_->SavePersistTensor(index, *input_tensor); input_tensor->Destroy(); GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); } diff --git a/ge/hybrid/node_executor/task_context.h b/ge/hybrid/node_executor/task_context.h index cff5d680..c96e194e 100644 --- a/ge/hybrid/node_executor/task_context.h +++ b/ge/hybrid/node_executor/task_context.h @@ -78,7 +78,6 @@ class TaskContext { Status AllocateOutputs(AllocationAttr *attr = nullptr); Status AllocateWorkspaces(); Status AllocateWorkspace(size_t size, void **buffer, void *ori_addr = nullptr); - Status UpdatePersistTensor(); bool IsTraceEnabled() const; From dbc989a4c39b33d16fcbc5598f5b67ad5644ad5f Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Sat, 26 Jun 2021 15:42:55 +0800 Subject: [PATCH 18/24] Clear UpdatePersistTensor Warning for first run --- ge/graph/passes/mark_force_unknown_for_cond_pass.cc | 6 +++--- ge/graph/passes/mark_force_unknown_for_cond_pass.h | 2 +- ge/hybrid/executor/node_state.cc | 4 ++++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc index a4095c1b..e024217f 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -56,8 +56,8 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { /// @param [out] Search queue /// @return true: Switch In while loop / false: Not in while Loop. /// -bool MarkForceUnknownForCondPass::DealWithLoopSwitch(const NodePtr &node, uint32_t dst_span, - std::queue> search_queue) { +bool MarkForceUnknownForCondPass::DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, + std::queue> &search_queue) { /// LoopCond --->\. /// \. /// Enter-----------+ \. @@ -121,7 +121,7 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), in_node->GetName().c_str(), dst_span); if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. - if (DealWithLoopSwitch(in_node, dst_span, search_queue)) { + if (DealAsLoopSwitch(in_node, dst_span, search_queue)) { continue; } diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.h b/ge/graph/passes/mark_force_unknown_for_cond_pass.h index d2be9a9e..030b55ee 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.h +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.h @@ -34,7 +34,7 @@ class MarkForceUnknownForCondPass : public GraphPass { /// @param [out] Search queue /// @return true: Switch In while loop / false: Not in while Loop. /// - bool DealWithLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue> search_queue); + bool DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue> &search_queue); /// /// @brief Mark force unknown shape for Switch node diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 7ab7b536..ad38c792 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -355,6 +355,10 @@ void NodeState::UpdatePersistTensor() { } }; + if (root_tensor_values_.empty()) { + return; + } + update_tensor(node_item_->root_data_); if (iteration_count_ > 0) { update_tensor(node_item_->enter_data_); From b23075b62f4178fb445fe7fd11dee2f46af6b9f2 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Sat, 26 Jun 2021 23:20:31 +0800 Subject: [PATCH 19/24] UT for control flow group --- tests/depends/mmpa/src/mmpa_stub.cc | 4 + tests/ut/ge/CMakeLists.txt | 3 +- .../build/logical_stream_allocator_unittest.cc | 10 +- .../ut/ge/graph/build/stream_allocator_unittest.cc | 2 +- tests/ut/ge/graph/passes/assert_pass_unittest.cc | 6 +- tests/ut/ge/graph/passes/base_pass_unittest.cc | 14 +-- .../ut/ge/graph/passes/cond_branch_v1_unittest.cc | 6 +- .../graph/passes/constant_folding_pass_unittest.cc | 38 +++--- .../passes/dimension_compute_pass_unittest.cc | 8 +- .../ssd_prior_box_kernel_unittest.cc | 2 +- ...e_data_nodes_with_common_input_pass_unittest.cc | 2 +- .../mark_force_unknown_for_cond_pass_unittest.cc | 129 +++++++++++++++------ tests/ut/ge/graph/passes/merge_pass_unittest.cc | 28 ++--- .../graph/passes/parallel_group_pass_unittest.cc | 12 +- .../graph/passes/reshape_recovery_pass_unittest.cc | 6 +- .../graph/passes/reshape_remove_pass_unittest.cc | 16 +-- .../passes/resource_pair_control_pass_unittest.cc | 2 +- .../passes/switch_logic_remove_pass_unittest.cc | 12 +- .../trans_op_breadth_fusion_pass_unittest.cc | 4 +- .../passes/trans_op_depth_fusion_pass_unittest.cc | 14 +-- ...ransop_nearby_allreduce_fusion_pass_unittest.cc | 4 +- .../ge/graph/passes/variable_op_pass_unittest.cc | 2 +- .../ge/graph/variable_accelerate_ctrl_unittest.cc | 10 +- 23 files changed, 195 insertions(+), 139 deletions(-) diff --git a/tests/depends/mmpa/src/mmpa_stub.cc b/tests/depends/mmpa/src/mmpa_stub.cc index aae8de9f..b0f1fb87 100644 --- a/tests/depends/mmpa/src/mmpa_stub.cc +++ b/tests/depends/mmpa/src/mmpa_stub.cc @@ -345,6 +345,10 @@ INT32 mmIsDir(const CHAR *fileName) INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) { + const char *env = getenv(name); + if (env != nullptr) { + strcpy(value, env); + } return 0; } diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index c5790d59..d55b8861 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -726,7 +726,6 @@ set(PASS_TEST_FILES "graph/passes/memcpy_addr_async_unittest.cc" "graph/passes/hccl_continuous_pass_unittest.cc" "graph/passes/hccl_memcpy_pass_unittest.cc" - ) set(KERNEL_TEST_FILES @@ -859,7 +858,6 @@ set(HYBRID_TEST_FILES "hybrid/executor/hybrid_model_async_executor_unittest.cc" "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" - ) set(OTHERS_TEST_FILES @@ -887,6 +885,7 @@ add_library(ge_ut_graph STATIC target_compile_definitions(ge_ut_graph PRIVATE google=ascend_private + FMK_SUPPORT_DUMP ) target_compile_options(ge_ut_graph PRIVATE diff --git a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc index 218bfd0d..352984fa 100644 --- a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc +++ b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc @@ -349,7 +349,7 @@ class UtestLogicalStreamAllocator : public testing::Test { /// B --> C(AllReduce) --- D /// / /// stream id: 0 A - /// \ + /// \. /// E --> F(AllReduce) --- G /// stream id: 2 2 2 /// @@ -599,7 +599,7 @@ TEST_F(UtestLogicalStreamAllocator, test_label_not_reusable2) { /// case of multi-output, then unuse stream /// sub1 -/// / | \ +/// / | \. /// sub2 sub3 sub4 TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { SubGraphInfoPtr data = CreateDataSubgraph(); @@ -624,7 +624,7 @@ TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { /// if paralle id 1, then use stream /// sub1 -/// / | | \ +/// / | | \. /// sub2 sub3 sub4 sub5 TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { SubGraphInfoPtr data = CreateDataSubgraph(); @@ -653,7 +653,7 @@ TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { /// if the param of engine independent is true, then set independent stream /// sub1 -/// / | | \ +/// / | | \. /// sub2 sub3 sub4 sub5 TEST_F(UtestLogicalStreamAllocator, test_independent) { SubGraphInfoPtr data = CreateDataSubgraph(); @@ -692,7 +692,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) { /// set stream based on stream label, and then based on independent /// sub1 -/// / | | \ +/// / | | \. /// sub2 sub3 sub4 sub5 TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { SubGraphInfoPtr data = CreateDataSubgraph(); diff --git a/tests/ut/ge/graph/build/stream_allocator_unittest.cc b/tests/ut/ge/graph/build/stream_allocator_unittest.cc index 019e75d1..4ae871af 100644 --- a/tests/ut/ge/graph/build/stream_allocator_unittest.cc +++ b/tests/ut/ge/graph/build/stream_allocator_unittest.cc @@ -36,7 +36,7 @@ class UtestStreamAllocator : public testing::Test { /// /// A - /// / \ + /// / \. /// B C /// | | /// D 400 diff --git a/tests/ut/ge/graph/passes/assert_pass_unittest.cc b/tests/ut/ge/graph/passes/assert_pass_unittest.cc index 4aa133d3..9247681c 100644 --- a/tests/ut/ge/graph/passes/assert_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/assert_pass_unittest.cc @@ -55,7 +55,7 @@ class UtestGraphPassesAssertPass : public Test { }; /// D E -/// | \ | \ +/// | \ | \. /// F C G /// : | : /// H A I @@ -134,8 +134,8 @@ TEST_F(UtestGraphPassesAssertPass, assert_pass_test2) { EXPECT_EQ(graph->FindNode("D"), nullptr); } -/// E F -/// | \ | \ +/// E F +/// | \ | \. /// H C -> D G /// \ | : /// A I diff --git a/tests/ut/ge/graph/passes/base_pass_unittest.cc b/tests/ut/ge/graph/passes/base_pass_unittest.cc index 9bba5d77..c687e07f 100644 --- a/tests/ut/ge/graph/passes/base_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/base_pass_unittest.cc @@ -130,7 +130,7 @@ class UTESTGraphPassesBasePass : public testing::Test { /// reshape1 /// | /// add1 -/// / \ +/// / \. /// | | /// data1 const1 ComputeGraphPtr BuildGraph1() { @@ -148,9 +148,9 @@ ComputeGraphPtr BuildGraph1() { } /// sum1 -/// / \ -/// / \ -/// / \ +/// / \. +/// / \. +/// / \. /// reshape1 addn1 /// | c | /// add1 <--- shape1 @@ -217,7 +217,7 @@ void CheckIterOrder(UtestTestPass *pass, std::vector &loop, vector &cond) { /******************************************************************************* - * Exit Identify - * \ / \. - * \ / \. - * Switch Add - * / | | - * / | | - * / | | - * LoopCond | | - * \ | | - * \ | | - * \ | | - * Less | | - * \ | NextIteration - * \ | | - * \ | | - * Merge <---------| - * | - * | - * Enter + * | + * +--------------------- Merge ----------------------+ + * / | + * / | + * / | + * / | + * Exit Identify | + * \ / \. | + * \ / \. | + * Switch Add Add + * / | | | + * / | | | + * / | | | + * LoopCond | | | + * \ | | | + * \ | | | + * \ | | | + * Less | | | + * \ | NextIteration | + * \ | | | + * \ | | | + * Merge <---------| | + * | | + * | | + * Enter | + * \ | + * \ | + * Switch Switch + * | | + * +-----------------Equal----------------------+ + * | ******************************************************************************/ - auto data1 = CreateNode(*graph, "data", DATA, 1, 1); + auto data1 = CreateNode(*graph, "data1", DATA, 1, 1); + auto data2 = CreateNode(*graph, "data2", DATA, 1, 1); + + auto equal1 = CreateNode(*graph, "equal1", EQUAL, 2, 1); + auto switch1 = CreateNode(*graph, "switch1", SWITCH, 2, 2); + auto switch2 = CreateNode(*graph, "switch2", SWITCH, 2, 2); + auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); - auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); - auto less1 = CreateNode(*graph, "less", LESS, 2, 1); + auto merge1 = CreateNode(*graph, "merge1", MERGE, 2, 2); + auto less1 = CreateNode(*graph, "less1", LESS, 2, 1); auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); - auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); + auto switch3 = CreateNode(*graph, "switch3", SWITCH, 2, 2); auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); - auto add1 = CreateNode(*graph, "add", ADD, 2, 1); + auto add1 = CreateNode(*graph, "add1", ADD, 2, 1); auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); - auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); - auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); + auto value1 = CreateNode(*graph, "const1", CONSTANT, 0, 1); + + auto value2 = CreateNode(*graph, "const2", CONSTANT, 0, 1); + auto add2 = CreateNode(*graph, "add2", ADD, 2, 1); + auto merge2 = CreateNode(*graph, "merge2", MERGE, 2, 2); auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); - GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), equal1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data2->GetOutDataAnchor(0), equal1->GetInDataAnchor(1)); + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0)); + GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); + GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1)); + cond.emplace_back(switch1); + cond.emplace_back(switch2); + + GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); // false GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); - GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); + GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch3->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1)); + loop.emplace_back(merge1); - GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); - GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); // false + GraphUtils::AddEdge(switch3->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); // true + loop.emplace_back(switch3); GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); - GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); - GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); - merge = merge1; + GraphUtils::AddEdge(switch2->GetOutDataAnchor(1), add2->GetInDataAnchor(1)); // true + GraphUtils::AddEdge(value2->GetOutDataAnchor(0), add2->GetInDataAnchor(0)); + + GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), merge2->GetInDataAnchor(0)); + GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1)); + GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); + + cond.emplace_back(merge2); + merge = merge2; } static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { @@ -197,12 +235,27 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { auto graph = std::make_shared("test_graph"); NodePtr merge; - CreateLoopGraph(graph, merge); - - AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); + vector loop; + vector cond; + CreateLoopGraph(graph, merge, loop, cond); MarkForceUnknownForCondPass mark_force_unknown_pass; EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond + setenv("DUMP_GE_GRAPH", "1", true); + GE_DUMP(graph, "control_group"); + unsetenv("DUMP_GE_GRAPH"); + + EXPECT_EQ(loop.size(), 2); + for (const auto &node : loop) { + EXPECT_FALSE(node->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)); + } + + EXPECT_EQ(cond.size(), 3); + for (const auto &node : cond) { + int64_t group_index = -1; + EXPECT_TRUE(AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)); + EXPECT_EQ(group_index, merge->GetOpDesc()->GetId()); + } } TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { diff --git a/tests/ut/ge/graph/passes/merge_pass_unittest.cc b/tests/ut/ge/graph/passes/merge_pass_unittest.cc index 75fdb21b..f8f0afea 100644 --- a/tests/ut/ge/graph/passes/merge_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/merge_pass_unittest.cc @@ -110,8 +110,8 @@ TEST_F(UtestGraphPassesMergePass, multiple_inputs) { } /// Merge -/// | \ -/// | \ +/// | \. +/// | \. /// Op1 Op2 Merge2 /// \ | | /// \ | Op3 @@ -137,10 +137,10 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_da } /// Merge -/// | \ -/// | \ +/// | \. +/// | \. /// Op1 Op2 Merge2 -/// \ | | \ +/// \ | | \. /// \ | Op3 /// \ | : /// NetOutput @@ -165,8 +165,8 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_co TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { /// Merge - /// | \ - /// | \ + /// | \. + /// | \. /// Op1 Op2 Merge2 /// \ | | /// \ | Op3 @@ -210,7 +210,7 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { /// Op1 Op2 Merge2 /// \ | /// \ Op3 - /// \ + /// \. /// Merge3 ret = pass_.Run(merge_node2); @@ -224,7 +224,7 @@ TEST_F(UtestGraphPassesMergePass, single_non_const_input) { /// Op1 /// | /// Merge - /// / \ + /// / \. /// Op2 Op3 auto merge_node = NewNode("Merge", MERGE, 1, 2); auto node1 = NewNode("Op1", RELU, 1, 1); @@ -253,7 +253,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input) { /// Const /// | /// Merge Pass Const - /// / \ ===> / \ + /// / \ ===> / \. /// Op1 Op2 Op1 Op2 auto merge_node = NewNode("Merge", MERGE, 1, 2); auto const_node = NewNode("Const", CONSTANT, 1, 1); @@ -284,7 +284,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes) /// / | ===> / \(control anchor) /// Op1 | \ Op1 Constant /// Op2 Op3 | - /// / \ + /// / \. /// Op2 Op3 auto merge_node = NewNode("Merge", MERGE, 1, 2); auto const_node = NewNode("Const", CONSTANT, 1, 1); @@ -329,7 +329,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes1) /// / | ===> / \(control anchor) /// Op1 | \ Op1 Constant /// Op2 Op3 | - /// / \ + /// / \. /// Op2 Op3 auto merge_node = NewNode("Merge", MERGE, 1, 2); auto const_node = NewNode("Const", CONSTANT, 1, 1); @@ -357,7 +357,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { /// C /// | /// Merge - /// / \ + /// / \. /// Op1 Op2 auto switch_node = NewNode("Switch", SWITCH, 1, 2); auto identity_node = NewNode("Identity", SWITCH, 1, 1); @@ -381,7 +381,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { /// . /// . /// C - /// / \ + /// / \. /// Op1 Op2 auto ret = pass_.Run(merge_node); EXPECT_EQ(ret, SUCCESS); diff --git a/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc index d5b1db41..374fe837 100644 --- a/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc @@ -66,11 +66,11 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { void BuildDefaultGraph() { /// input - /// \ + /// \. /// sqrt pred /// \ / /// cast - /// / \ + /// / \. /// switch_t switch_f /// | | /// F T @@ -118,13 +118,13 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { void BuildDefaultGraph1() { /// input - /// \ + /// \. /// sqrt pred /// \ / /// Switch /// | | /// ----F T---- - /// \ | / \ + /// \ | / \. /// \ Merge1 Merge2 /// \_________| input_node_ = NewNode("input", RELU, 0, 1); @@ -164,14 +164,14 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { void BuildDefaultGraph2() { /// input input1 - /// \ \ + /// \ \. /// sqrt pred sqrt1 pred1 /// \ / \ / /// Switch Switch1 /// | | _______| /// | | / /// ____F T____ - /// \ | / \ + /// \ | / \. /// \ Merge1 Merge2 /// \__________| input_node_ = NewNode("input", RELU, 0, 2); diff --git a/tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc b/tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc index 3be11452..f941645e 100644 --- a/tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc @@ -31,9 +31,9 @@ class UtestReshapeRecoveryPass : public testing::Test { namespace { /// netoutput1 -/// | \ -///transdata1 \ -/// | \ +/// | \. +///transdata1 \. +/// | \. /// | transdata2 /// | / /// var1 const1 diff --git a/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc b/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc index 351e96d7..ca0cac86 100644 --- a/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc @@ -35,7 +35,7 @@ namespace { /// transdata1 /// | /// reshape1 -/// | \ +/// | \. /// var1 const1 ut::GraphBuilder Graph1Builder() { ut::GraphBuilder builder = ut::GraphBuilder("g1"); @@ -55,11 +55,11 @@ ut::GraphBuilder Graph1Builder() { } /// netoutput1 -/// | \ -///transdata1 \ -/// | \ +/// | \. +///transdata1 \. +/// | \. /// reshape1 reshape2 -/// | \ / \ +/// | \ / \. /// var1 const1 var2 ut::GraphBuilder Graph2Builder() { ut::GraphBuilder builder = ut::GraphBuilder("g2"); @@ -83,9 +83,9 @@ ut::GraphBuilder Graph2Builder() { } /// netoutput1 -/// | \ -///transdata1 \ -/// | \ +/// | \. +///transdata1 \. +/// | \. /// reshape1 transdata2 /// | \ / /// var1 const1 diff --git a/tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc b/tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc index 6d12a49d..8cdfd0c7 100644 --- a/tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc @@ -34,7 +34,7 @@ class UtestResourcePairControlPass : public testing::Test { namespace { /// netoutput1 -/// | \ +/// | \. /// StackPush StackPop /// | | /// var1 const1 diff --git a/tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc b/tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc index dcad318c..22734047 100644 --- a/tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc @@ -63,9 +63,9 @@ ComputeGraphPtr BuildGraph1() { /// netoutput1 /// | /// merge1 -/// / \ +/// / \. /// / add1 -/// / F| \ +/// / F| \. /// addn1 swtich2 var3 /// \F T/ | /// switch1 | @@ -101,9 +101,9 @@ ComputeGraphPtr BuildGraph2() { /// add1 /// / \T /// var3 swtich2 -/// T/ \ -/// switch1 \ -/// / \ \ +/// T/ \. +/// switch1 \. +/// / \ \. /// var1 var2 var4 ComputeGraphPtr BuildGraph3() { auto builder = ut::GraphBuilder("g3"); @@ -129,7 +129,7 @@ ComputeGraphPtr BuildGraph3() { /// netoutput1 /// | /// merge1 -/// / \ +/// / \. /// add1 addn1 /// / \T F/ /// var3 swtich2 diff --git a/tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc b/tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc index dbb163e1..d05bd695 100644 --- a/tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc @@ -402,7 +402,7 @@ TEST_F(UtestGraphPassesTransOpBreadthFusionPass, test_multi_anchor_case) { } /// ----> netoutput1 -/// / | \ +/// / | \. /// transdata1 transdata2 transdata3 /// \ / | /// var1-------------- @@ -432,7 +432,7 @@ static ComputeGraphPtr BuildGraph1() { } /// ---------> netoutput1 -/// / | \ +/// / | \. /// transdata1 transdata2(l1) transdata3(l1) /// \ / | /// var1------------------ diff --git a/tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc b/tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc index a9ea41ea..dbac3246 100644 --- a/tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc @@ -456,19 +456,19 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) /// -->transpose1 -->transpose3-->sinh2 /// | \ / /// | -->transpose2 - /// | \ + /// | \. /// / -->cast3-->cast4-->sinh3 /// / /// / -->transpose4-->transpose5-->sinh4 /// / / /// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5 - /// \ \ + /// \ \. /// \ -->sinh6 - /// \ + /// \. /// \ -->transpose6-->transpose7-->sinh9 /// \ / /// -->reshape-->cast6-->cast7-->sinh8 - /// \ + /// \. /// -->sinh7 /// after optimized graph @@ -479,15 +479,15 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) /// / /-->transpose3-->sinh2 /// -->Cast1 /// / \-->sinh7 - /// / \ + /// / \. /// / -->sinh9 /// Node4D /// \ -->sinh4 /// \ / /// -->Cast5-->sinh5 - /// \ \ + /// \ \. /// \ -->sinh6 - /// \ + /// \. /// -->Cast7-->sinh8 ge::ComputeGraphPtr graph = std::make_shared("test"); diff --git a/tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc b/tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc index 1220b35e..9c6d8276 100644 --- a/tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc @@ -180,7 +180,7 @@ ComputeGraphPtr GetGraph7(size_t symmetric_transdata_num, size_t asymmetric_tran /// TransData TransData ... MatMul ... /// \ | / / / /// HcomAllReduce - /// / | \ \ \ + /// / | \ \ \. /// TransData TransData ... RealDiv ... ComputeGraphPtr graph = std::make_shared("test"); NodePtr allreduce = @@ -340,7 +340,7 @@ TEST(UtestTransopNearbyAllreduceFusionPass, test7_all_reduce_with_multiple_trans /// TransData TransData ... MatMul ... /// \ | / / / /// HcomAllReduce - /// / | \ \ \ + /// / | \ \ \. /// TransData TransData ... RealDiv ... size_t symmetric_transdata_num = 20; size_t asymmetric_transdata_num = 20; diff --git a/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc b/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc index f1ea7a27..655867a7 100644 --- a/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc @@ -66,7 +66,7 @@ namespace { /// transdata2 /// | /// assign1 -/// / \ +/// / \. /// transdata1 | /// | | /// var1 const1 diff --git a/tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc b/tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc index 37b4bda7..bf350b6c 100644 --- a/tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc +++ b/tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc @@ -35,8 +35,8 @@ namespace { /// shapeNo1 /// | /// addnYes1 -/// / \ -/// / \ +/// / \. +/// / \. /// const1 const2 ComputeGraphPtr BuildGraph1() { @@ -57,9 +57,9 @@ ComputeGraphPtr BuildGraph1() { /// /// netoutput1 -/// / \ \ -/// add1 assign1 \ -/// / \ / \ \ +/// / \ \. +/// add1 assign1 \. +/// / \ / \ \. /// var1 var2 const1 var3 ComputeGraphPtr BuildGraph2() { From 61c203619c3b48417dd366fc022b235bef74a5b6 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Sat, 26 Jun 2021 23:25:58 +0800 Subject: [PATCH 20/24] Remove UT dump env --- tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc b/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc index 50991822..557359b7 100644 --- a/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc @@ -241,9 +241,6 @@ TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { MarkForceUnknownForCondPass mark_force_unknown_pass; EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond - setenv("DUMP_GE_GRAPH", "1", true); - GE_DUMP(graph, "control_group"); - unsetenv("DUMP_GE_GRAPH"); EXPECT_EQ(loop.size(), 2); for (const auto &node : loop) { From ac1f4eb1c233c4dd6fa7b6ce31ef5e9c0e217282 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Mon, 28 Jun 2021 19:23:31 +0800 Subject: [PATCH 21/24] DSP: Switch -> TransData -> Cast -> Exit --- ge/graph/passes/next_iteration_pass.cc | 33 +++++++++++++++++---------- ge/hybrid/executor/worker/execution_engine.cc | 2 +- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc index fb8f8627..1c2d7218 100644 --- a/ge/graph/passes/next_iteration_pass.cc +++ b/ge/graph/passes/next_iteration_pass.cc @@ -24,7 +24,9 @@ using std::string; namespace ge { namespace { -const int64_t kLoopType = 1; +constexpr int64_t kLoopType = 1; +constexpr uint8_t kMaxTransOp = 3; +constexpr uint8_t kTransOpIoSize = 1; } Status NextIterationPass::Run(ComputeGraphPtr graph) { @@ -287,18 +289,25 @@ void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, i std::string node_type; for (const auto &switch_node : loop_group.switch_nodes) { SetControlFlowGroup(switch_node, group_index); - for (const auto &node : switch_node->GetOutDataNodes()) { - (void)GetOriginalType(node, node_type); - if (kExitOpTypes.count(node_type) > 0) { - SetControlFlowGroup(node, group_index); - } else { - // For: Switch -> Cast -> Exit - for (const auto &n : node->GetOutDataNodes()) { - (void)GetOriginalType(n, node_type); - if (kExitOpTypes.count(node_type) > 0) { - SetControlFlowGroup(n, group_index); - } + for (auto node : switch_node->GetOutDataNodes()) { + // Switch --> Exit + // Switch --> Cast --> Exit + // Switch --> TransData --> Cast --> Exit + for (uint8_t i = 0; i < kMaxTransOp; ++i) { + if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) { + break; } + + if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { + SetControlFlowGroup(node, group_index); + break; + } + + const auto &all_nodes = node->GetOutAllNodes(); + if (all_nodes.size() != kTransOpIoSize) { + break; + } + node = all_nodes.at(0); } } } diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc index d4c73f58..ca864244 100755 --- a/ge/hybrid/executor/worker/execution_engine.cc +++ b/ge/hybrid/executor/worker/execution_engine.cc @@ -373,9 +373,9 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, auto executor = node_item.node_executor; GE_CHECK_NOTNULL(executor); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); + node_state.UpdatePersistTensor(); GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", node_state.GetName().c_str()); - node_state.UpdatePersistTensor(); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); GELOGD("[%s] Done task preparation successfully.", node_state.GetName().c_str()); From 2d9e3da649a1c407bfcd78127b40dc3e0cbc62fb Mon Sep 17 00:00:00 2001 From: lianghao Date: Thu, 1 Jul 2021 18:02:40 +0800 Subject: [PATCH 22/24] IsEnterFeedNode --- ge/hybrid/model/node_item.cc | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 8e87c6e2..77bd8efd 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -24,6 +24,8 @@ namespace ge { namespace hybrid { namespace { +const uint8_t kMaxTransCount = 3; +const uint32_t kTransOpIoSize = 1; const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; const char *const kNodeTypeRetVal = "_RetVal"; const std::set kControlOpTypes{ @@ -39,6 +41,25 @@ const std::set kMergeOpTypes{ MERGE, REFMERGE, STREAMMERGE }; +bool IsEnterFeedNode(NodePtr node) { + // For: Enter -> node + // For: Enter -> Cast -> node + // For: Enter -> TransData -> Cast -> node + for (uint8_t i = 0; i < kMaxTransCount; ++i) { + if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { + GELOGD("Node[%u] is Enter feed node.", node->GetName().c_str()); + return true; + } + + const auto all_nodes = node->GetInDataNodes(); + if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) { + return false; + } + node = all_nodes.at(0); + } + return false; +} + Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { uint32_t parent_index = 0; if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { @@ -399,7 +420,7 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { data_anchors.emplace(anchor_index); } // If Enter feed Not Merge, take as root Node. - if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { + if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) { auto &data_anchors = node_item->enter_data_[this]; data_anchors.emplace(anchor_index); } @@ -419,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { node_item->root_ctrl_.emplace(this); } // If Enter feed control signal, take as root Node. - if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { + if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { node_item->enter_ctrl_.emplace(this); } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); From 2785670745b7000e350121a668803a1f1686d5cb Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Fri, 2 Jul 2021 15:05:01 +0800 Subject: [PATCH 23/24] fix printf format --- ge/hybrid/model/node_item.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 77bd8efd..fb7add48 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -47,7 +47,7 @@ bool IsEnterFeedNode(NodePtr node) { // For: Enter -> TransData -> Cast -> node for (uint8_t i = 0; i < kMaxTransCount; ++i) { if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { - GELOGD("Node[%u] is Enter feed node.", node->GetName().c_str()); + GELOGD("Node[%s] is Enter feed node.", node->GetName().c_str()); return true; } From 3929578deee3fd93db6b8ec25fa3ea9d26b23d6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=8D=8E?= Date: Sat, 3 Jul 2021 11:19:34 +0800 Subject: [PATCH 24/24] fix parallel group --- ge/graph/passes/parallel_group_pass.cc | 58 +++++++++++----- ge/graph/passes/parallel_group_pass.h | 1 + .../graph/passes/parallel_group_pass_unittest.cc | 80 +++++++++++++++++++++- 3 files changed, 119 insertions(+), 20 deletions(-) diff --git a/ge/graph/passes/parallel_group_pass.cc b/ge/graph/passes/parallel_group_pass.cc index 9c93f6cf..795002f1 100644 --- a/ge/graph/passes/parallel_group_pass.cc +++ b/ge/graph/passes/parallel_group_pass.cc @@ -15,7 +15,7 @@ */ #include "graph/passes/parallel_group_pass.h" - +#include #include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" #include "framework/common/ge_inner_error_codes.h" @@ -299,24 +299,19 @@ Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cu for (const auto &switch_node : cur_itr->second.first) { int64_t pre_id = pre_node->GetOpDesc()->GetId(); int64_t switch_id = switch_node->GetOpDesc()->GetId(); - // avoid ring - if (pre_id > switch_id) { - auto merge_node = cur_itr->second.second; - if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) { - GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - return FAILED; - } - } else { - if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) { - GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - return FAILED; - } + NodePtr first_node = pre_node; + NodePtr second_node = switch_node; + if (pre_id > switch_id && IsIndirectConnect(switch_node, pre_node)) { + // avoid ring, merge->pre_node + first_node = cur_itr->second.second; + second_node = pre_node; + } + if (AddCtrlEdge(first_node, second_node) != SUCCESS) { + GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + first_node->GetName().c_str(), second_node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + first_node->GetName().c_str(), second_node->GetName().c_str()); + return FAILED; } } } else { @@ -345,4 +340,29 @@ bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) { return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) && stream_switch_type == kLoopType); } + +bool ParallelGroupPass::IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b) { + if (node_a == nullptr || node_b == nullptr) { + GELOGW("node_a or node_b is nullptr."); + return false; + } + int64_t end_id = node_b->GetOpDesc()->GetId(); + std::queue nodes; + nodes.push(node_a); + while (!nodes.empty()) { + NodePtr tmp_node = nodes.front(); + nodes.pop(); + if (tmp_node == nullptr || tmp_node->GetOpDesc() == nullptr || + tmp_node->GetOpDesc()->GetId() > end_id) { + continue; + } + if (tmp_node == node_b) { + return true; + } + for (const auto &out_node : tmp_node->GetOutAllNodes()) { + nodes.push(out_node); + } + } + return false; +} } // namespace ge diff --git a/ge/graph/passes/parallel_group_pass.h b/ge/graph/passes/parallel_group_pass.h index 9b895598..31c87dba 100644 --- a/ge/graph/passes/parallel_group_pass.h +++ b/ge/graph/passes/parallel_group_pass.h @@ -48,6 +48,7 @@ class ParallelGroupPass : public GraphPass { bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc); bool IsWhileStreamSwitch(OpDescPtr switch_op_desc); + bool IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b); }; } // namespace ge #endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H diff --git a/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc index d5b1db41..588eac9a 100644 --- a/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc @@ -19,7 +19,8 @@ #include #define private public - +#include "inc/graph/ge_local_context.h" +#include "inc/external/ge/ge_api_types.h" #include "common/ge_inner_error_codes.h" #include "inc/pass_manager.h" #include "utils/graph_utils.h" @@ -225,6 +226,70 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { output_true_node_->GetOpDesc()->SetIsInputConst({false}); } + void BuildDefaultGraph3() { + /// input + /// \ + /// sqrt pred + /// \ / + /// Switch + /// | | + /// F T ------ + /// / \_/_ \ + /// / / \ \ + /// Merge sqrt2 sqrt3 + /// / \ \ + /// sqrt1 \ relu + /// \ \ + /// \ sqrt4 + /// \ / + /// Merge1 + input_node_ = NewNode("input", RELU, 0, 1); + AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + pred_node_ = NewNode("pred", GREATER, 2, 1); + sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); + cast_node_ = NewNode("cast", CAST, 2, 2); + + switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); + switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); + output_false_node_ = NewNode("false_output", RELU, 1, 2); + AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + output_true_node_ = NewNode("true_output", RELU, 1, 2); + AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); + sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1); + AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + sqrt_node2_ = NewNode("sqrt2", SQRT, 1, 1); + AttrUtils::SetStr(sqrt_node2_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + sqrt_node3_ = NewNode("sqrt3", SQRT, 1, 1); + relu_node_ = NewNode("relu", RELU, 1, 1); + sqrt_node4_ = NewNode("sqrt4", SQRT, 1, 1); + AttrUtils::SetStr(sqrt_node4_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1); + + GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); + + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), sqrt_node2_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), sqrt_node3_->GetInDataAnchor(0)); + + GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node3_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node4_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node2_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node4_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(1)); + output_false_node_->GetOpDesc()->SetIsInputConst({false}); + output_true_node_->GetOpDesc()->SetIsInputConst({false}); + } + ComputeGraphPtr graph_; ComputeGraphPtr sub_graph_; GeTensorDescPtr default_tensor_desc_; @@ -235,6 +300,9 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { NodePtr cast_node1_; NodePtr sqrt_node_; NodePtr sqrt_node1_; + NodePtr sqrt_node2_; + NodePtr sqrt_node3_; + NodePtr sqrt_node4_; NodePtr input_node_; NodePtr input_node1_; NodePtr switch_node_t; @@ -278,6 +346,16 @@ TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) { EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor())); } +TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph3) { + std::map options; + options.emplace(OPTION_GRAPH_RUN_MODE, "1"); + GetThreadLocalContext().SetGraphOption(options); + BuildDefaultGraph3(); + auto ret = pass_.Run(graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); + EXPECT_EQ(true, merge_node1_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor())); +} + TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { BuildDefaultGraph1(); NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true);