From 3401ca857ca4511067f91c624624914d2c6e70e4 Mon Sep 17 00:00:00 2001 From: wjm Date: Sun, 28 Feb 2021 14:23:58 +0800 Subject: [PATCH 1/9] dump for unknownshape --- ge/hybrid/executor/hybrid_execution_context.h | 1 + ge/hybrid/executor/hybrid_model_executor.cc | 6 ++++++ ge/hybrid/executor/worker/execution_engine.cc | 7 +------ .../node_executor/compiledsubgraph/known_node_executor.cc | 6 +----- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 15 ++++++++++++++- 5 files changed, 23 insertions(+), 12 deletions(-) mode change 100755 => 100644 ge/hybrid/executor/hybrid_model_executor.cc mode change 100755 => 100644 ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc diff --git a/ge/hybrid/executor/hybrid_execution_context.h b/ge/hybrid/executor/hybrid_execution_context.h index 4dc010df..003e8010 100644 --- a/ge/hybrid/executor/hybrid_execution_context.h +++ b/ge/hybrid/executor/hybrid_execution_context.h @@ -71,6 +71,7 @@ struct GraphExecutionContext { std::atomic_bool is_eos_; long profiling_level = 0; long iteration = 0; + void *global_step = nullptr; private: Status status = SUCCESS; diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc old mode 100755 new mode 100644 index 80b8983a..4b589a03 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -33,6 +33,9 @@ HybridModelExecutor::~HybridModelExecutor() { if (context_.rt_gen_context != nullptr) { (void) rtCtxDestroy(context_.rt_gen_context); } + if (context_.global_step != nullptr) { + (void) rtFree(context_.global_step); + } } Status HybridModelExecutor::Init() { @@ -47,6 +50,8 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { auto root_graph_item = model_->GetRootGraphItem(); GE_CHECK_NOTNULL(root_graph_item); + GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, + sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); auto ret = ExecuteGraphInternal(executor, args); Cleanup(); @@ -97,6 +102,7 @@ Status HybridModelExecutor::InitExecutionContext() { GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); + GE_CHK_RT_RET(rtMalloc(&context_.global_step, sizeof(uint64_t), RT_MEMORY_HBM)); context_.stream = stream_; context_.model = model_; diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc index 673c82dd..de3bdc37 100755 --- a/ge/hybrid/executor/worker/execution_engine.cc +++ b/ge/hybrid/executor/worker/execution_engine.cc @@ -231,12 +231,6 @@ Status NodeDoneCallback::DumpDynamicNode() { uint32_t model_id = model->GetModelId(); dump_op_.SetDynamicModelInfo(dynamic_model_name, model_id); - void *global_step = nullptr; - TensorValue *varible_global_step = context_->GetVariable(NODE_NAME_GLOBAL_STEP); - if (varible_global_step != nullptr) { - global_step = const_cast(varible_global_step->GetData()); - } - void *loop_per_iter = nullptr; TensorValue *varible_loop_per_iter = context_->GetVariable(NODE_NAME_FLOWCTRL_LOOP_PER_ITER); if (varible_loop_per_iter != nullptr) { @@ -248,6 +242,7 @@ Status NodeDoneCallback::DumpDynamicNode() { if (varible_loop_cond != nullptr) { loop_cond = const_cast(varible_loop_cond->GetData()); } + void *global_step = context_->GetExecutionContext()->global_step; dump_op_.SetLoopAddr(global_step, loop_per_iter, loop_cond); GE_CHK_STATUS_RET(dump_op_.LaunchDumpOp(), "Failed to launch dump op in hybird model"); diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc old mode 100755 new mode 100644 index cf5ac851..bb96c275 --- a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc +++ b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc @@ -126,11 +126,7 @@ Status KnownNodeTask::Init(TaskContext &context) { auto dump_properties = context.GetDumpProperties(); if (dump_properties.IsDumpOpen() || dump_properties.IsOpDebugOpen()) { davinci_model_->SetDumpProperties(dump_properties); - void *global_step = nullptr; - TensorValue *varible_global_step = context.GetVariable(NODE_NAME_GLOBAL_STEP); - if (varible_global_step != nullptr) { - global_step = varible_global_step->MutableData(); - } + void *global_step = context.GetExecutionContext()->global_step; davinci_model_->SetKnownShapeGlobalStep(global_step); } int32_t device_id = 0; diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index d7116dbc..3b5d19e6 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -30,6 +30,7 @@ #include "framework/common/debug/log.h" #include "graph/ge_context.h" #include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/hybrid_model_executor.h" #include "hybrid/node_executor/aicore/aicore_task_builder.h" #include "graph/load/model_manager/tbe_handle_store.h" #include "graph/manager/graph_mem_allocator.h" @@ -242,4 +243,16 @@ TEST_F(UtestGeHybrid, init_weight_success) { ge_sub_model->SetWeight(weight_buffer); ret = hybrid_model_builder.InitWeights(); ASSERT_EQ(ret,PARAM_INVALID); -} \ No newline at end of file +} + + TEST_F(UtestGeHybrid, hybrid_model_executor) { + ComputeGraphPtr compute_graph = MakeShared("abc"); + GeRootModelPtr root_model = MakeShared(compute_graph); + HybridModel model(root_model); + HybridModel *model_ptr = &model; + + uint32_t device_id = 0; + rtStream_t stream; + HybridModelExecutor executor(model_ptr, device_id, stream); + executor.Init(); +} From 77d5468cf651b33347d53bdafb0fe267fd831fee Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 24 Mar 2021 16:55:41 +0800 Subject: [PATCH 2/9] Fix bug of single_op inferdepend. --- ge/single_op/single_op_model.cc | 41 +++++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 49dde9c4..a5550deb 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -44,19 +44,46 @@ namespace ge { namespace { const size_t kDataOutputNum = 1; -bool NeedHybridModel(GeModelPtr &ge_model) { +Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { + auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); + GE_CHECK_NOTNULL(comp_graph); + for (const auto &node : comp_graph->GetAllNodes()) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const auto &depends = op_desc->GetOpInferDepends(); + if (!depends.empty()) { + flag = true; + return SUCCESS; + } + } + return SUCCESS; +} + +Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { + bool infer_depend_flag = false; + GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); auto tasks = ge_model->GetModelTaskDefPtr()->task(); int32_t kernel_task_num = 0; for (int i = 0; i < tasks.size(); ++i) { auto task_type = static_cast(tasks[i].type()); if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { - kernel_task_num++; - if (kernel_task_num > 1) { - return true; + const auto &context = task_type == RT_MODEL_TASK_KERNEL ? tasks[i].kernel().context() : + tasks[i].kernel_with_handle().context(); + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type == ccKernelType::TE) { + if (infer_depend_flag) { + flag = true; + return SUCCESS; + } + kernel_task_num++; + if (kernel_task_num > 1) { + flag = true; + return SUCCESS; + } } } } - return false; + return SUCCESS; } } // namespace @@ -503,7 +530,9 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & auto ge_model = model_helper_.GetGeModel(); GE_CHECK_NOTNULL(ge_model); - if (NeedHybridModel(ge_model)) { + bool need_hybrid_model = false; + GE_CHK_STATUS_RET(NeedHybridModel(ge_model, need_hybrid_model), "[Check][NeedHybridModel] failed."); + if (need_hybrid_model) { GELOGD("Build single op HybridModel."); GE_CHK_STATUS_RET_NOLOG(hybrid::NodeExecutorManager::GetInstance().EnsureInitialized()); auto root_model = model_helper_.GetGeRootModel(); From 13ecbe405a45aa3d6ad1007536c302cfcd79e6fc Mon Sep 17 00:00:00 2001 From: zhou_chao1993 Date: Thu, 25 Mar 2021 19:08:22 +0800 Subject: [PATCH 3/9] sync runtime head --- ge/common/dump/opdebug_register.cc | 5 ----- tests/depends/runtime/src/runtime_stub.cc | 4 ++++ third_party/fwkacllib/inc/runtime/stream.h | 22 ++++++++++++++++++++++ 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/ge/common/dump/opdebug_register.cc b/ge/common/dump/opdebug_register.cc index 340b89e5..f576f3eb 100644 --- a/ge/common/dump/opdebug_register.cc +++ b/ge/common/dump/opdebug_register.cc @@ -80,13 +80,11 @@ Status OpdebugRegister::RegisterDebugForStream(rtStream_t stream, uint32_t op_de uint32_t debug_stream_id = 0; uint32_t debug_task_id = 0; -#ifdef ONLY_COMPILE_OPEN_SRC auto rt_ret = rtDebugRegisterForStream(stream, op_debug_mode, op_debug_addr_, &debug_stream_id, &debug_task_id); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "rtDebugRegisterForStream error, ret: 0x%X", rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); } -#endif GELOGD("debug_task_id:%u, debug_stream_id:%u in stream overflow.", debug_task_id, debug_stream_id); data_dumper.SaveOpDebugId(debug_task_id, debug_stream_id, p2p_debug_addr_, true); return SUCCESS; @@ -94,7 +92,6 @@ Status OpdebugRegister::RegisterDebugForStream(rtStream_t stream, uint32_t op_de void OpdebugRegister::UnregisterDebugForStream(rtStream_t stream) { rtError_t rt_ret = RT_ERROR_NONE; -#ifdef ONLY_COMPILE_OPEN_SRC if (stream != nullptr) { GELOGD("start call rtDebugUnRegisterForStream in unknown shape over flow."); rt_ret = rtDebugUnRegisterForStream(stream); @@ -102,8 +99,6 @@ void OpdebugRegister::UnregisterDebugForStream(rtStream_t stream) { GELOGW("rtDebugUnRegisterForStream failed, ret: 0x%X", rt_ret); } } -#endif - if (op_debug_addr_ != nullptr) { rt_ret = rtFree(op_debug_addr_); if (rt_ret != RT_ERROR_NONE) { diff --git a/tests/depends/runtime/src/runtime_stub.cc b/tests/depends/runtime/src/runtime_stub.cc index 440b98e7..40cd92dc 100644 --- a/tests/depends/runtime/src/runtime_stub.cc +++ b/tests/depends/runtime/src/runtime_stub.cc @@ -435,3 +435,7 @@ rtError_t rtGetTaskIdAndStreamID(uint32_t *taskId, uint32_t *streamId) rtError_t rtDebugRegisterForStream(rtStream_t stream, uint32_t flag, const void *addr, uint32_t *streamId, uint32_t *taskId) { return RT_ERROR_NONE; } + +rtError_t rtDebugUnRegisterForStream(rtStream_t stream) { + return RT_ERROR_NONE; +} diff --git a/third_party/fwkacllib/inc/runtime/stream.h b/third_party/fwkacllib/inc/runtime/stream.h index 6b9f80ae..f9981514 100644 --- a/third_party/fwkacllib/inc/runtime/stream.h +++ b/third_party/fwkacllib/inc/runtime/stream.h @@ -189,6 +189,28 @@ RTS_API rtError_t rtStreamActive(rtStream_t activeStream, rtStream_t stream); */ RTS_API rtError_t rtStreamSwitchN(void *ptr, uint32_t size, void *valuePtr, rtStream_t *trueStreamPtr, uint32_t elementSize, rtStream_t stream, rtSwitchDataType_t dataType); + +/* + * @ingroup dvrt_stream + * @brief enable debug for dump overflow exception with stream + * @param [in] addr: ddr address of kernel exception dumpped + * @param [in] stream: stream handle + * @param [in] flag: debug flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDebugRegisterForStream(rtStream_t stream, uint32_t flag, const void *addr, + uint32_t *streamId, uint32_t *taskId); + +/* + * @ingroup rt_model + * @brief disable debug for dump overflow exception with stream + * @param [in] stream: stream handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDebugUnRegisterForStream(rtStream_t stream); + #if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) } #endif From f0d897b0bb073bf6df271f8615a263d6be5ebd0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=A3=8A?= Date: Mon, 22 Mar 2021 15:32:28 +0800 Subject: [PATCH 4/9] fixed compiled issue for proto files --- build.sh | 2 +- ge/CMakeLists.txt | 8 +++++++- ge/common/CMakeLists.txt | 5 +++-- ge/executor/CMakeLists.txt | 5 +++-- ge/ge_local_engine/CMakeLists.txt | 14 ++++++++------ ge/host_cpu_engine/CMakeLists.txt | 5 +++-- metadef | 2 +- parser | 2 +- 8 files changed, 27 insertions(+), 16 deletions(-) diff --git a/build.sh b/build.sh index 3e2dcdec..7b1c0792 100644 --- a/build.sh +++ b/build.sh @@ -229,7 +229,7 @@ if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then rm -rf ${BASEPATH}/cov mkdir ${BASEPATH}/cov lcov -c -d build/tests/ut/ge -d build/tests/ut/common/graph/ -o cov/tmp.info - lcov -r cov/tmp.info '*/output/*' '*/build/opensrc/*' '*/build/proto/*' '*/third_party/*' '*/tests/*' '/usr/local/*' -o cov/coverage.info + lcov -r cov/tmp.info '*/output/*' '*/build/opensrc/*' '*/build/proto/*' '*/third_party/*' '*/tests/*' '/usr/local/*' '/usr/include/*' '*/metadef/*' '*/parser/*' -o cov/coverage.info cd ${BASEPATH}/cov genhtml coverage.info fi diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index c29936bb..9b979200 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -31,6 +31,7 @@ set(PROTO_HEADER_LIST protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) +protobuf_generate(ge_client PROTO_CLIENT_HEADER_SRCS PROTO_CLIENT_HEADER_HDRS ${PROTO_HEADER_LIST}) if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) ############ libge_proto_common.a ############ @@ -56,7 +57,7 @@ target_link_libraries(ge_proto_common PRIVATE ############ libge_proto_client.a ############ add_library(ge_proto_client STATIC - ${PROTO_HEADER_HDRS} + ${PROTO_CLIENT_HEADER_HDRS} ${PROTO_CLIENT_SRCS} ) @@ -65,6 +66,11 @@ target_compile_definitions(ge_proto_client PRIVATE google=ascend_private ) +target_include_directories(ge_proto_client PRIVATE + ${CMAKE_BINARY_DIR}/proto/ge_client + ${CMAKE_BINARY_DIR}/proto/ge_client/proto +) + target_compile_options(ge_proto_client PRIVATE -O2 -fno-common diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index a6f8e57c..75cb8ad1 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -16,6 +16,7 @@ set(PROTO_LIST ) protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) set(SRC_LIST "context/ctx.cc" @@ -127,7 +128,7 @@ target_link_libraries(ge_common PRIVATE ) ############ libge_common.a ############ -add_library(ge_common_static STATIC ${SRC_LIST} ${PROTO_HDRS}) +add_library(ge_common_static STATIC ${SRC_LIST} ${PROTO_STATIC_HDRS}) target_compile_definitions(ge_common_static PRIVATE PROTOBUF_INLINE_NOT_IN_HEADERS=0 HOST_VISIBILITY @@ -158,7 +159,7 @@ target_include_directories(ge_common_static PRIVATE ${METADEF_DIR}/inc/external/graph ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge + ${CMAKE_BINARY_DIR}/proto/ge_static #### yellow zone #### ${GE_DEPEND_DIR}/inc ${GE_DEPEND_DIR}/inc/cce diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt index 396c4617..363900d0 100644 --- a/ge/executor/CMakeLists.txt +++ b/ge/executor/CMakeLists.txt @@ -8,6 +8,7 @@ set(PROTO_LIST ) protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) set(SRC_LIST "ge_executor.cc" @@ -162,7 +163,7 @@ set(SRC_LIST ) ######## libge_executor.a ######## -add_library(ge_executor STATIC ${SRC_LIST} ${PROTO_HDRS}) +add_library(ge_executor STATIC ${SRC_LIST} ${PROTO_STATIC_HDRS}) target_compile_options(ge_executor PRIVATE $<$,$>:-fvisibility=hidden -O2 -Werror -Wno-deprecated-declarations -fno-common> @@ -191,7 +192,7 @@ target_include_directories(ge_executor SYSTEM PRIVATE ${METADEF_DIR}/inc/external/graph ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge + ${CMAKE_BINARY_DIR}/proto/ge_static #### yellow zone #### ${GE_CODE_DIR}/../inc ${GE_CODE_DIR}/../inc/cce diff --git a/ge/ge_local_engine/CMakeLists.txt b/ge/ge_local_engine/CMakeLists.txt index 00142cfe..ab767ccb 100755 --- a/ge/ge_local_engine/CMakeLists.txt +++ b/ge/ge_local_engine/CMakeLists.txt @@ -20,6 +20,8 @@ set(OPS_KERNEL_SRC_LIST ) protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +protobuf_generate(ge_ops_shared PROTO_OPS_SHARED_SRCS PROTO_OPS_SHARED_HDRS ${PROTO_LIST}) +protobuf_generate(ge_ops_static PROTO_OPS_STATIC_SRCS PROTO_OPS_STATIC_HDRS ${PROTO_LIST}) ############ libge_local_engine.so ############ add_library(ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) @@ -119,7 +121,7 @@ set_target_properties(atc_ge_local_engine PROPERTIES ) ############ libge_local_opskernel_builder.so ############ -add_library(ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS}) +add_library(ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_OPS_SHARED_HDRS}) target_compile_options(ge_local_opskernel_builder PRIVATE -Werror @@ -143,7 +145,7 @@ target_include_directories(ge_local_opskernel_builder PRIVATE ${METADEF_DIR}/inc/external/graph ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge + ${CMAKE_BINARY_DIR}/proto/ge_ops_shared #### yellow zone #### ${GE_CODE_DIR}/../inc #### blue zone #### @@ -166,7 +168,7 @@ target_link_libraries(ge_local_opskernel_builder PRIVATE ) ############ atclib/libge_local_opskernel_builder.so ############ -add_library(atc_ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS}) +add_library(atc_ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_OPS_SHARED_HDRS}) target_compile_options(atc_ge_local_opskernel_builder PRIVATE -Werror @@ -190,7 +192,7 @@ target_include_directories(atc_ge_local_opskernel_builder PRIVATE ${METADEF_DIR}/inc/external/graph ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge + ${CMAKE_BINARY_DIR}/proto/ge_ops_shared #### yellow zone #### ${GE_CODE_DIR}/../inc #### blue zone #### @@ -218,7 +220,7 @@ set_target_properties(atc_ge_local_opskernel_builder PROPERTIES ) ############ libge_local_opskernel_builder.a ############ -add_library(ge_local_opskernel_builder_static STATIC ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS}) +add_library(ge_local_opskernel_builder_static STATIC ${OPS_KERNEL_SRC_LIST} ${PROTO_OPS_STATIC_HDRS}) target_compile_options(ge_local_opskernel_builder_static PRIVATE -Werror @@ -243,7 +245,7 @@ target_include_directories(ge_local_opskernel_builder_static PRIVATE ${METADEF_DIR}/inc/external/graph ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge + ${CMAKE_BINARY_DIR}/proto/ge_ops_static #### yellow zone #### ${GE_CODE_DIR}/../inc #### blue zone #### diff --git a/ge/host_cpu_engine/CMakeLists.txt b/ge/host_cpu_engine/CMakeLists.txt index 13cb7434..8d84ee28 100644 --- a/ge/host_cpu_engine/CMakeLists.txt +++ b/ge/host_cpu_engine/CMakeLists.txt @@ -3,6 +3,7 @@ set(PROTO_LIST ) protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST}) set(SRC_LIST "engine/host_cpu_engine.cc" @@ -61,7 +62,7 @@ target_link_libraries(host_cpu_engine PRIVATE ) ############ atcstub/libhost_cpu_engine.so ############ -add_library(atc_host_cpu_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) +add_library(atc_host_cpu_engine SHARED ${SRC_LIST} ${PROTO_ATCSTUB_HDRS}) target_compile_options(atc_host_cpu_engine PRIVATE -Werror @@ -84,7 +85,7 @@ target_include_directories(atc_host_cpu_engine PRIVATE ${METADEF_DIR}/inc/external ${METADEF_DIR}/inc/external/graph ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge + ${CMAKE_BINARY_DIR}/proto/ge_atcstub #### yellow zone #### ${GE_CODE_DIR}/../inc #### blue zone #### diff --git a/metadef b/metadef index 140538ea..f0dd9337 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 140538eadb161278f1c733e7850bfaba65cf665e +Subproject commit f0dd933702e5224399e757e1b6174e49eb4e71fa diff --git a/parser b/parser index b203d478..d851e1d4 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit b203d47837421b2c149f353fc0808f6a29fa584e +Subproject commit d851e1d467768b6cefd8f5f44745be1c5312121a From 37c928ed29ea9a3fef4f9f3093ec2e2ef62262ad Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Sat, 27 Mar 2021 10:52:59 +0800 Subject: [PATCH 5/9] bugfix for auto find fp --- ge/graph/build/task_generator.cc | 8 ++- tests/ut/ge/CMakeLists.txt | 2 + tests/ut/ge/graph/build/task_generator_unittest.cc | 68 ++++++++++++++++++++++ 3 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 tests/ut/ge/graph/build/task_generator_unittest.cc diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc index 4eda4020..13fc2601 100755 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -49,6 +49,7 @@ const char *const kIsLastNode = "is_last_node"; const char *const kIsInputVar = "INPUT_IS_VAR"; const char *const kIsOutputVar = "OUTPUT_IS_VAR"; const char *const kProfilingMode = "PROFILING_MODE"; +const char *const kIteratorV2 = "IteratorV2"; const uint32_t kProfilingArStep = 2; const uint64_t kProfilingFpStartLogid = 1; const uint64_t kProfilingBpEndLogid = 2; @@ -57,6 +58,7 @@ const uint64_t kProfilingArEndLogid = 4; const uint64_t kProfilingIterEndLogid = 65535; const int64_t kHashFactor = 100000; const int64_t kInvalidGroupId = -1; +const std::set kFpNodeTypes = {ge::DATA, ge::GETNEXT, kIteratorV2}; } // namespace namespace ge { TaskGenerator::TaskGenerator(uint8_t *var_mem_base, uint64_t var_mem_size) { @@ -621,8 +623,10 @@ Status TaskGenerator::AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingP if (op_kernel_lib_name.empty()) { continue; } - - if (op_desc->GetType() == GETNEXT || op_desc->GetType() == DATA) { + auto type = op_desc->GetType(); + std::string original_type; + (void)AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type); + if (kFpNodeTypes.find(type) != kFpNodeTypes.end() || kFpNodeTypes.find(original_type) != kFpNodeTypes.end()) { auto out_anchor = node->GetOutDataAnchor(0); for (auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { GE_CHECK_NOTNULL(peer_in_anchor); diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 80636a20..453269da 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -731,6 +731,7 @@ set(KERNEL_TEST_FILES "graph/passes/folding_kernel/gather_v2_kernel_unittest.cc" "graph/passes/folding_kernel/slice_kernel_unittest.cc" "graph/passes/folding_kernel/dynamic_stitch_kernel_unittest.cc" + "graph/passes/atomic_addr_clean_pass_unittest.cc" ) set(MULTI_PARTS_TEST_FILES @@ -760,6 +761,7 @@ set(MULTI_PARTS_TEST_FILES "graph/variable_accelerate_ctrl_unittest.cc" "graph/build/logical_stream_allocator_unittest.cc" "graph/build/mem_assigner_unittest.cc" + "graph/build/task_generator_unittest.cc" "graph/preprocess/graph_preprocess_unittest.cc" "graph/manager/hcom_util_unittest.cc" "graph/manager/graph_caching_allocator_unittest.cc" diff --git a/tests/ut/ge/graph/build/task_generator_unittest.cc b/tests/ut/ge/graph/build/task_generator_unittest.cc new file mode 100644 index 00000000..95e75eb7 --- /dev/null +++ b/tests/ut/ge/graph/build/task_generator_unittest.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "graph/anchor.h" +#include "graph/attr_value.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "omg/omg_inner_types.h" +#include "../passes/graph_builder_utils.h" + +#define protected public +#define private public +#include "graph/build/task_generator.h" +#undef protected +#undef private + +using namespace std; +using namespace testing; +using namespace ge; + +class UtestTaskGeneratorTest : public testing::Test { + public: + ge::ComputeGraphPtr BuildGraphFpProfiling() { + ge::ut::GraphBuilder builder("graph"); + auto data = builder.AddNode("data", "phony", 1, 1); + auto addn1 = builder.AddNode("addn1", "AddN", 1, 1); + auto netoutput = builder.AddNode("netoutput", "NetOutput", 2, 0); + auto op_desc = data->GetOpDesc(); + (void)AttrUtils::SetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "IteratorV2"); + op_desc->SetOpKernelLibName("GE"); + builder.AddDataEdge(data, 0, addn1, 0); + builder.AddDataEdge(addn1, 0, netoutput, 0); + return builder.GetGraph(); + } + + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UtestTaskGeneratorTest, AutoFindFpOpIndex) { + auto graph = BuildGraphFpProfiling(); + TaskGenerator task_generator(nullptr, 0); + ProfilingPoint profiling_point; + profiling_point.fp_index = -1; + EXPECT_EQ(task_generator.AutoFindFpOpIndex(graph, profiling_point), SUCCESS); + // addn1 is fp + EXPECT_EQ(profiling_point.fp_index, 2); +} From 9d34427af96c05533dae263f28cf8fc456456b23 Mon Sep 17 00:00:00 2001 From: y00500818 Date: Sat, 27 Mar 2021 16:46:42 +0800 Subject: [PATCH 6/9] bugfix for atomic_addr_clean_pass --- ge/graph/passes/atomic_addr_clean_pass.cc | 12 ++-- tests/ut/ge/CMakeLists.txt | 1 + .../passes/atomic_addr_clean_pass_unittest.cc | 65 ++++++++++++++++++++++ 3 files changed, 72 insertions(+), 6 deletions(-) create mode 100644 tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc diff --git a/ge/graph/passes/atomic_addr_clean_pass.cc b/ge/graph/passes/atomic_addr_clean_pass.cc index 7c6ed8ce..16d3c129 100755 --- a/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/ge/graph/passes/atomic_addr_clean_pass.cc @@ -126,11 +126,11 @@ bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int6 bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) { OpDescPtr op_desc = node->GetOpDesc(); - std::map> node_workspace_offset; + std::map> atomic_workspace_index_size; bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); - node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset); - if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) { + atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size); + if (!has_atomic_input && has_atomic_output && atomic_workspace_index_size.empty()) { std::vector atomic_output_index; (void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); bool is_all_output_peer_also_atomic = true; @@ -332,11 +332,11 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { } // 2.Check atomic attr in node - std::map> node_workspace_offset; + std::map> atomic_workspace_index_size; bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); - node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset); - if (!has_atomic_input && !has_atomic_output && node_workspace_offset.empty()) { + atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size); + if (!has_atomic_input && !has_atomic_output && atomic_workspace_index_size.empty()) { return false; } diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 80636a20..f7fce11d 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -731,6 +731,7 @@ set(KERNEL_TEST_FILES "graph/passes/folding_kernel/gather_v2_kernel_unittest.cc" "graph/passes/folding_kernel/slice_kernel_unittest.cc" "graph/passes/folding_kernel/dynamic_stitch_kernel_unittest.cc" + "graph/passes/atomic_addr_clean_pass_unittest.cc" ) set(MULTI_PARTS_TEST_FILES diff --git a/tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc b/tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc new file mode 100644 index 00000000..59636511 --- /dev/null +++ b/tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "graph/passes/atomic_addr_clean_pass.h" +#include "common/op/ge_op_utils.h" +#include "common/types.h" +#include "graph/anchor.h" +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/op_desc.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "inc/pass_manager.h" +using namespace testing; + +namespace ge { +class UtestGraphPassesAtomicAddrCleanPass : public Test { +public: + UtestGraphPassesAtomicAddrCleanPass() { + graph_ = std::make_shared("test"); + } + + NodePtr NewNode(const string &name, const string &type, int input_cnt, int output_cnt) { + OpDescPtr op_desc = std::make_shared(name, type); + for (int i = 0; i < input_cnt; ++i) { + op_desc->AddInputDesc(GeTensorDesc()); + } + for (int i = 0; i < output_cnt; ++i) { + op_desc->AddOutputDesc(GeTensorDesc()); + } + NodePtr node = graph_->AddNode(op_desc); + return node; + } + + ComputeGraphPtr graph_; +}; + +// node1 -> node2 -> node3 +TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) { + auto node1 = NewNode("node1", DATA, 0, 1); + auto node2 = NewNode("node2", RELU, 1, 1); + auto node3 = NewNode("node3", NETOUTPUT, 1, 0); + GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); + GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); + AtomicAddrCleanPass atomi_addr_clean_pass; + Status ret = atomi_addr_clean_pass.Run(graph_); + EXPECT_EQ(ret, SUCCESS); +} +} // namespace ge From 2cf49ced1c080e8792302e539cb5603dbe708ee6 Mon Sep 17 00:00:00 2001 From: lianghao Date: Sat, 27 Mar 2021 14:41:23 +0800 Subject: [PATCH 7/9] online_inference c77 --- ge/graph/passes/attach_stream_label_pass.cc | 1 - ge/graph/passes/pass_utils.cc | 8 +- ge/graph/passes/pass_utils.h | 2 + ge/graph/passes/subexpression_migration_pass.cc | 2 +- ge/graph/passes/switch_dead_branch_elimination.cc | 10 +- ge/graph/passes/switch_to_stream_switch_pass.cc | 2 + metadef | 2 +- tests/ut/ge/CMakeLists.txt | 1 + .../switch_dead_branch_elimination_unittest.cc | 163 +++++++++++++++++++++ 9 files changed, 185 insertions(+), 6 deletions(-) create mode 100644 tests/ut/ge/graph/passes/switch_dead_branch_elimination_unittest.cc diff --git a/ge/graph/passes/attach_stream_label_pass.cc b/ge/graph/passes/attach_stream_label_pass.cc index cd3509c7..4927e3aa 100644 --- a/ge/graph/passes/attach_stream_label_pass.cc +++ b/ge/graph/passes/attach_stream_label_pass.cc @@ -137,7 +137,6 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea return INTERNAL_ERROR; } stream_label = node->GetInDataNodes().at(0)->GetName(); - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); bool value = false; OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); diff --git a/ge/graph/passes/pass_utils.cc b/ge/graph/passes/pass_utils.cc index 3adfbde3..b827e88a 100644 --- a/ge/graph/passes/pass_utils.cc +++ b/ge/graph/passes/pass_utils.cc @@ -35,9 +35,9 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" +#include "utils/node_utils.h" namespace ge { - Status PassUtils::ConstructTensorDescWithData(const GeTensorDesc &out_desc, std::vector &data, std::vector &v_output, const bool scalar_output) { Status ret = SUCCESS; @@ -246,6 +246,12 @@ NodePtr PassUtils::GetInDataNode(const ConstNodePtr &node, int index) { return src_node; } +NodePtr PassUtils::GetInNodeCrossSubgraphByIndex(const ConstNodePtr &node, int index) { + auto src_node = GetInDataNode(node, index); + + return NodeUtils::GetInNodeCrossSubgraph(src_node); +} + bool PassUtils::IsNeedTrainIteFlowCtrl(const ComputeGraphPtr &compute_graph) { if (compute_graph == nullptr) { return false; diff --git a/ge/graph/passes/pass_utils.h b/ge/graph/passes/pass_utils.h index fbfb3b47..bd506d09 100755 --- a/ge/graph/passes/pass_utils.h +++ b/ge/graph/passes/pass_utils.h @@ -30,6 +30,8 @@ class PassUtils { static NodePtr GetInDataNode(const ConstNodePtr &node, int index); + static NodePtr GetInNodeCrossSubgraphByIndex(const ConstNodePtr &node, int index); + static bool IsConstant(const ConstNodePtr &node); static Status SetOutNodeWeight(const OutDataAnchorPtr &out_data_anchor, const NodePtr &src_node); diff --git a/ge/graph/passes/subexpression_migration_pass.cc b/ge/graph/passes/subexpression_migration_pass.cc index dc4d2185..05b7baa1 100755 --- a/ge/graph/passes/subexpression_migration_pass.cc +++ b/ge/graph/passes/subexpression_migration_pass.cc @@ -279,7 +279,7 @@ Status SubexpressionMigrationPass::GraphNodeMigration(const ComputeGraphPtr &gra const auto &in_anchor = in_anchors.at(i); const auto &base_node = in_anchor->GetOwnerNode(); GELOGD("Get Data direct node: %s", base_node->GetName().c_str()); - if (!base_node->GetHostNode()) { + if (!base_node->GetHostNode() || base_node->GetType() == SWITCH) { continue; } diff --git a/ge/graph/passes/switch_dead_branch_elimination.cc b/ge/graph/passes/switch_dead_branch_elimination.cc index 70105aea..20598f17 100644 --- a/ge/graph/passes/switch_dead_branch_elimination.cc +++ b/ge/graph/passes/switch_dead_branch_elimination.cc @@ -94,6 +94,12 @@ Status SwitchDeadBranchElimination::DeleteSwitchNode(NodePtr &node, NodePtr &pre GELOGE(FAILED, "parameter is null."); return FAILED; } + + // If two nodes aren't in same graph, get node's direct in_node instead of pred_node. + if (node->GetOwnerComputeGraph() != pred_node->GetOwnerComputeGraph()) { + pred_node = PassUtils::GetInDataNode(node, kPredInputIndex); + } + // link pred's in control nodes to switch if (GraphUtils::CopyInCtrlEdges(pred_node, node) != GRAPH_SUCCESS) { return FAILED; @@ -131,7 +137,7 @@ Status SwitchDeadBranchElimination::Run(NodePtr &node) { return SUCCESS; } - auto pred_node = PassUtils::GetInDataNode(node, kPredInputIndex); + auto pred_node = PassUtils::GetInNodeCrossSubgraphByIndex(node, kPredInputIndex); if (pred_node == nullptr) { GELOGD("[%s] Pred input is null.", node->GetName().c_str()); return SUCCESS; @@ -143,7 +149,7 @@ Status SwitchDeadBranchElimination::Run(NodePtr &node) { return SUCCESS; } - auto input_node = PassUtils::GetInDataNode(node, kDataInputIndex); + auto input_node = PassUtils::GetInNodeCrossSubgraphByIndex(node, kDataInputIndex); if (input_node == nullptr) { GELOGD("[%s] Data input is null.", node->GetName().c_str()); return SUCCESS; diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index 392968e7..d7fa8844 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -448,6 +448,8 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) // select first stream_switch NodePtr stream_switch = switch_list.front(); + // set stream_label + GE_CHK_STATUS_RET(SetStreamLabel(stream_switch, cast_node->GetName()), "Set stream label failed."); OpDescPtr switch_desc = stream_switch->GetOpDesc(); GE_CHECK_NOTNULL(switch_desc); switch_desc->SetName(CheckDuplicateName(cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f"))); diff --git a/metadef b/metadef index f0dd9337..0c4602a4 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit f0dd933702e5224399e757e1b6174e49eb4e71fa +Subproject commit 0c4602a4615a9368b06633a5087e2114518f29ca diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 80636a20..12c33db7 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -667,6 +667,7 @@ set(PASS_TEST_FILES "graph/passes/merge_pass_unittest.cc" #"graph/passes/switch_pass_unittest.cc" "graph/passes/switch_logic_remove_pass_unittest.cc" + "graph/passes/switch_dead_branch_elimination_unittest.cc" "graph/passes/assert_pass_unittest.cc" "graph/passes/dropout_pass_unittest.cc" "graph/passes/unused_const_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/switch_dead_branch_elimination_unittest.cc b/tests/ut/ge/graph/passes/switch_dead_branch_elimination_unittest.cc new file mode 100644 index 00000000..c3f21251 --- /dev/null +++ b/tests/ut/ge/graph/passes/switch_dead_branch_elimination_unittest.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "common/ge_inner_error_codes.h" +#include "graph/passes/switch_dead_branch_elimination.h" +#include "graph_builder_utils.h" + +namespace ge { +class UtestSwitchDeadBranchElimination : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +namespace { +/* + * data1 const1 + * \ / + * case1 + * | + * relu1 + * | + * netoutput + */ +ut::GraphBuilder ParentGraphBuilder() { + ut::GraphBuilder builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", "Data", 0, 1); + auto const1 = builder.AddNode("const1", "Const", 0, 1); + auto case1 = builder.AddNode("case1", CASE, 2, 1); + auto relu1 = builder.AddNode("relu1", "Relu", 1, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + int32_t weight[1] = {1}; + GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); + GeTensorPtr tensor = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); + OpDescUtils::SetWeights(const1, {tensor}); + + builder.AddDataEdge(data1, 0, case1, 0); + builder.AddDataEdge(const1, 0, case1, 1); + builder.AddDataEdge(case1, 0, relu1, 0); + builder.AddDataEdge(relu1, 0, netoutput, 0); + return builder; +} + +/* + * data1 data2 + * \ / + * switch + * / \ + * relu1 relu2 + * \ / + * merge + * | + * netoutput + */ +ut::GraphBuilder SwitchSubgraphBuilder(string graph_name, uint32_t num) { + ut::GraphBuilder builder = ut::GraphBuilder(graph_name); + + string data1_name = "data1_" + std::to_string(num); + auto data1 = builder.AddNode(data1_name, "Data", 0, 1); + auto data1_desc = data1->GetOpDesc(); + EXPECT_NE(data1_desc, nullptr); + AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); + + string data2_name = "data2_" + std::to_string(num); + auto data2 = builder.AddNode(data2_name, "Data", 0, 1); + auto data2_desc = data2->GetOpDesc(); + EXPECT_NE(data2_desc, nullptr); + AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); + + string switch_name = "switch_" + std::to_string(num); + auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2); + + string relu1_name = "relu1_" + std::to_string(num); + auto relu1 = builder.AddNode(relu1_name, "Relu", 1, 1); + + string relu2_name = "relu2_" + std::to_string(num); + auto relu2 = builder.AddNode(relu2_name, "Relu", 1, 1); + + string merge_name = "merge_" + std::to_string(num); + auto merge = builder.AddNode(merge_name, "Merge", 2, 1); + + string output_name = "output_" + std::to_string(num); + auto netoutput = builder.AddNode(output_name, NETOUTPUT, 1, 0); + + builder.AddDataEdge(data1, 0, switch1, 0); + builder.AddDataEdge(data2, 0, switch1, 1); + builder.AddDataEdge(switch1, 0, relu1, 0); + builder.AddDataEdge(switch1, 1, relu2, 0); + builder.AddDataEdge(relu1, 0, merge, 0); + builder.AddDataEdge(relu2, 0, merge, 1); + builder.AddDataEdge(merge, 0, netoutput, 0); + + return builder; +} + +void AddCaseSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) { + auto case_node = parent_graph->FindNode("case1"); + EXPECT_NE(case_node, nullptr); + + for (uint32_t i = 0; i < branch_num; ++i) { + string name = "Branch_Graph_" + std::to_string(i); + + auto builder_subgraph = SwitchSubgraphBuilder(name, i); + auto switch_subgraph = builder_subgraph.GetGraph(); + + case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName()); + case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName()); + + switch_subgraph->SetParentNode(case_node); + switch_subgraph->SetParentGraph(parent_graph); + EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS); + } +} +} // namespace + + +TEST_F(UtestSwitchDeadBranchElimination, switch_dead_branch_elimination_across_case_success) { + auto builder = ParentGraphBuilder(); + auto parent_graph = builder.GetGraph(); + + AddCaseSubgraph(parent_graph, 2); + auto subgraphs = parent_graph->GetAllSubgraphs(); + EXPECT_EQ(subgraphs.size(), 2); + + SwitchDeadBranchElimination switch_pass; + for (auto &subgraph : subgraphs) { + auto switch_node = subgraph->FindFirstNodeMatchType("Switch"); + if (switch_node != nullptr) { + EXPECT_EQ(switch_pass.Run(switch_node), SUCCESS); + } + } + + auto all_nodes = parent_graph->GetAllNodes(); + EXPECT_EQ(all_nodes.size(), 17); + + for (auto &subgraph : subgraphs) { + EXPECT_EQ(subgraph->GetDirectNode().size(), 6); + EXPECT_EQ(subgraph->FindFirstNodeMatchType("Switch"), nullptr); + auto merge_node = subgraph->FindFirstNodeMatchType("Merge"); + EXPECT_NE(merge_node, nullptr); + auto merge_innode = merge_node->GetInDataNodes(); + EXPECT_EQ(merge_innode.size(), 1); + } +} +} // namespace ge From aad154cdf1d616b5c2992a4a0978171305eb5dd3 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 29 Mar 2021 10:16:20 +0800 Subject: [PATCH 8/9] Fix error of single_op memory free. --- ge/graph/manager/graph_caching_allocator.cc | 8 ++++++++ ge/graph/manager/graph_caching_allocator.h | 7 +++++++ ge/single_op/single_op_manager.cc | 4 ++++ 3 files changed, 19 insertions(+) diff --git a/ge/graph/manager/graph_caching_allocator.cc b/ge/graph/manager/graph_caching_allocator.cc index 5822056d..cc8bd90d 100644 --- a/ge/graph/manager/graph_caching_allocator.cc +++ b/ge/graph/manager/graph_caching_allocator.cc @@ -356,6 +356,14 @@ void CachingAllocator::FreeBlocks() { (void) FreeCachedBlocks(); } +void CachingAllocator::TryFreeBlocks() { + GELOGI("Try free blocks."); + std::lock_guard lock(mutex_); + if (allocated_blocks_.empty()) { + (void) FreeCachedBlocks(); + } +} + void CachingAllocator::FreeBlockBins() { GELOGI("Free block bins."); std::lock_guard lock(mutex_); diff --git a/ge/graph/manager/graph_caching_allocator.h b/ge/graph/manager/graph_caching_allocator.h index 27563c2d..a9c3202a 100644 --- a/ge/graph/manager/graph_caching_allocator.h +++ b/ge/graph/manager/graph_caching_allocator.h @@ -94,6 +94,13 @@ class CachingAllocator { /// Status Free(uint8_t *memory_addr, uint32_t device_id = 0); + /// + /// @ingroup ge_graph + /// @brief try to free memory when no memory is referenced + /// @return void + /// + void TryFreeBlocks(); + private: /// diff --git a/ge/single_op/single_op_manager.cc b/ge/single_op/single_op_manager.cc index fddbeec2..bbad7a9d 100644 --- a/ge/single_op/single_op_manager.cc +++ b/ge/single_op/single_op_manager.cc @@ -19,6 +19,9 @@ #include #include +#include "graph/manager/graph_mem_allocator.h" +#include "graph/manager/graph_caching_allocator.h" + namespace ge { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY SingleOpManager::~SingleOpManager() { for (auto &it : stream_resources_) { @@ -67,6 +70,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::Release delete it->second; it->second = nullptr; (void)stream_resources_.erase(it); + MemManager::Instance().CachingInstance(RT_MEMORY_HBM).TryFreeBlocks(); return SUCCESS; } From 167621141b2f5bfbddd117d20706aea8461d5940 Mon Sep 17 00:00:00 2001 From: chuxing Date: Mon, 29 Mar 2021 10:46:17 +0800 Subject: [PATCH 9/9] hccl ops with same parallel group can not be execute parallelly --- ge/hybrid/model/hybrid_model_builder.cc | 232 +++++++++++++++------ ge/hybrid/model/hybrid_model_builder.h | 9 +- ge/hybrid/model/node_item.cc | 4 + ge/hybrid/model/node_item.h | 2 + .../compiledsubgraph/known_node_executor.cc | 43 ++-- .../compiledsubgraph/known_node_executor.h | 8 +- 6 files changed, 208 insertions(+), 90 deletions(-) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index f5cb5f7e..25dabd78 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -255,9 +255,7 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n (void) AttrUtils::SetBool(new_node->op_desc, kIsFirstNode, false); (void) AttrUtils::SetBool(new_node->op_desc, kIsLastNode, false); - new_node->node_id = node_index; - new_node->op_desc->SetId(node_index); - node_index += 1; + new_node->node_id = static_cast(new_node->op_desc->GetId()); NodeExecutorManager::ExecutorType executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); new_node->is_profiling_report = (executor_type == NodeExecutorManager::ExecutorType::AICORE) || (executor_type == NodeExecutorManager::ExecutorType::AICPU_TF) || @@ -273,16 +271,16 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt // not care result, if no this attr, stand for the op does not need force infershape (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); GELOGD("node [%s] is need do infershape , flag is %d", - op_desc->GetName().c_str(), - node_item.is_need_force_infershape); + op_desc->GetName().c_str(), + node_item.is_need_force_infershape); return SUCCESS; } Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies) { - std::set dependent_input_nodes; + std::set dependent_for_shape_inference; + std::set dependent_for_execution; auto &ge_node = node_item.node; - bool is_hccl_op = - NodeExecutorManager::GetInstance().ResolveExecutorType(*ge_node) == NodeExecutorManager::ExecutorType::HCCL; + bool is_hccl_op = node_item.IsHcclOp(); // The input tensors become valid after computation is done for parent nodes of type DEPEND_COMPUTE. // Wait for these parent nodes before execution. @@ -297,29 +295,15 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s auto src_node_item = MutableNodeItem(src_node); GE_CHECK_NOTNULL(src_node_item); - if (is_hccl_op) { - GELOGD("[%s] Add input data dependent node [%s] due to engine type is HCCL", - node_item.NodeName().c_str(), - src_node_item->NodeName().c_str()); - src_node_item->has_observer = true; - node_item.dependents_for_execution.emplace_back(src_node); - node_item.has_observer = true; - for (auto &dst_node : ge_node->GetOutNodes()) { - if (dst_node == nullptr) { - continue; - } - - NodeItem *dst_node_item = nullptr; - GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(dst_node, &dst_node_item)); - dst_node_item->dependents_for_execution.emplace_back(ge_node); - } - } else if (src_node_item->shape_inference_type == DEPEND_COMPUTE) { - GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE", - node_item.NodeName().c_str(), - src_node_item->NodeName().c_str()); - + if (src_node_item->shape_inference_type == DEPEND_COMPUTE || is_hccl_op || src_node_item->IsHcclOp()) { + GELOGD("[%s](%s) Add input data dependent node [%s](%s), shape inference type = %d", + ge_node->GetName().c_str(), + ge_node->GetType().c_str(), + src_node->GetName().c_str(), + src_node->GetType().c_str(), + static_cast(src_node_item->shape_inference_type)); src_node_item->has_observer = true; - node_item.dependents_for_execution.emplace_back(src_node); + dependent_for_execution.emplace(src_node); } if (src_node_item->shape_inference_type == DEPEND_SHAPE_RANGE) { @@ -327,22 +311,17 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s node_item.NodeName().c_str(), src_node_item->NodeName().c_str()); src_node_item->has_observer = true; - dependent_input_nodes.emplace(src_node); + dependent_for_shape_inference.emplace(src_node); } } // cond or branch need to be prepared before the execution of IF or CASE if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { - const auto &in_anchor = ge_node->GetInDataAnchor(0); - GE_CHECK_NOTNULL(in_anchor); - const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_anchor); - auto src_node = peer_anchor->GetOwnerNode(); + auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input GE_CHECK_NOTNULL(src_node); auto src_node_item = MutableNodeItem(src_node); GE_CHECK_NOTNULL(src_node_item); - src_node_item->has_observer = true; - node_item.dependents_for_execution.emplace_back(src_node); + dependent_for_execution.emplace(src_node); GELOGD("[%s] Dependent added from %s for control op's cond/branch", node_item.NodeName().c_str(), src_node_item->NodeName().c_str()); @@ -366,24 +345,32 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s GE_CHECK_NOTNULL(src_node); auto src_node_item = MutableNodeItem(src_node); src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); - src_node_item->has_observer = true; - - dependent_input_nodes.emplace(src_node); + dependent_for_shape_inference.emplace(src_node); GELOGD("[%s] Dependent added from output of [%s:%d]", node_item.NodeName().c_str(), src_node_item->NodeName().c_str(), peer_out_anchor->GetIdx()); } - for (const auto &dep_node : dependent_input_nodes) { + GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item, dependent_for_shape_inference)); + for (const auto &dep_node : dependent_for_shape_inference) { + auto src_node_item = MutableNodeItem(dep_node); + GE_CHECK_NOTNULL(src_node_item); + src_node_item->has_observer = true; node_item.dependents_for_shape_inference.emplace_back(dep_node); } - GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item)); + for (const auto &dep_node : dependent_for_execution) { + auto src_node_item = MutableNodeItem(dep_node); + GE_CHECK_NOTNULL(src_node_item); + src_node_item->has_observer = true; + node_item.dependents_for_execution.emplace_back(dep_node); + } + return SUCCESS; } -Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { +Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item, std::set &dependencies) { if (node_item.fused_subgraph == nullptr) { return SUCCESS; } @@ -413,17 +400,12 @@ Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { node_item.NodeName().c_str(), op_desc->GetName().c_str(), src_node_item->NodeName().c_str()); - src_node_item->has_observer = true; src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); - - auto &depends = node_item.dependents_for_shape_inference; - if (std::find(depends.begin(), depends.end(), src_node) == depends.end()) { - depends.emplace_back(src_node); - GELOGD("[%s] Dependent added from output of [%s:%d]", - node_item.NodeName().c_str(), - src_node_item->NodeName().c_str(), - peer_out_anchor->GetIdx()); - } + dependencies.emplace(src_node); + GELOGD("[%s] Dependent added from output of [%s:%d]", + node_item.NodeName().c_str(), + src_node_item->NodeName().c_str(), + peer_out_anchor->GetIdx()); } return SUCCESS; @@ -770,9 +752,23 @@ Status HybridModelBuilder::LoadGraph() { GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), root_graph->GetAllNodesSize()); - GE_DUMP(root_graph, "hybrid_merged_graph"); } + root_graph_ = root_graph; + // Reset node id by topological order across all subgraphs + int64_t index = 0; + for (const auto &node : root_graph->GetAllNodes()) { + GE_CHECK_NOTNULL(node); + auto parent_graph = node->GetOwnerComputeGraph(); + // No need to update nodes in known subgraph + if (parent_graph != nullptr && !parent_graph->GetGraphUnknownFlag()) { + continue; + } + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + op_desc->SetId(index++); + } + GE_DUMP(root_graph, "hybrid_merged_graph"); GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); GELOGD("Done loading root graph successfully."); GE_CHK_STATUS_RET(hybrid_model_.root_graph_item_->GroupNodes(), "Failed to group nodes for root graph"); @@ -810,6 +806,7 @@ Status HybridModelBuilder::LoadGraph() { } } + GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), "Failed to establish dependencies for hccl ops"); GELOGI("Done loading all subgraphs successfully."); return SUCCESS; } @@ -1075,25 +1072,41 @@ Status HybridModelBuilder::InitWeights() { return SUCCESS; } +Status HybridModelBuilder::LoadTask(NodeItem &node_item) { + auto &node_ptr = node_item.node; + GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str()); + auto load_ret = node_item.node_executor->LoadTask(hybrid_model_, + node_ptr, + node_item.kernel_task); + if (load_ret != UNSUPPORTED && load_ret != SUCCESS) { + GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str()); + return load_ret; + } + + GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str()); + return SUCCESS; +} + Status HybridModelBuilder::LoadTasks() { GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed."); + std::map> ordered_partitioned_calls; for (auto &it : hybrid_model_.node_items_) { auto &node_item = it.second; - auto &node_ptr = node_item->node; if (node_item->node_type == NETOUTPUT) { continue; } - - GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str()); - auto load_ret = node_item->node_executor->LoadTask(hybrid_model_, - node_ptr, - node_item->kernel_task); - if (load_ret != UNSUPPORTED && load_ret != SUCCESS) { - GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str()); - return load_ret; + if (node_item->node_type == PARTITIONEDCALL) { + ordered_partitioned_calls[node_item->node_id][node_item->node_name] = node_item.get(); + continue; } + GE_CHK_STATUS_RET_NOLOG(LoadTask(*node_item)); + } - GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str()); + // HCCL operators need to be loaded in the same order across different processes + for (auto &it : ordered_partitioned_calls) { + for (auto &it2 : it.second) { + GE_CHK_STATUS_RET_NOLOG(LoadTask(*it2.second)); + } } return SUCCESS; @@ -1626,6 +1639,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem auto temp_graph = MakeShared("temp"); GE_CHECK_NOTNULL(temp_graph); auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); + wrapper_op_desc->SetId(parent_node_item->node_id); GeModelPtr ge_model = subgraph_models_[subgraph_name]; GE_CHECK_NOTNULL(ge_model); hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); @@ -2011,5 +2025,93 @@ Status HybridModelBuilder::CheckAicpuOpList() { "Launch check aicpu op type failed."); return SUCCESS; } + +Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { + const auto &node = node_item->node; + auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); + if (executor_type == NodeExecutorManager::ExecutorType::HCCL) { + std::string parallel_group; + if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { + GELOGD("[%s] Got parallel group = [%s]", node_item->NodeName().c_str(), parallel_group.c_str()); + parallel_group_to_nodes_[parallel_group].emplace(node_item); + std::set group{parallel_group}; + node_to_parallel_groups_[node_item].emplace(parallel_group); + } + } else if (executor_type == NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH) { + std::set parallel_groups; + GELOGD("[%s] To collect parallel group for known-shaped subgraph", node_item->NodeName().c_str()); + for (const auto &subgraph_name : node->GetOpDesc()->GetSubgraphInstanceNames()) { + GELOGD("[%s] Start to get parallel group from subgraph: %s", + node_item->NodeName().c_str(), + subgraph_name.c_str()); + auto subgraph = root_graph_->GetSubgraph(subgraph_name); + GE_CHECK_NOTNULL(subgraph); + for (const auto &sub_node : subgraph->GetAllNodes()) { + std::string parallel_group; + if (AttrUtils::GetStr(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { + GELOGD("[%s::%s] Got parallel group = %s", + subgraph_name.c_str(), + sub_node->GetName().c_str(), + parallel_group.c_str()); + parallel_groups.emplace(parallel_group); + } + } + } + + if (!parallel_groups.empty()) { + for (const auto ¶llel_group : parallel_groups) { + parallel_group_to_nodes_[parallel_group].emplace(node_item); + GELOGD("[%s] has parallel group: %s", node_item->NodeName().c_str(), parallel_group.c_str()); + } + node_to_parallel_groups_.emplace(node_item, std::move(parallel_groups)); + } + } + + return SUCCESS; +} + +Status HybridModelBuilder::ParseDependentByParallelGroup() { + for (auto &it : hybrid_model_.node_items_) { + GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(it.second.get())); + } + for (const auto &it : node_to_parallel_groups_) { + auto node_item = it.first; + auto dst_executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); + for (const auto ¶llel_group : it.second) { + auto &dependent_nodes = parallel_group_to_nodes_[parallel_group]; + NodeItem *nearest_dep_node = nullptr; + int max_id = -1; + for (auto &dep_node : dependent_nodes) { + if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) { + nearest_dep_node = dep_node; + max_id = dep_node->node_id; + } + } + + if (nearest_dep_node != nullptr) { + GELOGD("[%s] Nearest node = [%s]", node_item->NodeName().c_str(), nearest_dep_node->NodeName().c_str()); + auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*nearest_dep_node->node); + if (src_engine_type == dst_executor_type) { + GELOGD("No need to add dependency for nodes with same executor type"); + continue; + } + auto &deps = node_item->dependents_for_execution; + if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) { + GELOGD("%s->%s Already has dependency, skip it", + nearest_dep_node->node->GetName().c_str(), + node_item->NodeName().c_str()); + continue; + } + nearest_dep_node->has_observer = true; + deps.emplace_back(nearest_dep_node->node); + GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]", + parallel_group.c_str(), + nearest_dep_node->NodeName().c_str(), + node_item->NodeName().c_str()); + } + } + } + return SUCCESS; +} } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 313d5ca6..a59a282a 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -57,14 +57,17 @@ class HybridModelBuilder { Status ValidateParams(); Status LoadGraph(); Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); + Status LoadTask(NodeItem &node_item); Status LoadTasks(); Status IdentifyVariableOutputs(NodeItem &node_item); Status IdentifySameInputs(NodeItem &node_item); Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); + Status CollectParallelGroups(NodeItem *node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); - Status ParseDependentForFusedSubgraph(NodeItem &node_item); + Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set &dependencies); + Status ParseDependentByParallelGroup(); Status IndexTaskDefs(); Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model); Status IndexSpecialNodes(); @@ -97,12 +100,14 @@ class HybridModelBuilder { NodeItem *MutableNodeItem(const NodePtr &node); GeRootModelPtr ge_root_model_; + ComputeGraphPtr root_graph_; std::map subgraph_models_; std::map constant_op_nodes_; + std::map> parallel_group_to_nodes_; + std::map> node_to_parallel_groups_; HybridModel &hybrid_model_; std::map>> node_ref_inputs_; - int node_index = 0; RuntimeParam &runtime_param_; VarManager *var_manager_ = nullptr; diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 805064be..06d654cf 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -251,6 +251,10 @@ bool NodeItem::IsControlOp() const { return ge::hybrid::IsControlOp(op_desc->GetType()); } +bool NodeItem::IsHcclOp() const { + return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL; +} + std::string NodeItem::DebugString() const { std::stringstream ss; ss << "Node: "; diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 631dbd9e..474a1da4 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -67,6 +67,8 @@ struct NodeItem { bool IsControlOp() const; + bool IsHcclOp() const; + void SetToDynamic(); std::string DebugString() const; diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc index bb96c275..45882343 100644 --- a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc +++ b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc @@ -95,13 +95,6 @@ Status KnownNodeTask::UpdateArgs(TaskContext &context) { Status KnownNodeTask::Init(TaskContext &context) { // allocate output mem GE_CHK_STATUS_RET(context.AllocateOutputs(), "known node task allocate output failed."); - - // init davinicmodel - if (!load_flag_) { - davinci_model_->InitRuntimeParams(); - GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed."); - } - // allocate mem base void *buffer = nullptr; if (davinci_model_->TotalMemSize() != 0) { @@ -129,23 +122,31 @@ Status KnownNodeTask::Init(TaskContext &context) { void *global_step = context.GetExecutionContext()->global_step; davinci_model_->SetKnownShapeGlobalStep(global_step); } - int32_t device_id = 0; - rtError_t rt_ret = rtGetDevice(&device_id); - if (rt_ret != RT_ERROR_NONE || device_id < 0) { - GELOGE(rt_ret, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); - return RT_ERROR_TO_GE_STATUS(rt_ret); - } - davinci_model_->SetDeviceId(device_id); - GE_CHK_STATUS_RET(davinci_model_->Init(), "KnownNodeExecutor::InitDavinciModel failed."); load_flag_ = true; - } else { - GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(), - davinci_model_->Id(), davinci_model_->SubModelId()), "KnownNodeTask::Init destroy aicpu kernel failed."); } + GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(), + davinci_model_->Id(), davinci_model_->SubModelId()), + "KnownNodeTask::Init destroy aicpu kernel failed."); GELOGI("[%s] KnownNodeExecutor::Init success.", context.GetNodeName()); return SUCCESS; } +Status KnownNodeTask::InitDavinciModel() { + GELOGD("[Init][Model] start"); + davinci_model_->InitRuntimeParams(); + GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed"); + int32_t device_id = 0; + GE_CHK_RT_RET(rtGetDevice(&device_id)); + davinci_model_->SetDeviceId(static_cast(device_id)); + GE_CHK_STATUS_RET(DoInitDavinciModel(), "[Init][Model] Failed to init davinci model."); + GELOGD("[Init][Model] success"); + return SUCCESS; +} + +Status KnownNodeTask::DoInitDavinciModel() { + return davinci_model_->Init(); +} + Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GELOGD("[%s] KnownNodeExecutor::PrepareTask in.", context.GetNodeName()); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorPrepareTask] Start"); @@ -182,9 +183,11 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed."); - task = MakeShared(davinci_model); - GE_CHECK_NOTNULL(task); + auto known_node_task = MakeShared(davinci_model); + GE_CHECK_NOTNULL(known_node_task); + GE_CHK_STATUS_RET_NOLOG(known_node_task->InitDavinciModel()); GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str()); + task = std::move(known_node_task); return SUCCESS; } diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h index 6e9740ad..5eed528a 100644 --- a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h +++ b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h @@ -31,11 +31,15 @@ class KnownNodeTask : public NodeTask { : davinci_model_(davinci_model) {} - ~KnownNodeTask() {} + ~KnownNodeTask() = default; Status UpdateArgs(TaskContext &context) override; Status ExecuteAsync(TaskContext &context, std::function done_callback) override; Status Init(TaskContext &context) override; + Status InitDavinciModel(); + + protected: + virtual Status DoInitDavinciModel(); private: std::shared_ptr davinci_model_ = nullptr; bool load_flag_ = false; @@ -47,8 +51,6 @@ class KnownNodeExecutor : public NodeExecutor { Status PrepareTask(NodeTask &task, TaskContext &context) const; Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const; ~KnownNodeExecutor() {} - private: - std::shared_ptr davinci_model_ = nullptr; }; } // namespace hybrid } // namespace ge