diff --git a/.clang-format b/.clang-format index c931e8f0..dd8abe32 100644 --- a/.clang-format +++ b/.clang-format @@ -11,7 +11,7 @@ AlignTrailingComments: true AllowAllParametersOfDeclarationOnNextLine: true AllowShortBlocksOnASingleLine: false AllowShortCaseLabelsOnASingleLine: false -AllowShortFunctionsOnASingleLine: All +AllowShortFunctionsOnASingleLine: Empty AllowShortIfStatementsOnASingleLine: true AllowShortLoopsOnASingleLine: true AlwaysBreakAfterDefinitionReturnType: None @@ -50,9 +50,8 @@ CommentPragmas: '^ IWYU pragma:' CompactNamespaces: false ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerIndentWidth: 4 -ContinuationIndentWidth: 2 +ContinuationIndentWidth: 4 Cpp11BracedListStyle: true -DerivePointerAlignment: true DisableFormat: false ExperimentalAutoDetectBinPacking: false FixNamespaceComments: true @@ -94,7 +93,7 @@ PenaltyBreakString: 1000 PenaltyBreakTemplateDeclaration: 10 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left +PointerAlignment: Right RawStringFormats: - Language: Cpp Delimiters: diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e1786f4..5e58eeba 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,13 +88,12 @@ else () find_module(hccl libhccl.so ${GE_LIB_PATH}) find_module(adump_server libadump_server.a ${GE_LIB_PATH}) find_module(runtime libruntime.so ${GE_LIB_PATH}) - find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) find_module(resource libresource.so ${GE_LIB_PATH}) find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) - #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) else() find_module(slog libalog.so ${ASCEND_ATC_DIR}) + find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR}) find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) if(PLATFORM STREQUAL "train") find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) @@ -107,7 +106,6 @@ else () elseif(PLATFORM STREQUAL "inference") find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) - find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) if(PRODUCT STREQUAL "flr3") elseif(PRODUCT STREQUAL "flr1") @@ -118,12 +116,11 @@ else () find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) endif() elseif(PLATFORM STREQUAL "all") - find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) - find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) + find_module(runtime libruntime.so ${ASCEND_ATC_DIR}) find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) - find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) - find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) - find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) + find_module(ascend_hal_stub libascend_hal.so ${ASCEND_ATC_DIR}/stub) + find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) else() message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") endif() diff --git a/README_CN.md b/README_CN.md index 0a1e9c09..48fe4216 100644 --- a/README_CN.md +++ b/README_CN.md @@ -34,18 +34,6 @@ 在训练/推理过程中,上述过程会自动执行,通过上述图操作,GE可以将前端下发的图转换为一种可以在昇腾AI处理器上高效运行的图模式。 - - -- [安装说明](#安装说明) - - [安装GE](#安装ge) - - [源码安装](#源码安装) - - [社区](#社区) - - [贡献](#贡献) - - [Release Notes](#release-notes) - - [License](#license) - - - # 安装说明 ## 安装GE @@ -54,45 +42,8 @@ GE内嵌在MindSpore安装包中,MindSpore安装完毕后,GE以三个动态 ## 源码安装 -GE也支持由源码编译,进行源码编译前,首先确保你有昇腾910 AI处理器的环境,同时系统满足以下要求: - -- GCC >= 7.3.0 -- CMake >= 3.14.0 -- Autoconf >= 2.64 -- Libtool >= 2.4.6 -- Automake >= 1.15.1 - -编译完成后会生成几个动态库,他们会链接到MindSpore中执行,无法单独运行。 - -1. 下载GE源码。 - - GE源码托管在码云平台,可由此下载。 - ``` - git clone https://gitee.com/mindspore/graphengine.git - cd graphengine - ``` - -2. 在GE根目录下执行下列命令即可进行编译。 - - ``` - bash build.sh - ``` - - > - 开始编译之前,请确保正确设置相关的环境变量。 - > - 在`build.sh`的脚本中,会进行`git clone`操作,请确保网络连接正常且git配置正确。 - > - 在`build.sh`的脚本中,默认会8线程编译,如果机器性能较差,可能会编译失败。可以通过`-j{线程数}`来控制线程数,如`bash build.sh –j4`。 - -3. 完成编译后,相应的动态库文件会生成在output文件夹中。 - -更多指令帮助,可以使用: -``` -bash build.sh –h -``` -如果想清除历史编译记录,可以如下操作: -``` -rm -rf build/ output/ -bash build.sh -``` +GE也支持由源码编译,请参考以下链接完成: +[个人开发工具链](https://gitee.com/mindspore/graphengine/blob/master/scripts/readme.md) ## 社区 diff --git a/build.sh b/build.sh index 96c46e1a..dbbf696b 100755 --- a/build.sh +++ b/build.sh @@ -144,7 +144,6 @@ build_graphengine() CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_UT=ON" fi - if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_ST=ON" fi @@ -176,7 +175,7 @@ build_graphengine() TARGET="ge_compiler atc_atc.bin ge_executor_shared ${TARGET}" elif [ "X$ENABLE_GE_ST" = "Xon" ] then - TARGET="ge_graph_dsl_test graph_engine_test" + TARGET="ge_graph_dsl_test ge_running_env_test graph_engine_test" elif [ "X$ENABLE_GE_UT" = "Xon" ] then TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" @@ -244,13 +243,13 @@ if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then mkdir -p ${OUTPUT_PATH}/plugin/opskernel cp ${BUILD_PATH}/tests/framework/libnnengine.so ${OUTPUT_PATH}/plugin/nnengine cp ${BUILD_PATH}/engine_conf.json ${OUTPUT_PATH}/plugin/nnengine/ge_config - cp ${BUILD_PATH}/tests/framework/libhost_cpu_engine.so ${OUTPUT_PATH}/plugin/opskernel cp ${BUILD_PATH}/tests/framework/libge_local_engine.so ${OUTPUT_PATH}/plugin/opskernel - cp ${BUILD_PATH}/tests/framework/stub_engine/libfe.so ${OUTPUT_PATH}/plugin/opskernel #prepare st execution bin cp ${BUILD_PATH}/tests/st/testcase/graph_engine_test ${OUTPUT_PATH} + cp ${BUILD_PATH}/tests/framework/ge_running_env/tests/ge_running_env_test ${OUTPUT_PATH} cp ${BUILD_PATH}/tests/framework/ge_graph_dsl/tests/ge_graph_dsl_test ${OUTPUT_PATH} #execute st testcase + RUN_TEST_CASE=${OUTPUT_PATH}/ge_running_env_test && ${RUN_TEST_CASE} RUN_TEST_CASE=${OUTPUT_PATH}/graph_engine_test && ${RUN_TEST_CASE} RUN_TEST_CASE=${OUTPUT_PATH}/ge_graph_dsl_test && ${RUN_TEST_CASE} if [[ "$?" -ne 0 ]]; then @@ -355,13 +354,13 @@ generate_package() if [ "x${PLATFORM}" = "xtrain" ] then - tar -cf graphengine_lib.tar fwkacllib + tar -zcf graphengine_lib.tar fwkacllib elif [ "x${PLATFORM}" = "xinference" ] then - tar -cf graphengine_lib.tar acllib atc + tar -zcf graphengine_lib.tar acllib atc elif [ "x${PLATFORM}" = "xall" ] then - tar -cf graphengine_lib.tar fwkacllib acllib atc + tar -zcf graphengine_lib.tar fwkacllib acllib atc fi } @@ -371,6 +370,6 @@ elif [ "X$MINDSPORE_MODE" = "Xon" ] then cd "${OUTPUT_PATH}" find ./ -name graphengine_lib.tar -exec rm {} \; - tar -cf graphengine_lib.tar lib + tar -zcf graphengine_lib.tar lib fi echo "---------------- GraphEngine package archive generated ----------------" diff --git a/cmake/external_libs/gflags.cmake b/cmake/external_libs/gflags.cmake index 50cfb2bc..b4b57dd7 100755 --- a/cmake/external_libs/gflags.cmake +++ b/cmake/external_libs/gflags.cmake @@ -10,12 +10,17 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") endif() -if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") - set(MD5 "") +if (GE_PB_PKG) + set(REQ_URL "${GE_PB_PKG}/libs/gflags/v2.2.2.tar.gz") + set(MD5 "1a865b93bacfa963201af3f75b7bd64c") else() - set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") - set(MD5 "") + if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") + set(MD5 "") + else() + set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") + set(MD5 "1a865b93bacfa963201af3f75b7bd64c") + endif () endif () set (gflags_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -Dgoogle=ascend_private") diff --git a/cmake/external_libs/protobuf_shared.cmake b/cmake/external_libs/protobuf_shared.cmake index 6334c8a3..dfdb0606 100755 --- a/cmake/external_libs/protobuf_shared.cmake +++ b/cmake/external_libs/protobuf_shared.cmake @@ -11,14 +11,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") endif() if (GE_PB_PKG) - set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") + set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz") else() if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") - set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") + set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") else() - set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") - set(MD5 "3d9e32700639618a4d2d342c99d4507a") + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") + set(MD5 "1a6274bc4a65b55a6fa70e264d796490") endif () endif() @@ -58,7 +58,7 @@ target_include_directories(ascend_protobuf INTERFACE ${PROTOBUF_SHARED_PKG_DIR}/ set(INSTALL_BASE_DIR "") set(INSTALL_LIBRARY_DIR lib) -install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.8.0.0 OPTIONAL +install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.13.0.0 OPTIONAL DESTINATION ${INSTALL_LIBRARY_DIR}) install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL DESTINATION ${INSTALL_LIBRARY_DIR}) diff --git a/cmake/external_libs/protobuf_static.cmake b/cmake/external_libs/protobuf_static.cmake index 22f537cf..51f6ffbc 100755 --- a/cmake/external_libs/protobuf_static.cmake +++ b/cmake/external_libs/protobuf_static.cmake @@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR endif() if(GE_PB_PKG) - set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") + set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz") else() if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") - set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") + set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") else() - set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") - set(MD5 "3d9e32700639618a4d2d342c99d4507a") + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") + set(MD5 "1a6274bc4a65b55a6fa70e264d796490") endif () endif() @@ -29,8 +29,6 @@ set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) ExternalProject_Add(protobuf_static_build URL ${REQ_URL} - #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz - #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 TLS_VERIFY OFF CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} diff --git a/cmake/external_libs/protoc.cmake b/cmake/external_libs/protoc.cmake index 421f2632..f16f5e22 100755 --- a/cmake/external_libs/protoc.cmake +++ b/cmake/external_libs/protoc.cmake @@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR endif() if(GE_PB_PKG) - set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") + set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz") else() if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") - set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") + set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") else() - set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") - set(MD5 "3d9e32700639618a4d2d342c99d4507a") + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") + set(MD5 "1a6274bc4a65b55a6fa70e264d796490") endif () endif() @@ -28,8 +28,6 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") ExternalProject_Add(protoc_build URL ${REQ_URL} - #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz - #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 TLS_VERIFY OFF CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc /cmake BUILD_COMMAND $(MAKE) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 215d2832..f98297d8 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -2,7 +2,6 @@ if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) add_subdirectory(common) add_subdirectory(plugin/engine) add_subdirectory(ge_local_engine) - add_subdirectory(executor) add_subdirectory(offline) elseif (ENABLE_D) add_subdirectory(common) @@ -109,81 +108,42 @@ target_link_libraries(ge_proto_client PRIVATE endif () ################################################################## -set(TRAIN_SRC_LIST - "common/formats/format_transfers/datatype_transfer.cc" - "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" - "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" - "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" - "common/formats/format_transfers/format_transfer_fractal_nz.cc" - "common/formats/format_transfers/format_transfer_fractal_z.cc" - "common/formats/format_transfers/format_transfer_fractal_zz.cc" - "common/formats/format_transfers/format_transfer_fracz_hwcn.cc" - "common/formats/format_transfers/format_transfer_fracz_nchw.cc" - "common/formats/format_transfers/format_transfer_fracz_nhwc.cc" - "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" - "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_transpose.cc" - "common/formats/formats.cc" - "common/formats/utils/formats_trans_utils.cc" - "common/fp16_t.cc" - "common/ge/plugin_manager.cc" - "common/ge/op_tiling_manager.cc" - "common/helper/model_cache_helper.cc" - "common/profiling/profiling_manager.cc" - "common/dump/dump_manager.cc" +set(EXECUTOR_SRC_LIST + "common/dump/dump_op.cc" "common/dump/exception_dumper.cc" - "common/dump/dump_properties.cc" "common/dump/opdebug_register.cc" - "common/dump/dump_op.cc" + "common/ge/op_tiling_manager.cc" + "common/ge/plugin_manager.cc" "common/profiling/ge_profiling.cc" - "common/profiling/ge_runner_profiling.cc" - "engine_manager/dnnengine_manager.cc" + "common/profiling/profiling_manager.cc" + "executor/ge_executor.cc" "ge_local_engine/engine/host_cpu_engine.cc" - "generator/ge_generator.cc" - "generator/generator_api.cc" - "graph/build/graph_builder.cc" - "graph/build/label_allocator.cc" - "graph/build/logical_stream_allocator.cc" - "graph/build/model_builder.cc" - "graph/build/run_context.cc" - "graph/build/stream_allocator.cc" - "graph/build/stream_graph_optimizer.cc" - "graph/build/task_generator.cc" - "graph/common/bcast.cc" - "graph/common/local_context.cc" - "graph/common/omg_util.cc" - "graph/common/transop_util.cc" + "graph/build/memory/var_mem_assign_util.cc" "graph/execute/graph_execute.cc" - "graph/label/case_label_maker.cc" - "graph/label/if_label_maker.cc" - "graph/label/label_maker.cc" - "graph/label/partitioned_call_label_maker.cc" - "graph/label/while_label_maker.cc" + "graph/execute/model_executor.cc" "graph/load/graph_loader.cc" + "graph/load/model_manager/aipp_utils.cc" "graph/load/model_manager/cpu_queue_schedule.cc" "graph/load/model_manager/data_dumper.cc" "graph/load/model_manager/data_inputer.cc" "graph/load/model_manager/davinci_model.cc" "graph/load/model_manager/model_manager.cc" "graph/load/model_manager/model_utils.cc" - "graph/load/model_manager/aipp_utils.cc" "graph/load/model_manager/task_info/end_graph_task_info.cc" - "graph/load/model_manager/task_info/model_exit_task_info.cc" "graph/load/model_manager/task_info/event_record_task_info.cc" "graph/load/model_manager/task_info/event_wait_task_info.cc" + "graph/load/model_manager/task_info/ffts_task_info.cc" "graph/load/model_manager/task_info/fusion_start_task_info.cc" "graph/load/model_manager/task_info/fusion_stop_task_info.cc" - "graph/load/model_manager/task_info/hccl_task_info.cc" + #"graph/load/model_manager/task_info/hccl_task_info.cc" # Just for runner. "graph/load/model_manager/task_info/kernel_ex_task_info.cc" "graph/load/model_manager/task_info/kernel_task_info.cc" + "graph/load/model_manager/task_info/label_goto_ex_task_info.cc" "graph/load/model_manager/task_info/label_set_task_info.cc" "graph/load/model_manager/task_info/label_switch_by_index_task_info.cc" - "graph/load/model_manager/task_info/label_goto_ex_task_info.cc" "graph/load/model_manager/task_info/memcpy_addr_async_task_info.cc" "graph/load/model_manager/task_info/memcpy_async_task_info.cc" + "graph/load/model_manager/task_info/model_exit_task_info.cc" "graph/load/model_manager/task_info/profiler_trace_task_info.cc" "graph/load/model_manager/task_info/stream_active_task_info.cc" "graph/load/model_manager/task_info/stream_switch_task_info.cc" @@ -192,66 +152,22 @@ set(TRAIN_SRC_LIST "graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc" "graph/load/model_manager/task_info/task_info.cc" "graph/load/model_manager/tbe_handle_store.cc" - "graph/load/model_manager/zero_copy_task.cc" "graph/load/model_manager/zero_copy_offset.cc" - "graph/manager/graph_context.cc" - "graph/manager/graph_manager.cc" + "graph/load/model_manager/zero_copy_task.cc" + "graph/manager/graph_caching_allocator.cc" "graph/manager/graph_manager_utils.cc" "graph/manager/graph_mem_allocator.cc" - "graph/manager/graph_caching_allocator.cc" - "graph/manager/session_scope_mem_allocator.cc" + "graph/manager/graph_mem_manager.cc" "graph/manager/graph_var_manager.cc" + "graph/manager/host_mem_allocator.cc" "graph/manager/host_mem_manager.cc" + #"graph/manager/memory_api.cc" # Just for runner. "graph/manager/rdma_pool_allocator.cc" - "graph/manager/host_mem_allocator.cc" - "graph/manager/graph_mem_manager.cc" - "graph/manager/memory_api.cc" - "graph/manager/model_manager/event_manager.cc" + "graph/manager/session_scope_mem_allocator.cc" "graph/manager/trans_var_data_utils.cc" "graph/manager/util/debug.cc" - "graph/manager/util/hcom_util.cc" - "graph/manager/util/rt_context_util.cc" - "graph/manager/util/variable_accelerate_ctrl.cc" - "graph/optimize/graph_optimize.cc" - "graph/optimize/mem_rw_conflict_optimize.cc" - "graph/optimize/summary_optimize.cc" - "graph/partition/engine_place.cc" - "graph/partition/graph_partition.cc" - "graph/passes/addn_pass.cc" - "graph/passes/aicpu_constant_folding_pass.cc" - "graph/passes/assert_pass.cc" - "graph/passes/input_output_connection_identify_pass.cc" - "graph/passes/atomic_addr_clean_pass.cc" - "graph/passes/mark_same_addr_pass.cc" - "graph/passes/mark_graph_unknown_status_pass.cc" - "graph/passes/mark_node_unknown_shape_pass.cc" - "graph/passes/mark_agnostic_pass.cc" - "graph/partition/dynamic_shape_partition.cc" - "graph/partition/stage_partition.cc" - "graph/passes/base_pass.cc" - "graph/passes/bitcast_pass.cc" - "graph/passes/cast_remove_pass.cc" - "graph/passes/cast_translate_pass.cc" - "graph/passes/common_subexpression_elimination_pass.cc" - "graph/passes/transop_symmetry_elimination_pass.cc" - "graph/passes/compile_nodes_pass.cc" - "graph/passes/constant_folding_pass.cc" - "graph/passes/constant_fuse_same_pass.cc" - "graph/passes/fuse_data_nodes_with_common_input_pass.cc" - "graph/passes/remove_same_const_pass.cc" - "graph/passes/useless_control_out_remove_pass.cc" - "graph/passes/control_trigger_pass.cc" - "graph/passes/dimension_adjust_pass.cc" - "graph/passes/dimension_compute_pass.cc" - "graph/passes/dropout_pass.cc" - "graph/passes/hccl_group_pass.cc" - "graph/passes/hccl_tailing_optimization_pass.cc" - "graph/passes/enter_pass.cc" - "graph/passes/assign_remove_pass.cc" - "graph/passes/inplace_support_check_pass.cc" - "graph/passes/flow_ctrl_pass.cc" - "graph/passes/global_step_insert_pass.cc" - "host_kernels/transpose_kernel.cc" + #"graph/manager/util/hcom_util.cc" # Just for runner. + "graph/passes/pass_utils.cc" "host_kernels/add_kernel.cc" "host_kernels/broadcast_args_kernel.cc" "host_kernels/broadcast_gradient_args_kernel.cc" @@ -259,7 +175,6 @@ set(TRAIN_SRC_LIST "host_kernels/concat_offset_kernel.cc" "host_kernels/concat_v2_kernel.cc" "host_kernels/dynamic_stitch_kernel.cc" - "host_kernels/identity_kernel.cc" "host_kernels/empty_kernel.cc" "host_kernels/expanddims_kernel.cc" "host_kernels/fill_kernel.cc" @@ -267,6 +182,7 @@ set(TRAIN_SRC_LIST "host_kernels/floormod_kernel.cc" "host_kernels/gather_v2_kernel.cc" "host_kernels/greater_kernel.cc" + "host_kernels/identity_kernel.cc" "host_kernels/kernel_utils.cc" "host_kernels/maximum_kernel.cc" "host_kernels/mul_kernel.cc" @@ -275,6 +191,7 @@ set(TRAIN_SRC_LIST "host_kernels/range_kernel.cc" "host_kernels/rank_kernel.cc" "host_kernels/reduce_prod_kernel.cc" + "host_kernels/reformat_kernel.cc" "host_kernels/reshape_kernel.cc" "host_kernels/rsqrt_kernel.cc" "host_kernels/shape_kernel.cc" @@ -283,439 +200,294 @@ set(TRAIN_SRC_LIST "host_kernels/slice_d_kernel.cc" "host_kernels/slice_kernel.cc" "host_kernels/squeeze_kernel.cc" - "host_kernels/unsqueeze_kernel.cc" "host_kernels/ssd_prior_box_kernel.cc" "host_kernels/strided_slice_kernel.cc" "host_kernels/sub_kernel.cc" "host_kernels/transdata_kernel.cc" + "host_kernels/transpose_kernel.cc" "host_kernels/unpack_kernel.cc" - "host_kernels/reformat_kernel.cc" - "graph/passes/folding_pass.cc" - "graph/passes/get_original_format_pass.cc" - "graph/passes/guarantee_const_pass.cc" - "graph/passes/hccl_memcpy_pass.cc" - "graph/passes/hccl_continuous_memcpy_pass.cc" - "graph/passes/identity_pass.cc" - "graph/passes/ref_identity_delete_op_pass.cc" - "graph/passes/infershape_pass.cc" - "graph/passes/iterator_op_pass.cc" - "graph/passes/link_gen_mask_nodes_pass.cc" - "graph/passes/merge_pass.cc" - "graph/passes/multi_batch_pass.cc" - "graph/passes/multi_batch_clone_pass.cc" - "graph/passes/subexpression_migration_pass.cc" - "graph/passes/subgraph_const_migration_pass.cc" - "graph/passes/unused_args_clean_pass.cc" - "graph/passes/net_output_pass.cc" - "graph/passes/next_iteration_pass.cc" - "graph/passes/no_use_reshape_remove_pass.cc" - "graph/passes/pass_manager.cc" - "graph/passes/pass_utils.cc" - "graph/passes/permute_pass.cc" - "graph/passes/placeholder_with_default_pass.cc" - "graph/passes/prevent_gradient_pass.cc" - "graph/passes/print_op_pass.cc" - "graph/passes/prune_pass.cc" - "graph/passes/ctrl_edge_transfer_pass.cc" - "graph/passes/replace_with_empty_const_pass.cc" - "graph/passes/reshape_remove_pass.cc" - "graph/passes/reshape_recovery_pass.cc" - "graph/passes/resource_pair_add_control_pass.cc" - "graph/passes/resource_pair_remove_control_pass.cc" - "graph/passes/same_transdata_breadth_fusion_pass.cc" - "graph/passes/save_pass.cc" - "graph/passes/shape_operate_op_remove_pass.cc" - "graph/passes/snapshot_pass.cc" - "graph/passes/stop_gradient_pass.cc" - "graph/passes/subgraph_pass.cc" - "graph/passes/data_pass.cc" - "graph/passes/switch_data_edges_bypass.cc" - "graph/passes/switch_logic_remove_pass.cc" - "graph/passes/merge_to_stream_merge_pass.cc" - "graph/passes/merge_input_memcpy_pass.cc" - "graph/passes/switch_to_stream_switch_pass.cc" - "graph/passes/mark_force_unknown_for_cond_pass.cc" - "graph/passes/attach_stream_label_pass.cc" - "graph/passes/switch_dead_branch_elimination.cc" - "graph/passes/replace_transshape_pass.cc" - "graph/passes/transop_breadth_fusion_pass.cc" - "graph/passes/transop_depth_fusion_pass.cc" - "graph/passes/transop_nearby_allreduce_fusion_pass.cc" - "graph/passes/transop_without_reshape_fusion_pass.cc" - "graph/passes/transpose_transdata_pass.cc" - "graph/passes/unused_const_pass.cc" - "graph/passes/var_is_initialized_op_pass.cc" - "graph/passes/parallel_concat_start_op_pass.cc" - "graph/passes/cond_pass.cc" - "graph/passes/cond_remove_pass.cc" - "graph/passes/for_pass.cc" - "graph/passes/variable_op_pass.cc" - "graph/passes/variable_prepare_op_pass.cc" - "graph/passes/variable_ref_delete_op_pass.cc" - "graph/passes/variable_ref_useless_control_out_delete_pass.cc" - "graph/passes/end_of_sequence_add_control_pass.cc" - "graph/passes/memcpy_addr_async_pass.cc" - "graph/passes/parallel_group_pass.cc" - "graph/passes/set_input_output_offset_pass.cc" - "graph/passes/buffer_pool_memory_pass.cc" - "graph/preprocess/graph_preprocess.cc" - "graph/preprocess/insert_op/ge_aipp_op.cc" - "graph/preprocess/insert_op/util_insert_aipp_op.cc" - "graph/preprocess/multi_batch_options.cc" - "graph/preprocess/multi_batch_copy_graph.cc" - "init/gelib.cc" - "model/ge_model.cc" - "model/ge_root_model.cc" - "opskernel_manager/ops_kernel_manager.cc" - "opskernel_manager/ops_kernel_builder_manager.cc" - "session/inner_session.cc" - "session/session_manager.cc" - "single_op/single_op.cc" - "single_op/single_op_manager.cc" - "single_op/single_op_model.cc" - "single_op/stream_resource.cc" - "single_op/task/build_task_utils.cc" - "single_op/task/op_task.cc" - "single_op/task/tbe_task_builder.cc" - "single_op/task/aicpu_task_builder.cc" - "single_op/task/aicpu_kernel_task_builder.cc" - "single_op/task/rts_kernel_task_builder.cc" - "hybrid/common/tensor_value.cc" + "host_kernels/unsqueeze_kernel.cc" "hybrid/common/npu_memory_allocator.cc" - "hybrid/executor/rt_callback_manager.cc" - "hybrid/executor/node_state.cc" - "hybrid/executor/node_done_manager.cc" - "hybrid/executor/hybrid_profiler.cc" + "hybrid/common/tensor_value.cc" + "hybrid/executor/hybrid_execution_context.cc" + "hybrid/executor/hybrid_model_async_executor.cc" "hybrid/executor/hybrid_model_executor.cc" "hybrid/executor/hybrid_model_pipeline_executor.cc" - "hybrid/executor/hybrid_model_async_executor.cc" - "hybrid/executor/hybrid_execution_context.cc" + "hybrid/executor/hybrid_profiler.cc" + "hybrid/executor/node_done_manager.cc" + "hybrid/executor/node_state.cc" + "hybrid/executor/rt_callback_manager.cc" "hybrid/executor/subgraph_context.cc" "hybrid/executor/subgraph_executor.cc" - "hybrid/executor/worker/task_compile_engine.cc" - "hybrid/executor/worker/shape_inference_engine.cc" "hybrid/executor/worker/execution_engine.cc" + "hybrid/executor/worker/shape_inference_engine.cc" + "hybrid/executor/worker/task_compile_engine.cc" + "hybrid/hybrid_davinci_model.cc" + "hybrid/model/graph_item.cc" "hybrid/model/hybrid_model.cc" "hybrid/model/hybrid_model_builder.cc" "hybrid/model/node_item.cc" - "hybrid/model/graph_item.cc" "hybrid/node_executor/aicore/aicore_node_executor.cc" "hybrid/node_executor/aicore/aicore_op_task.cc" "hybrid/node_executor/aicore/aicore_task_builder.cc" - "hybrid/node_executor/aicore/aicore_task_compiler.cc" "hybrid/node_executor/aicpu/aicpu_ext_info.cc" "hybrid/node_executor/aicpu/aicpu_node_executor.cc" "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" + "hybrid/node_executor/controlop/control_op_executor.cc" "hybrid/node_executor/ge_local/ge_local_node_executor.cc" + #"hybrid/node_executor/hccl/hccl_node_executor.cc" # Just for runner. "hybrid/node_executor/host_cpu/host_cpu_node_executor.cc" - "hybrid/node_executor/controlop/control_op_executor.cc" + "hybrid/node_executor/node_executor.cc" "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" - "hybrid/node_executor/hccl/hccl_node_executor.cc" "hybrid/node_executor/rts/rts_node_executor.cc" "hybrid/node_executor/rts/rts_node_task.cc" "hybrid/node_executor/rts/rts_task_factory.cc" - "hybrid/node_executor/node_executor.cc" "hybrid/node_executor/task_context.cc" - "hybrid/hybrid_davinci_model.cc" - "executor/ge_executor.cc" - "client/ge_api.cc" + "opskernel_manager/ops_kernel_builder_manager.cc" + "single_op/single_op.cc" + "single_op/single_op_manager.cc" + "single_op/single_op_model.cc" + "single_op/stream_resource.cc" + "single_op/task/aicpu_kernel_task_builder.cc" + "single_op/task/aicpu_task_builder.cc" + "single_op/task/build_task_utils.cc" + "single_op/task/op_task.cc" + "single_op/task/rts_kernel_task_builder.cc" + "single_op/task/tbe_task_builder.cc" +) + +################################################################## +set(COMPILER_SRC_LIST "analyzer/analyzer.cc" - "ir_build/ge_ir_build.cc" - "ir_build/attr_options/utils.cc" - "ir_build/attr_options/keep_dtype_option.cc" - "ir_build/attr_options/weight_compress_option.cc" - "ir_build/option_utils.cc" - "graph/build/memory/memory_assigner.cc" - "graph/build/memory/graph_mem_assigner.cc" + "common/dump/dump_op.cc" + "common/ge/op_tiling_manager.cc" + "common/ge/plugin_manager.cc" + "common/profiling/profiling_manager.cc" + "engine_manager/dnnengine_manager.cc" + "ge_local_engine/engine/host_cpu_engine.cc" + "ge_opt_info/ge_opt_info.cc" + "generator/ge_generator.cc" + "generator/generator_api.cc" + "graph/build/graph_builder.cc" + "graph/build/label_allocator.cc" + "graph/build/logical_stream_allocator.cc" "graph/build/memory/binary_block_mem_assigner.cc" "graph/build/memory/block_mem_assigner.cc" + "graph/build/memory/buffer_pool_mem_assigner.cc" + "graph/build/memory/graph_mem_assigner.cc" "graph/build/memory/hybrid_mem_assigner.cc" "graph/build/memory/max_block_mem_assigner.cc" + "graph/build/memory/memory_assigner.cc" "graph/build/memory/var_mem_assign_util.cc" - "graph/build/memory/buffer_pool_mem_assigner.cc" -) - -set(INFER_SRC_LIST - "graph/manager/trans_var_data_utils.cc" - "common/fp16_t.cc" - "common/formats/utils/formats_trans_utils.cc" - "common/formats/format_transfers/datatype_transfer.cc" - "common/formats/format_transfers/format_transfer_transpose.cc" - "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_fractal_z.cc" - "common/formats/format_transfers/format_transfer_fractal_nz.cc" - "common/formats/format_transfers/format_transfer_fractal_zz.cc" - "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" - "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" - "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" - "common/formats/format_transfers/format_transfer_fracz_nchw.cc" - "common/formats/format_transfers/format_transfer_fracz_nhwc.cc" - "common/formats/format_transfers/format_transfer_fracz_hwcn.cc" - "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" - "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" - "common/formats/format_transfers/format_transfer_nchw_fz_c04.cc" - "common/formats/formats.cc" - "common/profiling/profiling_manager.cc" - "common/dump/dump_properties.cc" - "common/dump/exception_dumper.cc" - "common/dump/dump_manager.cc" - "common/dump/dump_op.cc" - "common/dump/opdebug_register.cc" - "common/dump/dump_server.cc" - "common/helper/model_cache_helper.cc" - "ge_local_engine/engine/host_cpu_engine.cc" - "common/ge/plugin_manager.cc" - "common/ge/op_tiling_manager.cc" - "init/gelib.cc" - "session/inner_session.cc" - "session/session_manager.cc" - "engine_manager/dnnengine_manager.cc" - "opskernel_manager/ops_kernel_manager.cc" - "opskernel_manager/ops_kernel_builder_manager.cc" + "graph/build/model_builder.cc" + "graph/build/run_context.cc" + "graph/build/stream_allocator.cc" + "graph/build/stream_graph_optimizer.cc" + "graph/build/task_generator.cc" + "graph/label/case_label_maker.cc" + "graph/label/if_label_maker.cc" + "graph/label/label_maker.cc" + "graph/label/partitioned_call_label_maker.cc" + "graph/label/while_label_maker.cc" + "graph/load/model_manager/model_utils.cc" + "graph/manager/graph_caching_allocator.cc" + "graph/manager/graph_context.cc" "graph/manager/graph_manager.cc" "graph/manager/graph_manager_utils.cc" - "graph/manager/graph_context.cc" - "graph/preprocess/graph_preprocess.cc" - "graph/preprocess/multi_batch_options.cc" - "graph/preprocess/multi_batch_copy_graph.cc" - "graph/execute/graph_execute.cc" - "graph/load/graph_loader.cc" + "graph/manager/graph_mem_allocator.cc" + "graph/manager/graph_mem_manager.cc" + "graph/manager/graph_var_manager.cc" + "graph/manager/host_mem_allocator.cc" + "graph/manager/host_mem_manager.cc" + "graph/manager/rdma_pool_allocator.cc" + "graph/manager/session_scope_mem_allocator.cc" + "graph/manager/trans_var_data_utils.cc" + "graph/manager/util/debug.cc" + "graph/manager/util/rt_context_util.cc" + "graph/manager/util/variable_accelerate_ctrl.cc" "graph/optimize/graph_optimize.cc" "graph/optimize/mem_rw_conflict_optimize.cc" "graph/optimize/summary_optimize.cc" - "graph/build/graph_builder.cc" + "graph/partition/dynamic_shape_partition.cc" "graph/partition/engine_place.cc" "graph/partition/graph_partition.cc" - "graph/partition/dynamic_shape_partition.cc" "graph/partition/stage_partition.cc" - "generator/ge_generator.cc" - "generator/generator_api.cc" - "graph/manager/graph_var_manager.cc" - "graph/manager/host_mem_manager.cc" - "graph/manager/rdma_pool_allocator.cc" - "graph/manager/host_mem_allocator.cc" - "graph/manager/graph_mem_allocator.cc" - "graph/manager/graph_caching_allocator.cc" - "graph/manager/session_scope_mem_allocator.cc" - "graph/manager/graph_mem_manager.cc" - "model/ge_model.cc" - "model/ge_root_model.cc" - "graph/common/transop_util.cc" - "graph/passes/pass_manager.cc" - "graph/passes/resource_pair_add_control_pass.cc" - "graph/passes/resource_pair_remove_control_pass.cc" - "graph/passes/pass_utils.cc" + "graph/passes/addn_pass.cc" + "graph/passes/aicpu_constant_folding_pass.cc" + "graph/passes/assert_pass.cc" + "graph/passes/assign_remove_pass.cc" + "graph/passes/atomic_addr_clean_pass.cc" + "graph/passes/attach_stream_label_pass.cc" "graph/passes/base_pass.cc" "graph/passes/bitcast_pass.cc" - "graph/passes/constant_folding_pass.cc" - "graph/passes/aicpu_constant_folding_pass.cc" - "graph/passes/reshape_remove_pass.cc" - "graph/passes/reshape_recovery_pass.cc" - "graph/passes/transop_breadth_fusion_pass.cc" - "graph/passes/transop_depth_fusion_pass.cc" - "graph/passes/transop_nearby_allreduce_fusion_pass.cc" - "graph/passes/same_transdata_breadth_fusion_pass.cc" - "graph/passes/transop_without_reshape_fusion_pass.cc" + "graph/passes/buffer_pool_memory_pass.cc" + "graph/passes/cast_remove_pass.cc" + "graph/passes/cast_translate_pass.cc" + "graph/passes/common_subexpression_elimination_pass.cc" "graph/passes/compile_nodes_pass.cc" - "graph/passes/variable_prepare_op_pass.cc" - "graph/passes/variable_ref_delete_op_pass.cc" - "graph/passes/variable_ref_useless_control_out_delete_pass.cc" - "graph/passes/subgraph_pass.cc" - "graph/passes/data_pass.cc" - "graph/passes/net_output_pass.cc" - "graph/passes/replace_transshape_pass.cc" + "graph/passes/cond_pass.cc" + "graph/passes/cond_remove_pass.cc" + "graph/passes/constant_folding_pass.cc" "graph/passes/constant_fuse_same_pass.cc" - "graph/passes/fuse_data_nodes_with_common_input_pass.cc" - "graph/passes/print_op_pass.cc" - "graph/passes/no_use_reshape_remove_pass.cc" - "graph/passes/iterator_op_pass.cc" - "graph/passes/input_output_connection_identify_pass.cc" - "graph/passes/atomic_addr_clean_pass.cc" - "graph/passes/mark_same_addr_pass.cc" - "graph/passes/mark_graph_unknown_status_pass.cc" - "graph/passes/mark_node_unknown_shape_pass.cc" - "graph/passes/mark_agnostic_pass.cc" - "graph/common/omg_util.cc" - "graph/common/bcast.cc" - "graph/common/local_context.cc" - "graph/passes/dimension_compute_pass.cc" + "graph/passes/control_trigger_pass.cc" + "graph/passes/ctrl_edge_transfer_pass.cc" + "graph/passes/data_pass.cc" "graph/passes/dimension_adjust_pass.cc" - "graph/passes/get_original_format_pass.cc" - "graph/passes/shape_operate_op_remove_pass.cc" - "graph/passes/assert_pass.cc" + "graph/passes/dimension_compute_pass.cc" "graph/passes/dropout_pass.cc" - "graph/passes/infershape_pass.cc" - "graph/passes/unused_const_pass.cc" - "graph/passes/permute_pass.cc" - "graph/passes/ctrl_edge_transfer_pass.cc" "graph/passes/end_of_sequence_add_control_pass.cc" - "host_kernels/broadcast_gradient_args_kernel.cc" - "host_kernels/greater_kernel.cc" - "host_kernels/gather_v2_kernel.cc" - "host_kernels/maximum_kernel.cc" - "host_kernels/floormod_kernel.cc" - "host_kernels/floordiv_kernel.cc" - "host_kernels/range_kernel.cc" - "host_kernels/shape_kernel.cc" - "host_kernels/size_kernel.cc" - "host_kernels/shape_n_kernel.cc" - "host_kernels/rank_kernel.cc" - "host_kernels/broadcast_args_kernel.cc" - "host_kernels/fill_kernel.cc" - "host_kernels/empty_kernel.cc" - "host_kernels/expanddims_kernel.cc" - "host_kernels/reshape_kernel.cc" - "host_kernels/squeeze_kernel.cc" - "host_kernels/unsqueeze_kernel.cc" - "host_kernels/kernel_utils.cc" - "host_kernels/cast_kernel.cc" - "host_kernels/transdata_kernel.cc" - "host_kernels/unpack_kernel.cc" - "host_kernels/transpose_kernel.cc" - "host_kernels/permute_kernel.cc" - "host_kernels/pack_kernel.cc" - "host_kernels/concat_v2_kernel.cc" - "host_kernels/concat_offset_kernel.cc" - "host_kernels/strided_slice_kernel.cc" - "host_kernels/ssd_prior_box_kernel.cc" - "host_kernels/add_kernel.cc" - "host_kernels/sub_kernel.cc" - "host_kernels/mul_kernel.cc" - "host_kernels/reduce_prod_kernel.cc" - "host_kernels/rsqrt_kernel.cc" - "host_kernels/slice_kernel.cc" - "host_kernels/slice_d_kernel.cc" - "host_kernels/dynamic_stitch_kernel.cc" - "host_kernels/identity_kernel.cc" - "host_kernels/reformat_kernel.cc" - "graph/passes/stop_gradient_pass.cc" - "graph/passes/prevent_gradient_pass.cc" - "graph/passes/identity_pass.cc" - "graph/passes/ref_identity_delete_op_pass.cc" - "graph/passes/placeholder_with_default_pass.cc" - "graph/passes/snapshot_pass.cc" - "graph/passes/guarantee_const_pass.cc" - "graph/passes/var_is_initialized_op_pass.cc" - "graph/passes/parallel_concat_start_op_pass.cc" + "graph/passes/enter_pass.cc" + "graph/passes/flow_ctrl_pass.cc" "graph/passes/folding_pass.cc" - "graph/passes/cast_translate_pass.cc" - "graph/passes/prune_pass.cc" - "graph/passes/merge_to_stream_merge_pass.cc" - "graph/passes/merge_input_memcpy_pass.cc" - "graph/passes/switch_to_stream_switch_pass.cc" + "graph/passes/for_pass.cc" + "graph/passes/fuse_data_nodes_with_common_input_pass.cc" + "graph/passes/get_original_format_pass.cc" + "graph/passes/global_step_insert_pass.cc" + "graph/passes/guarantee_const_pass.cc" + "graph/passes/hccl_continuous_memcpy_pass.cc" + "graph/passes/hccl_group_pass.cc" + "graph/passes/hccl_memcpy_pass.cc" + "graph/passes/hccl_tailing_optimization_pass.cc" + "graph/passes/identity_pass.cc" + "graph/passes/infer_base_pass.cc" + "graph/passes/infer_value_range_pass.cc" + "graph/passes/infershape_pass.cc" + "graph/passes/inplace_support_check_pass.cc" + "graph/passes/input_output_connection_identify_pass.cc" + "graph/passes/iterator_op_pass.cc" + "graph/passes/link_gen_mask_nodes_pass.cc" + "graph/passes/mark_agnostic_pass.cc" "graph/passes/mark_force_unknown_for_cond_pass.cc" - "graph/passes/attach_stream_label_pass.cc" - "graph/passes/multi_batch_pass.cc" + "graph/passes/mark_graph_unknown_status_pass.cc" + "graph/passes/mark_node_unknown_shape_pass.cc" + "graph/passes/mark_same_addr_pass.cc" + "graph/passes/memcpy_addr_async_pass.cc" + "graph/passes/merge_input_memcpy_pass.cc" + "graph/passes/merge_pass.cc" + "graph/passes/merge_to_stream_merge_pass.cc" "graph/passes/multi_batch_clone_pass.cc" - "graph/passes/subexpression_migration_pass.cc" - "graph/passes/subgraph_const_migration_pass.cc" - "graph/passes/unused_args_clean_pass.cc" + "graph/passes/multi_batch_pass.cc" + "graph/passes/net_output_pass.cc" "graph/passes/next_iteration_pass.cc" - "graph/passes/control_trigger_pass.cc" - "graph/passes/cond_pass.cc" - "graph/passes/cond_remove_pass.cc" - "graph/passes/for_pass.cc" - "graph/passes/enter_pass.cc" - "graph/passes/assign_remove_pass.cc" - "graph/passes/inplace_support_check_pass.cc" - "graph/passes/addn_pass.cc" - "graph/passes/common_subexpression_elimination_pass.cc" + "graph/passes/no_use_reshape_remove_pass.cc" + "graph/passes/parallel_concat_start_op_pass.cc" + "graph/passes/parallel_group_pass.cc" + "graph/passes/pass_manager.cc" + "graph/passes/pass_utils.cc" + "graph/passes/permute_pass.cc" + "graph/passes/placeholder_with_default_pass.cc" + "graph/passes/prevent_gradient_pass.cc" + "graph/passes/print_op_pass.cc" + "graph/passes/prune_pass.cc" + "graph/passes/ref_identity_delete_op_pass.cc" "graph/passes/remove_same_const_pass.cc" - "graph/passes/useless_control_out_remove_pass.cc" - "graph/passes/transop_symmetry_elimination_pass.cc" + "graph/passes/replace_transshape_pass.cc" + "graph/passes/replace_with_empty_const_pass.cc" + "graph/passes/reshape_recovery_pass.cc" + "graph/passes/reshape_remove_pass.cc" + "graph/passes/resource_pair_add_control_pass.cc" + "graph/passes/resource_pair_remove_control_pass.cc" + "graph/passes/same_transdata_breadth_fusion_pass.cc" "graph/passes/save_pass.cc" + "graph/passes/set_input_output_offset_pass.cc" + "graph/passes/shape_operate_op_remove_pass.cc" + "graph/passes/snapshot_pass.cc" + "graph/passes/stop_gradient_pass.cc" + "graph/passes/subexpression_migration_pass.cc" + "graph/passes/subgraph_const_migration_pass.cc" + "graph/passes/subgraph_pass.cc" + "graph/passes/switch_data_edges_bypass.cc" "graph/passes/switch_dead_branch_elimination.cc" "graph/passes/switch_logic_remove_pass.cc" - "graph/passes/switch_data_edges_bypass.cc" - "graph/passes/merge_pass.cc" - "graph/passes/variable_op_pass.cc" - "graph/passes/cast_remove_pass.cc" + "graph/passes/switch_to_stream_switch_pass.cc" + "graph/passes/transop_breadth_fusion_pass.cc" + "graph/passes/transop_depth_fusion_pass.cc" + "graph/passes/transop_nearby_allreduce_fusion_pass.cc" + "graph/passes/transop_symmetry_elimination_pass.cc" + "graph/passes/transop_without_reshape_fusion_pass.cc" "graph/passes/transpose_transdata_pass.cc" - "graph/passes/hccl_memcpy_pass.cc" - "graph/passes/hccl_continuous_memcpy_pass.cc" - "graph/passes/flow_ctrl_pass.cc" - "graph/passes/global_step_insert_pass.cc" - "graph/passes/link_gen_mask_nodes_pass.cc" - "graph/passes/replace_with_empty_const_pass.cc" - "graph/passes/hccl_group_pass.cc" - "graph/passes/hccl_tailing_optimization_pass.cc" - "graph/passes/memcpy_addr_async_pass.cc" - "graph/passes/set_input_output_offset_pass.cc" - "graph/passes/parallel_group_pass.cc" - "graph/passes/buffer_pool_memory_pass.cc" - "graph/manager/model_manager/event_manager.cc" - "graph/manager/util/rt_context_util.cc" - "graph/manager/util/variable_accelerate_ctrl.cc" - "graph/manager/util/debug.cc" - "graph/load/model_manager/model_manager.cc" - "graph/load/model_manager/data_inputer.cc" - "graph/load/model_manager/davinci_model.cc" - "graph/load/model_manager/model_utils.cc" - "graph/load/model_manager/aipp_utils.cc" - "graph/load/model_manager/tbe_handle_store.cc" - "graph/load/model_manager/cpu_queue_schedule.cc" - "graph/load/model_manager/zero_copy_task.cc" - "graph/load/model_manager/zero_copy_offset.cc" - "graph/load/model_manager/data_dumper.cc" - "graph/load/model_manager/task_info/task_info.cc" - "graph/load/model_manager/task_info/event_record_task_info.cc" - "graph/load/model_manager/task_info/event_wait_task_info.cc" - "graph/load/model_manager/task_info/fusion_start_task_info.cc" - "graph/load/model_manager/task_info/fusion_stop_task_info.cc" - "graph/load/model_manager/task_info/kernel_ex_task_info.cc" - "graph/load/model_manager/task_info/kernel_task_info.cc" - "graph/load/model_manager/task_info/label_set_task_info.cc" - "graph/load/model_manager/task_info/label_switch_by_index_task_info.cc" - "graph/load/model_manager/task_info/label_goto_ex_task_info.cc" - "graph/load/model_manager/task_info/memcpy_async_task_info.cc" - "graph/load/model_manager/task_info/memcpy_addr_async_task_info.cc" - "graph/load/model_manager/task_info/profiler_trace_task_info.cc" - "graph/load/model_manager/task_info/stream_active_task_info.cc" - "graph/load/model_manager/task_info/stream_switch_task_info.cc" - "graph/load/model_manager/task_info/stream_switchn_task_info.cc" - "graph/load/model_manager/task_info/end_graph_task_info.cc" - "graph/load/model_manager/task_info/model_exit_task_info.cc" - "graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc" - "graph/load/model_manager/task_info/super_kernel/super_kernel.cc" - "hybrid/hybrid_davinci_model_stub.cc" - "ir_build/ge_ir_build.cc" - "ir_build/attr_options/utils.cc" - "ir_build/attr_options/keep_dtype_option.cc" - "ir_build/attr_options/weight_compress_option.cc" - "ir_build/option_utils.cc" + "graph/passes/unused_args_clean_pass.cc" + "graph/passes/unused_const_pass.cc" + "graph/passes/useless_control_out_remove_pass.cc" + "graph/passes/var_is_initialized_op_pass.cc" + "graph/passes/variable_op_pass.cc" + "graph/passes/variable_prepare_op_pass.cc" + "graph/passes/variable_ref_delete_op_pass.cc" + "graph/passes/variable_ref_useless_control_out_delete_pass.cc" + "graph/preprocess/graph_preprocess.cc" "graph/preprocess/insert_op/ge_aipp_op.cc" "graph/preprocess/insert_op/util_insert_aipp_op.cc" + "graph/preprocess/multi_batch_copy_graph.cc" + "graph/preprocess/multi_batch_options.cc" + "host_kernels/add_kernel.cc" + "host_kernels/broadcast_args_kernel.cc" + "host_kernels/broadcast_gradient_args_kernel.cc" + "host_kernels/cast_kernel.cc" + "host_kernels/concat_offset_kernel.cc" + "host_kernels/concat_v2_kernel.cc" + "host_kernels/dynamic_stitch_kernel.cc" + "host_kernels/empty_kernel.cc" + "host_kernels/expanddims_kernel.cc" + "host_kernels/fill_kernel.cc" + "host_kernels/floordiv_kernel.cc" + "host_kernels/floormod_kernel.cc" + "host_kernels/gather_v2_kernel.cc" + "host_kernels/greater_kernel.cc" + "host_kernels/identity_kernel.cc" + "host_kernels/kernel_utils.cc" + "host_kernels/maximum_kernel.cc" + "host_kernels/mul_kernel.cc" + "host_kernels/pack_kernel.cc" + "host_kernels/permute_kernel.cc" + "host_kernels/range_kernel.cc" + "host_kernels/rank_kernel.cc" + "host_kernels/reduce_prod_kernel.cc" + "host_kernels/reformat_kernel.cc" + "host_kernels/reshape_kernel.cc" + "host_kernels/rsqrt_kernel.cc" + "host_kernels/shape_kernel.cc" + "host_kernels/shape_n_kernel.cc" + "host_kernels/size_kernel.cc" + "host_kernels/slice_d_kernel.cc" + "host_kernels/slice_kernel.cc" + "host_kernels/squeeze_kernel.cc" + "host_kernels/ssd_prior_box_kernel.cc" + "host_kernels/strided_slice_kernel.cc" + "host_kernels/sub_kernel.cc" + "host_kernels/transdata_kernel.cc" + "host_kernels/transpose_kernel.cc" + "host_kernels/unpack_kernel.cc" + "host_kernels/unsqueeze_kernel.cc" "hybrid/node_executor/aicpu/aicpu_ext_info.cc" - "graph/build/model_builder.cc" - "graph/build/task_generator.cc" - "graph/build/stream_allocator.cc" - "graph/build/logical_stream_allocator.cc" - "graph/build/stream_graph_optimizer.cc" - "graph/build/run_context.cc" - "graph/build/label_allocator.cc" - "graph/label/label_maker.cc" - "graph/label/if_label_maker.cc" - "graph/label/case_label_maker.cc" - "graph/label/while_label_maker.cc" - "graph/label/partitioned_call_label_maker.cc" - "analyzer/analyzer.cc" - "graph/build/memory/memory_assigner.cc" - "graph/build/memory/graph_mem_assigner.cc" - "graph/build/memory/binary_block_mem_assigner.cc" - "graph/build/memory/block_mem_assigner.cc" - "graph/build/memory/hybrid_mem_assigner.cc" - "graph/build/memory/max_block_mem_assigner.cc" - "graph/build/memory/var_mem_assign_util.cc" - "graph/build/memory/buffer_pool_mem_assigner.cc" + "init/gelib.cc" + "ir_build/attr_options/keep_dtype_option.cc" + "ir_build/attr_options/utils.cc" + "ir_build/attr_options/weight_compress_option.cc" + "ir_build/ge_ir_build.cc" + "ir_build/option_utils.cc" + "opskernel_manager/ops_kernel_builder_manager.cc" + "opskernel_manager/ops_kernel_manager.cc" +) + +set(RUNNER_SRC_LIST + "client/ge_api.cc" + "session/inner_session.cc" + "session/session_manager.cc" + "common/profiling/ge_runner_profiling.cc" + "graph/manager/memory_api.cc" + "graph/manager/util/hcom_util.cc" + "graph/load/model_manager/task_info/hccl_task_info.cc" + "hybrid/node_executor/hccl/hccl_node_executor.cc" + "hybrid/node_executor/aicore/aicore_task_compiler.cc" ) if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) message("CMAKE_CXX_COMPILER_VERSION = ${CMAKE_CXX_COMPILER_VERSION}") ############ libge_runner.so ############ add_library(ge_runner SHARED - ${TRAIN_SRC_LIST} + ${EXECUTOR_SRC_LIST} + ${COMPILER_SRC_LIST} + ${RUNNER_SRC_LIST} $,msprofiler_fwk,msprofiler_fwk_object>> ) @@ -752,29 +524,29 @@ target_compile_options(ge_runner PRIVATE target_include_directories(ge_runner SYSTEM PRIVATE ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/analyzer ${GE_CODE_DIR}/inc ${GE_CODE_DIR}/inc/external ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common - ${METADEF_DIR} ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external/graph ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### ${GE_CODE_DIR}/../inc - ${GE_CODE_DIR}/../inc/external - ${GE_CODE_DIR}/../inc/cce ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external ${GE_CODE_DIR}/../abl/adump/external - #### blue zone + ${GE_CODE_DIR}/../abl/licctrl + ${GE_CODE_DIR}/../ace/comop/inc + ${GE_CODE_DIR}/../ace/comop/inc/external + $<$>:${GE_DEPEND_DIR}/inc> + $<$>:$> + $<$>:$> + #### blue zone #### ${ASCEND_DIR}/driver/include ${ASCEND_DIR}/fwkacllib/include - ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc> + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain> + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info> ) target_link_options(ge_runner PRIVATE @@ -783,6 +555,11 @@ target_link_options(ge_runner PRIVATE target_link_libraries(ge_runner PRIVATE $ + $<$>:$> + $<$>:$> + $<$>:$> + $<$>:$> + $<$>:$> adump_server static_mmpa ge_proto_common @@ -797,6 +574,7 @@ target_link_libraries(ge_runner PRIVATE runtime error_manager ascend_hal_stub + opt_feature -Wl,--as-needed json -lrt @@ -805,7 +583,7 @@ target_link_libraries(ge_runner PRIVATE ############ libge_compiler.so ############ add_library(ge_compiler SHARED - ${INFER_SRC_LIST} + ${COMPILER_SRC_LIST} ) add_dependencies(ge_compiler @@ -817,7 +595,6 @@ target_compile_definitions(ge_compiler PRIVATE REUSE_MEMORY=1 FMK_SUPPORT_DUMP FMK_HOST_INFER - COMPILE_OMG_PACKAGE google=ascend_private FUNC_VISIBILITY $<$:ONLY_COMPILE_OPEN_SRC> @@ -833,29 +610,29 @@ target_compile_options(ge_compiler PRIVATE target_include_directories(ge_compiler SYSTEM PRIVATE ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/analyzer ${GE_CODE_DIR}/inc ${GE_CODE_DIR}/inc/external ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common - ${METADEF_DIR} ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external/graph ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### ${GE_CODE_DIR}/../inc - ${GE_CODE_DIR}/../inc/external - ${GE_CODE_DIR}/../inc/cce ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external ${GE_CODE_DIR}/../abl/adump/external + ${GE_CODE_DIR}/../abl/licctrl + ${GE_CODE_DIR}/../ace/comop/inc + ${GE_CODE_DIR}/../ace/comop/inc/external + $<$>:${GE_DEPEND_DIR}/inc> + $<$>:$> + $<$>:$> #### blue zone #### ${ASCEND_DIR}/driver/include ${ASCEND_DIR}/fwkacllib/include - ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc> + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain> + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info> ) target_link_options(ge_compiler PRIVATE @@ -864,6 +641,11 @@ target_link_options(ge_compiler PRIVATE target_link_libraries(ge_compiler PRIVATE $ + $<$>:$> + $<$>:$> + $<$>:$> + $<$>:$> + $<$>:$> static_mmpa ge_proto_common -Wl,--no-as-needed @@ -874,13 +656,146 @@ target_link_libraries(ge_compiler PRIVATE c_sec error_manager slog - runtime_compile + runtime + opt_feature -Wl,--as-needed json -lrt -ldl ) +######## libge_executor.a ######## +add_library(ge_executor STATIC + ${EXECUTOR_SRC_LIST} +) + +add_dependencies(ge_executor + graphengine_protos +) + +target_compile_options(ge_executor PRIVATE + $<$,$>:-fvisibility=hidden -O2 -Werror -Wno-deprecated-declarations -fno-common> + $<$,$>:/MTd> + $<$,$>:/MT> + $<$:-Werror=unused-variable> + $<$:-Werror=unused-const-variable -Werror=format> +) + +target_compile_definitions(ge_executor PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + DAVINCI_SUPPORT_PROFILING + google=ascend_private + $,OS_TYPE=WIN,OS_TYPE=0> + $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> + $<$:ONLY_COMPILE_OPEN_SRC> + LOG_CPP +) + +target_include_directories(ge_executor SYSTEM PRIVATE + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/graphengine_protos + #### yellow zone #### + ${GE_CODE_DIR}/../ace/comop/inc + ${GE_CODE_DIR}/../ace/comop/inc/external + $<$>:${GE_DEPEND_DIR}/inc> + $<$>:$> + $<$>:$> + #### blue zone #### + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc> + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain> +) + +target_link_libraries(ge_executor PRIVATE + $ + $<$>:$> + $<$>:$> + $<$>:$> + json + ascend_protobuf_static + c_sec + $<$>:-lrt> + -ldl +) + +######## libge_executor.so ######## +add_library(ge_executor_shared SHARED + ${EXECUTOR_SRC_LIST} +) + +add_dependencies(ge_executor_shared + graphengine_protos +) + +target_compile_options(ge_executor_shared PRIVATE + -fno-common + -Werror + -O2 + -Wno-deprecated-declarations + -fvisibility=hidden +) + +target_compile_definitions(ge_executor_shared PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + DAVINCI_SUPPORT_PROFILING + google=ascend_private + FUNC_VISIBILITY + $<$:ONLY_COMPILE_OPEN_SRC> +) + +target_include_directories(ge_executor_shared PRIVATE + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/graphengine_protos + #### yellow zone #### + ${GE_CODE_DIR}/../ace/comop/inc + ${GE_CODE_DIR}/../ace/comop/inc/external + $<$>:${GE_DEPEND_DIR}/inc> + #### blue zone #### + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc> +) + +target_link_options(ge_executor_shared PRIVATE + -Wl,-Bsymbolic + -Wl,--exclude-libs,ALL +) + +target_link_libraries(ge_executor_shared PRIVATE + $ + $<$>:$> + $<$>:$> + $<$>:$> + $<$>:$> + $<$>:$> + -Wl,--no-as-needed + ge_common + runtime + slog + graph + register + error_manager + ascend_protobuf + c_sec + -Wl,--as-needed + json + $<$>:-lrt> + -ldl +) + +set_target_properties(ge_executor_shared PROPERTIES + OUTPUT_NAME ge_executor +) + ############ libascendcl.so ############ file(GENERATE OUTPUT ${CMAKE_BINARY_DIR}/dummy.c CONTENT "") #add_library(dummy_obj OBJECT ${CMAKE_BINARY_DIR}/dummy.c) @@ -998,18 +913,14 @@ set_target_properties(atc_stub_ge_compiler PROPERTIES ) target_include_directories(atc_stub_ge_compiler PRIVATE - ${GE_CODE_DIR} ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/analyzer ${GE_CODE_DIR}/inc ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common ${GE_CODE_DIR}/inc/external ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc #### yellow zone #### - ${GE_CODE_DIR}/../inc/cce + ${GE_CODE_DIR}/../inc ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external ${GE_CODE_DIR}/../abl/adump/external #### blue zone #### @@ -1039,18 +950,14 @@ set_target_properties(fwk_stub_ge_runner PROPERTIES ) target_include_directories(fwk_stub_ge_runner PRIVATE - ${GE_CODE_DIR} ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/analyzer ${GE_CODE_DIR}/inc ${GE_CODE_DIR}/inc/external ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc #### yellow zone #### - ${GE_CODE_DIR}/../inc/cce + ${GE_CODE_DIR}/../inc ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external ${GE_CODE_DIR}/../abl/adump/external #### blue zone #### @@ -1084,7 +991,7 @@ add_custom_command( set(INSTALL_BASE_DIR "") set(INSTALL_LIBRARY_DIR lib) -install(TARGETS ge_runner ge_compiler opensrc_ascendcl OPTIONAL +install(TARGETS ge_runner ge_compiler ge_executor_shared opensrc_ascendcl OPTIONAL LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} ) diff --git a/ge/analyzer/analyzer.cc b/ge/analyzer/analyzer.cc index 95036267..97b59411 100755 --- a/ge/analyzer/analyzer.cc +++ b/ge/analyzer/analyzer.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "analyzer.h" +#include "analyzer/analyzer.h" #include #include diff --git a/ge/client/ge_api.cc b/ge/client/ge_api.cc index 1aa4c41d..e4a016b3 100644 --- a/ge/client/ge_api.cc +++ b/ge/client/ge_api.cc @@ -14,10 +14,10 @@ * limitations under the License. */ -#include "ge/ge_api.h" +#include "external/ge/ge_api.h" #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "framework/common/debug/ge_log.h" #include "common/ge/datatype_util.h" #include "proto/ge_api.pb.h" @@ -29,7 +29,7 @@ #include "graph/opsproto_manager.h" #include "graph/utils/type_utils.h" #include "graph/manager/util/rt_context_util.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "register/op_registry.h" #include "common/ge/tbe_plugin_manager.h" #include "common/util/error_manager/error_manager.h" @@ -47,6 +47,7 @@ const int32_t kMaxStrLen = 128; static bool g_ge_initialized = false; static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use +static std::shared_ptr g_session_manager; namespace ge { void GetOpsProtoPath(std::string &opsproto_path) { @@ -70,8 +71,7 @@ Status CheckOptionsValid(const std::map &options) { auto job_id_iter = options.find(OPTION_EXEC_JOB_ID); if (job_id_iter != options.end()) { if (job_id_iter->second.length() > kMaxStrLen) { - GELOGE(PARAM_INVALID, "[Check][JobId]Failed," - "the job_id [%s] string length: %zu > max string length: %d", + GELOGE(PARAM_INVALID, "[Check][JobId]Failed, the job_id [%s] string length: %zu > max string length: %d", job_id_iter->second.c_str(), job_id_iter->second.length(), kMaxStrLen); REPORT_INPUT_ERROR("E10051", std::vector({"id", "length"}), std::vector({job_id_iter->second, @@ -95,8 +95,7 @@ Status GEInitializeImpl(const std::map &options) { std::string path_base = ge::GELib::GetPath(); auto ret = ErrorManager::GetInstance().Init(path_base); if (ret != SUCCESS) { - GELOGE(GE_CLI_INIT_FAILED, - "[Init][PathBase]Init failed when pass param path_base:%s", path_base.c_str()); + GELOGE(GE_CLI_INIT_FAILED, "[Init][PathBase]Init failed when pass param path_base:%s", path_base.c_str()); REPORT_CALL_ERROR("E19999", "Init failed when pass param path_base:%s", path_base.c_str()); return ret; } @@ -117,11 +116,9 @@ Status GEInitializeImpl(const std::map &options) { bool is_proto_init = manager->Initialize(option_tmp); GE_TIMESTAMP_END(GEInitialize, "GEInitialize::ManagerInitialize"); if (!is_proto_init) { - GELOGE(GE_CLI_INIT_FAILED, - "[Init][OpsProtoPath]Loading OpsProto lib plugin failed, OpsProtoPath:%s invalid.", + GELOGE(GE_CLI_INIT_FAILED, "[Init][OpsProtoPath]Loading OpsProto lib plugin failed, OpsProtoPath:%s invalid.", opsproto_path.c_str()); - REPORT_CALL_ERROR("E19999", "Loading OpsProto lib plugin failed, OpsProtoPath:%s invalid", - opsproto_path.c_str()); + REPORT_CALL_ERROR("E19999", "Loading OpsProto lib plugin failed, OpsProtoPath:%s invalid", opsproto_path.c_str()); return FAILED; } @@ -148,6 +145,22 @@ Status GEInitializeImpl(const std::map &options) { return FAILED; } + ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); + GELOGI("sessionManager initial."); + GE_TIMESTAMP_START(SessionManagerInitialize); + g_session_manager = MakeShared(); + if (g_session_manager == nullptr) { + GELOGE(GE_CLI_INIT_FAILED, "[Init][Create]SessionManager failed"); + return FAILED; + } + ret = g_session_manager->Initialize(options); + GE_TIMESTAMP_END(SessionManagerInitialize, "InnerInitialize::SessionManagerInitialize"); + if (ret != SUCCESS) { + GELOGE(ret, "[Init][SessionManager] GE session manager initial failed."); + REPORT_CALL_ERROR("E19999", "SessionManager initialize failed."); + return ret; + } + // 7.check return status, return if (!g_ge_initialized) { // Initialize success, first time calling initialize @@ -173,8 +186,7 @@ Status GEInitialize(const std::map &options) { for (auto &option : options) { if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { GELOGE(FAILED, "[Check][Param]Options invalid, first or second option is nullptr."); - REPORT_INNER_ERROR("E19999", "Check parameter's options invalid," - "the first or second option is nullptr."); + REPORT_INNER_ERROR("E19999", "Check parameter's options invalid, the first or second option is nullptr."); return FAILED; } std::string key = option.first.GetString(); @@ -217,6 +229,12 @@ Status GEFinalize() { ret = middle_ret; } } + + GELOGI("SessionManager finalization."); + if (g_session_manager != nullptr) { + (void)g_session_manager->Finalize(); // always success. + } + middle_ret = TBEPluginManager::Instance().Finalize(); if (middle_ret != SUCCESS) { ret = middle_ret; @@ -251,28 +269,18 @@ std::string GEGetWarningMsg() { Session::Session(const std::map &options) { ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); GELOGT(TRACE_INIT, "Start to construct session."); - ErrorManager::GetInstance().GenWorkStreamIdDefault(); // check init status sessionId_ = 0; if (!g_ge_initialized) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Construct][Session]Failed because lack GEInitialize call before."); - REPORT_INNER_ERROR("E19999", - "Creating session failed because lack GEInitialize call before."); - return; - } - // call Initialize - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Construct][Session]Failed, GELib instance is nullptr or it is not InitFlag"); + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return; } GELOGT(TRACE_RUNNING, "Creating session"); uint64_t session_id = 0; - Status ret = instance_ptr->SessionManagerObj().CreateSession(options, session_id); + Status ret = g_session_manager->CreateSession(options, session_id); GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); // check return status, return, update session id if success @@ -288,32 +296,21 @@ Session::Session(const std::map &options) { Session::Session(const std::map &options) { ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); GELOGT(TRACE_INIT, "Session Constructor start"); - ErrorManager::GetInstance().GenWorkStreamIdDefault(); // check init status sessionId_ = 0; if (!g_ge_initialized) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Construct][Session]Failed because lack GEInitialize call before."); - REPORT_INNER_ERROR("E19999", - "Creating session failed because lack GEInitialize call before."); + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return; } // call Initialize - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Construct][Session]Failed, the GELib instance is nullptr or is not InitFlag"); - return; - } - GELOGT(TRACE_RUNNING, "Creating session"); std::map str_options; for (auto &option : options) { if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { GELOGE(FAILED, "[Construct][Session]Failed, the first or second option is nullptr."); - REPORT_INNER_ERROR("E19999", "Creating session's options invalid," - "the first or second option is nullptr."); + REPORT_INNER_ERROR("E19999", "Creating session's options invalid, the first or second option is nullptr."); return; } std::string key = option.first.GetString(); @@ -321,7 +318,7 @@ Session::Session(const std::map &options) { str_options[key] = val; } uint64_t session_id = 0; - Status ret = instance_ptr->SessionManagerObj().CreateSession(str_options, session_id); + Status ret = g_session_manager->CreateSession(str_options, session_id); GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); // check return status, return, update session id if success @@ -350,19 +347,12 @@ Session::~Session() { try { uint64_t session_id = sessionId_; // call DestroySession - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGW("GE is not yet initialized or is finalized."); - return; - } GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); - GELOGT(TRACE_RUNNING, "Destroying session"); - ret = instance_ptr->SessionManagerObj().DestroySession(session_id); + ret = g_session_manager->DestroySession(session_id); } catch (google::protobuf::FatalException &e) { - GELOGE(GE_CLI_SESS_DESTROY_FAILED, "[Destruct][Session]Failed " - "because get fatalException."); + GELOGE(GE_CLI_SESS_DESTROY_FAILED, "[Destruct][Session]Failed because get fatalException."); REPORT_CALL_ERROR("E19999", "Destruct session failed, get fatal exception"); } @@ -377,9 +367,7 @@ Session::~Session() { // Add Graph Status Session::AddGraph(uint32_t graph_id, const Graph &graph) { - ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); std::map options; - ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); return AddGraph(graph_id, graph, options); } @@ -388,20 +376,16 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Add][Graph]Failed because GELib instance is nullptr or it is not InitFlag."); - REPORT_INNER_ERROR("E19999", - "AddGraph Failed, GELib instance is nullptr or it is not InitFlag."); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return FAILED; } + GELOGD("Adding graph to session"); - Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, options); + Status ret = g_session_manager->AddGraph(sessionId_, graph_id, graph, options); if (ret != SUCCESS) { - GELOGE(ret, - "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", - ret, sessionId_, graph_id); + GELOGE(ret, "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id); return FAILED; } GELOGD("AddGraph finished in Session."); @@ -409,37 +393,31 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map &options) { +Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map &options) { ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Add][Graph]Failed, the GELib instance is nullptr or is not InitFlag."); - REPORT_INNER_ERROR("E19999", - "AddGraph Failed, GELib instance is nullptr or it is not InitFlag."); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return FAILED; } + GELOGD("Adding graph to session"); std::map str_options; for (auto &option : options) { if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { GELOGE(FAILED, "[Add][Graph]Failed, the first or second option is nullptr."); - REPORT_INNER_ERROR("E19999", - "Add Graph Failed, the first or second option is nullptr."); + REPORT_INNER_ERROR("E19999", "Add Graph Failed, the first or second option is nullptr."); return FAILED; } std::string key = option.first.GetString(); std::string val = option.second.GetString(); str_options[key] = val; } - Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, str_options); + Status ret = g_session_manager->AddGraph(sessionId_, graph_id, graph, str_options); if (ret != SUCCESS) { - GELOGE(ret, - "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", - ret, sessionId_, graph_id); + GELOGE(ret, "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id); return FAILED; } GELOGD("AddGraph finished in Session."); @@ -447,8 +425,6 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, } Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) { - ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); - ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); std::map options; return AddGraphWithCopy(graph_id, graph, options); } @@ -459,24 +435,20 @@ Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph, ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Add][Graph]Failed, the GELib instance is nullptr or is not InitFlag."); - REPORT_INNER_ERROR("E19999", - "AddGraph Failed, GELib instance is nullptr or is not InitFlag."); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return FAILED; } + std::map str_options; for (auto it = options.begin(); it != options.end(); ++it) { str_options.insert({it->first.GetString(), it->second.GetString()}); } GELOGD("Adding graph to session"); - Status ret = instance_ptr->SessionManagerObj().AddGraphWithCopy(sessionId_, graph_id, graph, str_options); + Status ret = g_session_manager->AddGraphWithCopy(sessionId_, graph_id, graph, str_options); if (ret != SUCCESS) { - GELOGE(ret, - "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", - ret, sessionId_, graph_id); + GELOGE(ret, "[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id); return FAILED; } GELOGD("AddGraph finished in Session."); @@ -487,29 +459,21 @@ Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph, Status Session::RemoveGraph(uint32_t graph_id) { ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); GELOGT(TRACE_INIT, "Session RemoveGraph start"); - ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); // call RemoveGraph - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (!instance_ptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Remove][Graph]Failed, GELib instance is nullptr or is not InitFlag, " - "session_id %lu, graph_id %u", sessionId_, graph_id); - REPORT_INNER_ERROR("E19999", - "RemoveGraph Failed, GELib instance is nullptr or is not InitFlag, " - "session_id %lu, graph_id %u", sessionId_, graph_id); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return FAILED; } GELOGT(TRACE_RUNNING, "Removing Graph from session"); - Status ret = instance_ptr->SessionManagerObj().RemoveGraph(sessionId_, graph_id); + Status ret = g_session_manager->RemoveGraph(sessionId_, graph_id); // check return status, return if (ret != SUCCESS) { - GELOGE(ret, - "[Remove][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", - ret, sessionId_, graph_id); - REPORT_CALL_ERROR("E19999", "Remove graph failed, error code:%u, " - "session_id:%lu, graph_id:%u", ret, sessionId_, graph_id); + GELOGE(ret, "[Remove][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id); + REPORT_CALL_ERROR("E19999", "Remove graph failed, error code:%u, session_id:%lu, graph_id:%u", + ret, sessionId_, graph_id); return FAILED; } GELOGT(TRACE_STOP, "Session RemoveGraph finished"); @@ -568,29 +532,21 @@ void PrintOutputResult(std::vector &outputs) { Status Session::RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector &outputs) { ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); GELOGT(TRACE_INIT, "Session RunGraph start"); - ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); - std::vector graph_inputs = inputs; - // call RunGraph - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Run][Graph]Failed, GELib instance is nullptr or is not InitFlag, " - "session_id %lu, graph_id %u", sessionId_, graph_id); - REPORT_INNER_ERROR("E19999", - "RunGraph Failed, GELib instance is nullptr or is not InitFlag, " - "session_id %lu, graph_id %u", sessionId_, graph_id); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return FAILED; } + + // call RunGraph GELOGT(TRACE_RUNNING, "Running Graph"); - Status ret = instance_ptr->SessionManagerObj().RunGraph(sessionId_, graph_id, graph_inputs, outputs); + Status ret = g_session_manager->RunGraph(sessionId_, graph_id, inputs, outputs); // check return status if (ret != SUCCESS) { - GELOGE(ret, - "[Run][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", - ret, sessionId_, graph_id); - REPORT_CALL_ERROR("E19999", "Remove graph failed, error code:%u, " - "session_id:%lu, graph_id:%u", ret, sessionId_, graph_id); + GELOGE(ret, "[Run][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id); + REPORT_CALL_ERROR("E19999", "Remove graph failed, error code:%u, session_id:%lu, graph_id:%u", + ret, sessionId_, graph_id); return FAILED; } @@ -609,30 +565,15 @@ Status Session::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const s std::vector &outputs) { ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); GELOGT(TRACE_INIT, "Start to run graph with stream async."); - ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Run][Graph]Run graph with stream async failed, the GELib instance is nullptr," - "session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); - REPORT_INNER_ERROR("E19999", - "Run graph with stream async failed, the GELib instance is nullptr" - "session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); - return FAILED; - } - if (!instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Run][Graph]Run graph with stream asyn failed, the GELib instance is not init," - "session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); - REPORT_INNER_ERROR("E19999", - "Run graph with stream asyn failed, the GELib instance is not init," - "session id = %lu, graph id = %u, stream = %p.", sessionId_, graph_id, stream); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return FAILED; } + GELOGT(TRACE_RUNNING, "Run Graph Run graph with stream asyn."); - Status ret = instance_ptr->SessionManagerObj().RunGraphWithStreamAsync(sessionId_, graph_id, stream, inputs, - outputs); + Status ret = g_session_manager->RunGraphWithStreamAsync(sessionId_, graph_id, stream, inputs, outputs); if (ret != SUCCESS) { GELOGE(ret, "[Run][Graph]Run graph with stream asyn Failed," "error code = %u, session id = %lu, graph id = %u, stream = %p.", ret, sessionId_, graph_id, stream); @@ -648,40 +589,46 @@ Status Session::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const s // Register Call Back Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) { ErrorManager::GetInstance().GenWorkStreamIdDefault(); - return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); + return FAILED; + } + + return g_session_manager->RegisterCallBackFunc(sessionId_, key, callback); } Status Session::RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback) { ErrorManager::GetInstance().GenWorkStreamIdDefault(); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); + return FAILED; + } + std::string str_key; if (key != nullptr) { str_key = key; } - return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, str_key, callback); + return g_session_manager->RegisterCallBackFunc(sessionId_, str_key, callback); } // Build Graph Status Session::BuildGraph(uint32_t graph_id, const std::vector &inputs) { ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Build][Graph]Failed, the GELib instance is nullptr or is not InitFlag, " - "session_id %lu, graph_id %u", sessionId_, graph_id); - REPORT_INNER_ERROR("E19999", - "Build graph failed, the GELib instance is nullptr or is not InitFlag, " - "session_id %lu, graph_id %u", sessionId_, graph_id); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return FAILED; } + GELOGT(TRACE_RUNNING, "Building Graph"); - Status ret = instance_ptr->SessionManagerObj().BuildGraph(sessionId_, graph_id, inputs); + Status ret = g_session_manager->BuildGraph(sessionId_, graph_id, inputs); if (ret != SUCCESS) { - GELOGE(ret, - "[Build][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", - ret, sessionId_, graph_id); - REPORT_CALL_ERROR("E19999", "Build graph failed , error code:%u, " - "session_id:%lu, graph_id:%u", ret, sessionId_, graph_id); + GELOGE(ret, "[Build][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id); + REPORT_CALL_ERROR("E19999", "Build graph failed , error code:%u, session_id:%lu, graph_id:%u", + ret, sessionId_, graph_id); return FAILED; } return SUCCESS; @@ -691,24 +638,18 @@ Status Session::BuildGraph(uint32_t graph_id, const std::vector Status Session::BuildGraph(uint32_t graph_id, const std::vector &inputs) { ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Build][Graph]Failed, the GELib instance is nullptr or is not InitFlag, " - "session_id %lu, graph_id %u", sessionId_, graph_id); - REPORT_INNER_ERROR("E19999", - "Build graph failed, the GELib instance is nullptr or is not InitFlag, " - "session_id %lu, graph_id %u", sessionId_, graph_id); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return FAILED; } + GELOGT(TRACE_RUNNING, "Building Graph"); - Status ret = instance_ptr->SessionManagerObj().BuildGraph(sessionId_, graph_id, inputs); + Status ret = g_session_manager->BuildGraph(sessionId_, graph_id, inputs); if (ret != SUCCESS) { - GELOGE(ret, - "[Build][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", - ret, sessionId_, graph_id); - REPORT_CALL_ERROR("E19999", "Build graph failed , error code:%u, " - "session_id:%lu, graph_id:%u", ret, sessionId_, graph_id); + GELOGE(ret, "[Build][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id); + REPORT_CALL_ERROR("E19999", "Build graph failed , error code:%u, session_id:%lu, graph_id:%u", + ret, sessionId_, graph_id); return FAILED; } return SUCCESS; @@ -719,26 +660,22 @@ Status Session::RunGraphAsync(uint32_t graph_id, const std::vector & RunAsyncCallback callback) { ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id); - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Run][Graph]RunGraphAsyncFailed, the GELib instance is nullptr or is not InitFlag, " - "session_id %lu, graph_id %u", sessionId_, graph_id); - REPORT_INNER_ERROR("E19999", - "RunGraphAsync Failed, the GELib instance is nullptr or is not InitFlag, " - "session_id %lu, graph_id %u", sessionId_, graph_id); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return FAILED; } + GELOGT(TRACE_RUNNING, "Run Graph Asynchronously"); GELOGW( "The callback function will not be checked. Please ensure that the implementation of the function is trusted."); - Status ret = ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, callback); + Status ret = g_session_manager->RunGraphAsync(sessionId_, graph_id, inputs, callback); if (ret != SUCCESS) { GELOGE(ret, "[Run][Graph]RunGraphAsync Failed, error code:%u, session_id:%lu, graph_id:%u.", ret, sessionId_, graph_id); - REPORT_CALL_ERROR("E19999", "RunGraphAsync Failed, error code:%u, session_id:%lu, " - "graph_id:%u", ret, sessionId_, graph_id); + REPORT_CALL_ERROR("E19999", "RunGraphAsync Failed, error code:%u, session_id:%lu, graph_id:%u", + ret, sessionId_, graph_id); return FAILED; } return SUCCESS; @@ -748,16 +685,14 @@ Status Session::RunGraphAsync(uint32_t graph_id, const std::vector & Status Session::GetVariables(const std::vector &var_names, std::vector &var_values) { ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); ErrorManager::GetInstance().GenWorkStreamIdDefault(); - auto instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Get][Variables]Failed, the GELib instance is nullptr or is not InitFlag."); - REPORT_INNER_ERROR("E19999", - "GetVariables failed, the GELib instance is nullptr or is not InitFlag."); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return FAILED; } + GELOGT(TRACE_RUNNING, "Get Variables"); - Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, var_names, var_values); + Status ret = g_session_manager->GetVariables(sessionId_, var_names, var_values); if (ret != SUCCESS) { GELOGE(ret, "[Get][Variables]Failed, error code:%u, session_id:%lu.", ret, sessionId_); return FAILED; @@ -769,14 +704,12 @@ Status Session::GetVariables(const std::vector &var_names, std::vec Status Session::GetVariables(const std::vector &var_names, std::vector &var_values) { ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); ErrorManager::GetInstance().GenWorkStreamIdDefault(); - auto instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, - "[Get][Variables]Failed, the GELib instance is nullptr or is not InitFlag."); - REPORT_INNER_ERROR("E19999", - "GetVariables failed, the GELib instance is nullptr or is not InitFlag."); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); return FAILED; } + GELOGT(TRACE_RUNNING, "Get Variables"); std::vector str_var_names; for (auto &var_name : var_names) { @@ -787,17 +720,22 @@ Status Session::GetVariables(const std::vector &var_names, std::ve } str_var_names.emplace_back(var_name.GetString()); } - Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, str_var_names, var_values); + Status ret = g_session_manager->GetVariables(sessionId_, str_var_names, var_values); if (ret != SUCCESS) { GELOGE(ret, "[Get][Variables]Failed, error code:%u, session_id:%lu.", ret, sessionId_); - REPORT_CALL_ERROR("E19999", "Get variables failed, error code:%u, session_id:%lu.", - ret, sessionId_); + REPORT_CALL_ERROR("E19999", "Get variables failed, error code:%u, session_id:%lu.", ret, sessionId_); return FAILED; } return SUCCESS; } bool Session::IsGraphNeedRebuild(uint32_t graph_id) { - return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id); + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Construct][Session]Failed because lack GEInitialize call before."); + REPORT_INNER_ERROR("E19999", "Creating session failed because lack GEInitialize call before."); + return false; + } + + return g_session_manager->IsGraphNeedRebuild(sessionId_, graph_id); } } // namespace ge diff --git a/ge/client/proto/ge_api.proto b/ge/client/proto/ge_api.proto deleted file mode 100644 index 26d705fe..00000000 --- a/ge/client/proto/ge_api.proto +++ /dev/null @@ -1 +0,0 @@ -../../proto/ge_api.proto \ No newline at end of file diff --git a/ge/client/proto/ge_ir.proto b/ge/client/proto/ge_ir.proto deleted file mode 100644 index c0ef3071..00000000 --- a/ge/client/proto/ge_ir.proto +++ /dev/null @@ -1,193 +0,0 @@ -syntax = "proto3"; - -package ge.proto; - -enum DataType -{ - DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. - DT_FLOAT = 1; // float type - DT_FLOAT16 = 2; // fp16 type - DT_INT8 = 3; // int8 type - DT_UINT8 = 4; // uint8 type - DT_INT16 = 5; // int16 type - DT_UINT16 = 6; // uint16 type - DT_INT32 = 7; // - DT_INT64 = 8; // int64 type - DT_UINT32 = 9; // unsigned int32 - DT_UINT64 = 10; // unsigned int64 - DT_BOOL = 11; // bool type - DT_DOUBLE = 12; // double type - DT_STRING = 13; // string type - DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ - DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ - DT_COMPLEX64 = 16; // complex64 type - DT_COMPLEX128 = 17; // complex128 type - DT_QINT8 = 18; // qint8 type - DT_QINT16 = 19; // qint16 type - DT_QINT32 = 20; // qint32 type - DT_QUINT8 = 21; // quint8 type - DT_QUINT16 = 22; // quint16 type - DT_RESOURCE = 23; // resource type - DT_STRING_REF = 24; // string_ref type - DT_DUAL = 25; /**< dual output type */ - DT_VARIANT = 26; // variant type - DT_BF16 = 27; // bf16 type - DT_INT4 = 28; // int4 type -} - -message AttrDef -{ - message ListValue - { - enum ListValueType{ - VT_LIST_NONE = 0; - VT_LIST_STRING = 1; - VT_LIST_INT = 2; - VT_LIST_FLOAT = 3; - VT_LIST_BOOL = 4; - VT_LIST_BYTES = 5; - VT_LIST_TENSOR_DESC = 6; - VT_LIST_TENSOR = 7; - VT_LIST_GRAPH = 8; - VT_LIST_NAMED_ATTRS = 9; - VT_LIST_DATA_TYPE = 10; - } - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3; // "list(int)" - repeated float f = 4; // "list(float)" - repeated bool b = 5; // "list(bool)" - repeated bytes bt = 7; - repeated TensorDescriptor td = 8; - repeated TensorDef t = 9; - repeated GraphDef g = 10; - repeated NamedAttrs na = 11; - repeated int64 dt = 12; // list ge::DataType - - ListValueType val_type = 20; - } - - message ListListInt{ - message ListInt{ - repeated int64 list_i = 1; // list int - } - repeated ListInt list_list_i = 1; // list list int - } - - oneof value - { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; // Used to support attr nesting - TensorDescriptor td = 11; // GeTensorDesc type - TensorDef t = 12; // GeTensor type - GraphDef g = 13; // Graph type - ListListInt list_list_int = 14; // List List Int type - int64 dt = 15; // ge::DataType - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs -{ - string name = 1; - map attr = 2; -} - -// Shape / dimension description, using row-major order -message ShapeDef -{ - repeated int64 dim = 1; // Size of each dimension -} - -// Multidimensional data description -message TensorDescriptor -{ - string name = 1; // Optional parameter, tensor name - - DataType dtype = 2; // tensor datatype - ShapeDef shape = 3; // Shape / dimension - string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" - - bool has_out_attr = 9; - int64 size = 10; - int64 weight_size = 11; - bool reuse_input = 12; - bool output_tensor = 13; - string device_type = 14; - bool input_tensor =15; - int64 real_dim_cnt = 16; - int64 reuse_input_index = 17; - int64 data_offset = 18; - int64 cmps_size = 19; - string cmps_tab = 20; - int64 cmps_tab_offset = 21; - - map attr = 5; // Set of extra parameter fields -} - -// GeTensor definition -message TensorDef -{ - TensorDescriptor desc = 1; // Tensor description - bytes data = 2; // Tensor data -} - - -// Operator description -message OpDef -{ - string name = 1; // name - string type = 2; // type - - repeated string input = 5; // input original op name + outgoing index. op_name:index - - map attr = 10; // Set of operator parameter fields - - bool has_out_attr = 20; - int64 id = 21; - int64 stream_id =22; - repeated string input_name = 23; - repeated string src_name = 24; - repeated int64 src_index = 25; - repeated string dst_name = 26; - repeated int64 dst_index = 27; - repeated int64 input_i = 28; - repeated int64 output_i = 29; - repeated int64 workspace = 30; - repeated int64 workspace_bytes = 31; - repeated bool is_input_const = 32; - repeated TensorDescriptor input_desc = 33; - repeated TensorDescriptor output_desc = 34; - repeated string subgraph_name = 35; -} - -// Graph definition -message GraphDef -{ - string name = 1; // name - - repeated string input = 4; // Graph input - repeated string output = 5; // Graph output - - repeated OpDef op = 6; // List of operators - - map attr = 11; // Extended field -} - -// model definition -message ModelDef -{ - string name = 1; // name - uint32 version = 2; // IR Proto verion - string custom_version = 3; // User model version number, passed in by user - - repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef - - map attr = 11; // Extended field -} - diff --git a/ge/client/proto/insert_op.proto b/ge/client/proto/insert_op.proto deleted file mode 100644 index 7d708865..00000000 --- a/ge/client/proto/insert_op.proto +++ /dev/null @@ -1,140 +0,0 @@ -syntax = "proto3"; - -package domi; - -message InsertNewOps { - repeated AippOpParams aipp_op = 1; - repeated MultiShapeOpParams multi_shape_op = 2; -} - -message AippOpParams { - enum InputFormat { - UNDEFINED = 0; - YUV420SP_U8 = 1; - XRGB8888_U8 = 2; - RGB888_U8 = 3; - YUV400_U8 = 4; - NC1HWC0DI_FP16 = 5; - NC1HWC0DI_S8 = 6; - ARGB8888_U8 = 7; - YUYV_U8 = 8; - YUV422SP_U8 = 9; - AYUV444_U8 = 10; - RAW10 = 11; - RAW12 = 12; - RAW16 = 13; - RAW24 = 14; - RGB16 = 15; - RGB20 = 16; - RGB24 = 17; - RGB8_IR = 18; - RGB16_IR = 19; - RGB24_IR = 20; - } - - enum AippMode { - undefined = 0; - static = 1; - dynamic = 2; - } - - // AIPPģʽ־̬AIPPͶ̬AIPP - AippMode aipp_mode = 1; - - // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 - // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 - uint32 related_input_rank = 2; - - // related_input_name is optional and the top name of data node which inserts aipp - string related_input_name = 6; - - // input_edge_idxΪѡΪͣ÷ΧΪ>=0 - // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP - // ֵ <= Dataߵĸ - repeated uint32 input_edge_idx = 3; - - // [Begin] ̬AIPPþ̬AIPPʱЧ - uint32 max_src_image_size = 4; - - // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ - bool support_rotation = 5; - - // [End] ̬AIPP - - - // [Begin] ̬AIPPö̬AIPPʱЧ - InputFormat input_format = 51; - bool csc_switch = 52; - float cpadding_value = 53; - bool rbuv_swap_switch = 54; - bool ax_swap_switch = 55; - bool single_line_mode = 56; - - int32 src_image_size_w = 57; - int32 src_image_size_h = 58; - - bool crop = 59; - int32 load_start_pos_w = 60; - int32 load_start_pos_h = 61; - int32 crop_size_w = 62; - int32 crop_size_h = 63; - - bool resize = 64; - int32 resize_output_w = 65; - int32 resize_output_h = 66; - - bool padding = 67; - int32 left_padding_size = 68; - int32 right_padding_size = 69; - int32 top_padding_size = 70; - int32 bottom_padding_size = 71; - float padding_value = 72; - - int32 mean_chn_0 = 10; - int32 mean_chn_1 = 11; - int32 mean_chn_2 = 12; - int32 mean_chn_3 = 19; - float min_chn_0 = 13; - float min_chn_1 = 14; - float min_chn_2 = 15; - float min_chn_3 = 20; - repeated float var_reci_chn_0 = 16; - repeated float var_reci_chn_1 = 17; - repeated float var_reci_chn_2 = 18; - repeated float var_reci_chn_3 = 21; - - repeated int32 matrix_r0c0 = 30; - repeated int32 matrix_r0c1 = 31; - repeated int32 matrix_r0c2 = 32; - repeated int32 matrix_r1c0 = 33; - repeated int32 matrix_r1c1 = 34; - repeated int32 matrix_r1c2 = 35; - repeated int32 matrix_r2c0 = 36; - repeated int32 matrix_r2c1 = 37; - repeated int32 matrix_r2c2 = 38; - repeated int32 output_bias_0 = 39; - repeated int32 output_bias_1 = 40; - repeated int32 output_bias_2 = 41; - repeated int32 input_bias_0 = 42; - repeated int32 input_bias_1 = 43; - repeated int32 input_bias_2 = 44; - - // [End] ̬AIPP - - // The n number that is used for raw/rgbir data into f16 transformation. - // The transformation equation is x/(2^n). If set to 0, no transform is performed. - uint32 raw_rgbir_to_f16_n = 45; -} - -message MultiShapeOpParams { - enum MultiShapeMode { - batch = 0; //̬batch - resolution = 1; //ֱ̬ʣչ - } - - MultiShapeMode mode = 1; //ģʽ - uint32 related_input_rank = 2; //Ӳ뵽ĸ - - - repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ -} diff --git a/ge/client/proto/om.proto b/ge/client/proto/om.proto deleted file mode 100755 index e15e5f80..00000000 --- a/ge/client/proto/om.proto +++ /dev/null @@ -1,396 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package domi; - -enum TargetType -{ - MINI = 0; - TINY = 1; - LITE = 2; -} - -// offline model -message ModelDef { - string name = 1; - uint32 version = 2; - - uint64 memory_size = 10; - uint32 stream_num = 11; - uint32 event_num = 12; - uint64 weight_size = 13; - uint32 label_num = 15; - repeated OpDef op = 20; - TargetType target_type = 23; - - map attr = 30; -}; - -// operator define -message OpDef { - string name = 1; - string type = 2; - - uint32 id = 3; - uint32 stream_id = 4; - - repeated string input_name = 5; - - repeated string src_name = 8; - repeated int32 src_index = 9; - repeated int64 input = 10; - repeated int64 output = 11; - repeated TensorDescriptor input_desc = 12; - repeated TensorDescriptor output_desc = 13; - repeated WeightDef weights = 14; - repeated string dst_name = 15; - repeated int32 dst_index = 16; - - repeated int64 workspace = 20; - repeated uint32 workspace_bytes = 21; - - repeated string weight_name = 22; - repeated bool is_input_const = 23; - - map attr = 30; - - QuantizeFactorParams quantize_factor = 31; - - oneof op_params { - // start at 100 here - SendOpParams sender_param = 100; - RecvOpParams receiver_param = 200; - ConvolutionOpParams convolution_param = 300; - PoolingOpParams pooling_param = 400; - EltwiseOpParams eltwise_param = 500; - BatchNormOpParams batchnorm_param = 600; - ScaleOpParams scale_param = 700; - FullConnectionOpParams full_connection_param = 800; - SoftmaxOpParams softmax_param = 900; - ActivationOpParams activation_param = 1000; - ReshapeOpParams reshape_param = 1100; - } -}; - -message SendOpParams { - uint32 event_id = 1; -}; - -message RecvOpParams { - uint32 event_id = 1; -}; - -enum QuantizeScaleType -{ - VECTOR_SCALE = 0; - SCALAR_SCALE = 1; -} - -enum QuantizeScaleMode -{ - NORMAL_MODE = 0; - SQRT_MODE = 1; -} - -enum QuantizeAlgorithm -{ - NON_OFFSET_ALGO = 0; - HALF_OFFSET_ALGO = 1; - ALL_OFFSET_ALGO = 2; -} -message QuantizeFactor -{ - QuantizeScaleMode scale_mode = 1; - bytes scale_value = 2; - int64 scale_offset = 3; - bytes offset_data_value = 4; - int64 offset_data_offset = 5; - bytes offset_weight_value = 6; - int64 offset_weight_offset = 7; - bytes offset_pad_value = 8; - int64 offset_pad_offset = 9; -}; - -message QuantizeCalcFactor -{ - bytes offsetw = 1; - int64 offsetw_offset = 2; - bytes offsetd = 3; - int64 offsetd_offset = 4; - bytes scalereq = 5; - int64 scaledreq_offset = 6; - bytes offsetdnext = 7; - int64 offsetdnext_offset = 8; -} - -message QuantizeFactorParams -{ - QuantizeAlgorithm quantize_algo = 1; - QuantizeScaleType scale_type = 2; - QuantizeFactor quantize_param = 3; - QuantizeFactor dequantize_param = 4; - QuantizeFactor requantize_param = 5; - QuantizeCalcFactor quantizecalc_param = 6; -}; - -message ConvolutionOpParams { - int32 mode = 1; - int32 algo = 2; - int32 pad_mode = 3; - uint32 group = 4; - uint32 num_output = 5; - - repeated uint32 pad = 10; - repeated uint32 stride = 11; - repeated uint32 dilation = 12; - repeated uint32 kernel = 13; - - float alpha = 20; - float beta = 21; - - WeightDef filter = 40; - WeightDef bias = 41; - - bool relu_flag = 62; - repeated uint32 adj = 70; - repeated uint32 target_shape = 71; - repeated uint32 before_pad = 72; -}; - -message PoolingOpParams { - int32 mode = 1; - int32 nan_opt = 2; - int32 pad_mode = 3; - bool global_pooling = 4; - - repeated uint32 window = 10; - repeated uint32 pad = 11; - repeated uint32 stride = 12; - bool ceil_mode = 13; - int32 data_mode = 14; - - float alpha = 20; - float beta = 21; - repeated uint32 before_pad = 22; -}; - -message EltwiseOpParams { - int32 mode = 1; - repeated float coeff = 2; - float alpha = 3; - float beta = 4; - repeated WeightDef weight = 5; - bool relu_flag = 6; -}; - -message ActivationOpParams { - int32 mode = 1; - float coef = 2; - float alpha = 3; - float beta = 4; -}; - -message BatchNormOpParams { - int32 mode = 1; - - float alpha = 2; - float beta = 3; - double epsilon = 4;//optinal,[default = 1e-5] - bool use_global_stats = 5; //optinal,by default true,testing mode - float moving_average_fraction = 6; //optinal,[default = .999]; - - WeightDef estimated_mean = 7; - WeightDef estimated_variance = 8; - - WeightDef scale = 9; - WeightDef bias = 10; -}; - -message ScaleOpParams { - WeightDef scale = 1; - WeightDef bias = 2; -}; - -message ReshapeOpParams { - float alpha = 1; - float beta = 2; - ShapeDef shape = 3; - int32 axis = 4; - int32 num_axes = 5; - int32 format = 6; -}; - -message SoftmaxOpParams { - int32 algo = 1; - int32 mode = 2; - float alpha = 3; - float beta = 4; -}; - -message FullConnectionOpParams { - WeightDef filter = 1; - WeightDef bias = 2; - uint32 num_output = 3; - bool relu_flag = 12; -}; - -message FlattenOpParams { - float alpha = 1; - float beta = 2; - int32 start_axis = 3; - int32 end_axis = 4; -} - -message AddLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message MulLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message AddOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message MulOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message SubOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message BiasAddOpParams { - float alpha = 1; - float beta = 2; - - WeightDef bias = 10; -}; - -message MatMulOpParams { - float alpha = 1; - float beta = 2; - bool transposeX = 3; - bool transposeW = 4; - - WeightDef filter = 10; - WeightDef bias = 12; -}; - -message RsqrtOpParams { - float alpha = 1; - float beta = 2; -}; - - -message WeightDef { - int32 format = 1; - int32 data_type = 2; - ShapeDef shape = 3; - bytes data = 4; - int64 data_offset = 5; - uint32 cmps_size = 6; - bytes cmps_tab = 7; - int64 cmps_tab_offset = 10; - CompressInfo cmps_info = 8; - AllOffsetQuantizeInfo alloffset_quantize_info = 11; -} - -message ShapeDef { - repeated int64 dim = 1; -} - -enum DeviceType { - NPU = 0; // In default, we will use NPU. - CPU = 1; // CPU -} - -message AllOffsetQuantizeInfo { - float scale = 1; - int32 offset = 2; -} - -message TensorDescriptor { - int32 format = 1; - int32 data_type = 2; - repeated int64 dim = 3; - uint32 size = 4; - bool reuse_input = 5; - bool output_tensor = 7; - DeviceType device_type = 8; - bool input_tensor = 9; - uint32 real_dim_cnt = 10; - uint32 reuse_input_index = 11; - AllOffsetQuantizeInfo alloffset_quantize_info = 12; -} - -message CompressInfo { - int32 blockRow = 1; // block row - int32 blockCol = 2; // block col - int32 fractalK = 3; // fractal K - int32 fractalN = 4; // fractal N - int32 lastFractalK = 5; // K of last fractal - int32 lastFractalN = 6; // N of last fractal - int32 cubeSize = 7; // cube's length - int32 loadDir = 8; // data load directtiono 0:col load 1:row load -} - -message AttrDef { - message ListValue { - repeated string s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated uint32 u = 6 [packed = true]; // "list(uint)" - repeated bytes bt = 7; - } - - oneof value { - string s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - uint32 u = 6; // "uint32" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs { - string name = 1; - map attr = 2; -} - diff --git a/ge/client/proto/task.proto b/ge/client/proto/task.proto deleted file mode 100644 index 0da5631e..00000000 --- a/ge/client/proto/task.proto +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; - KernelDefWithHandle kernel_with_handle = 40; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; -} - -message KernelDefWithHandle { - KernelContext context = 1; - - uint64 handle = 10; - string dev_func = 11; - uint32 block_dim = 12; - uint32 args_size = 13; - bytes args = 14; - bytes sm_desc = 15; - string original_kernel_key = 16; - string node_info = 17; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - repeated uint32 origin_op_index = 8; -} - - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index 7974a46d..99d6ead3 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -1,48 +1,55 @@ set(SRC_LIST - "context/ctx.cc" - "model_saver.cc" - "ge/datatype_util.cc" - "helper/om_file_helper.cc" - "helper/model_helper.cc" - "../model/ge_model.cc" - "../model/ge_root_model.cc" - "auth/file_saver.cc" - "fp16_t.cc" - "math/fp16_math.cc" - "debug/memory_dumper.cc" - "formats/utils/formats_trans_utils.cc" - "dump/dump_properties.cc" - "formats/format_transfers/datatype_transfer.cc" - "formats/format_transfers/format_transfer_transpose.cc" - "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" - "formats/format_transfers/format_transfer_fractal_z.cc" - "formats/format_transfers/format_transfer_fractal_nz.cc" - "formats/format_transfers/format_transfer_fractal_zz.cc" - "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" - "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" - "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" - "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" - "formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" - "formats/format_transfers/format_transfer_fracz_nchw.cc" - "formats/format_transfers/format_transfer_fracz_nhwc.cc" - "formats/format_transfers/format_transfer_fracz_hwcn.cc" - "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" - "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" - "formats/format_transfers/format_transfer_nchw_fz_c04.cc" - "formats/formats.cc" - "ge_format_util.cc" - "fmk_error_codes.cc" - "util.cc" - "properties_manager.cc" - "types.cc" - "model_parser/model_parser.cc" - "kernel_store.cc" - "tbe_kernel_store.cc" - "cust_aicpu_kernel_store.cc" - "op/attr_value_util.cc" - "op/ge_op_utils.cc" - "thread_pool.cc" - "ge/tbe_plugin_manager.cc" + "${GE_CODE_DIR}/ge/common/auth/file_saver.cc" + "${GE_CODE_DIR}/ge/common/bcast.cc" + "${GE_CODE_DIR}/ge/common/context/ctx.cc" + "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" + "${GE_CODE_DIR}/ge/common/debug/memory_dumper.cc" + "${GE_CODE_DIR}/ge/common/dump/dump_manager.cc" + "${GE_CODE_DIR}/ge/common/dump/dump_properties.cc" + "${GE_CODE_DIR}/ge/common/fmk_error_codes.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/datatype_transfer.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fractal_z.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_transpose.cc" + "${GE_CODE_DIR}/ge/common/formats/formats.cc" + "${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc" + "${GE_CODE_DIR}/ge/common/fp16_t.cc" + "${GE_CODE_DIR}/ge/common/ge/datatype_util.cc" + "${GE_CODE_DIR}/ge/common/ge/op_tiling_manager.cc" + "${GE_CODE_DIR}/ge/common/ge/plugin_manager.cc" + "${GE_CODE_DIR}/ge/common/ge/tbe_plugin_manager.cc" + "${GE_CODE_DIR}/ge/common/ge_format_util.cc" + "${GE_CODE_DIR}/ge/common/helper/model_helper.cc" + "${GE_CODE_DIR}/ge/common/helper/om_file_helper.cc" + "${GE_CODE_DIR}/ge/common/kernel_store.cc" + "${GE_CODE_DIR}/ge/common/local_context.cc" + "${GE_CODE_DIR}/ge/common/math/fp16_math.cc" + "${GE_CODE_DIR}/ge/common/model/ge_model.cc" + "${GE_CODE_DIR}/ge/common/model/ge_root_model.cc" + "${GE_CODE_DIR}/ge/common/model_parser/model_parser.cc" + "${GE_CODE_DIR}/ge/common/model_saver.cc" + "${GE_CODE_DIR}/ge/common/omg_util.cc" + "${GE_CODE_DIR}/ge/common/op/attr_value_util.cc" + "${GE_CODE_DIR}/ge/common/op/ge_op_utils.cc" + "${GE_CODE_DIR}/ge/common/properties_manager.cc" + "${GE_CODE_DIR}/ge/common/tbe_kernel_store.cc" + "${GE_CODE_DIR}/ge/common/thread_pool.cc" + "${GE_CODE_DIR}/ge/common/transop_util.cc" + "${GE_CODE_DIR}/ge/common/types.cc" + "${GE_CODE_DIR}/ge/common/util.cc" ) if (NOT ENABLE_D AND NOT ENABLE_ACL) @@ -63,7 +70,7 @@ target_compile_definitions(ge_common PRIVATE ) target_compile_options(ge_common PRIVATE - -fvisibility=hidden + -fvisibility=default -O2 -Werror -Wno-deprecated-declarations @@ -72,24 +79,18 @@ target_compile_options(ge_common PRIVATE target_include_directories(ge_common PRIVATE ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/common - ${GE_CODE_DIR}/ge/common/op ${GE_CODE_DIR}/inc/external ${GE_CODE_DIR}/inc ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### - ${GE_DEPEND_DIR}/inc - ${GE_DEPEND_DIR}/inc/cce + $<$>:${GE_DEPEND_DIR}/inc> #### blue zone #### - #${GE_DEPEND_DIR}/include - ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc> + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain> ) target_link_options(ge_common PRIVATE @@ -98,6 +99,10 @@ target_link_options(ge_common PRIVATE target_link_libraries(ge_common PRIVATE $ + $<$>:$> + $<$>:$> + $<$>:$> + $<$>:$> static_mmpa -Wl,--no-as-needed graph @@ -139,28 +144,26 @@ target_compile_options(ge_common_static PRIVATE target_include_directories(ge_common_static PRIVATE ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/common - ${GE_CODE_DIR}/ge/common/op ${GE_CODE_DIR}/inc ${GE_CODE_DIR}/inc/external ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### - ${GE_DEPEND_DIR}/inc - ${GE_DEPEND_DIR}/inc/cce + $<$>:${GE_DEPEND_DIR}/inc> #### blue zone #### - #${GE_DEPEND_DIR}/include - ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc> + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain> ) target_link_libraries(ge_common_static PRIVATE $ + $<$>:$> + $<$>:$> + $<$>:$> + $<$>:$> ascend_protobuf_static json c_sec @@ -187,7 +190,7 @@ target_compile_definitions(ge_common PRIVATE ) target_compile_options(ge_common PRIVATE - -fvisibility=hidden + -fvisibility=default -O2 -Werror -Wno-deprecated-declarations @@ -196,15 +199,11 @@ target_compile_options(ge_common PRIVATE target_include_directories(ge_common PRIVATE ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/common - ${GE_CODE_DIR}/ge/common/op ${GE_CODE_DIR}/inc/external ${GE_CODE_DIR}/inc ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos ${GE_CODE_DIR}/third_party/fwkacllib/inc diff --git a/ge/common/auth/file_saver.cc b/ge/common/auth/file_saver.cc index 57ab901b..d6f24497 100755 --- a/ge/common/auth/file_saver.cc +++ b/ge/common/auth/file_saver.cc @@ -238,7 +238,7 @@ Status FileSaver::SaveToBuffWithFileHeader(const ModelFileHeader &file_header, return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::CheckPath(const std::string &file_path) { +Status FileSaver::CheckPath(const std::string &file_path) { // Determine file path length if (file_path.size() >= MMPA_MAX_PATH) { GELOGE(FAILED, "[Check][FilePath]Failed, file path's length:%zu > mmpa_max_path:%d", @@ -271,8 +271,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::CheckPath(con return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -FileSaver::SaveToFile(const string &file_path, const ge::ModelData &model, const ModelFileHeader *model_file_header) { +Status FileSaver::SaveToFile(const string &file_path, const ge::ModelData &model, + const ModelFileHeader *model_file_header) { if (file_path.empty() || model.model_data == nullptr || model.model_len == 0) { GELOGE(FAILED, "[Save][File]Incorrect input param, " "file_path is empty or model_data is nullptr or model_len is 0"); @@ -301,19 +301,18 @@ FileSaver::SaveToFile(const string &file_path, const ge::ModelData &model, const return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, ModelPartitionTable &model_partition_table, - const std::vector &partition_datas) { +Status FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, + ModelPartitionTable &model_partition_table, + const std::vector &partition_datas) { const Status ret = SaveWithFileHeader(file_path, file_header, model_partition_table, partition_datas); GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "save file failed, file_path:%s, file header len:%u.", file_path.c_str(), file_header.length); return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, - vector &model_partition_tables, - const vector> &all_partition_datas) { +Status FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, + vector &model_partition_tables, + const vector> &all_partition_datas) { const Status ret = SaveWithFileHeader(file_path, file_header, model_partition_tables, all_partition_datas); GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "save file failed, file_path:%s, file header len:%u.", file_path.c_str(), file_header.length); @@ -372,8 +371,7 @@ Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFi return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(const string &file_path, const void *data, - int len) { +Status FileSaver::SaveToFile(const string &file_path, const void *data, int len) { if (data == nullptr || len <= 0) { GELOGE(FAILED, "[Check][Param]Failed, model_data is null or the " "length[%d] is less than 1.", len); diff --git a/ge/common/base64.h b/ge/common/base64.h index a537e585..22a78b46 100644 --- a/ge/common/base64.h +++ b/ge/common/base64.h @@ -20,8 +20,8 @@ #include #include -#include "debug/ge_log.h" -#include "ge_error_codes.h" +#include "framework/common/debug/ge_log.h" +#include "external/ge/ge_error_codes.h" namespace ge { namespace { diff --git a/ge/graph/common/bcast.cc b/ge/common/bcast.cc similarity index 99% rename from ge/graph/common/bcast.cc rename to ge/common/bcast.cc index b36b50b2..a4e8d1a1 100644 --- a/ge/graph/common/bcast.cc +++ b/ge/common/bcast.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/common/bcast.h" +#include "common/bcast.h" #include diff --git a/ge/graph/common/bcast.h b/ge/common/bcast.h similarity index 100% rename from ge/graph/common/bcast.h rename to ge/common/bcast.h diff --git a/ge/common/context/ctx.cc b/ge/common/context/ctx.cc index 9fe2f8c7..8e138ade 100755 --- a/ge/common/context/ctx.cc +++ b/ge/common/context/ctx.cc @@ -18,7 +18,7 @@ using ge::OmgContext; namespace domi { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OmgContext &GetContext() { +OmgContext &GetContext() { static OmgContext context; return context; } diff --git a/ge/common/cust_aicpu_kernel_store.h b/ge/common/cust_aicpu_kernel_store.h index 033a636b..38124587 100755 --- a/ge/common/cust_aicpu_kernel_store.h +++ b/ge/common/cust_aicpu_kernel_store.h @@ -21,7 +21,7 @@ namespace ge { -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY CustAICPUKernelStore : public KernelStore { +class CustAICPUKernelStore : public KernelStore { public: CustAICPUKernelStore(); ~CustAICPUKernelStore() {} diff --git a/ge/common/debug/memory_dumper.cc b/ge/common/debug/memory_dumper.cc index 78ef2daa..f4a49440 100644 --- a/ge/common/debug/memory_dumper.cc +++ b/ge/common/debug/memory_dumper.cc @@ -30,13 +30,12 @@ const int kInvalidFd = (-1); } // namespace namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY MemoryDumper::MemoryDumper() : fd_(kInvalidFd) {} +MemoryDumper::MemoryDumper() : fd_(kInvalidFd) {} -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY MemoryDumper::~MemoryDumper() { Close(); } +MemoryDumper::~MemoryDumper() { Close(); } // Dump the data to the file -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::DumpToFile(const char *filename, void *data, - int64_t len) { +Status MemoryDumper::DumpToFile(const char *filename, void *data, int64_t len) { #ifdef FMK_SUPPORT_DUMP GE_CHECK_NOTNULL(filename); GE_CHECK_NOTNULL(data); @@ -81,7 +80,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::DumpToFile } // Open file -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::Open(const char *filename) { +Status MemoryDumper::Open(const char *filename) { GE_CHK_BOOL_RET_STATUS(filename != nullptr, FAILED, "Incorrect parameter. filename is nullptr"); // Try to remove file first for reduce the close time by overwriting way @@ -104,7 +103,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::Open(const } // Dump the data to file -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::Dump(void *data, uint32_t len) const { +Status MemoryDumper::Dump(void *data, uint32_t len) const { GE_CHK_BOOL_RET_STATUS(data != nullptr, FAILED, "Incorrect parameter. data is nullptr"); #ifdef FMK_SUPPORT_DUMP diff --git a/ge/common/dump/dump_manager.cc b/ge/common/dump/dump_manager.cc index a6944fc6..da8160ff 100644 --- a/ge/common/dump/dump_manager.cc +++ b/ge/common/dump/dump_manager.cc @@ -15,6 +15,7 @@ */ #include "common/dump/dump_manager.h" + #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" @@ -26,14 +27,14 @@ const uint64_t kInferSessionId = 0; const uint32_t kAllOverflow = 3; } // namespace namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpManager &DumpManager::GetInstance() { +DumpManager &DumpManager::GetInstance() { static DumpManager instance; return instance; } bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump_properties) { if (dump_config.dump_status.empty() && dump_config.dump_debug.empty()) { - dump_properties_map_.emplace(kInferSessionId, dump_properties); + dump_properties_map_[kInferSessionId] = dump_properties; GELOGI("Dump does not open"); return false; } @@ -41,7 +42,7 @@ bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump if ((dump_config.dump_status == kDumpoff || dump_config.dump_status == kDumpOFF) && dump_config.dump_debug == kDumpoff) { dump_properties.ClearDumpPropertyValue(); - dump_properties_map_.emplace(kInferSessionId, dump_properties); + dump_properties_map_[kInferSessionId] = dump_properties; return false; } if (dump_config.dump_status == kDumpOn && dump_config.dump_debug == kDumpOn) { @@ -74,7 +75,7 @@ void DumpManager::SetDumpList(const DumpConfig &dump_config, DumpProperties &dum Status DumpManager::SetNormalDumpConf(const DumpConfig &dump_config, DumpProperties &dump_properties) { if (dump_config.dump_status == kDumpOn) { - GELOGI("Only do normal dump process, dump status is %s.", dump_config.dump_status.c_str()); + GELOGI("Only do normal dump process, dump status is %s", dump_config.dump_status.c_str()); dump_properties.SetDumpStatus(dump_config.dump_status); std::string dump_op_switch = dump_config.dump_op_switch; dump_properties.SetDumpOpSwitch(dump_op_switch); @@ -104,8 +105,8 @@ Status DumpManager::SetNormalDumpConf(const DumpConfig &dump_config, DumpPropert Status DumpManager::SetDumpPath(const DumpConfig &dump_config, DumpProperties &dump_properties) { std::string dump_path = dump_config.dump_path; if (dump_path.empty()) { - GELOGE(PARAM_INVALID, "[Check][DumpPath]It is empty"); - REPORT_INNER_ERROR("E19999", "Dump path check is empty"); + GELOGE(PARAM_INVALID, "[Check][DumpPath]It is empty."); + REPORT_INNER_ERROR("E19999", "Dump path check is empty."); return PARAM_INVALID; } if (dump_path[dump_path.size() - 1] != '/') { @@ -117,7 +118,7 @@ Status DumpManager::SetDumpPath(const DumpConfig &dump_config, DumpProperties &d return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpManager::SetDumpConf(const DumpConfig &dump_config) { +Status DumpManager::SetDumpConf(const DumpConfig &dump_config) { DumpProperties dump_properties; if (!NeedDoDump(dump_config, dump_properties)) { GELOGD("No need do dump process."); @@ -131,8 +132,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpManager::SetDumpConf return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const DumpProperties &DumpManager::GetDumpProperties( - uint64_t session_id) { +const DumpProperties &DumpManager::GetDumpProperties(uint64_t session_id) { std::lock_guard lock(mutex_); auto iter = dump_properties_map_.find(session_id); if (iter != dump_properties_map_.end()) { @@ -142,13 +142,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const DumpProperties &DumpManag return default_properties; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpManager::AddDumpProperties( - uint64_t session_id, const DumpProperties &dump_properties) { +void DumpManager::AddDumpProperties(uint64_t session_id, const DumpProperties &dump_properties) { std::lock_guard lock(mutex_); dump_properties_map_.emplace(session_id, dump_properties); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpManager::RemoveDumpProperties(uint64_t session_id) { +void DumpManager::RemoveDumpProperties(uint64_t session_id) { std::lock_guard lock(mutex_); auto iter = dump_properties_map_.find(session_id); if (iter != dump_properties_map_.end()) { diff --git a/ge/common/dump/dump_manager.h b/ge/common/dump/dump_manager.h index fa96de93..69152bcf 100644 --- a/ge/common/dump/dump_manager.h +++ b/ge/common/dump/dump_manager.h @@ -20,7 +20,7 @@ #include #include "common/dump/dump_properties.h" -#include "common/ge_types.h" +#include "framework/common/ge_types.h" namespace ge { class DumpManager { diff --git a/ge/common/dump/dump_op.h b/ge/common/dump/dump_op.h index b664495a..73922cb3 100755 --- a/ge/common/dump/dump_op.h +++ b/ge/common/dump/dump_op.h @@ -19,7 +19,7 @@ #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/properties_manager.h" #include "proto/op_mapping.pb.h" #include "runtime/stream.h" diff --git a/ge/common/dump/dump_properties.cc b/ge/common/dump/dump_properties.cc index 08bddf43..3bed76d9 100644 --- a/ge/common/dump/dump_properties.cc +++ b/ge/common/dump/dump_properties.cc @@ -18,9 +18,10 @@ #include #include +#include #include "common/ge/ge_util.h" -#include "common/util.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "framework/common/ge_types.h" @@ -37,17 +38,186 @@ const uint32_t kAtomicOverflow = (0x1 << 1); const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); } // namespace namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties::DumpProperties(const DumpProperties &other) { +void DumpProperties::Split(const std::string &s, std::vector &result, const char *delchar) { + if (s.empty()) { + return; + } + result.clear(); + + char *buffer = new (std::nothrow)char[s.size() + 1]; + if (buffer == nullptr) { + GELOGE(FAILED, "[Split][string] failed while malloc memory, string value is:%s", s.c_str()); + REPORT_CALL_ERROR("E19999", "Memory malloc may fail when split string, get fatal exception, " + "string value is:%s", s.c_str()); + return; + } + buffer[s.size()] = '\0'; + errno_t e = strcpy_s(buffer, s.size() + 1, s.c_str()); + if (e != EOK) { + delete[] buffer; + return; + } + char *context = nullptr; + char *p = strtok_s(buffer, delchar, &context); + while (p != nullptr) { + result.emplace_back(p); + p = strtok_s(nullptr, delchar, &context); + } + delete[] buffer; +} + +Status DumpProperties::CheckDumpStep(const std::string &dump_step) { + std::string modified_dum_step = dump_step + "|"; + std::smatch result; + std::vector match_vecs; + std::regex pattern(R"((\d{1,}-\d{1,}\||\d{1,}\|)+)"); + if (regex_match(modified_dum_step, result, pattern)) { + Split(result.str(), match_vecs, "|"); + if (match_vecs.empty()) { + REPORT_CALL_ERROR("E19999", "Split may get fatal exception, dump_step:%s.", dump_step.c_str()); + GELOGE(FAILED, "[Check][Param] failed. Split may get fatal exception, ge.exec.dumpStep:%s.", dump_step.c_str()); + return FAILED; + } + // 100 is the max sets of dump steps. + if (match_vecs.size() > 100) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ + "ge.exec.dumpStep", + dump_step.c_str(), + " is not supported, only support dump <= 100 sets of data"})); + GELOGE(PARAM_INVALID, "[Check][Param] get dump_step value:%s, " + "dump_step only support dump <= 100 sets of data.", dump_step.c_str()); + return PARAM_INVALID; + } + for (const auto &match_vec : match_vecs) { + std::vector vec_after_split; + Split(match_vec, vec_after_split, "-"); + if (match_vecs.empty()) { + REPORT_CALL_ERROR("E19999", "Split may get fatal exception."); + GELOGE(FAILED, "[Check][Param] failed, split may get fatal exception."); + return FAILED; + } + if (vec_after_split.size() > 1) { + if (std::atoi(vec_after_split[0].c_str()) >= std::atoi(vec_after_split[1].c_str())) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ + "ge.exec.dumpStep", + dump_step.c_str(), + " is not supported." + "in range steps, the first step is >= second step, correct example:'0|5|10-20"})); + GELOGE(PARAM_INVALID, "[Check][Param] get dump_step value:%s, " + "in range steps, the first step is >= second step, correct example:'0|5|10-20'", dump_step.c_str()); + return PARAM_INVALID; + } + } + } + } else { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ + "ge.exec.dumpStep", + dump_step.c_str(), + " is not supported, correct example:'0|5|10|50-100."})); + GELOGE(PARAM_INVALID, "[Check][Param] get dump_step value:%s, " + "dump_step string style is error, correct example:'0|5|10|50-100.'", dump_step.c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + +Status DumpProperties::CheckDumpMode(const std::string &dump_mode) { + const std::set dump_mode_list = {"input", "output", "all"}; + std::set::iterator iter; + + if ((iter = dump_mode_list.find(dump_mode)) == dump_mode_list.end()) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ + "ge.exec.dumpMode", + dump_mode.c_str(), + " is not supported, should be one of the following:[input, output, all]"})); + GELOGE(PARAM_INVALID, "[Check][Param] the dump_debug_mode:%s, is is not supported," + "should be one of the following:[input, output, all].", dump_mode.c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + +Status DumpProperties::CheckDumpPath(const std::string &input) { + if (mmIsDir(input.c_str()) != EN_OK) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ + "ge.exec.dumpPath", + input.c_str(), + " is not a directory."})); + GELOGE(PARAM_INVALID, "[Check][Param] the path:%s, is not directory.", input.c_str()); + return PARAM_INVALID; + } + char trusted_path[MMPA_MAX_PATH] = { "\0" }; + if (mmRealPath(input.c_str(), trusted_path, MMPA_MAX_PATH) != EN_OK) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ + "ge.exec.dumpPath", + input.c_str(), + " dumpPath invalid."})); + GELOGE(PARAM_INVALID, "[Check][Param] the dumpPath:%s, is invalid.", input.c_str()); + return PARAM_INVALID; + } + if (mmAccess2(trusted_path, M_R_OK | M_W_OK) != EN_OK) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ + "ge.exec.dumpPath", + input.c_str(), + " does't have read, write permissions."})); + GELOGE(PARAM_INVALID, "[Check][Param] the path:%s, does't have read, write permissions.", input.c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + +Status DumpProperties::CheckEnableDump(const std::string &input) { + std::set enable_dump_option_list = {"1", "0"}; + auto it = enable_dump_option_list.find(input); + if (it == enable_dump_option_list.end()) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ + "ge.exec.enableDump", + input.c_str(), + " only support 1 or 0."})); + GELOGE(PARAM_INVALID, "[Check][Param] Not support ge.exec.enableDump or ge.exec.enableDumpDebug format:%s, " + "only support 1 or 0.", input.c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + +DumpProperties::DumpProperties(const DumpProperties &other) { CopyFrom(other); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &DumpProperties::operator=( - const DumpProperties &other) { +DumpProperties &DumpProperties::operator=(const DumpProperties &other) { CopyFrom(other); return *this; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOptions() { +Status DumpProperties::SetDumpOptions() { + if (enable_dump_ == kEnableFlag) { + std::string dump_step; + if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS && !dump_step.empty()) { + GE_CHK_STATUS_RET(CheckDumpStep(dump_step), "[Check][dump_step] failed."); + GELOGI("Get dump step %s successfully", dump_step.c_str()); + SetDumpStep(dump_step); + } + string dump_mode = "output"; + if (GetContext().GetOption(OPTION_EXEC_DUMP_MODE, dump_mode) == GRAPH_SUCCESS) { + GELOGI("Get dump mode %s successfully", dump_mode.c_str()); + GE_CHK_STATUS_RET(CheckDumpMode(dump_mode), "[Check][dump_mode] failed."); + SetDumpMode(dump_mode); + } + AddPropertyValue(DUMP_ALL_MODEL, {}); + } + return SUCCESS; +} + +Status DumpProperties::InitByOptions() { enable_dump_.clear(); enable_dump_debug_.clear(); dump_path_.clear(); @@ -57,17 +227,32 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOpti is_infer_op_debug_ = false; op_debug_mode_ = 0; - std::string enable_dump; + std::string enable_dump = std::to_string(false); (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP, enable_dump); enable_dump_ = enable_dump; + if (!enable_dump_.empty()) { + GE_CHK_STATUS_RET(CheckEnableDump(enable_dump_), "[Check][enable_dump] failed."); + } - std::string enable_dump_debug; + std::string enable_dump_debug = std::to_string(false); (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP_DEBUG, enable_dump_debug); enable_dump_debug_ = enable_dump_debug; - + if (!enable_dump_debug_.empty()) { + GE_CHK_STATUS_RET(CheckEnableDump(enable_dump_debug_), "[Check][enable_dump_debug] failed."); + } + if ((enable_dump_ == kEnableFlag) && (enable_dump_debug_ == kEnableFlag)) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ + "ge.exec.enableDump and ge.exec.enableDumpDebug", + enable_dump_ + ", " + enable_dump_debug, + "ge.exec.enableDump and ge.exec.enableDumpDebug cannot be set to 1 at the same time."})); + GELOGE(FAILED, "ge.exec.enableDump and ge.exec.enableDumpDebug cannot be both set to 1 at the same time."); + return FAILED; + } if ((enable_dump_ == kEnableFlag) || (enable_dump_debug_ == kEnableFlag)) { std::string dump_path; if (GetContext().GetOption(OPTION_EXEC_DUMP_PATH, dump_path) == GRAPH_SUCCESS) { + GE_CHK_STATUS_RET(CheckDumpPath(dump_path), "Check dump path failed."); if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { dump_path = dump_path + "/"; } @@ -75,30 +260,25 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOpti GELOGI("Get dump path %s successfully", dump_path.c_str()); SetDumpPath(dump_path); } else { - GELOGW("Dump path is not set"); + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ + "ge.exec.dumpPath", + dump_path, + "ge.exec.dumpPath is not set."})); + GELOGE(FAILED, "[Check][dump_path] failed. Dump path is not set."); + return FAILED; } } - if (enable_dump_ == kEnableFlag) { - std::string dump_step; - if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS) { - GELOGI("Get dump step %s successfully", dump_step.c_str()); - SetDumpStep(dump_step); - } - string dump_mode; - if (GetContext().GetOption(OPTION_EXEC_DUMP_MODE, dump_mode) == GRAPH_SUCCESS) { - GELOGI("Get dump mode %s successfully", dump_mode.c_str()); - SetDumpMode(dump_mode); - } - AddPropertyValue(DUMP_ALL_MODEL, {}); - } + GE_CHK_STATUS_RET(SetDumpOptions(), "SetDumpOptions failed."); + + GE_CHK_STATUS_RET(SetDumpDebugOptions(), "SetDumpDebugOptions failed."); - SetDumpDebugOptions(); + return SUCCESS; } // The following is the new dump scenario of the fusion operator -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::AddPropertyValue( - const std::string &model, const std::set &layers) { +void DumpProperties::AddPropertyValue(const std::string &model, const std::set &layers) { for (const std::string &layer : layers) { GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str()); } @@ -106,18 +286,18 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::AddPropert model_dump_properties_map_[model] = layers; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::DeletePropertyValue(const std::string &model) { +void DumpProperties::DeletePropertyValue(const std::string &model) { auto iter = model_dump_properties_map_.find(model); if (iter != model_dump_properties_map_.end()) { model_dump_properties_map_.erase(iter); } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::ClearDumpPropertyValue() { +void DumpProperties::ClearDumpPropertyValue() { model_dump_properties_map_.clear(); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::ClearDumpInfo() { +void DumpProperties::ClearDumpInfo() { enable_dump_.clear(); enable_dump_debug_.clear(); dump_path_.clear(); @@ -130,7 +310,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::ClearDumpI op_debug_mode_ = 0; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetAllDumpModel() const { +std::set DumpProperties::GetAllDumpModel() const { std::set model_list; for (auto &iter : model_dump_properties_map_) { model_list.insert(iter.first); @@ -139,8 +319,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpPrope return model_list; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetPropertyValue( - const std::string &model) const { +std::set DumpProperties::GetPropertyValue(const std::string &model) const { auto iter = model_dump_properties_map_.find(model); if (iter != model_dump_properties_map_.end()) { return iter->second; @@ -148,8 +327,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpPrope return {}; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsLayerNeedDump( - const std::string &model, const std::string &om_name, const std::string &op_name) const { +bool DumpProperties::IsLayerNeedDump(const std::string &model, const std::string &om_name, + const std::string &op_name) const { // if dump all GELOGD("model name is %s om name is %s op is %s in layer need dump", model.c_str(), om_name.c_str(), op_name.c_str()); if (model_dump_properties_map_.find(DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { @@ -169,67 +348,66 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsLayerNee return model_iter->second.find(op_name) != model_iter->second.end(); } - GELOGD("Model %s is not seated to be dump.", model.c_str()); + GELOGD("Model %s is not seated to be dump", model.c_str()); return false; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpPath(const std::string &path) { +void DumpProperties::SetDumpPath(const std::string &path) { dump_path_ = path; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpPath() const { +const std::string &DumpProperties::GetDumpPath() const { return dump_path_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStep(const std::string &step) { +void DumpProperties::SetDumpStep(const std::string &step) { dump_step_ = step; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpStep() const { +const std::string &DumpProperties::GetDumpStep() const { return dump_step_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpMode(const std::string &mode) { +void DumpProperties::SetDumpMode(const std::string &mode) { dump_mode_ = mode; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpMode() const { +const std::string &DumpProperties::GetDumpMode() const { return dump_mode_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStatus(const std::string &status) { +void DumpProperties::SetDumpStatus(const std::string &status) { dump_status_ = status; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpStatus() const { +const std::string &DumpProperties::GetDumpStatus() const { return dump_status_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitInferOpDebug() { +void DumpProperties::InitInferOpDebug() { is_infer_op_debug_ = true; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetOpDebugMode(const uint32_t &op_debug_mode) { +void DumpProperties::SetOpDebugMode(const uint32_t &op_debug_mode) { op_debug_mode_ = op_debug_mode; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpOpSwitch( - const std::string &dump_op_switch) { +void DumpProperties::SetDumpOpSwitch(const std::string &dump_op_switch) { dump_op_switch_ = dump_op_switch; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpOpSwitch() const { +const std::string &DumpProperties::GetDumpOpSwitch() const { return dump_op_switch_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsSingleOpNeedDump() const { +bool DumpProperties::IsSingleOpNeedDump() const { if (dump_op_switch_ == kDumpStatusOpen) { return true; } return false; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsDumpOpen() const { +bool DumpProperties::IsDumpOpen() const { if (enable_dump_ == kEnableFlag || dump_status_ == kDumpStatusOpen) { return true; } @@ -253,14 +431,14 @@ void DumpProperties::CopyFrom(const DumpProperties &other) { } } -void DumpProperties::SetDumpDebugOptions() { +Status DumpProperties::SetDumpDebugOptions() { if (enable_dump_debug_ == kEnableFlag) { std::string dump_debug_mode; if (GetContext().GetOption(OPTION_EXEC_DUMP_DEBUG_MODE, dump_debug_mode) == GRAPH_SUCCESS) { - GELOGD("Get dump debug mode %s successfully", dump_debug_mode.c_str()); + GELOGD("Get ge.exec.dumpDebugMode %s successfully.", dump_debug_mode.c_str()); } else { - GELOGW("Dump debug mode is not set."); - return; + GELOGW("ge.exec.dumpDebugMode is not set."); + return SUCCESS; } if (dump_debug_mode == OP_DEBUG_AICORE) { @@ -276,10 +454,17 @@ void DumpProperties::SetDumpDebugOptions() { is_train_op_debug_ = true; op_debug_mode_ = kAllOverflow; } else { - GELOGW("ge.exec.dumpDebugMode is invalid."); + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ + "ge.exec.dumpDebugMode", + dump_debug_mode, + "ge.exec.dumpDebugMode is invalid."})); + GELOGE(PARAM_INVALID, "[Set][DumpDebugOptions] failed, ge.exec.dumpDebugMode is invalid."); + return PARAM_INVALID; } } else { - GELOGI("ge.exec.enableDumpDebug is false or is not set."); + GELOGI("ge.exec.enableDumpDebug is false or is not set"); } + return SUCCESS; } } // namespace ge diff --git a/ge/common/dump/dump_properties.h b/ge/common/dump/dump_properties.h index 98487491..cbfc362d 100644 --- a/ge/common/dump/dump_properties.h +++ b/ge/common/dump/dump_properties.h @@ -23,6 +23,7 @@ #include namespace ge { +using Status = uint32_t; class DumpProperties { public: DumpProperties() = default; @@ -33,7 +34,7 @@ class DumpProperties { DumpProperties &operator=(const DumpProperties &dump); - void InitByOptions(); + Status InitByOptions(); void AddPropertyValue(const std::string &model, const std::set &layers); @@ -95,7 +96,20 @@ class DumpProperties { private: void CopyFrom(const DumpProperties &other); - void SetDumpDebugOptions(); + Status SetDumpDebugOptions(); + + Status SetDumpOptions(); + + void Split(const std::string &s, std::vector &result, const char *delchar); + + Status CheckDumpStep(const std::string &dump_step); + + Status CheckDumpMode(const std::string &dump_mode); + + Status CheckDumpPath(const std::string &input); + + Status CheckEnableDump(const std::string &input); + std::string enable_dump_; std::string enable_dump_debug_; diff --git a/ge/common/dump/exception_dumper.cc b/ge/common/dump/exception_dumper.cc index c8ec3d35..c41da551 100644 --- a/ge/common/dump/exception_dumper.cc +++ b/ge/common/dump/exception_dumper.cc @@ -161,6 +161,7 @@ Status ExceptionDumper::DumpExceptionInfo(const std::vector &ex uint64_t proto_size = dump_data.ByteSizeLong(); std::unique_ptr proto_msg(new (std::nothrow) char[proto_size]); + GE_CHECK_NOTNULL(proto_msg); bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); if (!ret || proto_size == 0) { REPORT_INNER_ERROR("E19999", "Serialize proto to string fail"); diff --git a/ge/common/dump/opdebug_register.cc b/ge/common/dump/opdebug_register.cc index 816455a0..41c85ff6 100644 --- a/ge/common/dump/opdebug_register.cc +++ b/ge/common/dump/opdebug_register.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "opdebug_register.h" +#include "common/dump/opdebug_register.h" namespace { const size_t kOpDebugMemorySize = 2048UL; diff --git a/ge/common/dump/opdebug_register.h b/ge/common/dump/opdebug_register.h index 1826287d..5b927b67 100644 --- a/ge/common/dump/opdebug_register.h +++ b/ge/common/dump/opdebug_register.h @@ -18,8 +18,8 @@ #define GE_COMMON_DUMP_OPDEBUG_REGISTER_H_ #include -#include "common/debug/ge_log.h" -#include "common/debug/log.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" #include "graph/load/model_manager/data_dumper.h" namespace ge { diff --git a/ge/common/executor.h b/ge/common/executor.h new file mode 100644 index 00000000..7f1d7ef9 --- /dev/null +++ b/ge/common/executor.h @@ -0,0 +1,89 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GE_COMMON_EXECUTOR_H +#define GE_COMMON_EXECUTOR_H + +#include "external/ge/ge_api_types.h" +#include "graph/ge_local_context.h" +#include "graph/manager/graph_manager_utils.h" + +namespace ge { +struct RunArgs { + GraphNodePtr graph_node; + GraphId graph_id; + uint64_t session_id; + struct error_message::Context error_context; + std::vector input_tensor; + GeRootModelPtr ge_root_model; + GEThreadLocalContext context; + RunAsyncCallback callback; +}; + +class Executor { + public: + /// + /// @ingroup ge + /// @brief Load mode from graph. + /// @param [in] GeRootModel: root model of graph compiled. + /// @param [in] GraphNode: node of graph. + /// @return Status result of function + /// + virtual Status LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) = 0; + + /// + /// @ingroup ge + /// @brief Unload mode. + /// @param [in] GeRootModel: root model of graph compiled. + /// @param [in] graph_id: graph identifier. + /// @return Status result of function + /// + virtual Status UnloadGraph(const GeRootModelPtr &ge_root_model, uint32_t graph_id) = 0; + + /// + /// @ingroup ge + /// @brief Push model execution params to queue. + /// @param [in] RunArgs of for model execution. + /// @return Status result of function + /// + virtual Status PushGraph(const RunArgs &args) = 0; + + /// + /// @ingroup ge + /// @brief Run graph for synchronize model. + /// @param [in] graph_node: node of graph. + /// @param [in] graph_id: graph identifier. + /// @param [in] inputs: input data for the graph running. + /// @param [out] outputs: output data of the graph running + /// @return Status result of function + /// + virtual Status RunGraph(const GraphNodePtr &graph_node, GraphId graph_id, + const std::vector &inputs, std::vector &outputs) = 0; + + /// + /// @ingroup ge + /// @brief Run graph for NN synchronize model. + /// @param [in] graph_node: node of graph. + /// @param [in] graph_id: graph identifier. + /// @param [in] stream: Stream for model running. + /// @param [in] inputs: input data for the graph running. + /// @param [out] outputs: output data of the graph running + /// @return Status result of function + /// + virtual Status RunGraphWithStream(const GraphNodePtr &graph_node, GraphId graph_id, rtStream_t stream, + const std::vector &inputs, std::vector &outputs) = 0; +}; +} +#endif // GE_COMMON_EXECUTOR_H diff --git a/ge/common/fmk_error_codes.cc b/ge/common/fmk_error_codes.cc index ddb8089d..180af0e2 100755 --- a/ge/common/fmk_error_codes.cc +++ b/ge/common/fmk_error_codes.cc @@ -17,19 +17,18 @@ #include "framework/common/fmk_error_codes.h" namespace domi { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY StatusFactory *StatusFactory::Instance() { +StatusFactory *StatusFactory::Instance() { static StatusFactory instance; return &instance; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void StatusFactory::RegisterErrorNo(uint32_t err, - const std::string &desc) { +void StatusFactory::RegisterErrorNo(uint32_t err, const std::string &desc) { if (err_desc_.find(err) != err_desc_.end()) { return; } err_desc_[err] = desc; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string StatusFactory::GetErrDesc(uint32_t err) { +std::string StatusFactory::GetErrDesc(uint32_t err) { auto iter_find = err_desc_.find(err); if (iter_find == err_desc_.end()) { return ""; diff --git a/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc b/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc index ce271c6d..aae95584 100644 --- a/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc +++ b/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc @@ -123,6 +123,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc index 24be6023..4f597e32 100755 --- a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc @@ -59,7 +59,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { return CheckShapeValid(shape, kDimSize4D); default: std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + - " and FORMAT_FRACTAL_NZ is not supported."; + " and FORMAT_FRACTAL_NZ is not supported."; GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); return false; } @@ -185,6 +185,7 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con auto src_offset = (src_h_head + w1_idx * w0) * size; auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size * w0)); if (ret != EOK) { @@ -202,6 +203,7 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con auto src_offset = (src_h_head + src_w_idx) * size; auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { @@ -267,6 +269,7 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con auto dst_offset = (dst_h_head + w1_idx * w0) * size; auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size * w0)); if (ret != EOK) { @@ -285,6 +288,7 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con auto dst_offset = (dst_h_head + dst_w_idx) * size; auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index ddce348b..882a2a68 100644 --- a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -19,7 +19,7 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/formats/utils/formats_definitions.h" #include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" @@ -226,6 +226,7 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { auto protected_size = dst_size - offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); errno_t ret = EOK; if (need_pad_zero) { ret = memset_s(dst.get() + offset, static_cast(protected_size), 0, static_cast(size)); @@ -390,6 +391,7 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); errno_t ret = EOK; if (pad_zero) { @@ -474,6 +476,7 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); errno_t ret = EOK; if (pad_zero) { diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc index 1cb142b3..14315084 100755 --- a/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc @@ -193,6 +193,7 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size * w0)); if (ret != EOK) { @@ -213,6 +214,7 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { @@ -284,6 +286,7 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size * w0)); if (ret != EOK) { @@ -304,6 +307,7 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc b/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc old mode 100755 new mode 100644 index f6af7534..ed3a062c --- a/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc +++ b/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc @@ -17,6 +17,7 @@ #include "common/formats/format_transfers/format_transfer_fracz_hwcn.h" #include + #include #include "common/formats/utils/formats_definitions.h" @@ -35,8 +36,8 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { auto dst_shape = args.dst_shape; if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_HWCN) { std::string error = "Dose not support trans format from " + - FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + - FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); + FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + + FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); return ACL_ERROR_GE_FORMAT_INVALID; } @@ -52,15 +53,13 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { if (!CheckShapeValid(src_shape, kFracZDimsNum)) { GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", ShapeToString(src_shape).c_str()); - REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", - ShapeToString(src_shape).c_str()); + REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", ShapeToString(src_shape).c_str()); return ACL_ERROR_GE_SHAPE_INVALID; } if (!CheckShapeValid(dst_shape, kHwcnDimsNum)) { GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", ShapeToString(dst_shape).c_str()); - REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", - ShapeToString(dst_shape).c_str()); + REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", ShapeToString(dst_shape).c_str()); return ACL_ERROR_GE_SHAPE_INVALID; } int64_t c0 = GetCubeSizeByDataType(args.src_data_type); @@ -71,9 +70,8 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { int64_t n0 = Ceil(dst_shape.at(kHwcnN), static_cast(kNiSize)); if (src_shape.at(kFracZHWC1) != dst_shape.at(kHwcnH) * dst_shape.at(kHwcnW) * c1 || src_shape.at(kFracZC0) != c0 || src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { - std::string error = "Failed to check relationship between src shape" + - FmtToStr(ShapeToString(src_shape)) + " and dst shape" + - FmtToStr(ShapeToString(dst_shape)); + std::string error = "Failed to check relationship between src shape" + FmtToStr(ShapeToString(src_shape)) + + " and dst shape" + FmtToStr(ShapeToString(dst_shape)); GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); return ACL_ERROR_GE_SHAPE_INVALID; } @@ -128,6 +126,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto dst_offset = dst_idx * size; auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc b/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc index aaeca490..58073397 100755 --- a/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc +++ b/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc @@ -130,6 +130,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto dst_offset = dst_idx * size; auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc b/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc index 1e71ea09..3122f137 100755 --- a/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc +++ b/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc @@ -128,6 +128,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size auto dst_offset = dst_idx * size; auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc index cb7f889b..c597cde0 100755 --- a/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc +++ b/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc @@ -149,6 +149,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); int64_t c_idx = c0_idx + c1_idx * c0; int64_t src_idx = h_idx * wcn + w_idx * cn + c_idx * n + n_idx; auto src_offset = src_idx * size; diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc index 09ff45d9..c442bee9 100755 --- a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc +++ b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc @@ -129,6 +129,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc index e9e41cd1..603ddffa 100755 --- a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc +++ b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc @@ -129,6 +129,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc b/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc index 5efe486c..88de4d14 100644 --- a/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc +++ b/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc @@ -23,7 +23,7 @@ #include "common/formats/utils/formats_definitions.h" #include "common/formats/utils/formats_trans_utils.h" -#include "common/util.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "graph/utils/type_utils.h" diff --git a/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc b/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc index ea2b1d7f..5cab311d 100755 --- a/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc +++ b/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc @@ -144,6 +144,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); int64_t cIdx = c0_idx + c1_idx * c0; int64_t srcIdx = n_idx * chw + cIdx * hw + h_idx * w + w_idx; auto src_offset = srcIdx * size; diff --git a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc index 518790b6..939c967c 100755 --- a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc +++ b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc @@ -149,6 +149,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in auto protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) ? total_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); int64_t c_idx = c0_idx + c1_idx * c0; int64_t src_idx = n_idx * hwc + h_idx * wc + w_idx * c + c_idx; auto src_offset = src_idx * size; diff --git a/ge/common/formats/format_transfers/format_transfer_transpose.cc b/ge/common/formats/format_transfers/format_transfer_transpose.cc index 54c5444b..9a4d3fd6 100755 --- a/ge/common/formats/format_transfers/format_transfer_transpose.cc +++ b/ge/common/formats/format_transfers/format_transfer_transpose.cc @@ -171,6 +171,7 @@ Status Transpose(const uint8_t *src, const std::vector &src_shape, Data auto protected_size = dst_size - dst_offset_bytes < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - dst_offset_bytes : static_cast(SECUREC_MEM_MAX_LEN); + GE_CHECK_GE(protected_size, 0); auto ret = memcpy_s(dst.get() + dst_offset_bytes, static_cast(protected_size), src + src_offset, static_cast(data_size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_transpose.h b/ge/common/formats/format_transfers/format_transfer_transpose.h index 7fa19ff0..b608777c 100755 --- a/ge/common/formats/format_transfers/format_transfer_transpose.h +++ b/ge/common/formats/format_transfers/format_transfer_transpose.h @@ -33,7 +33,6 @@ Status TransposeWithShapeCheck(const uint8_t *src, const std::vector &s Status GetPermByForamt(Format src_format, Format dst_format, std::vector &perm); - class FormatTransferTranspose : public FormatTransfer { public: Status TransFormat(const TransArgs &args, TransResult &result) override; diff --git a/ge/common/formats/formats.cc b/ge/common/formats/formats.cc index 9e97a4d2..5a454d60 100755 --- a/ge/common/formats/formats.cc +++ b/ge/common/formats/formats.cc @@ -17,6 +17,7 @@ #include "common/formats/formats.h" #include + #include #include #include @@ -32,7 +33,7 @@ namespace ge { namespace formats { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArgs &args, TransResult &result) { +Status TransFormat(const TransArgs &args, TransResult &result) { auto transfer = BuildFormatTransfer(args); if (transfer == nullptr) { std::string error = "Failed to trans data from format " + @@ -56,11 +57,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArg return transfer->TransFormat(args, result); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_format, - const std::vector &src_shape, - DataType data_type, - Format dst_format, - std::vector &dst_shape) { +Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, + std::vector &dst_shape) { formats::TransArgs args; args.src_format = src_format; args.dst_format = dst_format; @@ -76,7 +74,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_form return transfer->TransShape(src_format, src_shape, data_type, dst_format, dst_shape); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastArgs &args, TransResult &result) { +Status TransDataType(const CastArgs &args, TransResult &result) { auto transfer = BuildDataTypeTransfer(args); if (transfer == nullptr) { std::string error = "Failed to trans data from datatype " + @@ -95,11 +93,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastAr return transfer->TransDataType(args, result); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool IsTransFormatSupport(const TransArgs &args) { +bool IsTransFormatSupport(const TransArgs &args) { return FormatTransferExists(args); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool IsTransDataTypeSupport(const CastArgs &args) { +bool IsTransDataTypeSupport(const CastArgs &args) { return DataTypeTransferExists(args); } } // namespace formats diff --git a/ge/common/formats/utils/formats_trans_utils.cc b/ge/common/formats/utils/formats_trans_utils.cc index 052951ce..63ad424f 100755 --- a/ge/common/formats/utils/formats_trans_utils.cc +++ b/ge/common/formats/utils/formats_trans_utils.cc @@ -41,14 +41,32 @@ int64_t GetCubeSizeByDataType(DataType data_type) { } } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string ShapeToString(const GeShape &shape) { +std::string ShapeToString(const GeShape &shape) { return ShapeToString(shape.GetDims()); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string ShapeToString(const std::vector &shape) { +std::string ShapeToString(const std::vector &shape) { return JoinToString(shape); } +std::string RangeToString(const std::vector> &ranges) { + bool first = true; + std::stringstream ss; + ss << "["; + for (const auto &range : ranges) { + if (first) { + first = false; + } else { + ss << ","; + } + ss << "{"; + ss << range.first << "," << range.second; + ss << "}"; + } + ss << "]"; + return ss.str(); +} + int64_t GetItemNumByShape(const std::vector &shape) { int64_t num = 1; for (auto dim : shape) { diff --git a/ge/common/formats/utils/formats_trans_utils.h b/ge/common/formats/utils/formats_trans_utils.h index 848e8b3a..64f9f820 100755 --- a/ge/common/formats/utils/formats_trans_utils.h +++ b/ge/common/formats/utils/formats_trans_utils.h @@ -54,6 +54,8 @@ std::string ShapeToString(const GeShape &shape); std::string ShapeToString(const std::vector &shape); +std::string RangeToString(const std::vector> &ranges); + int64_t GetItemNumByShape(const std::vector &shape); bool CheckShapeValid(const std::vector &shape, const int64_t expect_dims); diff --git a/ge/common/fp16_t.cc b/ge/common/fp16_t.cc index 2f94323d..adb55dfb 100755 --- a/ge/common/fp16_t.cc +++ b/ge/common/fp16_t.cc @@ -1180,20 +1180,40 @@ fp16_t &fp16_t::operator=(const double &d_val) { } // convert -fp16_t::operator float() const { return Fp16ToFloat(val); } -fp16_t::operator double() const { return Fp16ToDouble(val); } -fp16_t::operator int8_t() const { return Fp16ToInt8(val); } -fp16_t::operator uint8_t() const { return Fp16ToUInt8(val); } -fp16_t::operator int16_t() const { return Fp16ToInt16(val); } -fp16_t::operator uint16_t() const { return Fp16ToUInt16(val); } -fp16_t::operator int32_t() const { return Fp16ToInt32(val); } -fp16_t::operator uint32_t() const { return Fp16ToUInt32(val); } +fp16_t::operator float() const { + return Fp16ToFloat(val); +} +fp16_t::operator double() const { + return Fp16ToDouble(val); +} +fp16_t::operator int8_t() const { + return Fp16ToInt8(val); +} +fp16_t::operator uint8_t() const { + return Fp16ToUInt8(val); +} +fp16_t::operator int16_t() const { + return Fp16ToInt16(val); +} +fp16_t::operator uint16_t() const { + return Fp16ToUInt16(val); +} +fp16_t::operator int32_t() const { + return Fp16ToInt32(val); +} +fp16_t::operator uint32_t() const { + return Fp16ToUInt32(val); +} // Cannot be used, just in order to solve the compile error -fp16_t::operator int64_t() const { return 0; } +fp16_t::operator int64_t() const { + return 0; +} // Cannot be used, just in order to solve the compile error -fp16_t::operator uint64_t() const { return 0; } +fp16_t::operator uint64_t() const { + return 0; +} -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() { +int fp16_t::IsInf() { if ((val & kFp16AbsMax) == kFp16ExpMask) { if (val & kFp16SignMask) { return -1; @@ -1205,12 +1225,28 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() { } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY float fp16_t::ToFloat() const { return Fp16ToFloat(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY double fp16_t::ToDouble() const { return Fp16ToDouble(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int8_t fp16_t::ToInt8() const { return Fp16ToInt8(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint8_t fp16_t::ToUInt8() const { return Fp16ToUInt8(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int16_t fp16_t::ToInt16() const { return Fp16ToInt16(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint16_t fp16_t::ToUInt16() const { return Fp16ToUInt16(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int32_t fp16_t::ToInt32() const { return Fp16ToInt32(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t fp16_t::ToUInt32() const { return Fp16ToUInt32(val); } +float fp16_t::ToFloat() const { + return Fp16ToFloat(val); +} +double fp16_t::ToDouble() const { + return Fp16ToDouble(val); +} +int8_t fp16_t::ToInt8() const { + return Fp16ToInt8(val); +} +uint8_t fp16_t::ToUInt8() const { + return Fp16ToUInt8(val); +} +int16_t fp16_t::ToInt16() const { + return Fp16ToInt16(val); +} +uint16_t fp16_t::ToUInt16() const { + return Fp16ToUInt16(val); +} +int32_t fp16_t::ToInt32() const { + return Fp16ToInt32(val); +} +uint32_t fp16_t::ToUInt32() const { + return Fp16ToUInt32(val); +} } // namespace ge diff --git a/ge/common/ge/datatype_util.h b/ge/common/ge/datatype_util.h index e42b25a7..82c8d259 100644 --- a/ge/common/ge/datatype_util.h +++ b/ge/common/ge/datatype_util.h @@ -20,7 +20,7 @@ #include #include -#include "graph/types.h" +#include "external/graph/types.h" namespace ge { static const int32_t kGeSizeFloat = sizeof(float); @@ -42,7 +42,7 @@ static std::map CONST_OPDATA_TYPE_SIZE_MAP = { {ge::DT_UINT8, kGeSizeUint8}, {ge::DT_UINT16, kGeSizeUint16}, {ge::DT_UINT32, kGeSizeUint32}, {ge::DT_UINT64, kGeSizeUint64}, {ge::DT_DOUBLE, kGeSizeDouble}, {ge::DT_BOOL, kGeSizeBool}}; -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY DataTypeUtil { +class DataTypeUtil { public: static bool DataTypeTranslatable(const ge::DataType &src_out_data_type, const ge::DataType &dst_in_data_type); static const std::vector &GetTranslatableDataTypesBySrc(const ge::DataType &src_out_data_type); diff --git a/ge/common/ge/plugin_manager.h b/ge/common/ge/plugin_manager.h index 8c351e62..0869704f 100755 --- a/ge/common/ge/plugin_manager.h +++ b/ge/common/ge/plugin_manager.h @@ -26,8 +26,8 @@ #include #include -#include "common/ge_inner_error_codes.h" -#include "engine/dnnengine.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/engine/dnnengine.h" #include "framework/common/debug/ge_log.h" #include "mmpa/mmpa_api.h" diff --git a/ge/common/ge/tbe_plugin_manager.cc b/ge/common/ge/tbe_plugin_manager.cc index 94ba8a9a..3680a8bb 100755 --- a/ge/common/ge/tbe_plugin_manager.cc +++ b/ge/common/ge/tbe_plugin_manager.cc @@ -42,7 +42,7 @@ const int kBaseInt = 10; std::map TBEPluginManager::options_ = {}; // Get Singleton Instance -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEPluginManager &TBEPluginManager::Instance() { +TBEPluginManager &TBEPluginManager::Instance() { static TBEPluginManager instance_ptr_; return instance_ptr_; } @@ -61,7 +61,7 @@ Status TBEPluginManager::ClearHandles_() { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status TBEPluginManager::Finalize() { +Status TBEPluginManager::Finalize() { Status ret = ClearHandles_(); return ret; } @@ -104,7 +104,15 @@ void TBEPluginManager::ProcessSoFullName(vector &file_list, string &caff } } -void TBEPluginManager::FindParserSo(const string &path, vector &file_list, string &caffe_parser_path) { +void TBEPluginManager::FindParserSo(const string &path, vector &file_list, + string &caffe_parser_path, uint32_t recursive_depth) { + static const uint32_t max_recursive_depth = 20; // For recursive depth protection + + if (recursive_depth >= max_recursive_depth) { + GELOGW("Recursive depth is become %u, Please check input!", recursive_depth); + return; + } + // Path, change to absolute path string real_path = RealPath(path.c_str()); // Plugin path does not exist @@ -138,7 +146,7 @@ void TBEPluginManager::FindParserSo(const string &path, vector &file_lis ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff, aicpu_host_so_suff); } else { - FindParserSo(full_name, file_list, caffe_parser_path); + FindParserSo(full_name, file_list, caffe_parser_path, recursive_depth + 1); } } mmScandirFree(entries, ret); @@ -199,7 +207,6 @@ void TBEPluginManager::LoadCustomOpLib() { } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::LoadPluginSo(const std::map &options) { vector file_list; string caffe_parser_path; @@ -238,7 +245,6 @@ void TBEPluginManager::LoadPluginSo(const std::map &options) { } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::InitPreparation(const std::map &options) { options_.insert(options.begin(), options.end()); // Load TBE plugin diff --git a/ge/common/ge/tbe_plugin_manager.h b/ge/common/ge/tbe_plugin_manager.h index 4bd8c6e3..eada3e64 100755 --- a/ge/common/ge/tbe_plugin_manager.h +++ b/ge/common/ge/tbe_plugin_manager.h @@ -57,7 +57,8 @@ class TBEPluginManager { static void ProcessSoFullName(vector &file_list, string &caffe_parser_path, string &full_name, const string &caffe_parser_so_suff, const string &aicpu_so_suff, const string &aicpu_host_so_suff); - static void FindParserSo(const string &path, vector &file_list, string &caffe_parser_path); + static void FindParserSo(const string &path, vector &file_list, string &caffe_parser_path, + uint32_t recursive_depth = 0); static void GetPluginSoFileList(const string &path, vector &file_list, string &caffe_parser_path); static void GetCustomOpPath(std::string &customop_path); void LoadCustomOpLib(); diff --git a/ge/graph/common/ge_call_wrapper.h b/ge/common/ge_call_wrapper.h similarity index 100% rename from ge/graph/common/ge_call_wrapper.h rename to ge/common/ge_call_wrapper.h diff --git a/ge/common/ge_format_util.cc b/ge/common/ge_format_util.cc index d0240224..0ffa686f 100755 --- a/ge/common/ge_format_util.cc +++ b/ge/common/ge_format_util.cc @@ -15,12 +15,10 @@ */ #include "framework/common/ge_format_util.h" -#include "formats/formats.h" +#include "common/formats/formats.h" namespace ge { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status GeFormatUtil::TransShape(const TensorDesc &src_desc, - Format dst_format, - std::vector &dst_shape) { +Status GeFormatUtil::TransShape(const TensorDesc &src_desc, Format dst_format, std::vector &dst_shape) { return formats::TransShape(src_desc.GetFormat(), src_desc.GetShape().GetDims(), src_desc.GetDataType(), dst_format, dst_shape); } diff --git a/ge/common/helper/model_cache_helper.cc b/ge/common/helper/model_cache_helper.cc deleted file mode 100755 index 9cd88ef1..00000000 --- a/ge/common/helper/model_cache_helper.cc +++ /dev/null @@ -1,1714 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/helper/model_cache_helper.h" - -#include -#include -#include - -#include "common/model_parser/model_parser.h" -#include "framework/common/helper/model_helper.h" -#include "graph/detail/model_serialize_imp.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/tensor_utils.h" -#include "init/gelib.h" -#include "proto/ge_ir.pb.h" - -using namespace std; - -namespace { -const char *const kTbeKernelInfoStoreName = "AIcoreEngine"; -const char *const kGraphName = "temp_name"; -// Keys of json -const char *const kNodeNum = "nodeNum"; -const char *const kEdgeNum = "edgeNum"; -const char *const kGraphHash = "graphHash"; -const char *const kNodeHash = "nodeHash"; -const char *const kHash = "hash"; -const char *const kSessionId = "sessionId"; -const char *const kDeviceId = "deviceId"; -const char *const kJobId = "jobId"; -const char *const kGraphMemMaxSize = "graphMemMaxSize"; -const char *const kVarMemMaxSize = "varMemMaxSize"; -const char *const kVarMemLogicBase = "varMemLogicBase"; -const char *const kUseMaxMemSize = "useMaxMemSize"; -const char *const kMemResourceMap = "memResourceMap"; -const char *const kMemType = "memType"; -const char *const kTotalSize = "totalSize"; -const char *const kVarMemSize = "varMemSize"; -const char *const kVarResource = "varResource"; -const char *const kVarAddrMgrMap = "varAddrMgrMap"; -const char *const kName = "name"; -const char *const kAddress = "address"; -const char *const kOffset = "offset"; -const char *const kMemoryType = "memoryType"; -const char *const kTensorDesc = "tensorDesc"; -const char *const kDataType = "dataType"; -const char *const kShape = "shape"; -const char *const kLayout = "layout"; -const char *const kOriginDataType = "originDataType"; -const char *const kOriginShape = "originShape"; -const char *const kOriginLayout = "originLayout"; -const char *const kRealDimCnt = "realDimCnt"; -const char *const kCurVarTensorDescMap = "curVarTensorDescMap"; -const char *const kTransRoads = "transRoads"; -const char *const kTransRoad = "transRoad"; -const char *const kNodeType = "nodeType"; -const char *const kInputTensorDesc = "inputTensorDesc"; -const char *const kOutputTensorDesc = "outputTensorDesc"; -const char *const kChangedGraphId = "changedGraphId"; -const char *const kAllocatedGraphId = "allocatedGraphId"; -const char *const kGraphId = "graphId"; -const char *const kVarBroadcastInfo = "varBroadcastInfo"; -const char *const kBroadcastName = "broadcastName"; -const char *const kIdx = "idx"; -const char *const kInputOffset = "inputOffset"; -const char *const kInputSize = "inputSize"; -const char *const kOutputOffset = "outputOffset"; -const char *const kOutputSize = "outputSize"; -// Suffix of cache files -const char *const kBeforeVarManagerSuffix = "_before_build_var_manager.json"; -const char *const kAfterVarManagerSuffix = "_after_build_var_manager.json"; -const char *const kManifestSuffix = ".manifest"; -const char *const kOmSuffix = ".om"; -} // namespace - -namespace ge { -map ModelCacheHelper::graph_id_run_times_; -ModelCacheHelper::ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph) - : session_id_(session_id), - graph_id_(graph_id), - compute_graph_(compute_graph), - is_cache_path_valid_for_output(false) { - if (graph_id_run_times_.count(graph_id) == 0) { - graph_id_run_times_[graph_id] = 1; - } else { - graph_id_run_times_[graph_id] = graph_id_run_times_[graph_id] + 1; - } - for (const auto &node : compute_graph_->GetDirectNode()) { - bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || - (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP); - if (!is_variable) { - continue; - } - var_names_.insert(node->GetName()); - } - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) { - std::string cache_path = instance_ptr->GetIncreBuildCachePath(); - GELOGD("Incre build path conf: %s", cache_path.c_str()); - string fake_file_path = cache_path + to_string(graph_id_) + kManifestSuffix; - if (CheckOutputPathValid(fake_file_path)) { - is_cache_path_valid_for_output = true; - } else { - GELOGW("Invalid cache path for output."); - } - std::string real_cache_path = RealPath(cache_path.c_str()); - if (real_cache_path.empty()) { - GELOGW("Invalid incre build cache path conf: %s", cache_path.c_str()); - return; - } - cache_path_ = real_cache_path + '/'; - GELOGD("Try to use incre build cache path: %s", cache_path_.c_str()); - } -} - -ModelCacheHelper::~ModelCacheHelper() { var_names_.clear(); } - -bool ModelCacheHelper::IsModelCacheHit() const { - CacheInfo cache_info; - if (GetCacheInfo(cache_info) != SUCCESS) { - GELOGI("Get cache info of graph id[%u] failed.", graph_id_); - return false; - } - // Check number of nodes and edges first. - if (cache_info.node_num != compute_graph_->GetDirectNodesSize()) { - GELOGI("Graph id[%u] cache miss: the node number of the graph does not match the cache info.", graph_id_); - return false; - } - size_t edge_num = 0; - for (const auto &node : compute_graph_->GetDirectNode()) { - for (const auto &anchor : node->GetAllInAnchors()) { - edge_num += anchor->GetPeerAnchors().size(); - } - } - if (cache_info.edge_num != edge_num) { - GELOGI("Graph id[%u] cache miss: the edge number of the graph does not match the cache info.", graph_id_); - return false; - } - size_t compute_graph_hash; - auto ret = GetComputeGraphHash(compute_graph_hash); - if (ret != SUCCESS || cache_info.graph_hash != compute_graph_hash) { - GELOGI("Graph id[%u] cache miss: the hash code of the graph does not match the cache info.", graph_id_); - return false; - } - if (!IsNodeHashSameAsCache(cache_info.nodes_hash)) { - GELOGI("Graph id[%u] cache miss: the hash code of node does not match the cache info.", graph_id_); - return false; - } - - string var_manager_cache = - to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kBeforeVarManagerSuffix; - Json var_manager_json; - if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) { - GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str()); - return false; - } - if (!IsVarManagerSameAsCache(var_manager_json)) { - GELOGI("Graph id[%u] cache miss: the VarManager does not match the cache info.", graph_id_); - return false; - } - GELOGI("Graph id[%u] cache hit.", graph_id_); - return true; -} - -Status ModelCacheHelper::RefreshComputeGraph(const ComputeGraphPtr &compute_graph) { - if (compute_graph->IsValid()) { - compute_graph_ = compute_graph; - var_names_.clear(); - for (const auto &node : compute_graph_->GetDirectNode()) { - bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || - (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP); - if (!is_variable) { - continue; - } - var_names_.insert(node->GetName()); - } - return SUCCESS; - } else { - GELOGW("Invalid compute graph."); - return FAILED; - } -} - -Status ModelCacheHelper::ClearCache(uint32_t graph_id) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return SUCCESS; - } - string manifest_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string manifest_file_path = RealPath(manifest_file.c_str()); - int ret; - if (!manifest_file_path.empty()) { - ret = remove(manifest_file_path.c_str()); - // If remove file failed, print the warning log - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", manifest_file_path.c_str()); - } - } - string before_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string before_var_manager_file_path = RealPath(before_var_manager_file.c_str()); - if (!before_var_manager_file_path.empty()) { - ret = remove(before_var_manager_file_path.c_str()); - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", before_var_manager_file_path.c_str()); - } - } - string after_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string after_var_manager_file_path = RealPath(after_var_manager_file.c_str()); - if (!after_var_manager_file_path.empty()) { - ret = remove(after_var_manager_file_path.c_str()); - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", after_var_manager_file_path.c_str()); - } - } - string om_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string om_file_path = RealPath(om_file.c_str()); - if (!om_file_path.empty()) { - ret = remove(om_file_path.c_str()); - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", om_file_path.c_str()); - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverVarManagerFromCache() const { - string var_manager_cache = - to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kAfterVarManagerSuffix; - Json var_manager_json; - if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) { - GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str()); - return FAILED; - } - - Json mem_resource_json = move(var_manager_json[kMemResourceMap]); - auto ret = RecoverMemResource(mem_resource_json); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[MemResource]"); - return FAILED; - } - Json var_resource_json = move(var_manager_json[kVarResource]); - ret = RecoverAllocatedGraphId(var_resource_json[kAllocatedGraphId]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[AllocatedGraphId]"); - return FAILED; - } - ret = RecoverChangedGraphId(var_resource_json[kChangedGraphId]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[ChangedGraphId]"); - return FAILED; - } - ret = RecoverBroadcastInfo(var_resource_json[kVarBroadcastInfo]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[VarBroadcastInfo]"); - return FAILED; - } - ret = RecoverVarAddrAndTensorDesc(var_resource_json[kVarAddrMgrMap]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[VarAddrMgrMap & CurVarTensorDesc]"); - return FAILED; - } - ret = RecoverTransRoads(var_resource_json[kTransRoads]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[TransRoads]"); - return FAILED; - } - GELOGI("Recover VarManager from cache[%s] success.", cache_path_.c_str()); - return SUCCESS; -} - -Status ModelCacheHelper::GetNodesNeedRecompile(ComputeGraphPtr &graph, vector &nodes) { - std::shared_ptr instance = ge::GELib::GetInstance(); - if (instance == nullptr || !instance->InitFlag()) { - GELOGW("RecompileNodes failed."); - return ge::GE_CLI_GE_NOT_INITIALIZED; - } - // Collect aicore ops for recompile - for (auto &node : graph->GetDirectNode()) { - if (node == nullptr) { - continue; - } - auto op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - // Get op kernel lib name - string kernel_lib_name = op_desc->GetOpKernelLibName(); - if (kernel_lib_name.empty()) { - // reset op kernel lib - (void)instance->DNNEngineManagerObj().GetDNNEngineName(node); - kernel_lib_name = op_desc->GetOpKernelLibName(); - if (kernel_lib_name.empty()) { - GELOGW("Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), op_desc->GetType().c_str()); - continue; - } - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecompileNodes(GeModelPtr &ge_model) { - std::shared_ptr instance = ge::GELib::GetInstance(); - if (instance == nullptr || !instance->InitFlag()) { - GELOGW("RecompileNodes failed."); - return ge::GE_CLI_GE_NOT_INITIALIZED; - } - // Get aicore ops kernel info store. - OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kTbeKernelInfoStoreName); - if (kernel_info == nullptr) { - GELOGW("Get %s ops kernel info store failed", kTbeKernelInfoStoreName); - return INTERNAL_ERROR; - } - - auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); - vector node_vec; - auto ret = GetNodesNeedRecompile(compute_graph, node_vec); - GE_CHK_BOOL_EXEC_WARN(ret == ge::SUCCESS, return ret, "Get nodes need recompiling failed"); - // Recompile aicore ops - ret = kernel_info->CompileOp(node_vec); - GE_CHK_BOOL_EXEC_WARN(ret == ge::SUCCESS, return ret, "Recompile op failed"); - const TBEKernelStore &tbekernel_store = ge_model->GetTBEKernelStore(); - TBEKernelStore tbe_kernel_store; - for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { - auto node_op_desc = n->GetOpDesc(); - GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); - TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); - if (tbe_kernel == nullptr) { - // Load tbe kernel from tbe_kernel_store to op if op was not recompiled - auto op_desc = n->GetOpDesc(); - tbekernel_store.LoadTBEKernelBinToOpDesc(op_desc); - GELOGD("LoadOmModelFromCache: Load tbe kernel bin to op desc[%s].", op_desc->GetName().c_str()); - } - tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); - GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); - // Refresh tbe kernel in tbe_kernel_store - tbe_kernel_store.AddTBEKernel(tbe_kernel); - GELOGD("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); - } - GE_CHK_BOOL_EXEC_WARN(tbe_kernel_store.Build(), return FAILED, "TBE Kernels store build failed!"); - ge_model->SetTBEKernelStore(tbe_kernel_store); - return SUCCESS; -} - -Status ModelCacheHelper::GetNodesHash(map &hash_map) const { - vector nodes; - GraphUtils::TopologicalSortingByName(compute_graph_, nodes); - ModelSerializeImp model_serialize_imp; - std::hash node_hash; - for (const auto &node : nodes) { - if (node == nullptr) { - continue; - } - proto::OpDef op_def; - bool is_framework_op = (node->GetType() == FRAMEWORKOP); - int32_t framework_type = 0; - if (is_framework_op) { - AttrUtils::GetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, framework_type); - AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, 0); - } - bool ret = model_serialize_imp.SerializeNode(node, &op_def, is_framework_op); - op_def.set_id(0); // Id of op is not stable because of parallel parsing - // Clear weights attr in constant. - auto attr = op_def.mutable_attr(); - if (op_def.type() == CONSTANT || op_def.type() == CONSTANTOP) { - attr->erase(ATTR_NAME_WEIGHTS); - } - if (is_framework_op) { - AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, framework_type); - } - if (!ret) { - GELOGW("Fail to serialize node[%s].", node->GetName().c_str()); - return INTERNAL_ERROR; - } - string prototxt; - ret = google::protobuf::TextFormat::PrintToString(op_def, &prototxt); - if (!ret) { - GELOGW("Print OpDef to string failed."); - hash_map.clear(); - return INTERNAL_ERROR; - } - size_t hash_code = node_hash(prototxt); - hash_map[node->GetName()] = hash_code; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetComputeGraphHash(size_t &hash) const { - proto::GraphDef graph_proto; - ModelSerializeImp model_serialize_imp; - // The name of compute graph may be generated randomly, so replace it temporarily. - const string origin_name = compute_graph_->GetName(); - compute_graph_->SetName(kGraphName); - bool serialize_ret = model_serialize_imp.SerializeGraph(compute_graph_, &graph_proto); - graph_proto.clear_op(); - if (!serialize_ret) { - GELOGW("Serialize graph failed."); - hash = 0; - return INTERNAL_ERROR; - } - compute_graph_->SetName(origin_name); - // Generate proto text of GraphDef - string prototxt; - bool print_ret = google::protobuf::TextFormat::PrintToString(graph_proto, &prototxt); - if (!print_ret) { - GELOGW("Print GraphDef to string failed."); - hash = 0; - return INTERNAL_ERROR; - } - // Get the hash code of proto text - std::hash graph_hash; - hash = graph_hash(prototxt); - return SUCCESS; -} - -Status ModelCacheHelper::SaveJsonToFile(const string &file_name, const Json &json) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return PARAM_INVALID; - } - // Check whether the manifest exists, if not, create it. - string real_path = RealPath(cache_path_.c_str()); - if (real_path.empty()) { - GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str()); - return FAILED; - } - const string path = cache_path_ + file_name; - const int FILE_AUTHORITY = 0600; - int fd = mmOpen2(path.c_str(), M_WRONLY | M_CREAT | O_TRUNC, FILE_AUTHORITY); - if (fd < 0) { - GELOGW("Fail to open the file:%s. errmsg:%s", path.c_str(), strerror(errno)); - return INTERNAL_ERROR; - } - if (mmClose(fd) != 0) { - GELOGW("Fail to close the file:%s. errmsg:%s", path.c_str(), strerror(errno)); - return INTERNAL_ERROR; - } - - // Write json into cache file - ofstream ofs; - ofs.open(path); - if (!ofs.is_open()) { - GELOGW("Fail to open the file: %s.", path.c_str()); - return INTERNAL_ERROR; - } - ofs << json << std::endl; - ofs.close(); - return SUCCESS; -} - -Status ModelCacheHelper::LoadJsonFromFile(const string &file_name, Json &json) const { - if (!json.is_null()) { - GELOGW("Input param json type should be null."); - return PARAM_INVALID; - } - string real_path = RealPath(cache_path_.c_str()); - if (real_path.empty()) { - GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str()); - return FAILED; - } - const string path = cache_path_ + file_name; - if (!CheckInputPathValid(path)) { - GELOGW("Invalid cache path for input:%s.", path.c_str()); - return FAILED; - } - string cache_real_path = RealPath(path.c_str()); - if (cache_real_path.empty()) { - GELOGI("File[%s] is not found.", path.c_str()); - return FAILED; - } - // Read json from cache file - ifstream ifs; - ifs.open(path); - if (!ifs.is_open()) { - GELOGW("Fail to open the file: %s.", path.c_str()); - return INTERNAL_ERROR; - } - try { - ifs >> json; - } catch (nlohmann::detail::parse_error e) { - GELOGW("Fail to load json from file, json throw an error:%s.", e.what()); - return INTERNAL_ERROR; - } catch (nlohmann::detail::invalid_iterator e) { - GELOGW("Fail to load json from file, json throw an error:%s.", e.what()); - return INTERNAL_ERROR; - } catch (nlohmann::detail::type_error e) { - GELOGW("Fail to load json from file, json throw an error:%s.", e.what()); - return INTERNAL_ERROR; - } catch (nlohmann::detail::out_of_range e) { - GELOGW("Fail to load json from file, json throw an error:%s.", e.what()); - return INTERNAL_ERROR; - } catch (nlohmann::detail::other_error e) { - GELOGW("Fail to load json from file, json throw an error:%s.", e.what()); - return INTERNAL_ERROR; - } - - if (!json.is_object()) { - GELOGW("Fail to load the json file: %s.", path.c_str()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::SaveCacheInfoToCache() const { - // Generate cache json - // example: {"edgeNum":6,"nodeNum":7,"graphCache":134714827475991356} - Json cache_json; - try { - cache_json[kNodeNum] = compute_graph_->GetDirectNodesSize(); - size_t edge_num = 0; - for (const auto &node : compute_graph_->GetDirectNode()) { - for (const auto &anchor : node->GetAllInAnchors()) { - edge_num += anchor->GetPeerAnchors().size(); - } - } - cache_json[kEdgeNum] = edge_num; - size_t hash = 0; - auto ret = GetComputeGraphHash(hash); - if (ret != SUCCESS) { - GELOGW("Error occur when generate graph hash code."); - return ret; - } - cache_json[kGraphHash] = hash; - Json nodes_hash_json; - ret = GetNodesHashMapJson(nodes_hash_json); - if (ret != SUCCESS) { - GELOGW("Error occur when generate nodes hash code."); - return ret; - } - cache_json[kNodeHash] = nodes_hash_json; - } catch (const std::exception &e) { - GELOGW("Fail to generate cache info json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - string cache_manifest = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix; - - auto ret = SaveJsonToFile(cache_manifest, cache_json); - if (ret != SUCCESS) { - GELOGW("Fail to save cache info to json file, path: %s.", cache_path_.c_str()); - return ret; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetCacheInfo(CacheInfo &cache_info) const { - string cache_manifest = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix; - Json cache_json; - if (LoadJsonFromFile(cache_manifest, cache_json) != SUCCESS) { - GELOGW("Fail to load json from cache file: %s", cache_manifest.c_str()); - return INTERNAL_ERROR; - } - if (!cache_json.is_object()) { - GELOGW("Manifest should be a json object"); - return INTERNAL_ERROR; - } - try { - cache_info.node_num = cache_json[kNodeNum]; - cache_info.edge_num = cache_json[kEdgeNum]; - cache_info.graph_hash = cache_json[kGraphHash]; - Json nodes_hash_json = cache_json[kNodeHash]; - if (!(nodes_hash_json.is_null() || nodes_hash_json.is_array())) { - GELOGW("Nodes hash in cache should be null or array."); - return FAILED; - } - for (const auto &iter : nodes_hash_json) { - cache_info.nodes_hash[iter[kName].get()] = iter[kHash].get(); - } - } catch (const std::exception &e) { - GELOGW("Fail to get info from json file. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -bool ModelCacheHelper::IsAllocatedGraphIdSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare allocated graph id info between json and VarManager - std::map allocated_graph_id; - auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse AllocatedGraphId from Json."); - return false; - } - for (const auto &iter : allocated_graph_id) { - uint32_t graph_id = 0; - ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(iter.first, graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to find allocated graph id of var[%s].", iter.first.c_str()); - return false; - } - if (graph_id != iter.second) { - GELOGW("The allocated graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsNodeHashSameAsCache(const map &hash_map) const { - map cur_hash_map; - GetNodesHash(cur_hash_map); - if (hash_map.size() != cur_hash_map.size()) { - GELOGI("The number of hash code is different from cache info."); - return false; - } - for (const auto &iter : cur_hash_map) { - if (hash_map.count(iter.first) == 0) { - GELOGI("Node[%s] is not found in cache info.", iter.first.c_str()); - return false; - } - if (hash_map.at(iter.first) != iter.second) { - GELOGI("The hash code of node[%s] is different from cache info.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsMemResourceSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare var mem size info between json and VarManager - std::map var_mem_size; - auto ret = ParseMemResourceFromJson(json, var_mem_size); - if (ret != SUCCESS) { - GELOGW("Fail to parse MemResource from Json."); - return false; - } - for (const auto &iter : var_mem_size) { - int64_t mem_size = VarManager::Instance(session_id_)->GetVarMemSize(iter.first); - if (mem_size != iter.second) { - GELOGW("The var mem size of memory_type[%u] in cache is different from VarManager.", iter.first); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsChangedGraphIdSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare variable changed graph id info between json and VarManager - std::map changed_graph_id; - auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse ChangedGraphId from Json."); - return false; - } - for (const auto &iter : changed_graph_id) { - uint32_t graph_id = 0; - ret = VarManager::Instance(session_id_)->GetChangedGraphId(iter.first, graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to find changed graph id of var[%s].", iter.first.c_str()); - return false; - } - if (graph_id != iter.second) { - GELOGW("The changed graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsCurVarTensorDescSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare variable tensor desc info between json and VarManager - std::unordered_map cur_var_tensor_desc; - auto ret = ParseCurVarTensorDescMapFromJson(json, cur_var_tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to parse CurVarTensorDesc from Json."); - return false; - } - for (const auto &iter : cur_var_tensor_desc) { - GeTensorDesc tensor_desc; - ret = VarManager::Instance(session_id_)->GetCurVarDesc(iter.first, tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str()); - return false; - } - uint32_t l_real_dim_cnt = 0; - uint32_t r_real_dim_cnt = 0; - TensorUtils::GetRealDimCnt(tensor_desc, l_real_dim_cnt); - TensorUtils::GetRealDimCnt(iter.second, r_real_dim_cnt); - if ((tensor_desc.GetDataType() != iter.second.GetDataType()) || - (tensor_desc.GetOriginDataType() != iter.second.GetOriginDataType()) || - (tensor_desc.GetFormat() != iter.second.GetFormat()) || - (tensor_desc.GetOriginFormat() != iter.second.GetOriginFormat()) || - (tensor_desc.GetShape().ToString() != iter.second.GetShape().ToString()) || - (tensor_desc.GetOriginShape().ToString() != iter.second.GetOriginShape().ToString()) || - (l_real_dim_cnt != r_real_dim_cnt)) { - GELOGW("The var tensor desc of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsVarAddrMgrMapSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare variable address info between json and VarManager - std::vector> var_addr_mgr_vector; - std::set var_offset_set; - auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set); - if (ret != SUCCESS) { - GELOGW("Fail to parse VarAddrMgrMap from Json."); - return false; - } - for (const auto &iter : var_addr_mgr_vector) { - uint8_t *dev_ptr = nullptr; - rtMemType_t memory_type; - ret = VarManager::Instance(session_id_)->GetVarAddr(iter.first, iter.second.tensor_desc, &dev_ptr, memory_type); - if (ret != SUCCESS) { - GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str()); - return false; - } - // Compare memory type and logic address - if (iter.second.memory_type != memory_type || iter.second.address != dev_ptr) { - GELOGW("The VarAddrMgr of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsBroadcastInfoSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare broadcast info between json and VarManager - std::unordered_map var_broadcast_info; - auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info); - if (ret != SUCCESS) { - GELOGW("Fail to parse BroadcastInfo from Json."); - return false; - } - for (const auto &iter : var_broadcast_info) { - VarBroadCastInfo broadcast_info; - if (VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, iter.first, broadcast_info) != SUCCESS) { - GELOGW("Fail to find broadcast info of var[%s].", iter.first.c_str()); - return false; - } - if (iter.second.var_name != broadcast_info.var_name || iter.second.idx != broadcast_info.idx || - iter.second.input_size != broadcast_info.input_size || - iter.second.input_offset != broadcast_info.input_offset || - iter.second.output_size != broadcast_info.output_size || - iter.second.output_offset != broadcast_info.output_offset) { - GELOGW("The BroadcastInfo of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsTransRoadsSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare trans road between json and VarManager - std::unordered_map> trans_roads; - auto ret = ParseTransRoadsFromJson(json, trans_roads); - if (ret != SUCCESS) { - GELOGW("Fail to parse TransRoads from Json."); - return false; - } - for (const auto &iter : trans_roads) { - VarTransRoad *trans_road; - trans_road = VarManager::Instance(session_id_)->GetTransRoad(iter.first); - if (trans_road == nullptr) { - GELOGW("Fail to find trans road of var[%s].", iter.first.c_str()); - return false; - } - if (trans_road->size() != iter.second.size()) { - GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - // Compare every trans node in trans road. - for (size_t idx = 0; idx < trans_road->size(); idx += 1) { - if (!(trans_road->at(idx).node_type == iter.second.at(idx).node_type && - trans_road->at(idx).input == iter.second.at(idx).input && - trans_road->at(idx).output == iter.second.at(idx).output)) { - GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - } - return true; -} - -bool ModelCacheHelper::IsVarManagerParamSameAsCache(Json &json) const { - if (!json.is_object()) { - GELOGW("Input param json type should be object."); - return false; - } - try { - if (json[kSessionId].get() != session_id_) { - GELOGW("Check VarManager cache failed.[sessionId]"); - return false; - } - if (json[kDeviceId].get() != VarManager::Instance(session_id_)->DeviceId()) { - GELOGW("Check VarManager cache failed.[deviceId]"); - return false; - } - if (json[kJobId].get() != VarManager::Instance(session_id_)->JobId()) { - GELOGW("Check VarManager cache failed.[jobId]"); - return false; - } - if (json[kGraphMemMaxSize].get() != VarManager::Instance(session_id_)->GetGraphMemoryMaxSize()) { - GELOGW("Check VarManager cache failed.[graphMemMaxSize]"); - return false; - } - if (json[kVarMemMaxSize].get() != VarManager::Instance(session_id_)->GetVarMemMaxSize()) { - GELOGW("Check VarManager cache failed.[varMemMaxSize]"); - return false; - } - if (json[kVarMemLogicBase].get() != VarManager::Instance(session_id_)->GetVarMemLogicBase()) { - GELOGW("Check VarManager cache failed.[varMemLogicBase]"); - return false; - } - if (json[kUseMaxMemSize].get() != VarManager::Instance(session_id_)->GetUseMaxMemorySize()) { - GELOGW("Check VarManager cache failed.[useMaxMemSize]"); - return false; - } - } catch (const std::exception &e) { - GELOGW("Fail to check VarManager json. Error message: %s", e.what()); - return false; - } - return true; -} - -bool ModelCacheHelper::IsVarManagerSameAsCache(Json &json) const { - if (!json.is_object()) { - GELOGW("Input param json type should be object."); - return false; - } - try { - if (!IsVarManagerParamSameAsCache(json)) { - GELOGW("Check VarManager cache failed.[Param]"); - return false; - } - Json mem_resource_json = move(json[kMemResourceMap]); - auto ret = IsMemResourceSameAsCache(mem_resource_json); - if (!ret) { - GELOGW("Check VarManager cache failed.[MemResource]"); - return false; - } - Json var_resource_json = move(json[kVarResource]); - ret = IsAllocatedGraphIdSameAsCache(var_resource_json[kAllocatedGraphId]); - if (!ret) { - GELOGW("Check VarManager cache failed.[AllocatedGraphId]"); - return false; - } - ret = IsChangedGraphIdSameAsCache(var_resource_json[kChangedGraphId]); - if (!ret) { - GELOGW("Check VarManager cache failed.[ChangedGraphId]"); - return false; - } - ret = IsBroadcastInfoSameAsCache(var_resource_json[kVarBroadcastInfo]); - if (!ret) { - GELOGW("Check VarManager cache failed.[VarBroadcastInfo]"); - return false; - } - ret = IsCurVarTensorDescSameAsCache(var_resource_json[kCurVarTensorDescMap]); - if (!ret) { - GELOGW("Check VarManager cache failed.[CurVarTensorDesc]"); - return false; - } - ret = IsVarAddrMgrMapSameAsCache(var_resource_json[kVarAddrMgrMap]); - if (!ret) { - GELOGW("Check VarManager cache failed.[VarAddrMgrMap]"); - return false; - } - ret = IsTransRoadsSameAsCache(var_resource_json[kTransRoads]); - if (!ret) { - GELOGW("Check VarManager cache failed.[TransRoads]"); - return false; - } - } catch (const std::exception &e) { - GELOGW("Fail to check VarManager json. Error message: %s", e.what()); - return false; - } - return true; -} - -Status ModelCacheHelper::RecoverMemResource(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::map var_mem_size; - auto ret = ParseMemResourceFromJson(json, var_mem_size); - if (ret != SUCCESS) { - GELOGW("Fail to parse MemResource from Json."); - return ret; - } - for (const auto &iter : var_mem_size) { - ret = VarManager::Instance(session_id_)->UpdateVarMemSize(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover var mem size."); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverAllocatedGraphId(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::map allocated_graph_id; - auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse AllocatedGraphId from Json."); - return ret; - } - for (const auto &iter : allocated_graph_id) { - ret = VarManager::Instance(session_id_)->SetAllocatedGraphId(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover allocated graph id."); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverChangedGraphId(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::map changed_graph_id; - auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse AllocatedGraphId from Json."); - return ret; - } - for (const auto &iter : changed_graph_id) { - ret = VarManager::Instance(session_id_)->SetChangedGraphId(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover changed graph id."); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverVarAddrAndTensorDesc(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::vector> var_addr_mgr_vector; - std::set var_offset_set; - auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set); - if (ret != SUCCESS) { - GELOGW("Fail to parse VarAddrMgrMap from Json."); - return ret; - } - for (const auto &iter : var_addr_mgr_vector) { - const VarAddrMgr &tensor_addr_mgr = iter.second; - const bool var_exist = VarManager::Instance(session_id_)->IsVarExist(iter.first, tensor_addr_mgr.tensor_desc); - // SaveVarVddr if var does not exist, the logic address will be recorded by VarManager - if (!var_exist) { - auto logic_address = reinterpret_cast(reinterpret_cast(tensor_addr_mgr.address)); - auto offset = (tensor_addr_mgr.offset); - // Check logic address and offset - if (logic_address - offset != VarManager::Instance(session_id_)->GetVarMemLogicBase()) { - GELOGW("Check logic_address[%lu] and offset [%lu] of %s failed, var mem logic base is %lu, abandon", - logic_address, offset, iter.first.c_str(), VarManager::Instance(session_id_)->GetVarMemLogicBase()); - return PARAM_INVALID; - } - // Offset is needed by SaveVarVddr instead of logic address - ret = VarManager::Instance(session_id_)->SaveVarAddr(iter.first, tensor_addr_mgr.tensor_desc, - reinterpret_cast(reinterpret_cast(offset)), - tensor_addr_mgr.memory_type); - if (ret != SUCCESS) { - GELOGW("Fail to recover VarAddr or TensorDesc of var[%s].", iter.first.c_str()); - return ret; - } - } - // SetVarAddr to update cur_var_tensor_desc_map_ - ret = VarManager::Instance(session_id_) - ->SetVarAddr(iter.first, tensor_addr_mgr.tensor_desc, tensor_addr_mgr.address, tensor_addr_mgr.memory_type); - if (ret != SUCCESS) { - GELOGW("Fail to recover VarAddr or TensorDesc desc of var[%s].", iter.first.c_str()); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverBroadcastInfo(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::unordered_map var_broadcast_info; - auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info); - if (ret != SUCCESS) { - GELOGW("Fail to parse BroadcastInfo from Json."); - return ret; - } - for (const auto &iter : var_broadcast_info) { - VarBroadCastInfo broadcast_info; - ret = VarManager::Instance(session_id_)->SaveBroadCastInfo(graph_id_, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover broadcast info of var[%s].", iter.first.c_str()); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverTransRoads(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::unordered_map> trans_roads; - auto ret = ParseTransRoadsFromJson(json, trans_roads); - if (ret != SUCCESS) { - GELOGW("Fail to parse TransRoads from Json."); - return ret; - } - for (const auto &iter : trans_roads) { - ret = VarManager::Instance(session_id_)->SetTransRoad(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to find trans road of var[%s].", iter.first.c_str()); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json) { - if (!(json.is_null() || json.is_object())) { - GELOGW("Input param json type should be null or object."); - return PARAM_INVALID; - } - try { - json[kDataType] = static_cast(ge_tensor_desc.GetDataType()); - json[kOriginDataType] = static_cast(ge_tensor_desc.GetOriginDataType()); - json[kLayout] = static_cast(ge_tensor_desc.GetFormat()); - json[kOriginLayout] = static_cast(ge_tensor_desc.GetOriginFormat()); - json[kShape] = ge_tensor_desc.GetShape().GetDims(); - json[kOriginShape] = ge_tensor_desc.GetOriginShape().GetDims(); - uint32_t real_dim_cnt = 0; - (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value] - json[kRealDimCnt] = real_dim_cnt; - } catch (const std::exception &e) { - GELOGW("Fail to trans GeTensorDesc to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::JsonToTensorDesc(const Json &json, ge::GeTensorDesc &ge_tensor_desc) { - if (!json.is_object()) { - GELOGW("Input param json type should be object."); - return PARAM_INVALID; - } - try { - ge_tensor_desc.SetDataType(static_cast(json[kDataType].get())); - ge_tensor_desc.SetOriginDataType(static_cast(json[kOriginDataType].get())); - ge_tensor_desc.SetFormat(static_cast(json[kLayout].get())); - ge_tensor_desc.SetOriginFormat(static_cast(json[kOriginLayout].get())); - GeShape shape(json[kShape].get>()); - ge_tensor_desc.SetShape(shape); - GeShape origin_shape(json[kOriginShape].get>()); - ge_tensor_desc.SetOriginShape(origin_shape); - auto real_dim_cnt = json[kRealDimCnt].get(); - (void)TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value] - } catch (const std::exception &e) { - GELOGW("Fail to trans Json to GeTensorDesc. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetNodesHashMapJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - map hash_map; - GetNodesHash(hash_map); - for (const auto &iter : hash_map) { - Json node_hash_json; - try { - node_hash_json[kName] = iter.first; - node_hash_json[kHash] = iter.second; - json.emplace_back(move(node_hash_json)); - } catch (const std::exception &e) { - GELOGW("Fail to trans node cache to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetMemResourceMap(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - const auto total_size = VarManager::Instance(session_id_)->GetVarMemMaxSize(); - const auto var_mem_size = VarManager::Instance(session_id_)->GetVarMemSize(RT_MEMORY_HBM); - Json mem_resource_json; - try { - mem_resource_json[kMemType] = RT_MEMORY_HBM; - mem_resource_json[kTotalSize] = total_size; - mem_resource_json[kVarMemSize] = var_mem_size; - json.emplace_back(move(mem_resource_json)); - } catch (const std::exception &e) { - GELOGW("Fail to trans MemResourceMap to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarAddrMgrMapJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::unordered_map var_addr_mgr_map; - VarManager::Instance(session_id_)->GetAllVarAddrMgr(var_addr_mgr_map); - try { - for (const auto &iter : var_addr_mgr_map) { - Json var_addr_json; - string name; - GetVarNameFromVarKey(iter.first, iter.second.tensor_desc, name); - var_addr_json[kName] = name; - var_addr_json[kAddress] = static_cast(reinterpret_cast(iter.second.address)); - var_addr_json[kMemoryType] = iter.second.memory_type; - var_addr_json[kOffset] = iter.second.offset; - - // Copy tensor desc to json. - Json tensor_desc_json; - auto ret = TensorDescToJson(iter.second.tensor_desc, tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - var_addr_json[kTensorDesc] = move(tensor_desc_json); - - json.emplace_back(move(var_addr_json)); - } - } catch (const std::exception &e) { - GELOGW("Fail to trans VarAddrMgrMap to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetCurVarTensorDescMapJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - try { - for (const auto &name : var_names_) { - Json cur_tensor_desc_json; - GeTensorDesc tensor_desc; - auto ret = VarManager::Instance(session_id_)->GetCurVarDesc(name, tensor_desc); - if (ret != SUCCESS) { - GELOGI("Get variable[%s] current tensor desc failed. It will be skipped.", name.c_str()); - continue; - } - cur_tensor_desc_json[kName] = name; - - Json tensor_desc_json; - ret = TensorDescToJson(tensor_desc, tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - cur_tensor_desc_json[kTensorDesc] = move(tensor_desc_json); - json.emplace_back(move(cur_tensor_desc_json)); - } - } catch (const std::exception &e) { - GELOGW("Fail to trans CurVarTensorDescMap to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetTransRoadsJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - try { - for (const auto &name : var_names_) { - auto trans_road = VarManager::Instance(session_id_)->GetTransRoad(name); - if (trans_road == nullptr) { - continue; - } - // Json object, variable name and trans road - Json trans_road_map_json; - trans_road_map_json[kName] = name; - - Json trans_road_json; - Status ret; - // Add nodes' info to json - for (const auto &trans_node_info : *trans_road) { - Json trans_node_info_json; - trans_node_info_json[kNodeType] = trans_node_info.node_type; - Json input_tensor_desc_json; - ret = TensorDescToJson(trans_node_info.input, input_tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - trans_node_info_json[kInputTensorDesc] = move(input_tensor_desc_json); - Json output_tensor_desc_json; - ret = TensorDescToJson(trans_node_info.output, output_tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - trans_node_info_json[kOutputTensorDesc] = move(output_tensor_desc_json); - trans_road_json.emplace_back(move(trans_node_info_json)); - } - trans_road_map_json[kTransRoad] = move(trans_road_json); - json.emplace_back(move(trans_road_map_json)); - } - } catch (const std::exception &e) { - GELOGW("Fail to trans VarToTransRoad to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetChangedGraphIdJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const auto &name : var_names_) { - uint32_t changed_graph_id = 0; - Status ret = VarManager::Instance(session_id_)->GetChangedGraphId(name, changed_graph_id); - if (ret != SUCCESS) { - continue; - } - Json name_and_changed_graph_id; - try { - name_and_changed_graph_id[kName] = name; - name_and_changed_graph_id[kGraphId] = changed_graph_id; - json.emplace_back(move(name_and_changed_graph_id)); - } catch (const std::exception &e) { - GELOGW("Fail to trans ChangedGraphId to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetAllocatedGraphIdJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const auto &name : var_names_) { - uint32_t allocated_graph_id = 0; - Status ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(name, allocated_graph_id); - if (ret != SUCCESS) { - continue; - } - Json name_and_allocated_graph_id; - try { - name_and_allocated_graph_id[kName] = name; - name_and_allocated_graph_id[kGraphId] = allocated_graph_id; - json.emplace_back(move(name_and_allocated_graph_id)); - } catch (const std::exception &e) { - GELOGW("Fail to trans AllocatedGraphId to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetBroadcastInfoJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const auto &name : var_names_) { - VarBroadCastInfo var_broadcast_info; - Status ret = VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, name, var_broadcast_info); - if (ret != SUCCESS) { - continue; - } - Json var_broadcast_info_json; - try { - var_broadcast_info_json[kName] = name; - var_broadcast_info_json[kBroadcastName] = var_broadcast_info.broadcast_name; - var_broadcast_info_json[kIdx] = var_broadcast_info.idx; - var_broadcast_info_json[kInputOffset] = var_broadcast_info.input_offset; - var_broadcast_info_json[kInputSize] = var_broadcast_info.input_size; - var_broadcast_info_json[kOutputOffset] = var_broadcast_info.output_offset; - var_broadcast_info_json[kOutputSize] = var_broadcast_info.output_size; - json.emplace_back(move(var_broadcast_info_json)); - } catch (const std::exception &e) { - GELOGW("Fail to trans VarBroadcastInfo to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarResourceJson(Json &json) const { - if (!(json.is_null() || json.is_object())) { - GELOGW("Input param json type should be null or object."); - return PARAM_INVALID; - } - Json var_addr_mgr_map_json; - Status ret = GetVarAddrMgrMapJson(var_addr_mgr_map_json); - if (ret != SUCCESS) { - GELOGW("GetVarAddrMgrMapJson failed."); - return INTERNAL_ERROR; - } - - Json cur_var_tensor_desc_map_json; - ret = GetCurVarTensorDescMapJson(cur_var_tensor_desc_map_json); - if (ret != SUCCESS) { - GELOGW("GetCurVarTensorDescMapJson failed."); - return INTERNAL_ERROR; - } - - Json trans_roads_json; - ret = GetTransRoadsJson(trans_roads_json); - if (ret != SUCCESS) { - GELOGW("GetTransRoadsJson failed."); - return INTERNAL_ERROR; - } - - Json changed_graph_id_json; - ret = GetChangedGraphIdJson(changed_graph_id_json); - if (ret != SUCCESS) { - GELOGW("GetChangedGraphIdJson failed."); - return INTERNAL_ERROR; - } - - Json allocated_graph_id_json; - ret = GetAllocatedGraphIdJson(allocated_graph_id_json); - if (ret != SUCCESS) { - GELOGW("GetAllocatedGraphIdJson failed."); - return INTERNAL_ERROR; - } - - Json var_broadcast_info_json; - ret = GetBroadcastInfoJson(var_broadcast_info_json); - if (ret != SUCCESS) { - GELOGW("GetBroadcastInfoJson failed."); - return INTERNAL_ERROR; - } - - try { - json[kVarAddrMgrMap] = move(var_addr_mgr_map_json); - json[kCurVarTensorDescMap] = move(cur_var_tensor_desc_map_json); - json[kTransRoads] = move(trans_roads_json); - json[kChangedGraphId] = move(changed_graph_id_json); - json[kAllocatedGraphId] = move(allocated_graph_id_json); - json[kVarBroadcastInfo] = move(var_broadcast_info_json); - } catch (const exception &e) { - GELOGW("Fail to generate VarResource json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarManagerJson(Json &json) const { - if (!(json.is_null() || json.is_object())) { - GELOGW("Input param json type should be null or object."); - return PARAM_INVALID; - } - - Json mem_resource_map_json; - auto ret = GetMemResourceMap(mem_resource_map_json); - if (ret != SUCCESS) { - GELOGW("GetMemResourceMap failed."); - return INTERNAL_ERROR; - } - - Json var_resource_json; - ret = GetVarResourceJson(var_resource_json); - if (ret != SUCCESS) { - GELOGW("GetVarResourceJson failed."); - return INTERNAL_ERROR; - } - - try { - json[kSessionId] = session_id_; - json[kDeviceId] = VarManager::Instance(session_id_)->DeviceId(); - json[kJobId] = VarManager::Instance(session_id_)->JobId(); - json[kGraphMemMaxSize] = VarManager::Instance(session_id_)->GetGraphMemoryMaxSize(); - json[kVarMemMaxSize] = VarManager::Instance(session_id_)->GetVarMemMaxSize(); - json[kVarMemLogicBase] = VarManager::Instance(session_id_)->GetVarMemLogicBase(); - json[kUseMaxMemSize] = VarManager::Instance(session_id_)->GetUseMaxMemorySize(); - json[kMemResourceMap] = move(mem_resource_map_json); - json[kVarResource] = move(var_resource_json); - } catch (const exception &e) { - GELOGW("Fail to generate VarManager json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::SaveVarManagerToCache(bool before_build) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return FAILED; - } - Json var_manager_json; - auto ret = GetVarManagerJson(var_manager_json); - if (ret != SUCCESS) { - GELOGW("Fail to generate VarManager json."); - return FAILED; - } - string var_manager_path = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + - (before_build ? kBeforeVarManagerSuffix : kAfterVarManagerSuffix); - ret = SaveJsonToFile(var_manager_path, var_manager_json); - if (ret != SUCCESS) { - GELOGW("Fail to save VarManager info to json file, path: %s.", cache_path_.c_str()); - return ret; - } - return SUCCESS; -} - -Status ModelCacheHelper::SaveOmModelToCache(const GeModelPtr &ge_model) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return FAILED; - } - string om_path = RealPath(cache_path_.c_str()); - if (om_path.empty()) { - GELOGW("file path is invalid. please check path om: %s", cache_path_.c_str()); - return FAILED; - } - string cache_om_path = cache_path_; - cache_om_path += (to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix); - GELOGI("SaveOmModelToCache: start to save om model : %s", cache_om_path.c_str()); - ModelHelper model_helper; - SaveParam save_param; - ModelBufferData model; - Status ret = model_helper.SaveToOmModel(ge_model, save_param, cache_om_path, model); - if (ret != SUCCESS) { - GELOGW("SaveOmModelToCache: save mode failed. ret = %u", ret); - return ret; - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseMemResourceFromJson(const Json &json, map &mem_resource) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - mem_resource.clear(); - for (const Json &mem_resource_json : json) { - try { - rtMemType_t mem_type = mem_resource_json[kMemType].get(); - uint64_t var_mem_size = mem_resource_json[kVarMemSize].get(); - mem_resource[mem_type] = var_mem_size; - } catch (const exception &e) { - GELOGW("Fail to trans Json to MemResource. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseVarAddrMgrMapFromJson( - const Json &json, std::vector> &var_addr_mgr_vector, - std::set &var_offset_set) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - var_addr_mgr_vector.clear(); - var_offset_set.clear(); - for (const Json &var_addr_json : json) { - VarAddrMgr var_addr_mgr; - try { - auto logic_address = var_addr_json[kAddress].get(); - auto address = reinterpret_cast(reinterpret_cast(logic_address)); - var_addr_mgr.address = address; - var_addr_mgr.offset = var_addr_json[kOffset].get(); - var_addr_mgr.memory_type = var_addr_json[kMemoryType].get(); - auto ret = JsonToTensorDesc(var_addr_json[kTensorDesc], var_addr_mgr.tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - var_addr_mgr_vector.emplace_back(var_addr_json[kName].get(), move(var_addr_mgr)); - var_offset_set.insert(logic_address); - } catch (const exception &e) { - GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseCurVarTensorDescMapFromJson( - const Json &json, std::unordered_map &cur_var_tensor_desc_map) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - cur_var_tensor_desc_map.clear(); - for (const Json &tensor_desc_json : json) { - GeTensorDesc tensor_desc; - try { - auto ret = JsonToTensorDesc(tensor_desc_json[kTensorDesc], tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - cur_var_tensor_desc_map[tensor_desc_json[kName].get()] = move(tensor_desc); - } catch (const exception &e) { - GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseTransRoadsFromJson( - const Json &json, std::unordered_map> &trans_roads) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - trans_roads.clear(); - try { - for (const Json &name_trans_road_json : json) { - const Json &trans_road_json = name_trans_road_json[kTransRoad]; - if (!(trans_road_json.is_array() || trans_road_json.is_null())) { - GELOGW("%s json type should be null or object.", kTransRoad); - return PARAM_INVALID; - } - vector trans_road; - for (const Json &trans_node_json : trans_road_json) { - TransNodeInfo trans_node_info; - trans_node_info.node_type = trans_node_json[kNodeType]; - GeTensorDesc input_tensor_desc; - auto ret = JsonToTensorDesc(trans_node_json[kInputTensorDesc], input_tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - trans_node_info.input = move(input_tensor_desc); - GeTensorDesc output_tensor_desc; - ret = JsonToTensorDesc(trans_node_json[kOutputTensorDesc], output_tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - trans_node_info.output = move(output_tensor_desc); - trans_road.emplace_back(move(trans_node_info)); - } - trans_roads[name_trans_road_json[kName].get()] = move(trans_road); - } - } catch (const exception &e) { - GELOGW("Fail to trans Json to TransRoads. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseChangedGraphIdFromJson(const Json &json, - std::map &changed_graph_id) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - changed_graph_id.clear(); - for (const Json &name_graph_id_json : json) { - try { - changed_graph_id[name_graph_id_json[kName].get()] = name_graph_id_json[kGraphId].get(); - } catch (const exception &e) { - GELOGW("Fail to trans Json to changed graph id. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseAllocatedGraphIdFromJson(const Json &json, - std::map &allocated_graph_id) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - allocated_graph_id.clear(); - for (const Json &name_graph_id_json : json) { - try { - allocated_graph_id[name_graph_id_json[kName].get()] = name_graph_id_json[kGraphId].get(); - } catch (const exception &e) { - GELOGW("Fail to trans Json to allocated graph id. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseBroadcastInfoFromJson( - const Json &json, std::unordered_map &var_broadcast_info) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const Json &broadcast_info_json : json) { - VarBroadCastInfo broadcast_info; - try { - broadcast_info.var_name = broadcast_info_json[kName].get(); - broadcast_info.broadcast_name = broadcast_info_json[kBroadcastName].get(); - broadcast_info.idx = broadcast_info_json[kIdx].get(); - broadcast_info.input_offset = broadcast_info_json[kInputOffset].get(); - broadcast_info.input_size = broadcast_info_json[kInputSize].get(); - broadcast_info.output_offset = broadcast_info_json[kOutputOffset].get(); - broadcast_info.output_size = broadcast_info_json[kOutputSize].get(); - } catch (const exception &e) { - GELOGW("Fail to trans Json to VarBroadCastInfo. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - var_broadcast_info[broadcast_info.var_name] = broadcast_info; - } - return SUCCESS; -} - -Status ModelCacheHelper::LoadOmModelFromCache(GeModelPtr &ge_model) const { - string cache_om = cache_path_ + to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix; - if (!CheckInputPathValid(cache_om)) { - GELOGW("Invalid cache path for input:%s.", cache_om.c_str()); - return FAILED; - } - string om_path = RealPath(cache_om.c_str()); - if (om_path.empty()) { - GELOGW("file path is invalid. please check file om: %s", om_path.c_str()); - return FAILED; - } - GELOGI("load model data from file: %s", om_path.c_str()); - Status ret; - int32_t priority = 0; - ModelData model_data; - ret = ModelParserBase::LoadFromFile(om_path.c_str(), priority, model_data); - if (ret != SUCCESS) { - GELOGW("LoadOmModelFromCache: Load model from file failed. ret = %u", ret); - return ret; - } - - ModelHelper model_helper; - ret = model_helper.LoadModel(model_data); - if (ret != SUCCESS) { - GELOGW("LoadOmModelFromCache: Load model from data failed. ret = %u", ret); - return ret; - } - ge_model = model_helper.GetGeModel(); - ret = RecompileNodes(ge_model); - if (ret != SUCCESS) { - GELOGW("LoadOmModelFromCache: recompile nodes failed. ret = %u", ret); - return ret; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, - string &var_name) { - std::string::size_type underline_idx = var_key.rfind('_'); - if (underline_idx == std::string::npos) { - GELOGW("Invalid var key: underline not found"); - return FAILED; - } - std::string::size_type format_idx = - var_key.rfind(std::to_string(static_cast(tensor_desc.GetFormat())), underline_idx); - if (format_idx == std::string::npos) { - GELOGW("Invalid var key: format not found"); - return FAILED; - } - var_name = var_key.substr(0, format_idx); - return SUCCESS; -} -} // namespace ge diff --git a/ge/common/helper/model_cache_helper.h b/ge/common/helper/model_cache_helper.h deleted file mode 100755 index 398d6c03..00000000 --- a/ge/common/helper/model_cache_helper.h +++ /dev/null @@ -1,123 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ -#define GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ - -#include -#include -#include - -#include "ge/ge_api_error_codes.h" -#include "graph/compute_graph.h" -#include "graph/manager/graph_var_manager.h" -#include "model/ge_model.h" - -namespace ge { -using Json = nlohmann::json; - -struct CacheInfo { - size_t node_num; - size_t edge_num; - size_t graph_hash; - map nodes_hash; - CacheInfo() : node_num(0), edge_num(0), graph_hash(0) {} -}; - -class ModelCacheHelper { - public: - ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph); - ~ModelCacheHelper(); - - Status SaveCacheInfoToCache () const; - Status SaveVarManagerToCache(bool before_build) const; - Status SaveOmModelToCache(const GeModelPtr &ge_model) const; - bool IsModelCacheHit() const; - Status RecoverVarManagerFromCache() const; - Status LoadOmModelFromCache(GeModelPtr &ge_model) const; - Status RefreshComputeGraph(const ComputeGraphPtr &compute_graph); - Status ClearCache(uint32_t graph_id) const; - - private: - Status GetComputeGraphHash(size_t &hash) const; - Status GetNodesHash(map &hash_map) const; - Status GetCacheInfo(CacheInfo &cache_info) const; - - Status RecoverMemResource(const Json &json) const; - Status RecoverAllocatedGraphId(const Json &json) const; - Status RecoverChangedGraphId(const Json &json) const; - Status RecoverVarAddrAndTensorDesc(const Json &json) const; - Status RecoverBroadcastInfo(const Json &json) const; - Status RecoverTransRoads(const Json &json) const; - static Status GetNodesNeedRecompile(ComputeGraphPtr &graph, vector &nodes); - static Status RecompileNodes(GeModelPtr &ge_model); - - bool IsNodeHashSameAsCache(const map &hash_map) const; - bool IsMemResourceSameAsCache(Json &json) const; - bool IsChangedGraphIdSameAsCache(Json &json) const; - bool IsAllocatedGraphIdSameAsCache(Json &json) const; - bool IsCurVarTensorDescSameAsCache(Json &json) const; - bool IsVarAddrMgrMapSameAsCache(Json &json) const; - bool IsBroadcastInfoSameAsCache(Json &json) const; - bool IsTransRoadsSameAsCache(Json &json) const; - bool IsVarManagerSameAsCache(Json &json) const; - bool IsVarManagerParamSameAsCache(Json &json) const; - - Status SaveJsonToFile(const string &file_name, const Json &json) const; - Status LoadJsonFromFile(const string &file_name, Json &json) const; - - Status GetNodesHashMapJson(Json &json) const; - Status GetMemResourceMap(Json &json) const; - Status GetVarAddrMgrMapJson(Json &json) const; - Status GetCurVarTensorDescMapJson(Json &json) const; - Status GetTransRoadsJson(Json &json) const; - Status GetChangedGraphIdJson(Json &json) const; - Status GetAllocatedGraphIdJson(Json &json) const; - Status GetBroadcastInfoJson(Json &json) const; - Status GetVarResourceJson(Json &json) const; - Status GetVarManagerJson(Json &json) const; - - static Status TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json); - static Status JsonToTensorDesc(const Json &json, GeTensorDesc &ge_tensor_desc); - static Status ParseMemResourceFromJson(const Json &json, map &mem_resource); - static Status ParseVarAddrMgrMapFromJson(const Json &json, - std::vector> &var_addr_mgr_vector, - std::set &var_offset_set); - static Status ParseCurVarTensorDescMapFromJson( - const Json &json, std::unordered_map &cur_var_tensor_desc_map); - static Status ParseTransRoadsFromJson(const Json &json, - std::unordered_map> &trans_roads); - static Status ParseChangedGraphIdFromJson(const Json &json, - std::map &changed_graph_id); - static Status ParseAllocatedGraphIdFromJson(const Json &json, - std::map &allocated_graph_id); - static Status ParseBroadcastInfoFromJson(const Json &json, - std::unordered_map &var_broadcast_info); - static Status GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, string &var_name); - - uint64_t session_id_; - uint32_t graph_id_; - string cache_path_; - ComputeGraphPtr compute_graph_; - std::set var_names_; - bool is_cache_path_valid_for_output; - static map graph_id_run_times_; -}; - -using ModelCacheHelperPtr = std::shared_ptr; -} // namespace ge - -#endif // GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ diff --git a/ge/common/helper/model_helper.cc b/ge/common/helper/model_helper.cc index 4e760a4a..2608b1e1 100644 --- a/ge/common/helper/model_helper.cc +++ b/ge/common/helper/model_helper.cc @@ -33,7 +33,7 @@ const uint32_t kStatiOmFileModelNum = 1; namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } +ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } Status ModelHelper::SaveModelPartition(std::shared_ptr &om_file_save_helper, ModelPartitionType type, const uint8_t *data, size_t size, size_t model_index) { @@ -108,8 +108,8 @@ Status ModelHelper::SaveSizeToModelDef(const GeModelPtr &ge_model) { return SUCCESS; } -Status ModelHelper::SaveModelDef(std::shared_ptr &om_file_save_helper, - const GeModelPtr &ge_model, ge::Buffer &model_buffer, size_t model_index) { +Status ModelHelper::SaveModelDef(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + ge::Buffer &model_buffer, size_t model_index) { ModelPtr model_tmp = ge::MakeShared(ge_model->GetName(), ge_model->GetPlatformVersion()); if (model_tmp == nullptr) { GELOGE(FAILED, "[Creat][Model]Failed, Model %s Ptr", ge_model->GetName().c_str()); @@ -143,8 +143,8 @@ Status ModelHelper::SaveModelDef(std::shared_ptr &om_file_save return SUCCESS; } -Status ModelHelper::SaveModelWeights(std::shared_ptr &om_file_save_helper, - const GeModelPtr &ge_model, size_t model_index) { +Status ModelHelper::SaveModelWeights(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + size_t model_index) { auto ge_model_weight = ge_model->GetWeight(); GELOGD("WEIGHTS_DATA size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); // weight is not necessary @@ -187,8 +187,8 @@ Status ModelHelper::SaveModelCustAICPU(std::shared_ptr &om_fil return SUCCESS; } -Status ModelHelper::SaveModelTaskDef(std::shared_ptr &om_file_save_helper, - const GeModelPtr &ge_model, ge::Buffer &task_buffer, size_t model_index) { +Status ModelHelper::SaveModelTaskDef(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + ge::Buffer &task_buffer, size_t model_index) { std::shared_ptr model_task_def = ge_model->GetModelTaskDefPtr(); if (model_task_def == nullptr) { GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Creat][ModelTaskDef]Failed, it is nullptr, " @@ -231,8 +231,8 @@ Status ModelHelper::SaveModelTaskDef(std::shared_ptr &om_file_ return SUCCESS; } -Status ModelHelper::SaveModelHeader(std::shared_ptr &om_file_save_helper, - const GeModelPtr &ge_model, size_t model_num) { +Status ModelHelper::SaveModelHeader(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + size_t model_num) { // Save target/version to model_header ModelFileHeader &model_header = om_file_save_helper->GetModelFileHeader(); model_header.platform_type = ge_model->GetPlatformType(); @@ -246,8 +246,10 @@ Status ModelHelper::SaveModelHeader(std::shared_ptr &om_file_s if (err != EOK) { GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Save][Model]Failed while allocating memory for platform_version %s, model %s, " - "errno %d", platform_version.c_str(), ge_model->GetName().c_str(), err); - REPORT_CALL_ERROR("E19999", "ModelHelper save model %s failed while " + "errno %d", + platform_version.c_str(), ge_model->GetName().c_str(), err); + REPORT_CALL_ERROR("E19999", + "ModelHelper save model %s failed while " "allocating memory for platform_version %s, errno %d", ge_model->GetName().c_str(), platform_version.c_str(), err); return ACL_ERROR_GE_MEMORY_ALLOCATION; @@ -271,9 +273,9 @@ Status ModelHelper::SaveModelHeader(std::shared_ptr &om_file_s return SUCCESS; } -Status ModelHelper::SaveAllModelPartiton(std::shared_ptr& om_file_save_helper, - const GeModelPtr &ge_model, ge::Buffer &model_buffer, - ge::Buffer &task_buffer, size_t model_index) { +Status ModelHelper::SaveAllModelPartiton(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, ge::Buffer &model_buffer, ge::Buffer &task_buffer, + size_t model_index) { if (SaveModelDef(om_file_save_helper, ge_model, model_buffer, model_index) != SUCCESS) { GELOGE(FAILED, "[Save][ModelDef]Failed, model %s, model index %zu", ge_model->GetName().c_str(), model_index); @@ -316,10 +318,8 @@ Status ModelHelper::SaveAllModelPartiton(std::shared_ptr& om_f return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model, - const SaveParam &save_param, - const std::string &output_file, - ModelBufferData& model) { +Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model, const SaveParam &save_param, + const std::string &output_file, ModelBufferData &model) { if (output_file.empty()) { GELOGE(FAILED, "[Save][Model]GraphBuilder SaveModel received invalid file name prefix, " "model %s", ge_model->GetName().c_str()); @@ -367,13 +367,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmRootModel( - const GeRootModelPtr &ge_root_model, - const SaveParam &save_param, - const std::string &output_file, - ModelBufferData& model, - bool is_unknown_shape) { - +Status ModelHelper::SaveToOmRootModel(const GeRootModelPtr &ge_root_model, const SaveParam &save_param, + const std::string &output_file, ModelBufferData &model, bool is_unknown_shape) { GE_CHECK_NOTNULL(ge_root_model); GE_IF_BOOL_EXEC(ge_root_model == nullptr, GELOGE(FAILED, "[Check][GERootModel]Ge_root_model is nullptr"); @@ -466,8 +461,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmRoo return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file) { +Status ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file) { if (output_file.empty()) { GELOGE(FAILED, "[Save][Model]Received invalid file name prefix, output_file %s", output_file.c_str()); REPORT_INNER_ERROR("E19999", "Save model received invalid file name prefix, output_file %s", output_file.c_str()); @@ -545,7 +539,7 @@ ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::strin return (ret == SUCCESS ? SUCCESS : FAILED); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(const ge::ModelData &model_data) { +Status ModelHelper::LoadModel(const ge::ModelData &model_data) { if (model_data.model_data == nullptr || model_data.model_len == 0) { GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "[Load][Model]Model_data is nullptr or model_data_size is 0"); @@ -597,7 +591,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) { +Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) { if (model_data.model_data == nullptr || model_data.model_len == 0) { GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "[Load][RootModel] " "Model_data is nullptr or model data is empty."); @@ -783,7 +777,6 @@ Status ModelHelper::LoadModelData(OmFileLoadHelper &om_load_helper, GeModelPtr & return SUCCESS; } - Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) { ModelPartition partition; if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition) != SUCCESS) { @@ -814,7 +807,7 @@ Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper, GeModelPtr &cu return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper) { +Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper) { ModelPartition task_partition; if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) { GELOGE(FAILED, "[Get][ModelTaskPartition]Failed, task_partition size:%u", task_partition.size); @@ -838,9 +831,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(Om return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper, - GeModelPtr &cur_model, - size_t mode_index) { +Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index) { ModelPartition task_partition; if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition, mode_index) != SUCCESS) { GELOGE(FAILED, "Get task model partition failed."); @@ -915,8 +906,8 @@ Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper) { return SUCCESS; } -Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper, - GeModelPtr &cur_model, size_t mode_index) { +Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, + size_t mode_index) { // Load cust aicpu kernels ModelPartition partition_kernel_def; CustAICPUKernelStore kernel_store; @@ -933,7 +924,7 @@ Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper, return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeModel() { +GeModelPtr ModelHelper::GetGeModel() { if (model_ != nullptr) { return model_; } @@ -946,7 +937,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeMo return out_model; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeRootModelPtr ModelHelper::GetGeRootModel() { +GeRootModelPtr ModelHelper::GetGeRootModel() { if (root_model_ != nullptr) { return root_model_; } @@ -959,7 +950,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeRootModelPtr ModelHelper::Get return out_model; } - Status ModelHelper::ReleaseLocalModelData() noexcept { Status result = SUCCESS; if (model_addr_tmp_ != nullptr) { @@ -976,8 +966,7 @@ Status ModelHelper::ReleaseLocalModelData() noexcept { return result; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetBaseNameFromFileName( - const string &file_name, string &base_name) { +Status ModelHelper::GetBaseNameFromFileName(const string &file_name, string &base_name) { GELOGD("Get base_name from file, file_name:%s", file_name.c_str()); GE_CHK_BOOL_EXEC_WARN(!file_name.empty(), return FAILED, "File path may not valid, check params --output"); size_t start_position = 0; @@ -992,8 +981,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetBaseName return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetModelNameFromMergedGraphName( - const string &graph_name, string &model_name) { +Status ModelHelper::GetModelNameFromMergedGraphName(const string &graph_name, string &model_name) { GELOGD("Get model_name from graph_name, graph_name:%s", graph_name.c_str()); // this can only be used after merged graph(graph name will be append with "_x", x is index); GE_CHK_BOOL_EXEC_WARN(!graph_name.empty(), return FAILED, "File path may not valid, check params --output"); @@ -1035,8 +1023,7 @@ Status ModelTool::GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, {"om", model_file, "invalid om file, can't be parsed"}); GELOGE(ACL_ERROR_GE_PARAM_INVALID, - "[Parse][ModelContent]Failed because of invalid om file %s, please check om param", - model_file); + "[Parse][ModelContent]Failed because of invalid om file %s, please check om param", model_file); return ret; } diff --git a/ge/common/helper/om_file_helper.cc b/ge/common/helper/om_file_helper.cc index cd13c5d8..a42316ff 100644 --- a/ge/common/helper/om_file_helper.cc +++ b/ge/common/helper/om_file_helper.cc @@ -18,10 +18,11 @@ #include #include -#include "common/math/math_util.h" + #include "common/auth/file_saver.h" -#include "framework/common/debug/log.h" +#include "common/math/math_util.h" #include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/util.h" @@ -32,7 +33,7 @@ const int32_t kOptionalNum = 2; } namespace ge { // For Load -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(const ge::ModelData &model) { +Status OmFileLoadHelper::Init(const ge::ModelData &model) { if (CheckModelValid(model) != SUCCESS) { return FAILED; } @@ -42,8 +43,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(c return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(uint8_t *model_data, - const uint32_t model_data_size) { +Status OmFileLoadHelper::Init(uint8_t *model_data, const uint32_t model_data_size) { Status status = LoadModelPartitionTable(model_data, model_data_size); if (status != SUCCESS) { return status; @@ -52,9 +52,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(u return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(uint8_t *model_data, - uint32_t model_data_size, - uint32_t model_num) { +Status OmFileLoadHelper::Init(uint8_t *model_data, uint32_t model_data_size, uint32_t model_num) { Status status = LoadModelPartitionTable(model_data, model_data_size, model_num); if (status != SUCCESS) { return status; @@ -64,8 +62,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(u } // Use both -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, - ModelPartition &partition) { +Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, ModelPartition &partition) { if (!is_inited_) { GELOGE(PARAM_INVALID, "OmFileLoadHelper has not been initialized!"); return PARAM_INVALID; @@ -90,9 +87,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetMod return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, - ModelPartition &partition, - size_t model_index) { +Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, ModelPartition &partition, size_t model_index) { if (!is_inited_) { GELOGE(PARAM_INVALID, "OmFileLoadHelper has not been initialized!"); return PARAM_INVALID; @@ -248,12 +243,11 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, uint32_t m return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::vector - &OmFileSaveHelper::GetModelPartitions() const { +const std::vector &OmFileSaveHelper::GetModelPartitions() const { return context_.partition_datas_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSaveHelper::GetPartitionTable() { +ModelPartitionTable *OmFileSaveHelper::GetPartitionTable() { auto partition_size = static_cast(context_.partition_datas_.size()); // Build ModelPartitionTable, flex array context_.partition_table_.clear(); @@ -272,8 +266,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSave return partition_table; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSaveHelper::GetPartitionTable( - size_t cur_ctx_index) { +ModelPartitionTable *OmFileSaveHelper::GetPartitionTable(size_t cur_ctx_index) { auto &cur_ctx = model_contexts_[cur_ctx_index]; auto partition_size = static_cast(cur_ctx.partition_datas_.size()); // Build ModelPartitionTable, flex array @@ -293,8 +286,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSave return partition_table; } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPartition(ModelPartition &partition) { +Status OmFileSaveHelper::AddPartition(ModelPartition &partition) { if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) { GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size); return FAILED; @@ -379,8 +371,8 @@ Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferDat #endif } -Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char *output_file, - ModelBufferData &model, bool is_offline) { +Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char *output_file, ModelBufferData &model, + bool is_offline) { (void)save_param.cert_file; (void)save_param.ek_file; (void)save_param.encode_mode; @@ -409,8 +401,8 @@ Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char * model_header_.length += size_of_table + cur_model_data_len; model_partition_tabels.push_back(tmp_table); all_model_partitions.push_back(cur_ctx.partition_datas_); - GELOGD("sizeof(ModelPartitionTable):%u, cur_model_data_len:%u, cur_context_index:%zu", - size_of_table, cur_model_data_len, ctx_index); + GELOGD("sizeof(ModelPartitionTable):%u, cur_model_data_len:%u, cur_context_index:%zu", size_of_table, + cur_model_data_len, ctx_index); } Status ret; if (is_offline) { diff --git a/ge/common/kernel_store.h b/ge/common/kernel_store.h index b3f4a62e..e7b867a3 100755 --- a/ge/common/kernel_store.h +++ b/ge/common/kernel_store.h @@ -48,7 +48,7 @@ struct KernelStoreItemHead { uint32_t bin_len; }; -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY KernelStore { +class KernelStore { public: KernelStore() = default; virtual ~KernelStore() = default; diff --git a/ge/graph/common/local_context.cc b/ge/common/local_context.cc similarity index 71% rename from ge/graph/common/local_context.cc rename to ge/common/local_context.cc index d3e66861..e31f2342 100644 --- a/ge/graph/common/local_context.cc +++ b/ge/common/local_context.cc @@ -14,15 +14,14 @@ * limitations under the License. */ -#include "graph/common/local_context.h" +#include "common/local_context.h" -#include "common/ge_inner_error_codes.h" -#include "common/debug/ge_log.h" -#include "omg/omg_inner_types.h" +#include "framework/common/debug/ge_log.h" namespace ge { namespace { thread_local OmgContext *omg_context = nullptr; +thread_local OmeContext *ome_context = nullptr; } void SetLocalOmgContext(OmgContext &context) { @@ -37,4 +36,18 @@ OmgContext &GetLocalOmgContext() { return domi::GetContext(); } } + +void SetLocalOmeContext(OmeContext &context) { + ome_context = &context; +} + +OmeContext &GetLocalOmeContext() { + if (ome_context != nullptr) { + return *ome_context; + } + + GELOGW("ome_context is nullptr."); + static OmeContext context; + return context; +} } diff --git a/ge/common/local_context.h b/ge/common/local_context.h new file mode 100644 index 00000000..751c6692 --- /dev/null +++ b/ge/common/local_context.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_COMMON_LOCAL_CONTEXT_H_ +#define GE_GRAPH_COMMON_LOCAL_CONTEXT_H_ + +#include "framework/omg/omg_inner_types.h" + +namespace ge { +void SetLocalOmgContext(OmgContext &context); +OmgContext &GetLocalOmgContext(); + + +struct OmeContext { + bool need_multi_batch = false; + std::string dynamic_node_type; + std::vector data_nodes; + std::vector getnext_nosink_nodes; + std::vector dynamic_shape_dims; + std::vector>> user_input_dims; + std::vector> user_real_input_dims; +}; + +GE_FUNC_VISIBILITY +void SetLocalOmeContext(OmeContext &context); + +GE_FUNC_VISIBILITY +OmeContext &GetLocalOmeContext(); +} // namespace ge +#endif // GE_GRAPH_COMMON_LOCAL_CONTEXT_H_ diff --git a/ge/common/math/fp16_math.cc b/ge/common/math/fp16_math.cc index e465c953..c2dfeb61 100755 --- a/ge/common/math/fp16_math.cc +++ b/ge/common/math/fp16_math.cc @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "fp16_math.h" +#include "common/math/fp16_math.h" #include "external/register/register_types.h" namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t sqrt(fp16_t fp) { +fp16_t sqrt(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -29,7 +29,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t sqrt(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t rsqrt(fp16_t fp) { +fp16_t rsqrt(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -40,7 +40,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t rsqrt(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t rcp(fp16_t fp) { +fp16_t rcp(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -51,7 +51,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t rcp(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t exp(fp16_t fp) { +fp16_t exp(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -63,7 +63,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t exp(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t pow2(fp16_t fp) { +fp16_t pow2(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -75,7 +75,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t pow2(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t pow10(fp16_t fp) { +fp16_t pow10(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -87,7 +87,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t pow10(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t ln(fp16_t fp) { +fp16_t ln(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -99,7 +99,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t ln(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t log2(fp16_t fp) { +fp16_t log2(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -111,7 +111,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t log2(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t log10(fp16_t fp) { +fp16_t log10(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -123,7 +123,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t log10(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t cos(fp16_t fp) { +fp16_t cos(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -135,7 +135,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t cos(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t sin(fp16_t fp) { +fp16_t sin(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -147,13 +147,13 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t sin(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t abs(fp16_t fp) { +fp16_t abs(fp16_t fp) { fp16_t ret; ret.val = (fp.val & kFp16AbsMax); return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t max(fp16_t fp1, fp16_t fp2) { +fp16_t max(fp16_t fp1, fp16_t fp2) { if (fp1 >= fp2) { return fp1; } else { @@ -161,7 +161,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t max(fp16_t fp1, fp16_t f } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t min(fp16_t fp1, fp16_t fp2) { +fp16_t min(fp16_t fp1, fp16_t fp2) { if (fp1 <= fp2) { return fp1; } else { diff --git a/ge/model/ge_model.cc b/ge/common/model/ge_model.cc similarity index 97% rename from ge/model/ge_model.cc rename to ge/common/model/ge_model.cc index bcccc6f8..7fc58b6d 100755 --- a/ge/model/ge_model.cc +++ b/ge/common/model/ge_model.cc @@ -14,9 +14,9 @@ * limitations under the License. */ -#include "model/ge_model.h" +#include "common/model/ge_model.h" #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/attr_utils.h" diff --git a/ge/model/ge_model.h b/ge/common/model/ge_model.h similarity index 89% rename from ge/model/ge_model.h rename to ge/common/model/ge_model.h index 08db8cc3..0e791746 100755 --- a/ge/model/ge_model.h +++ b/ge/common/model/ge_model.h @@ -26,12 +26,12 @@ #include "framework/common/debug/log.h" #include "framework/common/fmk_error_codes.h" #include "graph/buffer.h" -#include "graph/graph.h" +#include "external/graph/graph.h" #include "proto/task.pb.h" namespace ge { const uint32_t INVALID_MODEL_ID = 0xFFFFFFFFUL; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder { +class GeModel : public AttrHolder { public: GeModel(); ~GeModel() = default; @@ -82,13 +82,13 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder private: void Init(); - ProtoAttrMapHelper attrs_; + ProtoAttrMapHelper attrs_; /*lint !e148*/ Graph graph_; - std::shared_ptr task_; - TBEKernelStore tbe_kernal_store_; - CustAICPUKernelStore cust_aicpu_kernal_store_; - Buffer weights_buffer_; + std::shared_ptr task_; /*lint !e148*/ + TBEKernelStore tbe_kernal_store_; /*lint !e148*/ + CustAICPUKernelStore cust_aicpu_kernal_store_; /*lint !e148*/ + Buffer weights_buffer_; /*lint !e148*/ std::string name_; uint32_t version_ = {0}; diff --git a/ge/model/ge_root_model.cc b/ge/common/model/ge_root_model.cc similarity index 95% rename from ge/model/ge_root_model.cc rename to ge/common/model/ge_root_model.cc index 68f868dd..3fe10991 100644 --- a/ge/model/ge_root_model.cc +++ b/ge/common/model/ge_root_model.cc @@ -14,8 +14,9 @@ * limitations under the License. */ -#include "ge_root_model.h" +#include "common/model/ge_root_model.h" #include "graph/debug/ge_attr_define.h" + namespace ge { void GeRootModel::SetSubgraphInstanceNameToModel(string instance_name, GeModelPtr ge_model) { subgraph_instance_name_to_model_.insert(std::pair(instance_name, ge_model)); diff --git a/ge/model/ge_root_model.h b/ge/common/model/ge_root_model.h similarity index 98% rename from ge/model/ge_root_model.h rename to ge/common/model/ge_root_model.h index 9e8e116e..e9ba3da6 100755 --- a/ge/model/ge_root_model.h +++ b/ge/common/model/ge_root_model.h @@ -15,7 +15,7 @@ */ #include #include "graph/compute_graph.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #ifndef GE_MODEL_GE_ROOT_MODEL_H_ #define GE_MODEL_GE_ROOT_MODEL_H_ diff --git a/ge/common/model_parser/model_parser.cc b/ge/common/model_parser/model_parser.cc index ef9ab9e6..5d1869be 100644 --- a/ge/common/model_parser/model_parser.cc +++ b/ge/common/model_parser/model_parser.cc @@ -20,15 +20,13 @@ #include #include "securec.h" -#include "common/helper/model_helper.h" +#include "framework/common/helper/model_helper.h" namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelParserBase::ModelParserBase() {} -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelParserBase::~ModelParserBase() {} +ModelParserBase::ModelParserBase() {} +ModelParserBase::~ModelParserBase() {} -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::LoadFromFile(const char *model_path, - int32_t priority, - ge::ModelData &model_data) { +Status ModelParserBase::LoadFromFile(const char *model_path, int32_t priority, ge::ModelData &model_data) { std::string real_path = RealPath(model_path); if (real_path.empty()) { GELOGE(ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID, "[Check][Param]Model file path %s is invalid", @@ -81,9 +79,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::LoadFro return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::ParseModelContent(const ge::ModelData &model, - uint8_t *&model_data, - uint32_t &model_len) { +Status ModelParserBase::ParseModelContent(const ge::ModelData &model, uint8_t *&model_data, uint32_t &model_len) { // Parameter validity check GE_CHECK_NOTNULL(model.model_data); diff --git a/ge/common/model_saver.cc b/ge/common/model_saver.cc index 24e837f7..56045030 100755 --- a/ge/common/model_saver.cc +++ b/ge/common/model_saver.cc @@ -29,8 +29,7 @@ namespace ge { const uint32_t kInteval = 2; -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFile(const char *file_path, - const Json &model) { +Status ModelSaver::SaveJsonToFile(const char *file_path, const Json &model) { Status ret = SUCCESS; if (file_path == nullptr || SUCCESS != CheckPath(file_path)) { GELOGE(FAILED, "[Check][OutputFile]Failed, file %s", file_path); diff --git a/ge/graph/common/omg_util.cc b/ge/common/omg_util.cc similarity index 92% rename from ge/graph/common/omg_util.cc rename to ge/common/omg_util.cc index 52e6cb9c..31e4270a 100644 --- a/ge/graph/common/omg_util.cc +++ b/ge/common/omg_util.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" @@ -59,8 +59,8 @@ Status SetStreamLabel(const ge::NodePtr &node, const std::string &label) { if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_STREAM_LABEL, label)) { REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_STREAM_LABEL.c_str(), node->GetName().c_str(), node->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_STREAM_LABEL.c_str(), - node->GetName().c_str(), node->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_STREAM_LABEL.c_str(), node->GetName().c_str(), + node->GetType().c_str()); return FAILED; } @@ -100,8 +100,8 @@ Status SetActiveLabelList(const ge::NodePtr &node, const std::vectorGetName().c_str(), node->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_ACTIVE_LABEL_LIST.c_str(), - node->GetName().c_str(), node->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_ACTIVE_LABEL_LIST.c_str(), node->GetName().c_str(), + node->GetType().c_str()); return FAILED; } @@ -163,8 +163,8 @@ Status SetOriginalNodeName(const ge::NodePtr &node, const std::string &orig_name if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_ORIG_NODE_NAME, orig_name)) { REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_ORIG_NODE_NAME.c_str(), node->GetName().c_str(), node->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_ORIG_NODE_NAME.c_str(), - node->GetName().c_str(), node->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_ORIG_NODE_NAME.c_str(), node->GetName().c_str(), + node->GetType().c_str()); return FAILED; } @@ -207,8 +207,8 @@ Status SetNextIteration(const NodePtr &node, const NodePtr &next) { if (!AttrUtils::SetStr(op_desc, ATTR_NAME_NEXT_ITERATION, name)) { REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), op_desc->GetName().c_str(), + op_desc->GetType().c_str()); return FAILED; } return SUCCESS; @@ -275,21 +275,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { } /// -/// @brief Set Op _force_unknown_shape flag -/// @param [in] node -/// @param [in] force_unknown, set attribute if true -/// @param [in] group_index, condition group index of node. -/// @return -/// -void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { - if (!force_unknown) { - return; - } - - SetControlFlowGroup(node, group_index); -} - -/// /// @brief Set Op _control_flow_group flag /// @param [in] node /// @param [in] group, condition group index of node. @@ -305,8 +290,8 @@ void SetControlFlowGroup(const NodePtr &node, int64_t group) { if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group)) { REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), node->GetName().c_str(), node->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), - node->GetName().c_str(), node->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), node->GetName().c_str(), + node->GetType().c_str()); } } } // namespace ge diff --git a/ge/graph/common/omg_util.h b/ge/common/omg_util.h similarity index 86% rename from ge/graph/common/omg_util.h rename to ge/common/omg_util.h index 148e4102..83057dfb 100644 --- a/ge/graph/common/omg_util.h +++ b/ge/common/omg_util.h @@ -22,16 +22,15 @@ #include #include -#include "common/types.h" -#include "common/util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "graph/node.h" namespace ge { -namespace { -const int64_t kBufferPoolMemAlignSize = 512; -const uint32_t kBufferPoolNodeOutIndex = 0; -const uint32_t kEventReuseThreshold = 65500; -} // namespace +static constexpr int64_t kBufferPoolMemAlignSize = 512; +static constexpr uint32_t kBufferPoolNodeOutIndex = 0; +static constexpr uint32_t kEventReuseThreshold = 65500; + /// /// @brief get the Original Type of FrameworkOp /// @param [in] node @@ -126,15 +125,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size); bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); /// -/// @brief Set Op _force_unknown_shape flag -/// @param [in] node -/// @param [in] force_unknown, set attribute if true -/// @param [in] group_index, condition group index of node. -/// @return -/// -void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); - -/// /// @brief Set Op _control_flow_group flag /// @param [in] node /// @param [in] group, condition group index of node. diff --git a/ge/common/op/attr_value_util.cc b/ge/common/op/attr_value_util.cc index 4315a25d..fd5b842a 100644 --- a/ge/common/op/attr_value_util.cc +++ b/ge/common/op/attr_value_util.cc @@ -17,7 +17,7 @@ #include "framework/common/op/attr_value_util.h" #include "framework/common/debug/log.h" #include "framework/common/util.h" -#include "register/register_types.h" +#include "external/register/register_types.h" namespace ge { #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ @@ -77,37 +77,33 @@ DEFINE_SET_ATTR_VALUE_LIST(const std::string &, s); } \ } while (0); -#define DEFINE_ADD_ATTR_VALUE(KEY_TYPE, VALUE_TYPE) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttr(KEY_TYPE map_key, VALUE_TYPE value, OpDef *op_def) { \ - GE_CHECK_NOTNULL_JUST_RETURN(op_def); \ - auto attr = op_def->mutable_attr(); \ - ADD_TO_ATTR_MAP(map_key, value, attr) \ - } \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttr(KEY_TYPE map_key, VALUE_TYPE value, \ - AttrDefMap *attr_map) { \ - ADD_TO_ATTR_MAP(map_key, value, attr_map) \ - } \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddModelAttr(KEY_TYPE map_key, VALUE_TYPE value, \ - ModelDef *model_def) { \ - GE_CHECK_NOTNULL_JUST_RETURN(model_def); \ - auto attr = model_def->mutable_attr(); \ - ADD_TO_ATTR_MAP(map_key, value, attr) \ +#define DEFINE_ADD_ATTR_VALUE(KEY_TYPE, VALUE_TYPE) \ + void AddOpAttr(KEY_TYPE map_key, VALUE_TYPE value, OpDef *op_def) { \ + GE_CHECK_NOTNULL_JUST_RETURN(op_def); \ + auto attr = op_def->mutable_attr(); \ + ADD_TO_ATTR_MAP(map_key, value, attr) \ + } \ + void AddOpAttr(KEY_TYPE map_key, VALUE_TYPE value, AttrDefMap *attr_map) { \ + ADD_TO_ATTR_MAP(map_key, value, attr_map) \ + } \ + void AddModelAttr(KEY_TYPE map_key, VALUE_TYPE value, ModelDef *model_def) { \ + GE_CHECK_NOTNULL_JUST_RETURN(model_def); \ + auto attr = model_def->mutable_attr(); \ + ADD_TO_ATTR_MAP(map_key, value, attr) \ } -#define DEFINE_ADD_ATTR_VALUE_LIST(KEY_TYPE, VALUE_TYPE) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttrList(KEY_TYPE map_key, VALUE_TYPE value, \ - OpDef *op_def) { \ - GE_CHECK_NOTNULL_JUST_RETURN(op_def); \ - auto attr = op_def->mutable_attr(); \ - ADD_TO_ATTR_MAP_LIST(map_key, value, attr) \ - } \ - FMK_FUNC_DEV_VISIBILITY void AddOpAttrList(KEY_TYPE map_key, VALUE_TYPE value, AttrDefMap *attr_map) { \ - ADD_TO_ATTR_MAP_LIST(map_key, value, attr_map) \ - } \ - FMK_FUNC_DEV_VISIBILITY void AddModelAttrList(KEY_TYPE map_key, VALUE_TYPE value, ModelDef *model_def) { \ - GE_CHECK_NOTNULL_JUST_RETURN(model_def); \ - auto attr = model_def->mutable_attr(); \ - ADD_TO_ATTR_MAP_LIST(map_key, value, attr) \ +#define DEFINE_ADD_ATTR_VALUE_LIST(KEY_TYPE, VALUE_TYPE) \ + void AddOpAttrList(KEY_TYPE map_key, VALUE_TYPE value, OpDef *op_def) { \ + GE_CHECK_NOTNULL_JUST_RETURN(op_def); \ + auto attr = op_def->mutable_attr(); \ + ADD_TO_ATTR_MAP_LIST(map_key, value, attr) \ + } \ + void AddOpAttrList(KEY_TYPE map_key, VALUE_TYPE value, AttrDefMap *attr_map) { \ + ADD_TO_ATTR_MAP_LIST(map_key, value, attr_map)} FMK_FUNC_DEV_VISIBILITY void \ + AddModelAttrList(KEY_TYPE map_key, VALUE_TYPE value, ModelDef *model_def) { \ + GE_CHECK_NOTNULL_JUST_RETURN(model_def); \ + auto attr = model_def->mutable_attr(); \ + ADD_TO_ATTR_MAP_LIST(map_key, value, attr) \ } DEFINE_ADD_ATTR_VALUE(const std::string &, const std::string &); @@ -127,46 +123,42 @@ DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const bool); DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const int64_t); DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const std::string &); -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttr(const std::string &map_key, AttrDef &attr, - OpDef *op_def) { +void AddOpAttr(const std::string &map_key, AttrDef &attr, OpDef *op_def) { GE_CHECK_NOTNULL_JUST_RETURN(op_def); GE_CHECK_NOTNULL_JUST_RETURN(op_def->mutable_attr()); (void)op_def->mutable_attr()->insert(AttrDefPair(map_key, attr)); } -#define DEFINE_GET_ATTR_VALUE(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetAttrDefValue(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, \ - const AttrDefMap &attr) { \ - auto it = attr.find(map_key); \ - if (it != attr.end()) { \ - *value = it->second.FIELD(); \ - return true; \ - } \ - return false; \ +#define DEFINE_GET_ATTR_VALUE(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ + bool GetAttrDefValue(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, const AttrDefMap &attr) { \ + auto it = attr.find(map_key); \ + if (it != attr.end()) { \ + *value = it->second.FIELD(); \ + return true; \ + } \ + return false; \ } -#define DEFINE_GET_ATTR_POINT_REF(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetAttrDefValue(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE *&value, \ - AttrDefMap *attr) { \ - GE_RT_FALSE_CHECK_NOTNULL(attr); \ - auto it = attr->find(map_key); \ - if (it != attr->end()) { \ - value = it->second.mutable_##FIELD(); \ - return true; \ - } \ - return false; \ +#define DEFINE_GET_ATTR_POINT_REF(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ + bool GetAttrDefValue(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE *&value, AttrDefMap *attr) { \ + GE_RT_FALSE_CHECK_NOTNULL(attr); \ + auto it = attr->find(map_key); \ + if (it != attr->end()) { \ + value = it->second.mutable_##FIELD(); \ + return true; \ + } \ + return false; \ } -#define DEFINE_GET_ATTR_CONST_POINT_REF(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetAttrDefValue( \ - ARG_TYPE_KEY map_key, const ARG_TYPE_VALUE *&value, const AttrDefMap &attr) { \ - auto it = attr.find(map_key); \ - if (it == attr.end()) { \ - return false; \ - } \ - \ - value = &(it->second.FIELD()); \ - return true; \ +#define DEFINE_GET_ATTR_CONST_POINT_REF(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ + bool GetAttrDefValue(ARG_TYPE_KEY map_key, const ARG_TYPE_VALUE *&value, const AttrDefMap &attr) { \ + auto it = attr.find(map_key); \ + if (it == attr.end()) { \ + return false; \ + } \ + \ + value = &(it->second.FIELD()); \ + return true; \ } #define DEFINE_GET_BYTES_ATTR_VALUE(ARG_TYPE_KEY, ARG_TYPE_VALUE) \ @@ -216,16 +208,14 @@ DEFINE_GET_ATTR_CONST_POINT_REF(const std::string &, NamedAttrs, func); DEFINE_GET_BYTES_ATTR_VALUE(const std::string &, std::string *); -#define DEFINE_GET_OP_ATTR(ARG_TYPE_KEY, ARG_TYPE_VALUE) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetOpAttr(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, \ - const OpDef *op_def) { \ - GE_RT_FALSE_CHECK_NOTNULL(op_def); \ - return GetAttrDefValue(map_key, value, op_def->attr()); \ - } \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetModelAttr(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, \ - const ModelDef *model_def) { \ - GE_RT_FALSE_CHECK_NOTNULL(model_def); \ - return GetAttrDefValue(map_key, value, model_def->attr()); \ +#define DEFINE_GET_OP_ATTR(ARG_TYPE_KEY, ARG_TYPE_VALUE) \ + bool GetOpAttr(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, const OpDef *op_def) { \ + GE_RT_FALSE_CHECK_NOTNULL(op_def); \ + return GetAttrDefValue(map_key, value, op_def->attr()); \ + } \ + bool GetModelAttr(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, const ModelDef *model_def) { \ + GE_RT_FALSE_CHECK_NOTNULL(model_def); \ + return GetAttrDefValue(map_key, value, model_def->attr()); \ } DEFINE_GET_OP_ATTR(const std::string &, std::string *); @@ -238,8 +228,7 @@ DEFINE_GET_OP_ATTR(const std::string &, bool *); DEFINE_GET_OP_ATTR(const std::string &, AttrDef_ListValue *); #define DEFINE_GET_BT_ATTR(ARG_TYPE_KEY, ARG_TYPE_VALUE) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetBytesAttr(ARG_TYPE_KEY key, ARG_TYPE_VALUE value, \ - const OpDef *op_def) { \ + bool GetBytesAttr(ARG_TYPE_KEY key, ARG_TYPE_VALUE value, const OpDef *op_def) { \ GE_RT_FALSE_CHECK_NOTNULL(op_def); \ return GetBytesValue(key, value, op_def->attr()); \ } \ @@ -250,7 +239,7 @@ DEFINE_GET_OP_ATTR(const std::string &, AttrDef_ListValue *); DEFINE_GET_BT_ATTR(const std::string &, std::string *); -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool HasOpAttr(const OpDef *op_def, const std::string &attr_name) { +bool HasOpAttr(const OpDef *op_def, const std::string &attr_name) { if (op_def == nullptr) { return false; } @@ -263,8 +252,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool HasOpAttr(const OpDef *op_ return false; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddModelAttr(const std::string &map_key, const void *value, - size_t size, ModelDef *model_def) { +void AddModelAttr(const std::string &map_key, const void *value, size_t size, ModelDef *model_def) { if (model_def == nullptr) { return; } @@ -280,8 +268,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddModelAttr(const std::st } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpBytesAttr(const std::string &key, const void *value, - size_t size, OpDef *op_def) { +void AddOpBytesAttr(const std::string &key, const void *value, size_t size, OpDef *op_def) { if (op_def == nullptr) { return; } diff --git a/ge/common/op/ge_op_utils.cc b/ge/common/op/ge_op_utils.cc index 99b5733c..429ce909 100644 --- a/ge/common/op/ge_op_utils.cc +++ b/ge/common/op/ge_op_utils.cc @@ -115,8 +115,7 @@ const int NORMAL_TENSOR_SIZE = 4; #define AIPP_CONVERT_LIST_FLOAT(KEY, REQUIRED) AIPP_CONVERT_LIST_FORMAT(KEY, float, REQUIRED, GeAttrValue::FLOAT) -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -OpUtils::ConvertAippParams(const GeAttrValue::NAMED_ATTRS &aipp_attr, domi::AippOpParams *aipp_params) { +Status OpUtils::ConvertAippParams(const GeAttrValue::NAMED_ATTRS &aipp_attr, domi::AippOpParams *aipp_params) { GE_CHECK_NOTNULL(aipp_params); AIPP_CONVERT_FORMAT_EX(aipp_mode, domi::AippOpParams::AippMode, int32_t, GeAttrValue::INT); AIPP_CONVERT_INT(related_input_rank); @@ -178,8 +177,7 @@ OpUtils::ConvertAippParams(const GeAttrValue::NAMED_ATTRS &aipp_attr, domi::Aipp return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::TransferDim(const std::vector &dim, - std::vector &dim_vector) { +Status OpUtils::TransferDim(const std::vector &dim, std::vector &dim_vector) { size_t input_shape_size = dim.size(); std::list new_dim_list; for (auto dim_temp : dim) { @@ -301,9 +299,9 @@ Status OpUtils::SetOutputSliceDataByDataType(void *data, int64_t data_size, cons return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::SetOutputSliceData( - void *data, int64_t data_size, int32_t data_type, std::vector &input_dims, std::vector &begin, - std::vector &output_dims, GeTensor *output, std::vector &stride) { +Status OpUtils::SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector &input_dims, + std::vector &begin, std::vector &output_dims, GeTensor *output, + std::vector &stride) { if (data == nullptr || output == nullptr) { GELOGE(PARAM_INVALID, "[Check][Param]Input param is nullptr"); REPORT_INNER_ERROR("E19999", "Input param is nullptr"); @@ -352,9 +350,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::SetOutputSliceD return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::TransDataHWCK2KCHW(const void *input, int64_t h, - int64_t w, int64_t c, int64_t k, - void **output) { +void OpUtils::TransDataHWCK2KCHW(const void *input, int64_t h, int64_t w, int64_t c, int64_t k, void **output) { if (input == nullptr) { return; } @@ -386,9 +382,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::TransDataHWCK2KCH *output = buf; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::TransDataKCHW2HWCK(const void *input, int64_t k, - int64_t c, int64_t h, int64_t w, - void *output) { +void OpUtils::TransDataKCHW2HWCK(const void *input, int64_t k, int64_t c, int64_t h, int64_t w, void *output) { if ((input == nullptr) || (output == nullptr)) { GELOGD("%s[%d]: input param is nullptr.", __FILE__, __LINE__); return; @@ -417,31 +411,22 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::TransDataKCHW2HWC vector OpUtils::GetWeights(const ge::Node &node) { return OpDescUtils::GetWeights(node); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector OpUtils::GetWeights(ge::ConstNodePtr node) { - return OpDescUtils::GetWeights(node); -} +vector OpUtils::GetWeights(ge::ConstNodePtr node) { return OpDescUtils::GetWeights(node); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector OpUtils::MutableWeights(const ge::Node &node) { - return OpDescUtils::MutableWeights(node); -} +vector OpUtils::MutableWeights(const ge::Node &node) { return OpDescUtils::MutableWeights(node); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector OpUtils::MutableWeights(const ge::NodePtr node) { - return OpDescUtils::MutableWeights(node); -} +vector OpUtils::MutableWeights(const ge::NodePtr node) { return OpDescUtils::MutableWeights(node); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::SetWeights(ge::Node &node, - const vector &weights) { +Status OpUtils::SetWeights(ge::Node &node, const vector &weights) { return OpDescUtils::SetWeights(node, weights); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::SetWeights(ge::NodePtr node, - const vector &weights) { +Status OpUtils::SetWeights(ge::NodePtr node, const vector &weights) { return OpDescUtils::SetWeights(node, weights); } // The caller guarantees that the input sensor is constant -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -OpUtils::GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, DataType type, std::vector &dims) { +Status OpUtils::GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, DataType type, std::vector &dims) { if (tensor == nullptr) { GELOGE(PARAM_INVALID, "[Check][Param]Input tensor is nullptr"); REPORT_INNER_ERROR("E19999", "Input tensor is nullptr"); diff --git a/ge/common/profiling/ge_profiling.cc b/ge/common/profiling/ge_profiling.cc index d0343326..a5857b35 100644 --- a/ge/common/profiling/ge_profiling.cc +++ b/ge/common/profiling/ge_profiling.cc @@ -14,14 +14,17 @@ * limitations under the License. */ -#include "common/profiling/ge_profiling.h" +#include "framework/common/profiling/ge_profiling.h" #include "runtime/base.h" #include "common/profiling/profiling_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "graph/load/graph_loader.h" +#include "graph/ge_context.h" #include "init/gelib.h" #include "framework/common/ge_inner_error_codes.h" +#include "common/model/ge_model.h" +#include "framework/omg/omg_inner_types.h" namespace { const uint32_t kDeviceListIndex = 3; @@ -34,6 +37,7 @@ const std::string kProfilingStop = "prof_stop"; const std::string kProfModelSubscribe = "prof_model_subscribe"; const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe"; const std::string kRtSetDeviceRegName = "profiling"; +const std::string kPofilingModelId = "modelId"; const std::map kProfCommandTypeMap = { {kProfCommandhandleInit, kProfilingInit}, @@ -42,6 +46,26 @@ const std::map kProfCommandTypeMap = { {kProfCommandhandleFinalize, kProfilingFinalize}, {kProfCommandhandleModelSubscribe, kProfModelSubscribe}, {kProfCommandhandleModelUnsubscribe, kProfModelUnsubscribe}}; + +const uint64_t kModelId = ge::INVALID_MODEL_ID; +const uint16_t kStepStart = 0; +const uint16_t kStepEnd = 1; + +ge::Status NeedUnsubscribe(ProfCommandHandleType type, bool is_subscribe, + uint32_t graph_id, vector &prof_params) { + if (type == kProfCommandhandleModelUnsubscribe && is_subscribe) { + prof_params.clear(); + prof_params.emplace_back(kPofilingModelId); + uint32_t model_id = 0; + auto ret = ge::ProfilingManager::Instance().GetModelIdFromGraph(graph_id, model_id); + if (ret != ge::SUCCESS) { + GELOGE(ret, "graph_id:%u not not found", graph_id); + return ret; + } + prof_params.emplace_back(std::to_string(model_id)); + } + return ge::SUCCESS; +} } // namespace bool TransProfConfigToParam(const ProfCommandHandleData &profCommand, vector &prof_config_params) { @@ -190,6 +214,24 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le return ge::PARAM_INVALID; } } + auto &profiling_manager = ge::ProfilingManager::Instance(); + auto is_train = domi::GetContext().train_flag; + if (type == kProfCommandhandleModelSubscribe && is_train) { + profiling_manager.SetSubscribeInfo(prof_config_param->profSwitch, prof_config_param->modelId, true); + return ge::SUCCESS; + } + auto is_subscribe = profiling_manager.GetSubscribeInfo().is_subscribe; + // GraphId is actually stored in prof_config_param + auto graph_id = prof_config_param->modelId; + ge::Status ret = NeedUnsubscribe(type, is_subscribe, graph_id, prof_params); + if (ret != ge::SUCCESS) { + GELOGE(ret, "graph_id:%u not not found", graph_id); + REPORT_INPUT_ERROR("E10001", std::vector({"value", "parameter", "reason"}), + std::vector({std::to_string(graph_id), + "GraphToModelMap", + "graph_id does not exist!"})); + return ge::FAILED; + } ge::GraphLoader graph_loader; ge::Command command; command.cmd_params.clear(); @@ -203,7 +245,7 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) { GELOGI("Profiling device nums:%s , deviceID:[%s]", prof_params[0].c_str(), prof_params[kDeviceListIndex].c_str()); } - ge::Status ret = graph_loader.CommandHandle(command); + ret = graph_loader.CommandHandle(command); if (ret != ge::SUCCESS) { GELOGE(ret, "[Handle][Command]Handle profiling command failed, command type %s, error_code %u", iter->second.c_str(), ret); @@ -216,6 +258,34 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le return ge::SUCCESS; } -GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) { - return ge::SUCCESS; +ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) { + static bool is_first_run = true; + int32_t device_id = 0; + rtError_t rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(rt_ret, "[Get][LogicDeviceId]Failed, ret 0x%X", rt_ret); + REPORT_CALL_ERROR("E19999", "Get logic device id failed, ret 0x%X", rt_ret); + return ge::FAILED; + } + auto &profiling_manager = ge::ProfilingManager::Instance(); + profiling_manager.SetStepInfoIndex(index_id); + if (is_first_run && tag_id == kStepStart) { + GE_CHK_STATUS_RET_NOLOG(profiling_manager.ProfileStepInfo(index_id, kModelId, tag_id, stream, device_id)); + is_first_run = false; + return ge::SUCCESS; + } + if (!is_first_run && tag_id == kStepEnd) { + GE_CHK_STATUS_RET_NOLOG(profiling_manager.ProfileStepInfo(index_id, kModelId, tag_id, stream, device_id)); + is_first_run = true; + return ge::SUCCESS; + } + GELOGE(ge::FAILED, "Param tag_id:%u invalid when is_first_run is %d", tag_id, is_first_run); + REPORT_INPUT_ERROR("E10001", std::vector({"value", "parameter", "reason"}), + std::vector({std::to_string(tag_id), "tag_id", + "tag id must be 0 when first run, must be 1 when second run"})); + return ge::FAILED; +} + +ge::Status ProfGetDeviceFormGraphId(uint32_t graph_id, uint32_t &device_id) { + return ge::ProfilingManager::Instance().GetDeviceIdFromGraph(graph_id, device_id); } diff --git a/ge/common/profiling/ge_runner_profiling.cc b/ge/common/profiling/ge_runner_profiling.cc index 067aafe3..f74ce384 100644 --- a/ge/common/profiling/ge_runner_profiling.cc +++ b/ge/common/profiling/ge_runner_profiling.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "common/profiling/ge_runner_profiling.h" +#include "framework/common/profiling/ge_runner_profiling.h" #include "init/gelib.h" bool IsInitialize() { diff --git a/ge/common/profiling/profiling_manager.cc b/ge/common/profiling/profiling_manager.cc index 61210de6..e8f41cc4 100644 --- a/ge/common/profiling/profiling_manager.cc +++ b/ge/common/profiling/profiling_manager.cc @@ -21,7 +21,7 @@ #include "framework/common/string_util.h" #include "graph/ge_context.h" #include "graph/utils/type_utils.h" -#include "graph/types.h" +#include "external/graph/types.h" #include "runtime/base.h" #include "graph/load/model_manager/davinci_model.h" #include "mmpa/mmpa_api.h" @@ -66,19 +66,23 @@ const std::string kIdx = "idx"; namespace ge { ProfilingManager::ProfilingManager() - : is_load_profiling_(false), is_execute_profiling_(false), is_training_trace_(false), subscribe_count_(0) { - prof_cb_.msprofCtrlCallback = nullptr; - prof_cb_.msprofReporterCallback = nullptr; + : is_load_profiling_(false), + is_execute_profiling_(false), + is_training_trace_(false), + subscribe_count_(0), + prof_cb_({nullptr, nullptr}), + index_id_(UINT64_MAX), + subscribe_info_({false, 0, 0}) { } ProfilingManager::~ProfilingManager() {} -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager &ProfilingManager::Instance() { +ProfilingManager &ProfilingManager::Instance() { static ProfilingManager profiling_manager; return profiling_manager; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::Init(const Options &options) { +ge::Status ProfilingManager::Init(const Options &options) { #ifdef DAVINCI_SUPPORT_PROFILING vector().swap(device_id_); subscribe_count_ = 0; @@ -217,7 +221,7 @@ ge::Status ProfilingManager::ParseOptions(const std::string &options) { return ge::SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProfiling() { +void ProfilingManager::StopProfiling() { #ifdef DAVINCI_SUPPORT_PROFILING uint64_t module = GetProfilingModule(); // The following if case will not be executed in normal case, inc case of ProfStopProfiling is abnormal @@ -255,8 +259,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingOpInputOutInfo( - const TaskDescInfo &task, Json &task_json) { +void ProfilingManager::ProfilingOpInputOutInfo(const TaskDescInfo &task, Json &task_json) { #ifdef DAVINCI_SUPPORT_PROFILING for (size_t i = 0; i < task.input_format.size(); i++) { Json tmp_input; @@ -282,8 +285,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( - uint32_t model_id, const std::vector &task_desc_info, const int32_t &device_id) { +void ProfilingManager::ProfilingTaskDescInfo(uint32_t model_id, const std::vector &task_desc_info, + const int32_t &device_id) { #ifdef DAVINCI_SUPPORT_PROFILING for (const auto &task : task_desc_info) { Json task_info; @@ -320,8 +323,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfileStepInfo( - uint64_t index_id, uint64_t model_id, uint16_t tag_id, rtStream_t stream, int32_t device_id) { +Status ProfilingManager::ProfileStepInfo(uint64_t index_id, uint64_t model_id, uint16_t tag_id, rtStream_t stream, + int32_t device_id) { #ifdef DAVINCI_SUPPORT_PROFILING if (!is_load_profiling_ && subscribe_count_ == 0) { GELOGD("Profiling is not turned on, no need to profile step info."); @@ -381,8 +384,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::Profil return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportData( - const int32_t &device_id, const string &data, const string &tag_name) { +void ProfilingManager::ReportData(const int32_t &device_id, const string &data, const string &tag_name) { #ifdef DAVINCI_SUPPORT_PROFILING ReporterData reporter_data{}; int ret = -1; @@ -422,8 +424,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportDa #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( - uint32_t model_id, const std::vector &task_desc_info) { +void ProfilingManager::ReportProfilingData(uint32_t model_id, const std::vector &task_desc_info) { #ifdef DAVINCI_SUPPORT_PROFILING int32_t logic_device_id = 0; rtError_t rt_ret = rtGetDevice(&logic_device_id); @@ -439,7 +440,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportPr #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t ProfilingManager::GetProfilingModule() { +uint64_t ProfilingManager::GetProfilingModule() { uint64_t module = PROF_MODEL_EXECUTE_MASK | PROF_RUNTIME_API_MASK | PROF_RUNTIME_TRACE_MASK | @@ -481,8 +482,7 @@ void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type, uin #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfModelSubscribe( - uint64_t module, void *model) { +Status ProfilingManager::ProfModelSubscribe(uint64_t module, void *model) { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); uint64_t model_load_mask = module & PROF_MODEL_LOAD_MASK; @@ -522,8 +522,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfModelUnsubscribe( - void *model) { +Status ProfilingManager::ProfModelUnsubscribe(void *model) { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); if (subscribe_count_ == 0) { @@ -564,7 +563,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfInit(uint64_t module) { +Status ProfilingManager::ProfInit(uint64_t module) { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); uint64_t model_load_mask = module & PROF_MODEL_LOAD_MASK; @@ -598,16 +597,19 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfIn return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfFinalize() { +Status ProfilingManager::ProfFinalize() { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); is_load_profiling_ = false; is_training_trace_ = false; is_execute_profiling_ = false; + index_id_ = UINT64_MAX; // profiling plugin uninit PluginUnInit(); + CleanSubscribeInfo(); + int32_t dev_num = -1; rtError_t rt_ret = rtProfilerStop(PROF_MODEL_LOAD_MASK, dev_num, nullptr); if (rt_ret != RT_ERROR_NONE) { @@ -630,6 +632,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfFi } device_id_module_map_.clear(); device_id_.clear(); + device_id_map_.clear(); + model_id_map_.clear(); GELOGI("Prof finalize success."); #endif return SUCCESS; @@ -688,8 +692,8 @@ Status ProfilingManager::ProfParseDeviceId(const std::map &config_para, - int32_t &device_num, vector &device_list) { +Status ProfilingManager::ProfParseParam(const std::map &config_para, int32_t &device_num, + vector &device_list) { #ifdef DAVINCI_SUPPORT_PROFILING // device num auto iter = config_para.find(kConfigNumsdev); @@ -738,8 +742,7 @@ Status ProfilingManager::ProfParseParam(const std::map return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfStartProfiling( - uint64_t module, const std::map &config_para) { +Status ProfilingManager::ProfStartProfiling(uint64_t module, const std::map &config_para) { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); uint64_t training_trace_mask = module & PROF_TRAINING_TRACE_MASK; @@ -794,8 +797,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfStopProfiling(uint64_t module, - const std::map &config_para) { +Status ProfilingManager::ProfStopProfiling(uint64_t module, const std::map &config_para) { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); int32_t device_num = 0; @@ -846,8 +848,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::UpdateDeviceIdModuleMap(string prof_type, - uint64_t module, const vector &device_list) { +void ProfilingManager::UpdateDeviceIdModuleMap(string prof_type, uint64_t module, const vector &device_list) { #ifdef DAVINCI_SUPPORT_PROFILING if (prof_type == kProfStart) { for (uint32_t i = 0; i < device_list.size(); i++) { @@ -877,7 +878,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::UpdateDe #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ProfilingManager::ProfilingModelExecuteOn() const { +bool ProfilingManager::ProfilingModelExecuteOn() const { int32_t logic_device_id = 0; rtError_t rt_ret = rtGetDevice(&logic_device_id); if (rt_ret != RT_ERROR_NONE) { @@ -895,7 +896,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ProfilingManager::Profilin return execute_model_prof_on; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::PluginInit() { +Status ProfilingManager::PluginInit() { if (prof_cb_.msprofReporterCallback == nullptr) { GELOGE(ge::PARAM_INVALID, "[Check][Param]MsprofReporterCallback callback is nullptr"); REPORT_INNER_ERROR("E19999", "MsprofReporterCallback callback is nullptr"); @@ -924,7 +925,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::Plugin return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUnInit() const { +void ProfilingManager::PluginUnInit() const { #ifdef DAVINCI_SUPPORT_PROFILING if (prof_cb_.msprofReporterCallback == nullptr) { GELOGE(ge::PARAM_INVALID, "[Check][Param]MsprofReporterCallback callback is nullptr"); @@ -941,8 +942,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUn #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::CallMsprofReport( - ReporterData &reporter_data) const { +Status ProfilingManager::CallMsprofReport(ReporterData &reporter_data) const { if (prof_cb_.msprofReporterCallback == nullptr) { GELOGE(ge::PARAM_INVALID, "[Check][Param]MsprofReporterCallback callback is nullptr"); REPORT_INNER_ERROR("E19999", "MsprofReporterCallback callback is nullptr"); @@ -998,14 +998,12 @@ void ProfilingManager::GetOpOutputInfo(const OpDescPtr &op, TaskDescInfo &task_d task_desc_info.output_data_type = output_data_type.empty() ? data_type_default : output_data_type; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInputOutputInfo( - const OpDescPtr &op, TaskDescInfo &task_desc_info) const { +void ProfilingManager::GetOpInputOutputInfo(const OpDescPtr &op, TaskDescInfo &task_desc_info) const { GetOpInputInfo(op, task_desc_info); GetOpOutputInfo(op, task_desc_info); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpPoint( - std::string &fp_point, std::string &bp_point) { +void ProfilingManager::GetFpBpPoint(std::string &fp_point, std::string &bp_point) { // Env or options mode, fp_point_/bp_point_ have initiliazed on profiling init if (!fp_point_.empty() && !bp_point_.empty()) { fp_point = fp_point_; @@ -1016,7 +1014,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpP } // ProfApi mode and training trace is set // Parse options first - char env_profiling_options[MSPROF_OPTIONS_DEF_LEN_MAX] = { 0x00 }; + char env_profiling_options[MSPROF_OPTIONS_DEF_LEN_MAX] = {0x00}; bool is_profiling_valid = false; std::string profiling_options; if (ge::GetContext().GetOption(OPTION_EXEC_PROFILING_OPTIONS, profiling_options) == SUCCESS && @@ -1055,4 +1053,40 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpP return; } +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::GetDeviceIdFromGraph( + uint32_t graph_id, uint32_t &device_id) { + auto iter = device_id_map_.find(graph_id); + if (iter != device_id_map_.end()) { + device_id = iter->second; + return SUCCESS; + } + REPORT_CALL_ERROR("E19999", "graph_id:%u does not exist!", graph_id); + GELOGE(PARAM_INVALID, "[Check][GraphId]graph_id:%u does not exist!", graph_id); + return FAILED; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::SetSubscribeInfo( + uint64_t prof_switch, uint32_t model_id, bool is_subscribe) { + subscribe_info_.is_subscribe = is_subscribe; + subscribe_info_.prof_switch = prof_switch; + subscribe_info_.graph_id = model_id; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::CleanSubscribeInfo() { + subscribe_info_.is_subscribe = false; + subscribe_info_.prof_switch = 0; + subscribe_info_.graph_id = 0; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::GetModelIdFromGraph( + uint32_t graph_id, uint32_t &model_id) { + auto iter = model_id_map_.find(graph_id); + if (iter != model_id_map_.end()) { + model_id = iter->second; + return SUCCESS; + } + REPORT_CALL_ERROR("E19999", "graph_id:%u does not exist!", graph_id); + GELOGE(PARAM_INVALID, "[Check][GraphId]graph_id:%u does not exist!", graph_id); + return FAILED; +} } // namespace ge diff --git a/ge/common/profiling/profiling_manager.h b/ge/common/profiling/profiling_manager.h index 049a4df4..86371d51 100755 --- a/ge/common/profiling/profiling_manager.h +++ b/ge/common/profiling/profiling_manager.h @@ -62,12 +62,18 @@ struct DeviceSubsInfo { uint32_t subscribe_count; }; +struct ProfSubscribeInfo { + bool is_subscribe; + uint64_t prof_switch; + uint32_t graph_id; +}; + struct MsprofCallback { MsprofCtrlCallback msprofCtrlCallback; MsprofReporterCallback msprofReporterCallback; }; -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { +class ProfilingManager { public: ProfilingManager(); virtual ~ProfilingManager(); @@ -101,6 +107,16 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { void GetOpInputOutputInfo(const OpDescPtr &op, TaskDescInfo &task_desc_info) const; void ReportData(const int32_t &device_id, const std::string &data, const std::string &tag_name); Status ProfileStepInfo(uint64_t index_id, uint64_t model_id, uint16_t tag_id, rtStream_t stream, int32_t device_id); + void SetStepInfoIndex(uint64_t index_id) { index_id_ = index_id; } + uint64_t GetStepInfoIndex() const { return index_id_; } + void SetGraphIdToDeviceMap(uint32_t graph_id, uint32_t device_id) { device_id_map_[graph_id] = device_id; } + Status GetDeviceIdFromGraph(uint32_t graph_id, uint32_t &device_id); + void SetSubscribeInfo(uint64_t prof_switch, uint32_t model_id, bool is_subscribe); + const ProfSubscribeInfo &GetSubscribeInfo() const { return subscribe_info_; } + void CleanSubscribeInfo(); + void SetGraphIdToModelMap(uint32_t graph_id, uint32_t model_id) { model_id_map_[graph_id] = model_id; } + Status GetModelIdFromGraph(uint32_t graph_id, uint32_t &model_id); + private: Status InitFromOptions(const Options &options, MsprofGeOptions &prof_conf); Status ParseOptions(const std::string &options); @@ -127,6 +143,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { std::string fp_point_; std::string bp_point_; uint32_t reporter_max_len_ = 0; + uint64_t index_id_; + std::map device_id_map_; // key: graph_id, value: device_id + std::map model_id_map_; // key: graph_id, value: model_id + ProfSubscribeInfo subscribe_info_; }; } // namespace ge #endif // GE_COMMON_PROFILING_PROFILING_MANAGER_H_ diff --git a/ge/common/properties_manager.cc b/ge/common/properties_manager.cc index e1f4c66e..aeabb008 100644 --- a/ge/common/properties_manager.cc +++ b/ge/common/properties_manager.cc @@ -21,7 +21,7 @@ #include #include "common/ge/ge_util.h" -#include "common/util.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "framework/common/ge_types.h" @@ -35,13 +35,13 @@ PropertiesManager::PropertiesManager() : is_inited_(false), delimiter("=") {} PropertiesManager::~PropertiesManager() {} // singleton -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PropertiesManager &PropertiesManager::Instance() { +PropertiesManager &PropertiesManager::Instance() { static PropertiesManager instance; return instance; } // Initialize property configuration -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::Init(const std::string &file_path) { +bool PropertiesManager::Init(const std::string &file_path) { std::lock_guard lock(mutex_); if (is_inited_) { GELOGW("Already inited, will be initialized again"); @@ -139,8 +139,7 @@ std::string PropertiesManager::Trim(const std::string &str) { } // Get property value, if not found, return "" -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetPropertyValue( - const std::string &map_key) { +std::string PropertiesManager::GetPropertyValue(const std::string &map_key) { std::lock_guard lock(mutex_); auto iter = properties_map_.find(map_key); if (properties_map_.end() != iter) { @@ -151,21 +150,19 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager:: } // Set property value -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetPropertyValue(const std::string &map_key, - const std::string &value) { +void PropertiesManager::SetPropertyValue(const std::string &map_key, const std::string &value) { std::lock_guard lock(mutex_); properties_map_[map_key] = value; } // return properties_map_ -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map -PropertiesManager::GetPropertyMap() { +std::map PropertiesManager::GetPropertyMap() { std::lock_guard lock(mutex_); return properties_map_; } // Set separator -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetPropertyDelimiter(const std::string &de) { +void PropertiesManager::SetPropertyDelimiter(const std::string &de) { std::lock_guard lock(mutex_); delimiter = de; } diff --git a/ge/common/properties_manager.h b/ge/common/properties_manager.h index 7079eecb..f3f9d7c9 100644 --- a/ge/common/properties_manager.h +++ b/ge/common/properties_manager.h @@ -25,7 +25,7 @@ #include "common/dump/dump_properties.h" #include "graph/op_desc.h" -#include "common/ge_compiler_options.h" +#include "framework/common/ge_compiler_options.h" namespace ge { // Configuration property management diff --git a/ge/common/proto/ge_ir.proto b/ge/common/proto/ge_ir.proto deleted file mode 100644 index c0ef3071..00000000 --- a/ge/common/proto/ge_ir.proto +++ /dev/null @@ -1,193 +0,0 @@ -syntax = "proto3"; - -package ge.proto; - -enum DataType -{ - DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. - DT_FLOAT = 1; // float type - DT_FLOAT16 = 2; // fp16 type - DT_INT8 = 3; // int8 type - DT_UINT8 = 4; // uint8 type - DT_INT16 = 5; // int16 type - DT_UINT16 = 6; // uint16 type - DT_INT32 = 7; // - DT_INT64 = 8; // int64 type - DT_UINT32 = 9; // unsigned int32 - DT_UINT64 = 10; // unsigned int64 - DT_BOOL = 11; // bool type - DT_DOUBLE = 12; // double type - DT_STRING = 13; // string type - DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ - DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ - DT_COMPLEX64 = 16; // complex64 type - DT_COMPLEX128 = 17; // complex128 type - DT_QINT8 = 18; // qint8 type - DT_QINT16 = 19; // qint16 type - DT_QINT32 = 20; // qint32 type - DT_QUINT8 = 21; // quint8 type - DT_QUINT16 = 22; // quint16 type - DT_RESOURCE = 23; // resource type - DT_STRING_REF = 24; // string_ref type - DT_DUAL = 25; /**< dual output type */ - DT_VARIANT = 26; // variant type - DT_BF16 = 27; // bf16 type - DT_INT4 = 28; // int4 type -} - -message AttrDef -{ - message ListValue - { - enum ListValueType{ - VT_LIST_NONE = 0; - VT_LIST_STRING = 1; - VT_LIST_INT = 2; - VT_LIST_FLOAT = 3; - VT_LIST_BOOL = 4; - VT_LIST_BYTES = 5; - VT_LIST_TENSOR_DESC = 6; - VT_LIST_TENSOR = 7; - VT_LIST_GRAPH = 8; - VT_LIST_NAMED_ATTRS = 9; - VT_LIST_DATA_TYPE = 10; - } - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3; // "list(int)" - repeated float f = 4; // "list(float)" - repeated bool b = 5; // "list(bool)" - repeated bytes bt = 7; - repeated TensorDescriptor td = 8; - repeated TensorDef t = 9; - repeated GraphDef g = 10; - repeated NamedAttrs na = 11; - repeated int64 dt = 12; // list ge::DataType - - ListValueType val_type = 20; - } - - message ListListInt{ - message ListInt{ - repeated int64 list_i = 1; // list int - } - repeated ListInt list_list_i = 1; // list list int - } - - oneof value - { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; // Used to support attr nesting - TensorDescriptor td = 11; // GeTensorDesc type - TensorDef t = 12; // GeTensor type - GraphDef g = 13; // Graph type - ListListInt list_list_int = 14; // List List Int type - int64 dt = 15; // ge::DataType - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs -{ - string name = 1; - map attr = 2; -} - -// Shape / dimension description, using row-major order -message ShapeDef -{ - repeated int64 dim = 1; // Size of each dimension -} - -// Multidimensional data description -message TensorDescriptor -{ - string name = 1; // Optional parameter, tensor name - - DataType dtype = 2; // tensor datatype - ShapeDef shape = 3; // Shape / dimension - string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" - - bool has_out_attr = 9; - int64 size = 10; - int64 weight_size = 11; - bool reuse_input = 12; - bool output_tensor = 13; - string device_type = 14; - bool input_tensor =15; - int64 real_dim_cnt = 16; - int64 reuse_input_index = 17; - int64 data_offset = 18; - int64 cmps_size = 19; - string cmps_tab = 20; - int64 cmps_tab_offset = 21; - - map attr = 5; // Set of extra parameter fields -} - -// GeTensor definition -message TensorDef -{ - TensorDescriptor desc = 1; // Tensor description - bytes data = 2; // Tensor data -} - - -// Operator description -message OpDef -{ - string name = 1; // name - string type = 2; // type - - repeated string input = 5; // input original op name + outgoing index. op_name:index - - map attr = 10; // Set of operator parameter fields - - bool has_out_attr = 20; - int64 id = 21; - int64 stream_id =22; - repeated string input_name = 23; - repeated string src_name = 24; - repeated int64 src_index = 25; - repeated string dst_name = 26; - repeated int64 dst_index = 27; - repeated int64 input_i = 28; - repeated int64 output_i = 29; - repeated int64 workspace = 30; - repeated int64 workspace_bytes = 31; - repeated bool is_input_const = 32; - repeated TensorDescriptor input_desc = 33; - repeated TensorDescriptor output_desc = 34; - repeated string subgraph_name = 35; -} - -// Graph definition -message GraphDef -{ - string name = 1; // name - - repeated string input = 4; // Graph input - repeated string output = 5; // Graph output - - repeated OpDef op = 6; // List of operators - - map attr = 11; // Extended field -} - -// model definition -message ModelDef -{ - string name = 1; // name - uint32 version = 2; // IR Proto verion - string custom_version = 3; // User model version number, passed in by user - - repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef - - map attr = 11; // Extended field -} - diff --git a/ge/common/proto/insert_op.proto b/ge/common/proto/insert_op.proto deleted file mode 100644 index 7d708865..00000000 --- a/ge/common/proto/insert_op.proto +++ /dev/null @@ -1,140 +0,0 @@ -syntax = "proto3"; - -package domi; - -message InsertNewOps { - repeated AippOpParams aipp_op = 1; - repeated MultiShapeOpParams multi_shape_op = 2; -} - -message AippOpParams { - enum InputFormat { - UNDEFINED = 0; - YUV420SP_U8 = 1; - XRGB8888_U8 = 2; - RGB888_U8 = 3; - YUV400_U8 = 4; - NC1HWC0DI_FP16 = 5; - NC1HWC0DI_S8 = 6; - ARGB8888_U8 = 7; - YUYV_U8 = 8; - YUV422SP_U8 = 9; - AYUV444_U8 = 10; - RAW10 = 11; - RAW12 = 12; - RAW16 = 13; - RAW24 = 14; - RGB16 = 15; - RGB20 = 16; - RGB24 = 17; - RGB8_IR = 18; - RGB16_IR = 19; - RGB24_IR = 20; - } - - enum AippMode { - undefined = 0; - static = 1; - dynamic = 2; - } - - // AIPPģʽ־̬AIPPͶ̬AIPP - AippMode aipp_mode = 1; - - // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 - // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 - uint32 related_input_rank = 2; - - // related_input_name is optional and the top name of data node which inserts aipp - string related_input_name = 6; - - // input_edge_idxΪѡΪͣ÷ΧΪ>=0 - // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP - // ֵ <= Dataߵĸ - repeated uint32 input_edge_idx = 3; - - // [Begin] ̬AIPPþ̬AIPPʱЧ - uint32 max_src_image_size = 4; - - // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ - bool support_rotation = 5; - - // [End] ̬AIPP - - - // [Begin] ̬AIPPö̬AIPPʱЧ - InputFormat input_format = 51; - bool csc_switch = 52; - float cpadding_value = 53; - bool rbuv_swap_switch = 54; - bool ax_swap_switch = 55; - bool single_line_mode = 56; - - int32 src_image_size_w = 57; - int32 src_image_size_h = 58; - - bool crop = 59; - int32 load_start_pos_w = 60; - int32 load_start_pos_h = 61; - int32 crop_size_w = 62; - int32 crop_size_h = 63; - - bool resize = 64; - int32 resize_output_w = 65; - int32 resize_output_h = 66; - - bool padding = 67; - int32 left_padding_size = 68; - int32 right_padding_size = 69; - int32 top_padding_size = 70; - int32 bottom_padding_size = 71; - float padding_value = 72; - - int32 mean_chn_0 = 10; - int32 mean_chn_1 = 11; - int32 mean_chn_2 = 12; - int32 mean_chn_3 = 19; - float min_chn_0 = 13; - float min_chn_1 = 14; - float min_chn_2 = 15; - float min_chn_3 = 20; - repeated float var_reci_chn_0 = 16; - repeated float var_reci_chn_1 = 17; - repeated float var_reci_chn_2 = 18; - repeated float var_reci_chn_3 = 21; - - repeated int32 matrix_r0c0 = 30; - repeated int32 matrix_r0c1 = 31; - repeated int32 matrix_r0c2 = 32; - repeated int32 matrix_r1c0 = 33; - repeated int32 matrix_r1c1 = 34; - repeated int32 matrix_r1c2 = 35; - repeated int32 matrix_r2c0 = 36; - repeated int32 matrix_r2c1 = 37; - repeated int32 matrix_r2c2 = 38; - repeated int32 output_bias_0 = 39; - repeated int32 output_bias_1 = 40; - repeated int32 output_bias_2 = 41; - repeated int32 input_bias_0 = 42; - repeated int32 input_bias_1 = 43; - repeated int32 input_bias_2 = 44; - - // [End] ̬AIPP - - // The n number that is used for raw/rgbir data into f16 transformation. - // The transformation equation is x/(2^n). If set to 0, no transform is performed. - uint32 raw_rgbir_to_f16_n = 45; -} - -message MultiShapeOpParams { - enum MultiShapeMode { - batch = 0; //̬batch - resolution = 1; //ֱ̬ʣչ - } - - MultiShapeMode mode = 1; //ģʽ - uint32 related_input_rank = 2; //Ӳ뵽ĸ - - - repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ -} diff --git a/ge/common/proto/om.proto b/ge/common/proto/om.proto deleted file mode 100644 index e15e5f80..00000000 --- a/ge/common/proto/om.proto +++ /dev/null @@ -1,396 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package domi; - -enum TargetType -{ - MINI = 0; - TINY = 1; - LITE = 2; -} - -// offline model -message ModelDef { - string name = 1; - uint32 version = 2; - - uint64 memory_size = 10; - uint32 stream_num = 11; - uint32 event_num = 12; - uint64 weight_size = 13; - uint32 label_num = 15; - repeated OpDef op = 20; - TargetType target_type = 23; - - map attr = 30; -}; - -// operator define -message OpDef { - string name = 1; - string type = 2; - - uint32 id = 3; - uint32 stream_id = 4; - - repeated string input_name = 5; - - repeated string src_name = 8; - repeated int32 src_index = 9; - repeated int64 input = 10; - repeated int64 output = 11; - repeated TensorDescriptor input_desc = 12; - repeated TensorDescriptor output_desc = 13; - repeated WeightDef weights = 14; - repeated string dst_name = 15; - repeated int32 dst_index = 16; - - repeated int64 workspace = 20; - repeated uint32 workspace_bytes = 21; - - repeated string weight_name = 22; - repeated bool is_input_const = 23; - - map attr = 30; - - QuantizeFactorParams quantize_factor = 31; - - oneof op_params { - // start at 100 here - SendOpParams sender_param = 100; - RecvOpParams receiver_param = 200; - ConvolutionOpParams convolution_param = 300; - PoolingOpParams pooling_param = 400; - EltwiseOpParams eltwise_param = 500; - BatchNormOpParams batchnorm_param = 600; - ScaleOpParams scale_param = 700; - FullConnectionOpParams full_connection_param = 800; - SoftmaxOpParams softmax_param = 900; - ActivationOpParams activation_param = 1000; - ReshapeOpParams reshape_param = 1100; - } -}; - -message SendOpParams { - uint32 event_id = 1; -}; - -message RecvOpParams { - uint32 event_id = 1; -}; - -enum QuantizeScaleType -{ - VECTOR_SCALE = 0; - SCALAR_SCALE = 1; -} - -enum QuantizeScaleMode -{ - NORMAL_MODE = 0; - SQRT_MODE = 1; -} - -enum QuantizeAlgorithm -{ - NON_OFFSET_ALGO = 0; - HALF_OFFSET_ALGO = 1; - ALL_OFFSET_ALGO = 2; -} -message QuantizeFactor -{ - QuantizeScaleMode scale_mode = 1; - bytes scale_value = 2; - int64 scale_offset = 3; - bytes offset_data_value = 4; - int64 offset_data_offset = 5; - bytes offset_weight_value = 6; - int64 offset_weight_offset = 7; - bytes offset_pad_value = 8; - int64 offset_pad_offset = 9; -}; - -message QuantizeCalcFactor -{ - bytes offsetw = 1; - int64 offsetw_offset = 2; - bytes offsetd = 3; - int64 offsetd_offset = 4; - bytes scalereq = 5; - int64 scaledreq_offset = 6; - bytes offsetdnext = 7; - int64 offsetdnext_offset = 8; -} - -message QuantizeFactorParams -{ - QuantizeAlgorithm quantize_algo = 1; - QuantizeScaleType scale_type = 2; - QuantizeFactor quantize_param = 3; - QuantizeFactor dequantize_param = 4; - QuantizeFactor requantize_param = 5; - QuantizeCalcFactor quantizecalc_param = 6; -}; - -message ConvolutionOpParams { - int32 mode = 1; - int32 algo = 2; - int32 pad_mode = 3; - uint32 group = 4; - uint32 num_output = 5; - - repeated uint32 pad = 10; - repeated uint32 stride = 11; - repeated uint32 dilation = 12; - repeated uint32 kernel = 13; - - float alpha = 20; - float beta = 21; - - WeightDef filter = 40; - WeightDef bias = 41; - - bool relu_flag = 62; - repeated uint32 adj = 70; - repeated uint32 target_shape = 71; - repeated uint32 before_pad = 72; -}; - -message PoolingOpParams { - int32 mode = 1; - int32 nan_opt = 2; - int32 pad_mode = 3; - bool global_pooling = 4; - - repeated uint32 window = 10; - repeated uint32 pad = 11; - repeated uint32 stride = 12; - bool ceil_mode = 13; - int32 data_mode = 14; - - float alpha = 20; - float beta = 21; - repeated uint32 before_pad = 22; -}; - -message EltwiseOpParams { - int32 mode = 1; - repeated float coeff = 2; - float alpha = 3; - float beta = 4; - repeated WeightDef weight = 5; - bool relu_flag = 6; -}; - -message ActivationOpParams { - int32 mode = 1; - float coef = 2; - float alpha = 3; - float beta = 4; -}; - -message BatchNormOpParams { - int32 mode = 1; - - float alpha = 2; - float beta = 3; - double epsilon = 4;//optinal,[default = 1e-5] - bool use_global_stats = 5; //optinal,by default true,testing mode - float moving_average_fraction = 6; //optinal,[default = .999]; - - WeightDef estimated_mean = 7; - WeightDef estimated_variance = 8; - - WeightDef scale = 9; - WeightDef bias = 10; -}; - -message ScaleOpParams { - WeightDef scale = 1; - WeightDef bias = 2; -}; - -message ReshapeOpParams { - float alpha = 1; - float beta = 2; - ShapeDef shape = 3; - int32 axis = 4; - int32 num_axes = 5; - int32 format = 6; -}; - -message SoftmaxOpParams { - int32 algo = 1; - int32 mode = 2; - float alpha = 3; - float beta = 4; -}; - -message FullConnectionOpParams { - WeightDef filter = 1; - WeightDef bias = 2; - uint32 num_output = 3; - bool relu_flag = 12; -}; - -message FlattenOpParams { - float alpha = 1; - float beta = 2; - int32 start_axis = 3; - int32 end_axis = 4; -} - -message AddLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message MulLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message AddOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message MulOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message SubOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message BiasAddOpParams { - float alpha = 1; - float beta = 2; - - WeightDef bias = 10; -}; - -message MatMulOpParams { - float alpha = 1; - float beta = 2; - bool transposeX = 3; - bool transposeW = 4; - - WeightDef filter = 10; - WeightDef bias = 12; -}; - -message RsqrtOpParams { - float alpha = 1; - float beta = 2; -}; - - -message WeightDef { - int32 format = 1; - int32 data_type = 2; - ShapeDef shape = 3; - bytes data = 4; - int64 data_offset = 5; - uint32 cmps_size = 6; - bytes cmps_tab = 7; - int64 cmps_tab_offset = 10; - CompressInfo cmps_info = 8; - AllOffsetQuantizeInfo alloffset_quantize_info = 11; -} - -message ShapeDef { - repeated int64 dim = 1; -} - -enum DeviceType { - NPU = 0; // In default, we will use NPU. - CPU = 1; // CPU -} - -message AllOffsetQuantizeInfo { - float scale = 1; - int32 offset = 2; -} - -message TensorDescriptor { - int32 format = 1; - int32 data_type = 2; - repeated int64 dim = 3; - uint32 size = 4; - bool reuse_input = 5; - bool output_tensor = 7; - DeviceType device_type = 8; - bool input_tensor = 9; - uint32 real_dim_cnt = 10; - uint32 reuse_input_index = 11; - AllOffsetQuantizeInfo alloffset_quantize_info = 12; -} - -message CompressInfo { - int32 blockRow = 1; // block row - int32 blockCol = 2; // block col - int32 fractalK = 3; // fractal K - int32 fractalN = 4; // fractal N - int32 lastFractalK = 5; // K of last fractal - int32 lastFractalN = 6; // N of last fractal - int32 cubeSize = 7; // cube's length - int32 loadDir = 8; // data load directtiono 0:col load 1:row load -} - -message AttrDef { - message ListValue { - repeated string s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated uint32 u = 6 [packed = true]; // "list(uint)" - repeated bytes bt = 7; - } - - oneof value { - string s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - uint32 u = 6; // "uint32" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs { - string name = 1; - map attr = 2; -} - diff --git a/ge/common/proto/op_mapping.proto b/ge/common/proto/op_mapping.proto deleted file mode 100644 index d626eb49..00000000 --- a/ge/common/proto/op_mapping.proto +++ /dev/null @@ -1,75 +0,0 @@ -syntax = "proto3"; -package toolkit.aicpu.dump; - -message Shape { - repeated uint64 dim = 1; -} - -message Output { - int32 data_type = 1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - string original_name = 5; - int32 original_output_index = 6; - int32 original_output_data_type = 7; - int32 original_output_format = 8; - uint64 size = 9; - Shape origin_shape = 10; -} - -message Input { - int32 data_type =1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - uint64 size = 5; - Shape origin_shape = 6; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - uint64 address = 2; - uint64 size = 3; -} - -message Op { - string op_name = 1; - string op_type = 2; -} - -message Task { - uint32 task_id = 1; - uint32 stream_id = 2; - Op op = 3; - repeated Output output = 4; - bool end_graph = 5; - repeated Input input = 6; - repeated OpBuffer buffer = 7; -} - -message OpMappingInfo { - string dump_path = 1; - oneof model_name_param { - string model_name = 2; - } - oneof model_id_param { - uint32 model_id = 3; - } - oneof step_id { - uint64 step_id_addr = 4; - } - oneof iterations_per_loop { - uint64 iterations_per_loop_addr = 5; - } - oneof loop_cond { - uint64 loop_cond_addr = 6; - } - uint32 flag = 7; // 0x01 load, 0x00 unload - repeated Task task = 8; - string dump_step = 9; -} \ No newline at end of file diff --git a/ge/common/proto/task.proto b/ge/common/proto/task.proto deleted file mode 100644 index 0da5631e..00000000 --- a/ge/common/proto/task.proto +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; - KernelDefWithHandle kernel_with_handle = 40; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; -} - -message KernelDefWithHandle { - KernelContext context = 1; - - uint64 handle = 10; - string dev_func = 11; - uint32 block_dim = 12; - uint32 args_size = 13; - bytes args = 14; - bytes sm_desc = 15; - string original_kernel_key = 16; - string node_info = 17; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - repeated uint32 origin_op_index = 8; -} - - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/ge/common/proto/tensorflow/attr_value.proto b/ge/common/proto/tensorflow/attr_value.proto deleted file mode 100644 index 438d7163..00000000 --- a/ge/common/proto/tensorflow/attr_value.proto +++ /dev/null @@ -1,70 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "AttrValueProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensor.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing the value for an attr used to configure an Op. -// Comment indicates the corresponding attr type. Only the field matching the -// attr type may be filled. -message AttrValue { - // LINT.IfChange - message ListValue { - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated DataType type = 6 [packed = true]; // "list(type)" - repeated TensorShapeProto shape = 7; // "list(shape)" - repeated TensorProto tensor = 8; // "list(tensor)" - repeated NameAttrList func = 9; // "list(attr)" - } - // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) - - oneof value { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - DataType type = 6; // "type" - TensorShapeProto shape = 7; // "shape" - TensorProto tensor = 8; // "tensor" - ListValue list = 1; // any "list(...)" - - // "func" represents a function. func.name is a function's name or - // a primitive op's name. func.attr.first is the name of an attr - // defined for that function. func.attr.second is the value for - // that attr in the instantiation. - NameAttrList func = 10; - - // This is a placeholder only used in nodes defined inside a - // function. It indicates the attr value will be supplied when - // the function is instantiated. For example, let us suppose a - // node "N" in function "FN". "N" has an attr "A" with value - // placeholder = "foo". When FN is instantiated with attr "foo" - // set to "bar", the instantiated node N's attr A will have been - // given the value "bar". - string placeholder = 9; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NameAttrList { - string name = 1; - map attr = 2; -} diff --git a/ge/common/proto/tensorflow/function.proto b/ge/common/proto/tensorflow/function.proto deleted file mode 100644 index 44681e32..00000000 --- a/ge/common/proto/tensorflow/function.proto +++ /dev/null @@ -1,108 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "FunctionProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; -import "node_def.proto"; -import "op_def.proto"; - -// A library is a set of named functions. -message FunctionDefLibrary { - repeated FunctionDef function = 1; - repeated GradientDef gradient = 2; -} - -// A function can be instantiated when the runtime can bind every attr -// with a value. When a GraphDef has a call to a function, it must -// have binding for every attr defined in the signature. -// * device spec, etc. -message FunctionDef { - // The definition of the function's name, arguments, return values, - // attrs etc. - OpDef signature = 1; - - // Attributes specific to this function definition. - map attr = 5; - - // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. - reserved 2; - - // In both of the following fields, there is the need to specify an - // output that is used as either the input to another node (in - // `node_def`) or as a return value of the function (in `ret`). - // Unlike the NodeDefs in GraphDef, we need to be able to specify a - // list in some cases (instead of just single outputs). Also, we - // need to be able to deal with lists of unknown length (so the - // output index may not be known at function definition time). So - // we use the following format instead: - // * "fun_in" where "fun_in" is the name of a function input arg in - // the `signature` field above. This represents that input, whether - // it is a single tensor or a list. - // * "fun_in:0" gives the first element of a function input arg (a - // non-list input is considered a list of length 1 for these - // purposes). - // * "node:out" where "node" is the name of a node in `node_def` and - // "out" is the name one of its op's output arguments (the name - // comes from the OpDef of the node's op). This represents that - // node's output, whether it is a single tensor or a list. - // Note: We enforce that an op's output arguments are never - // renamed in the backwards-compatibility test. - // * "node:out:0" gives the first element of a node output arg (a - // non-list output is considered a list of length 1 for these - // purposes). - // - // NOT CURRENTLY SUPPORTED (but may be in the future): - // * "node:out:-1" gives last element in a node output list - // * "node:out:1:" gives a list with all but the first element in a - // node output list - // * "node:out::-1" gives a list with all but the last element in a - // node output list - - // The body of the function. Unlike the NodeDefs in a GraphDef, attrs - // may have values of type `placeholder` and the `input` field uses - // the "output" format above. - - // By convention, "op" in node_def is resolved by consulting with a - // user-defined library first. If not resolved, "func" is assumed to - // be a builtin op. - repeated NodeDef node_def = 3; - - // A mapping from the output arg names from `signature` to the - // outputs from `node_def` that should be returned by the function. - map ret = 4; -} - -// GradientDef defines the gradient function of a function defined in -// a function library. -// -// A gradient function g (specified by gradient_func) for a function f -// (specified by function_name) must follow the following: -// -// The function 'f' must be a numerical function which takes N inputs -// and produces M outputs. Its gradient function 'g', which is a -// function taking N + M inputs and produces N outputs. -// -// I.e. if we have -// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), -// then, g is -// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, -// dL/dy1, dL/dy2, ..., dL/dy_M), -// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the -// loss function). dL/dx_i is the partial derivative of L with respect -// to x_i. -message GradientDef { - string function_name = 1; // The function name. - string gradient_func = 2; // The gradient function's name. -} diff --git a/ge/common/proto/tensorflow/graph.proto b/ge/common/proto/tensorflow/graph.proto deleted file mode 100644 index 73bfc6ee..00000000 --- a/ge/common/proto/tensorflow/graph.proto +++ /dev/null @@ -1,64 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "GraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "node_def.proto"; -import "function.proto"; -import "versions.proto"; - -// Represents the graph of operations -message GraphDef { - repeated NodeDef node = 1; - - // Compatibility versions of the graph. See core/public/version.h for version - // history. The GraphDef version is distinct from the TensorFlow version, and - // each release of TensorFlow will support a range of GraphDef versions. - VersionDef versions = 4; - - // Deprecated single version field; use versions above instead. Since all - // GraphDef changes before "versions" was introduced were forward - // compatible, this field is entirely ignored. - int32 version = 3 [deprecated = true]; - - // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. - // - // "library" provides user-defined functions. - // - // Naming: - // * library.function.name are in a flat namespace. - // NOTE: We may need to change it to be hierarchical to support - // different orgs. E.g., - // { "/google/nn", { ... }}, - // { "/google/vision", { ... }} - // { "/org_foo/module_bar", { ... }} - // map named_lib; - // * If node[i].op is the name of one function in "library", - // node[i] is deemed as a function call. Otherwise, node[i].op - // must be a primitive operation supported by the runtime. - // - // - // Function call semantics: - // - // * The callee may start execution as soon as some of its inputs - // are ready. The caller may want to use Tuple() mechanism to - // ensure all inputs are ready in the same time. - // - // * The consumer of return values may start executing as soon as - // the return values the consumer depends on are ready. The - // consumer may want to use Tuple() mechanism to ensure the - // consumer does not start until all return values of the callee - // function are ready. - FunctionDefLibrary library = 2; -}; diff --git a/ge/common/proto/tensorflow/graph_library.proto b/ge/common/proto/tensorflow/graph_library.proto deleted file mode 100644 index 7bca0838..00000000 --- a/ge/common/proto/tensorflow/graph_library.proto +++ /dev/null @@ -1,22 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; - -import "graph.proto"; - -message GeGraphDef { - string name = 1; - GraphDef graph = 2; -} - -message GraphDefLibrary { - repeated GeGraphDef graph_def = 1; -}; \ No newline at end of file diff --git a/ge/common/proto/tensorflow/node_def.proto b/ge/common/proto/tensorflow/node_def.proto deleted file mode 100644 index 50cf5cac..00000000 --- a/ge/common/proto/tensorflow/node_def.proto +++ /dev/null @@ -1,71 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "NodeProto"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; - -message NodeDef { - // The name given to this operator. Used for naming inputs, - // logging, visualization, etc. Unique within a single GraphDef. - // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". - string name = 1; - - // The operation name. There may be custom parameters in attrs. - // Op names starting with an underscore are reserved for internal use. - string op = 2; - - // Each input is "node:src_output" with "node" being a string name and - // "src_output" indicating which output tensor to use from "node". If - // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs - // may optionally be followed by control inputs that have the format - // "^node". - repeated string input = 3; - - // A (possibly partial) specification for the device on which this - // node should be placed. - // The expected syntax for this string is as follows: - // - // DEVICE_SPEC ::= PARTIAL_SPEC - // - // PARTIAL_SPEC ::= ("/" CONSTRAINT) * - // CONSTRAINT ::= ("job:" JOB_NAME) - // | ("replica:" [1-9][0-9]*) - // | ("task:" [1-9][0-9]*) - // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) - // - // Valid values for this string include: - // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) - // * "/job:worker/device:GPU:3" (partial specification) - // * "" (no specification) - // - // If the constraints do not resolve to a single device (or if this - // field is empty or not present), the runtime will attempt to - // choose a device automatically. - string device = 4; - - // Operation-specific graph-construction-time configuration. - // Note that this should include all attrs defined in the - // corresponding OpDef, including those with a value matching - // the default -- this allows the default to change and makes - // NodeDefs easier to interpret on their own. However, if - // an attr with a default is not specified in this list, the - // default will be used. - // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and - // one of the names from the corresponding OpDef's attr field). - // The values must have a type matching the corresponding OpDef - // attr's type field. - // Add some examples here showing best practices. - map attr = 5; -}; diff --git a/ge/common/proto/tensorflow/op_def.proto b/ge/common/proto/tensorflow/op_def.proto deleted file mode 100644 index 7f0e8ce2..00000000 --- a/ge/common/proto/tensorflow/op_def.proto +++ /dev/null @@ -1,172 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "OpDefProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; -import "types.proto"; - -// Defines an operation. A NodeDef in a GraphDef specifies an Op by -// using the "op" field which should match the name of a OpDef. -// LINT.IfChange -message OpDef { - // Op names starting with an underscore are reserved for internal use. - // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". - string name = 1; - - // For describing inputs and outputs. - message ArgDef { - // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". - string name = 1; - - // Human readable description. - string description = 2; - - // Describes the type of one or more tensors that are accepted/produced - // by this input/output arg. The only legal combinations are: - // * For a single tensor: either the "type" field is set or the - // "type_attr" field is set to the name of an attr with type "type". - // * For a sequence of tensors with the same type: the "number_attr" - // field will be set to the name of an attr with type "int", and - // either the "type" or "type_attr" field will be set as for - // single tensors. - // * For a sequence of tensors, the "type_list_attr" field will be set - // to the name of an attr with type "list(type)". - DataType type = 3; - string type_attr = 4; // if specified, attr must have type "type" - string number_attr = 5; // if specified, attr must have type "int" - // If specified, attr must have type "list(type)", and none of - // type, type_attr, and number_attr may be specified. - string type_list_attr = 6; - - // For inputs: if true, the inputs are required to be refs. - // By default, inputs can be either refs or non-refs. - // For outputs: if true, outputs are refs, otherwise they are not. - bool is_ref = 16; - }; - - // Description of the input(s). - repeated ArgDef input_arg = 2; - - // Description of the output(s). - repeated ArgDef output_arg = 3; - - // Description of the graph-construction-time configuration of this - // Op. That is to say, this describes the attr fields that will - // be specified in the NodeDef. - message AttrDef { - // A descriptive name for the argument. May be used, e.g. by the - // Python client, as a keyword argument name, and so should match - // the regexp "[a-z][a-z0-9_]+". - string name = 1; - - // One of the type names from attr_value.proto ("string", "list(string)", - // "int", etc.). - string type = 2; - - // A reasonable default for this attribute if the user does not supply - // a value. If not specified, the user must supply a value. - AttrValue default_value = 3; - - // Human-readable description. - string description = 4; - - - // --- Constraints --- - // These constraints are only in effect if specified. Default is no - // constraints. - - // For type == "int", this is a minimum value. For "list(___)" - // types, this is the minimum length. - bool has_minimum = 5; - int64 minimum = 6; - - // The set of allowed values. Has type that is the "list" version - // of the "type" field above (uses the "list" field of AttrValue). - // If type == "type" or "list(type)" above, then the "type" field - // of "allowed_values.list" has the set of allowed DataTypes. - // If type == "string" or "list(string)", then the "s" field of - // "allowed_values.list" has the set of allowed strings. - AttrValue allowed_values = 7; - } - repeated AttrDef attr = 4; - - // Optional deprecation based on GraphDef versions. - OpDeprecation deprecation = 8; - - // One-line human-readable description of what the Op does. - string summary = 5; - - // Additional, longer human-readable description of what the Op does. - string description = 6; - - // ------------------------------------------------------------------------- - // Which optimizations this operation can participate in. - - // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) - bool is_commutative = 18; - - // If is_aggregate is true, then this operation accepts N >= 2 - // inputs and produces 1 output all of the same type. Should be - // associative and commutative, and produce output with the same - // shape as the input. The optimizer may replace an aggregate op - // taking input from multiple devices with a tree of aggregate ops - // that aggregate locally within each device (and possibly within - // groups of nearby devices) before communicating. - bool is_aggregate = 16; // for things like add - - // Other optimizations go here, like - // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. - - // ------------------------------------------------------------------------- - // Optimization constraints. - - // Ops are marked as stateful if their behavior depends on some state beyond - // their input tensors (e.g. variable reading op) or if they have - // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops - // must always produce the same output for the same input and have - // no side-effects. - // - // By default Ops may be moved between devices. Stateful ops should - // either not be moved, or should only be moved if that state can also - // be moved (e.g. via some sort of save / restore). - // Stateful ops are guaranteed to never be optimized away by Common - // Subexpression Elimination (CSE). - bool is_stateful = 17; // for things like variables, queue - - // ------------------------------------------------------------------------- - // Non-standard options. - - // By default, all inputs to an Op must be initialized Tensors. Ops - // that may initialize tensors for the first time should set this - // field to true, to allow the Op to take an uninitialized Tensor as - // input. - bool allows_uninitialized_input = 19; // for Assign, etc. -}; -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) - -// Information about version-dependent deprecation of an op -message OpDeprecation { - // First GraphDef version at which the op is disallowed. - int32 version = 1; - - // Explanation of why it was deprecated and what to use instead. - string explanation = 2; -}; - -// A collection of OpDefs -message OpList { - repeated OpDef op = 1; -}; diff --git a/ge/common/proto/tensorflow/resource_handle.proto b/ge/common/proto/tensorflow/resource_handle.proto deleted file mode 100644 index 91c46c9a..00000000 --- a/ge/common/proto/tensorflow/resource_handle.proto +++ /dev/null @@ -1,37 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "ResourceHandle"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Protocol buffer representing a handle to a tensorflow resource. Handles are -// not valid across executions, but can be serialized back and forth from within -// a single run. -message ResourceHandleProto { - // Unique name for the device containing the resource. - string device = 1; - - // Container in which this resource is placed. - string container = 2; - - // Unique name of this resource. - string name = 3; - - // Hash code for the type of the resource. Is only valid in the same device - // and in the same execution. - uint64 hash_code = 4; - - // For debug-only, the name of the type pointed to by this handle, if - // available. - string maybe_type_name = 5; -}; diff --git a/ge/common/proto/tensorflow/tensor.proto b/ge/common/proto/tensorflow/tensor.proto deleted file mode 100644 index 48eeb6c4..00000000 --- a/ge/common/proto/tensorflow/tensor.proto +++ /dev/null @@ -1,102 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TensorProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "resource_handle.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing a tensor. -message TensorProto { - DataType dtype = 1; - - // Shape of the tensor. - TensorShapeProto tensor_shape = 2; - - // Only one of the representations below is set, one of "tensor_contents" and - // the "xxx_val" attributes. We are not using oneof because as oneofs cannot - // contain repeated fields it would require another extra set of messages. - - // Version number. - // - // In version 0, if the "repeated xxx" representations contain only one - // element, that element is repeated to fill the shape. This makes it easy - // to represent a constant Tensor with a single value. - int32 version_number = 3; - - // Serialized raw tensor content from either Tensor::AsProtoTensorContent or - // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation - // can be used for all tensor types. The purpose of this representation is to - // reduce serialization overhead during RPC call by avoiding serialization of - // many repeated small items. - bytes tensor_content = 4; - - // Type specific representations that make it easy to create tensor protos in - // all languages. Only the representation corresponding to "dtype" can - // be set. The values hold the flattened representation of the tensor in - // row major order. - - // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll - // have some pointless zero padding for each value here. - repeated int32 half_val = 13 [packed = true]; - - // DT_FLOAT. - repeated float float_val = 5 [packed = true]; - - // DT_DOUBLE. - repeated double double_val = 6 [packed = true]; - - // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. - repeated int32 int_val = 7 [packed = true]; - - // DT_STRING - repeated bytes string_val = 8; - - // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real - // and imaginary parts of i-th single precision complex. - repeated float scomplex_val = 9 [packed = true]; - - // DT_INT64 - repeated int64 int64_val = 10 [packed = true]; - - // DT_BOOL - repeated bool bool_val = 11 [packed = true]; - - // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real - // and imaginary parts of i-th double precision complex. - repeated double dcomplex_val = 12 [packed = true]; - - // DT_RESOURCE - repeated ResourceHandleProto resource_handle_val = 14; - - // DT_VARIANT - repeated VariantTensorDataProto variant_val = 15; - - // DT_UINT32 - repeated uint32 uint32_val = 16 [packed = true]; - - // DT_UINT64 - repeated uint64 uint64_val = 17 [packed = true]; -}; - -// Protocol buffer representing the serialization format of DT_VARIANT tensors. -message VariantTensorDataProto { - // Name of the type of objects being serialized. - string type_name = 1; - // Portions of the object that are not Tensors. - bytes metadata = 2; - // Tensors contained within objects being serialized. - repeated TensorProto tensors = 3; -} diff --git a/ge/common/proto/tensorflow/tensor_shape.proto b/ge/common/proto/tensorflow/tensor_shape.proto deleted file mode 100644 index 3a6d8c5a..00000000 --- a/ge/common/proto/tensorflow/tensor_shape.proto +++ /dev/null @@ -1,53 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -// Protocol buffer representing the shape of tensors. - -syntax = "proto3"; -option cc_enable_arenas = true; -option java_outer_classname = "TensorShapeProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -package domi.tensorflow; - -// Dimensions of a tensor. -message TensorShapeProto { - // One dimension of the tensor. - message Dim { - // Size of the tensor in that dimension. - // This value must be >= -1, but values of -1 are reserved for "unknown" - // shapes (values of -1 mean "unknown" dimension). Certain wrappers - // that work with TensorShapeProto may fail at runtime when deserializing - // a TensorShapeProto containing a dim value of -1. - int64 size = 1; - - // Optional name of the tensor dimension. - string name = 2; - }; - - // Dimensions of the tensor, such as {"input", 30}, {"output", 40} - // for a 30 x 40 2D tensor. If an entry has size -1, this - // corresponds to a dimension of unknown size. The names are - // optional. - // - // The order of entries in "dim" matters: It indicates the layout of the - // values in the tensor in-memory representation. - // - // The first entry in "dim" is the outermost dimension used to layout the - // values, the last entry is the innermost dimension. This matches the - // in-memory layout of RowMajor Eigen tensors. - // - // If "dim.size()" > 0, "unknown_rank" must be false. - repeated Dim dim = 2; - - // If true, the number of dimensions in the shape is unknown. - // - // If true, "dim.size()" must be 0. - bool unknown_rank = 3; -}; diff --git a/ge/common/proto/tensorflow/types.proto b/ge/common/proto/tensorflow/types.proto deleted file mode 100644 index f40e49cb..00000000 --- a/ge/common/proto/tensorflow/types.proto +++ /dev/null @@ -1,82 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TypesProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// LINT.IfChange -enum DataType { - // Not a legal value for DataType. Used to indicate a DataType field - // has not been set. - DT_INVALID = 0; - - // Data types that all computation devices are expected to be - // capable to support. - DT_FLOAT = 1; - DT_DOUBLE = 2; - DT_INT32 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_INT8 = 6; - DT_STRING = 7; - DT_COMPLEX64 = 8; // Single-precision complex - DT_INT64 = 9; - DT_BOOL = 10; - DT_QINT8 = 11; // Quantized int8 - DT_QUINT8 = 12; // Quantized uint8 - DT_QINT32 = 13; // Quantized int32 - DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. - DT_QINT16 = 15; // Quantized int16 - DT_QUINT16 = 16; // Quantized uint16 - DT_UINT16 = 17; - DT_COMPLEX128 = 18; // Double-precision complex - DT_HALF = 19; - DT_RESOURCE = 20; - DT_VARIANT = 21; // Arbitrary C++ data types - DT_UINT32 = 22; - DT_UINT64 = 23; - - // Do not use! These are only for parameters. Every enum above - // should have a corresponding value below (verified by types_test). - DT_FLOAT_REF = 101; - DT_DOUBLE_REF = 102; - DT_INT32_REF = 103; - DT_UINT8_REF = 104; - DT_INT16_REF = 105; - DT_INT8_REF = 106; - DT_STRING_REF = 107; - DT_COMPLEX64_REF = 108; - DT_INT64_REF = 109; - DT_BOOL_REF = 110; - DT_QINT8_REF = 111; - DT_QUINT8_REF = 112; - DT_QINT32_REF = 113; - DT_BFLOAT16_REF = 114; - DT_QINT16_REF = 115; - DT_QUINT16_REF = 116; - DT_UINT16_REF = 117; - DT_COMPLEX128_REF = 118; - DT_HALF_REF = 119; - DT_RESOURCE_REF = 120; - DT_VARIANT_REF = 121; - DT_UINT32_REF = 122; - DT_UINT64_REF = 123; -} -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/c/c_api.h, -// https://www.tensorflow.org/code/tensorflow/go/tensor.go, -// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, -// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, -// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/ge/common/proto/tensorflow/versions.proto b/ge/common/proto/tensorflow/versions.proto deleted file mode 100644 index 4e81548f..00000000 --- a/ge/common/proto/tensorflow/versions.proto +++ /dev/null @@ -1,39 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "VersionsProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Version information for a piece of serialized data -// -// There are different types of versions for each type of data -// (GraphDef, etc.), but they all have the same common shape -// described here. -// -// Each consumer has "consumer" and "min_producer" versions (specified -// elsewhere). A consumer is allowed to consume this data if -// -// producer >= min_producer -// consumer >= min_consumer -// consumer not in bad_consumers -// -message VersionDef { - // The version of the code that produced this data. - int32 producer = 1; - - // Any consumer below this version is not allowed to consume this data. - int32 min_consumer = 2; - - // Specific consumer versions which are disallowed (e.g. due to bugs). - repeated int32 bad_consumers = 3; -}; diff --git a/ge/common/tbe_kernel_store.h b/ge/common/tbe_kernel_store.h index 6304af50..1492bdd9 100755 --- a/ge/common/tbe_kernel_store.h +++ b/ge/common/tbe_kernel_store.h @@ -21,7 +21,7 @@ namespace ge { -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEKernelStore : public KernelStore { +class TBEKernelStore : public KernelStore { public: TBEKernelStore(); ~TBEKernelStore() {} diff --git a/ge/common/thread_pool.cc b/ge/common/thread_pool.cc index dead0127..56f8ee60 100644 --- a/ge/common/thread_pool.cc +++ b/ge/common/thread_pool.cc @@ -23,10 +23,10 @@ #include #include -#include "register/register_types.h" +#include "external/register/register_types.h" namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::ThreadPool(uint32_t size) : is_stoped_(false) { +ThreadPool::ThreadPool(uint32_t size) : is_stoped_(false) { idle_thrd_num_ = size < 1 ? 1 : size; for (uint32_t i = 0; i < idle_thrd_num_; ++i) { @@ -34,7 +34,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::ThreadPool(uint32_t } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::~ThreadPool() { +ThreadPool::~ThreadPool() { is_stoped_.store(true); { std::unique_lock lock{m_lock_}; diff --git a/ge/common/thread_pool.h b/ge/common/thread_pool.h index e173618f..777a3c9b 100755 --- a/ge/common/thread_pool.h +++ b/ge/common/thread_pool.h @@ -31,13 +31,13 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "external/ge/ge_api_error_codes.h" -#include "graph/types.h" +#include "external/graph/types.h" #include "common/ge/ge_util.h" namespace ge { using ThreadTask = std::function; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ThreadPool { +class ThreadPool { public: explicit ThreadPool(uint32_t size = 4); ~ThreadPool(); diff --git a/ge/graph/common/transop_util.cc b/ge/common/transop_util.cc similarity index 98% rename from ge/graph/common/transop_util.cc rename to ge/common/transop_util.cc index 62b4c4e4..914e80aa 100755 --- a/ge/graph/common/transop_util.cc +++ b/ge/common/transop_util.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "common/types.h" #include "graph/utils/type_utils.h" diff --git a/ge/graph/common/transop_util.h b/ge/common/transop_util.h similarity index 95% rename from ge/graph/common/transop_util.h rename to ge/common/transop_util.h index 883ae41b..57e4adad 100644 --- a/ge/graph/common/transop_util.h +++ b/ge/common/transop_util.h @@ -23,7 +23,7 @@ #include "graph/node.h" namespace ge { -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY TransOpUtil { +class TransOpUtil { public: static bool IsTransOp(const NodePtr &node); diff --git a/ge/common/types.cc b/ge/common/types.cc index 98ae7737..b1127483 100644 --- a/ge/common/types.cc +++ b/ge/common/types.cc @@ -15,7 +15,7 @@ */ #include "framework/common/types.h" -#include "graph/types.h" +#include "external/graph/types.h" namespace ge { // dump diff --git a/ge/common/util.cc b/ge/common/util.cc index 448efc0f..6d77dbc8 100644 --- a/ge/common/util.cc +++ b/ge/common/util.cc @@ -70,7 +70,7 @@ static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Messag return proto->ParseFromCodedStream(&coded_stream); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) { +bool ReadProtoFromArray(const void *data, int size, Message *proto) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), return false, "incorrect parameter. proto is nullptr || data is nullptr || size is 0"); @@ -112,8 +112,7 @@ long GetFileLength(const std::string &input_file) { * @return false fail * @return true success */ -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, - int &length) { +bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file is nullptr"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), return false, "incorrect parameter. buffer is nullptr"); @@ -141,8 +140,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co return true; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, - std::vector &buffer) { +bool ReadBytesFromBinaryFile(const char *file_name, std::vector &buffer) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file path is null"); std::string real_path = RealPath(file_name); @@ -177,7 +175,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co * @return -1 fail * @return 0 success */ -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std::string &directory_path) { +int CreateDirectory(const std::string &directory_path) { GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); auto dir_path_len = directory_path.length(); if (dir_path_len >= MMPA_MAX_PATH) { @@ -219,7 +217,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: return 0; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() { +std::string CurrentTimeInStr() { std::time_t now = std::time(nullptr); std::tm *ptm = std::localtime(&now); if (ptm == nullptr) { @@ -235,8 +233,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() return std::string(buffer); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file, - google::protobuf::Message *message) { +bool ReadProtoFromText(const char *file, google::protobuf::Message *message) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), return false, "incorrect parameter. nullptr == file || nullptr == message"); @@ -266,8 +263,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size, - google::protobuf::Message *message) { +bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), return false, "incorrect parameter. data is nullptr || message is nullptr"); std::string str(data, static_cast(size)); @@ -281,7 +277,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const cha return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { +uint64_t GetCurrentTimestamp() { mmTimeval tv{}; int ret = mmGetTimeOfDay(&tv, nullptr); GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed, ret:%d, errmsg:%s", ret, strerror(errno)); @@ -289,7 +285,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() return static_cast(total_use_time); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimestap() { +uint32_t GetCurrentSecondTimestap() { mmTimeval tv{}; int ret = mmGetTimeOfDay(&tv, nullptr); GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed, ret:%d, errmsg:%s", ret, strerror(errno)); @@ -297,7 +293,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimest return static_cast(total_use_time); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInt64MulOverflow(int64_t a, int64_t b) { +bool CheckInt64MulOverflow(int64_t a, int64_t b) { if (a > 0) { if (b > 0) { if (a > (INT64_MAX / b)) { @@ -322,7 +318,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInt64MulOverflow(int6 return true; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { +std::string RealPath(const char *path) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH, ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, @@ -340,15 +336,23 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char return res; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path, - const std::string &atc_param) { +void PathValidErrReport(const std::string &file_path, const std::string &atc_param, const std::string &reason) { + if (!atc_param.empty()) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({atc_param, file_path, reason})); + } else { + REPORT_INNER_ERROR("E19999", "Path[%s] invalid, reason:%s", file_path.c_str(), reason.c_str()); + } +} + +bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param) { // The specified path is empty std::map args_map; if (file_path.empty()) { - if (atc_param != "") { - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); + if (!atc_param.empty()) { + REPORT_INPUT_ERROR("E10004", std::vector({"parameter"}), std::vector({atc_param})); } else { - REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid"); + REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid."); } GELOGW("Input parameter %s is empty.", file_path.c_str()); return false; @@ -356,13 +360,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string real_path = RealPath(file_path.c_str()); // Unable to get absolute path (does not exist or does not have permission to access) if (real_path.empty()) { - if (atc_param != "") { - std::string reason = "realpath error, errmsg:" + std::string(strerror(errno)); - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {atc_param, file_path, reason}); - } else { - REPORT_INNER_ERROR("E19999", "Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); - } + std::string reason = "realpath error, errmsg:" + std::string(strerror(errno)); + PathValidErrReport(file_path, atc_param, reason); GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); return false; } @@ -378,23 +377,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( !ValidateStr(real_path, mode), - if (atc_param != "") { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {atc_param, real_path, kPathValidReason}); - } else { - REPORT_INNER_ERROR("E19999", "Path[%s] has invalid char, %s", file_path.c_str(), kPathValidReason); - } + PathValidErrReport(file_path, atc_param, kPathValidReason); return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); // The absolute path points to a file that is not readable if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) { - if (atc_param != "") { - std::string reason = "cat not access, errmsg:" + std::string(strerror(errno)); - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {atc_param, file_path, reason}); - } else { - REPORT_INNER_ERROR("E19999", "Path[%s] can't acccess, errmsg:%s", file_path.c_str(), strerror(errno)); - } + PathValidErrReport(file_path, atc_param, "cat not access, errmsg:" + std::string(strerror(errno))); GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); return false; } @@ -402,14 +390,13 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const return true; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const std::string &file_path, - const std::string &atc_param) { +bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param) { // The specified path is empty if (file_path.empty()) { - if (atc_param != "") { - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); + if (!atc_param.empty()) { + REPORT_INPUT_ERROR("E10004", std::vector({"parameter"}), std::vector({atc_param})); } else { - REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid"); + REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid."); } ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); GELOGW("Input parameter's value is empty."); @@ -417,17 +404,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const } GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH, - if (atc_param != "") { - std::string reason = "len is too long, it must be less than " + - std::to_string(MMPA_MAX_PATH); - ErrorManager::GetInstance().ATCReportErrMessage( - "E10001", {"parameter", "value", "reason"}, - {atc_param, file_path, reason}); - } else { - REPORT_INNER_ERROR("E19999", "Path[%s] len is too long, it must be less than %d", - file_path.c_str(), MMPA_MAX_PATH); - } - return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(), + std::string reason = "len is too long, it must be less than " + + std::to_string(MMPA_MAX_PATH); + PathValidErrReport(file_path, atc_param, reason); + return false, "Path[%s] len is too long, it must be less than %d", file_path.c_str(), MMPA_MAX_PATH); // A regular matching expression to verify the validity of the input file path @@ -441,12 +421,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( !ValidateStr(file_path, mode), - if (atc_param != "") { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {atc_param, file_path, kPathValidReason}); - } else { - REPORT_INNER_ERROR("E19999", "Path[%s] has invalid char, %s", file_path.c_str(), kPathValidReason); - } + PathValidErrReport(file_path, atc_param, kPathValidReason); return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason); std::string real_path = RealPath(file_path.c_str()); @@ -454,13 +429,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const if (!real_path.empty()) { // File is not readable or writable if (mmAccess2(real_path.c_str(), M_W_OK | M_F_OK) != EN_OK) { - if (atc_param != "") { - std::string reason = "cat not access, errmsg:" + std::string(strerror(errno)); - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {atc_param, file_path, reason}); - } else { - REPORT_INNER_ERROR("E19999", "Path[%s] can't acccess, errmsg:%s", file_path.c_str(), strerror(errno)); - } + PathValidErrReport(file_path, atc_param, "cat not access, errmsg:" + std::string(strerror(errno))); GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno)); return false; } @@ -479,12 +448,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const std::string prefix_path = std::string(file_path).substr(0, static_cast(path_split_pos)); // Determine whether the specified path is valid by creating the path if (CreateDirectory(prefix_path) != 0) { - if (atc_param != "") { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {atc_param, file_path, "Can not create directory"}); - } else { - REPORT_INNER_ERROR("E19999", "Path[%s] Can not create directory", file_path.c_str()); - } + PathValidErrReport(file_path, atc_param, "Can not create directory"); GELOGW("Can not create directory[%s].", file_path.c_str()); return false; } @@ -582,7 +546,7 @@ FMK_FUNC_HOST_VISIBILITY bool IsValidFile(const char *file_path) { return true; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status CheckPath(const char *path, size_t length) { +Status CheckPath(const char *path, size_t length) { if (path == nullptr) { GELOGE(PARAM_INVALID, "[Check][Param]Config path is invalid"); REPORT_CALL_ERROR("E19999", "Config path is invalid"); diff --git a/ge/engine_manager/dnnengine_manager.cc b/ge/engine_manager/dnnengine_manager.cc index e89fc847..36f11828 100644 --- a/ge/engine_manager/dnnengine_manager.cc +++ b/ge/engine_manager/dnnengine_manager.cc @@ -16,13 +16,12 @@ #include "engine_manager/dnnengine_manager.h" -#include #include #include #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/ge/ge_util.h" #include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" @@ -240,6 +239,10 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { op_desc->SetOpEngineName(it.engine); op_desc->SetOpKernelLibName(kernel_name); // set attrs for taking information when load txt to graph object + if (it.flagAsync) { + GELOGD("Set aicpu blocking op:%s attribute(is_blocking_op):true", op_desc->GetName().c_str()); + (void)AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + } (void) AttrUtils::SetStr(op_desc, ATTR_NAME_ENGINE_NAME_FOR_LX, it.engine); (void) AttrUtils::SetStr(op_desc, ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX, kernel_name); GELOGD("DNNEngineManager:Set OpKernelLibName %s and engine name %s to op_desc %s", kernel_name.c_str(), diff --git a/ge/engine_manager/dnnengine_manager.h b/ge/engine_manager/dnnengine_manager.h index c3ae5b95..42da3596 100755 --- a/ge/engine_manager/dnnengine_manager.h +++ b/ge/engine_manager/dnnengine_manager.h @@ -26,9 +26,9 @@ #include "nlohmann/json.hpp" #include "common/ge/plugin_manager.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/opskernel/ops_kernel_info_types.h" -#include "engine/dnnengine.h" +#include "framework/engine/dnnengine.h" #include "graph/op_desc.h" #include "graph/node.h" diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt old mode 100644 new mode 100755 index f1267c1e..54cb7639 --- a/ge/executor/CMakeLists.txt +++ b/ge/executor/CMakeLists.txt @@ -1,13 +1,9 @@ set(SRC_LIST "ge_executor.cc" "../common/profiling/profiling_manager.cc" - "../common/ge/plugin_manager.cc" - "../common/ge/op_tiling_manager.cc" - "../common/dump/dump_properties.cc" - "../common/dump/exception_dumper.cc" - "../common/dump/dump_manager.cc" "../common/dump/dump_op.cc" "../common/dump/opdebug_register.cc" + "../common/dump/exception_dumper.cc" "../common/profiling/ge_profiling.cc" "../graph/load/graph_loader.cc" "../graph/execute/graph_execute.cc" @@ -22,8 +18,6 @@ set(SRC_LIST "../graph/manager/rdma_pool_allocator.cc" "../graph/manager/host_mem_allocator.cc" "../hybrid/node_executor/aicpu/aicpu_ext_info.cc" - "../model/ge_model.cc" - "../model/ge_root_model.cc" "../graph/load/model_manager/davinci_model.cc" "../graph/load/model_manager/model_manager.cc" "../graph/load/model_manager/tbe_handle_store.cc" @@ -37,6 +31,7 @@ set(SRC_LIST "../graph/load/model_manager/task_info/task_info.cc" "../graph/load/model_manager/task_info/event_record_task_info.cc" "../graph/load/model_manager/task_info/event_wait_task_info.cc" + "../graph/load/model_manager/task_info/ffts_task_info.cc" "../graph/load/model_manager/task_info/fusion_start_task_info.cc" "../graph/load/model_manager/task_info/fusion_stop_task_info.cc" "../graph/load/model_manager/task_info/kernel_ex_task_info.cc" @@ -54,7 +49,6 @@ set(SRC_LIST "../graph/load/model_manager/task_info/model_exit_task_info.cc" "../graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc" "../graph/load/model_manager/task_info/super_kernel/super_kernel.cc" - "../graph/common/local_context.cc" "../opskernel_manager/ops_kernel_builder_manager.cc" "../single_op/single_op_manager.cc" "../single_op/single_op_model.cc" @@ -101,7 +95,6 @@ set(SRC_LIST "../hybrid/node_executor/task_context.cc" "../hybrid/hybrid_davinci_model.cc" "../ge_local_engine/engine/host_cpu_engine.cc" - "../graph/common/omg_util.cc" "../graph/manager/host_mem_manager.cc" "../graph/build/memory/var_mem_assign_util.cc" "../host_kernels/transpose_kernel.cc" @@ -143,10 +136,6 @@ set(SRC_LIST "../host_kernels/transdata_kernel.cc" "../host_kernels/unpack_kernel.cc" "../graph/passes/pass_utils.cc" - "../graph/common/bcast.cc" - "../common/fp16_t.cc" - "../common/formats/format_transfers/format_transfer_transpose.cc" - "../common/formats/utils/formats_trans_utils.cc" ) ######## libge_executor.a ######## @@ -181,20 +170,22 @@ target_include_directories(ge_executor SYSTEM PRIVATE ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### - ${GE_CODE_DIR}/../inc - ${GE_CODE_DIR}/../inc/cce + $<$>:${GE_DEPEND_DIR}/inc> + $<$>:$> + $<$>:$> #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc> + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain> ) target_link_libraries(ge_executor PRIVATE $ + $<$>:$> + $<$>:$> + $<$>:$> json ascend_protobuf_static c_sec @@ -232,15 +223,12 @@ target_include_directories(ge_executor_shared PRIVATE ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### - ${GE_CODE_DIR}/../inc - ${GE_CODE_DIR}/../inc/cce + $<$>:${GE_DEPEND_DIR}/inc> #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc + $<$:${GE_CODE_DIR}/third_party/fwkacllib/inc> ) target_link_options(ge_executor_shared PRIVATE @@ -250,6 +238,11 @@ target_link_options(ge_executor_shared PRIVATE target_link_libraries(ge_executor_shared PRIVATE $ + $<$>:$> + $<$>:$> + $<$>:$> + $<$>:$> + $<$>:$> -Wl,--no-as-needed ge_common runtime diff --git a/ge/executor/ge_executor.cc b/ge/executor/ge_executor.cc index 049d012f..76cde2b9 100755 --- a/ge/executor/ge_executor.cc +++ b/ge/executor/ge_executor.cc @@ -14,19 +14,20 @@ * limitations under the License. */ -#include "executor/ge_executor.h" +#include "framework/executor/ge_executor.h" #include #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/ge/ge_util.h" -#include "common/helper/model_helper.h" +#include "framework/common/helper/model_helper.h" #include "common/profiling/profiling_manager.h" #include "common/dump/dump_manager.h" #include "graph/execute/graph_execute.h" #include "graph/load/graph_loader.h" #include "graph/load/model_manager/model_manager.h" #include "graph/manager/graph_mem_manager.h" +#include "graph/manager/graph_var_manager.h" #include "single_op/single_op_manager.h" #include "graph/load/model_manager/davinci_model.h" #include "opskernel_manager/ops_kernel_builder_manager.h" @@ -125,34 +126,41 @@ void SetDynamicInputDataFlag(const ge::RunModelData &input_data, const std::vect bool IsDynamicBatchSizeMatchModel(uint64_t batch_size, const vector> &batch_info) { if (batch_info.empty()) { - GELOGE(ge::FAILED, "Dynamic batch info is empty."); + REPORT_INNER_ERROR("E19999", "param Dynamic batch info is empty, check invalid."); + GELOGE(ge::FAILED, "[Check][Param] Dynamic batch info is empty."); return false; } for (auto batch : batch_info) { if (batch.size() != kDynamicBatchSizeVecSize) { - GELOGE(ge::FAILED, "Dynamic batch param num is %zu, current batch size is %zu.", kDynamicBatchSizeVecSize, - batch.size()); + REPORT_INNER_ERROR("E19999", "Dynamic batch param num is %zu, current batch size is %zu.", + kDynamicBatchSizeVecSize, batch.size()); + GELOGE(ge::FAILED, "[Check][Param] Dynamic batch param num is %zu, current batch size is %zu.", + kDynamicBatchSizeVecSize, batch.size()); return false; } if (batch[0] == static_cast(batch_size)) { return true; } } - GELOGE(ge::FAILED, "Dynamic batch %lu can not match the gear of model.", batch_size); + REPORT_INNER_ERROR("E19999", "Dynamic batch %lu can not match the gear of model.", batch_size); + GELOGE(ge::FAILED, "[Check][Param] Dynamic batch %lu can not match the gear of model.", batch_size); return false; } bool IsDynamicImageSizeMatchModel(uint64_t image_height, uint64_t image_width, const vector> &batch_info) { if (batch_info.empty()) { - GELOGE(ge::FAILED, "Dynamic batch info is empty."); + REPORT_INNER_ERROR("E19999", "ParamDynamic batch info is empty. check invalid"); + GELOGE(ge::FAILED, "[Check][Param] Dynamic batch info is empty."); return false; } for (auto resolution : batch_info) { if (resolution.size() != kDynamicImageSizeVecSize) { - GELOGE(ge::FAILED, "Dynamic resolution param num is %zu, current resolution size is %zu.", + REPORT_INNER_ERROR("E19999", "Dynamic resolution param num is %zu, current resolution size is %zu.", + kDynamicImageSizeVecSize, resolution.size()); + GELOGE(ge::FAILED, "[Check][Param] Dynamic resolution param num is %zu, current resolution size is %zu.", kDynamicImageSizeVecSize, resolution.size()); return false; } @@ -160,22 +168,28 @@ bool IsDynamicImageSizeMatchModel(uint64_t image_height, uint64_t image_width, return true; } } - - GELOGE(ge::FAILED, "Dynamic resolution (%lu,%lu) can not match the gear of model.", image_height, image_width); + REPORT_INNER_ERROR("E19999", "Dynamic resolution (%lu,%lu) can not match the gear of model.", + image_height, image_width); + GELOGE(ge::FAILED, "[Check][Param]Dynamic resolution (%lu,%lu) can not match the gear of model.", + image_height, image_width); return false; } bool IsDynmaicDimsSizeMatchModel(const vector cur_dynamic_dims, const vector> &batch_info) { if (batch_info.empty()) { - GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Dynamic batch info is empty."); + REPORT_INNER_ERROR("E19999", "param batch_info is empty, check invalid"); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Param] Dynamic batch info is empty."); return false; } bool find_match = false; for (auto resolution : batch_info) { if (cur_dynamic_dims.size() != resolution.size()) { - GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Cur dynamic dims param num is %zu, current resolution size is %zu.", + REPORT_INNER_ERROR("E19999", "Cur dynamic dims param num is %zu, current resolution size is %zu.", + cur_dynamic_dims.size(), resolution.size()); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, + "[Check][Param] Cur dynamic dims param num is %zu, current resolution size is %zu.", cur_dynamic_dims.size(), resolution.size()); return false; } @@ -192,7 +206,7 @@ bool IsDynmaicDimsSizeMatchModel(const vector cur_dynamic_dims, } } if (!find_match) { - GELOGE(ACL_ERROR_GE_PARAM_INVALID, "choose dynamic dims can not match the gear of model."); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Param] choose dynamic dims can not match the gear of model."); } return find_match; } @@ -241,7 +255,7 @@ Status GeExecutor::Initialize() { Status init_hostcpu_engine_status = HostCpuEngine::GetInstance().Initialize(); if (init_hostcpu_engine_status != SUCCESS) { - GELOGE(init_hostcpu_engine_status, "Failed to initialize HostCpuEngine"); + GELOGE(init_hostcpu_engine_status, "[initialize][HostCpuEngine] failed"); return init_hostcpu_engine_status; } @@ -251,12 +265,12 @@ Status GeExecutor::Initialize() { mem_type.push_back(RT_MEMORY_P2P_DDR); auto ret = MemManager::Instance().Initialize(mem_type); if (ret != SUCCESS) { - GELOGE(ret, "Memory Manager init failed."); + GELOGE(ret, "[Initialize][MemManager] failed."); return ret; } GE_CHK_STATUS_RET(OpsKernelBuilderManager::Instance().Initialize({}, false), - "Failed to initialize OpsKernelBuilders."); + "[Initialize][OpsKernelBuilderManager] failed."); // Start profiling Options profiling_options; @@ -292,13 +306,18 @@ Status GeExecutor::Finalize() { Status GeExecutor::SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t batch_size) { if (dynamic_input_addr == nullptr) { - GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID, "Dynamic input addr is nullptr!"); + REPORT_INNER_ERROR("E19999", "param dynamic_input_addr is nullptr, check invalid, model id:%u", model_id); + GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID, + "[Check][Param] Dynamic input addr is nullptr, model id:%u", model_id); return ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID; } uint64_t size = sizeof(uint32_t); if (length < size) { - GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID, "Dynamic input size [%lu] is less than [%lu]!", length, size); + REPORT_INNER_ERROR("E19999", "Dynamic input size [%lu] is less than [%lu], check invalid, model id:%u", + length, size, model_id); + GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID, + "[Check][Param] Dynamic input size [%lu] is less than [%lu], model id:%u", length, size, model_id); return ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID; } if (length >= sizeof(uint64_t)) { @@ -311,24 +330,28 @@ Status GeExecutor::SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_ad int32_t dynamic_type = static_cast(FIXED); Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { - GELOGE(ret, "Get dynamic input info failed."); + REPORT_CALL_ERROR("E19999", "get dynamic batch info failed, model id:%u", model_id); + GELOGE(ret, "[Get][DynamicBatchInfo] failed, model id:%u.", model_id); return ret; } if (!IsDynamicBatchSizeMatchModel(batch_size, batch_info)) { - GELOGE(ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID, "The current dynamic input does not match the gear of the model."); + GELOGE(ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID, + "[Check][Param] The current dynamic input does not match the gear of the model(id:%u).", model_id); return ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID; } ret = GraphExecutor::SetDynamicSize(model_id, batch_num, static_cast(DYNAMIC_BATCH)); if (ret != SUCCESS) { - GELOGE(ret, "Set dynamic size failed"); + REPORT_CALL_ERROR("E19999", "set dynamic size failed, model id:%u, dynamic_type:1", model_id); + GELOGE(ret, "[Set][DynamicSize] failed, model id:%u, dynamic_type:1", model_id); return ret; } // memcpy dynamic_batch_size from host to device rtError_t rt_ret = rtMemcpy(dynamic_input_addr, length, &batch_size, size, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "memcpy dynamic batch input data failed! ret: 0x%X", rt_ret); + REPORT_CALL_ERROR("E19999", "Call rtMemcpy, size:%lu ret:0x%X", length, rt_ret); + GELOGE(rt_ret, "[Call][RtMemcpy] memcpy dynamic batch input data failed! size:%lu ret:0x%X", length, rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); } return SUCCESS; @@ -337,14 +360,19 @@ Status GeExecutor::SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_ad Status GeExecutor::SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height, uint64_t image_width) { if (dynamic_input_addr == nullptr) { - GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID, "Dynamic input addr is nullptr!"); + REPORT_INNER_ERROR("E19999", "param dynamic_input_addr is nullptr, check invalid, model id:%u", model_id); + GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID, + "[Check][Param] Dynamic input addr is nullptr, model id:%u", model_id); return ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID; } uint64_t dynamic_input_size = kDynamicImageSizeInputSize * sizeof(uint32_t); if (length < dynamic_input_size) { + REPORT_INNER_ERROR("E19999", "Dynamic input size [%lu] is less than [%lu], check invalid, model id:%u", + length, dynamic_input_size, model_id); GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID, - "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size); + "[Check][Param] Dynamic input size [%lu] is less than [%lu], model id:%u", + length, dynamic_input_size, model_id); return ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID; } uint64_t size = sizeof(uint32_t); @@ -357,18 +385,22 @@ Status GeExecutor::SetDynamicImageSize(uint32_t model_id, void *dynamic_input_ad int32_t dynamic_type = static_cast(FIXED); Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { - GELOGE(ret, "Get dynamic input info failed."); + REPORT_CALL_ERROR("E19999", "Get dynamic input info failed, model id:%u.", model_id); + GELOGE(ret, "[Get][DynamicBatchInfo] failed, model id:%u.", model_id); return ret; } if (!IsDynamicImageSizeMatchModel(image_height, image_width, batch_info)) { - GELOGE(ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID, "The current dynamic input does not match the gear of the model."); + GELOGE(ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID, + "[Check][Param] The current dynamic input does not match the gear of the model, " + "image_height:%lu, image_width:%lu.", image_height, image_width); return ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID; } ret = GraphExecutor::SetDynamicSize(model_id, batch_num, static_cast(DYNAMIC_IMAGE)); if (ret != SUCCESS) { - GELOGE(ret, "Set dynamic size failed"); + REPORT_CALL_ERROR("E19999", "Set dynamic size failed, model id:%u,", model_id); + GELOGE(ret, "[Set][DynamicSize] failed, model id:%u", model_id); return ret; } @@ -376,7 +408,9 @@ Status GeExecutor::SetDynamicImageSize(uint32_t model_id, void *dynamic_input_ad rtError_t rt_ret = rtMemcpy(dynamic_input_addr, size, &image_height, size, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "memcpy dynamic resolution input data failed! ret: 0x%X", rt_ret); + REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed! size:%lu, ret:0x%X, model id:%u", size, rt_ret, model_id); + GELOGE(rt_ret, "[Call][RtMemcpy] memcpy dynamic resolution input data failed! size:%lu, ret:0x%X, model id:%u", + size, rt_ret, model_id); return RT_ERROR_TO_GE_STATUS(rt_ret); } @@ -385,7 +419,10 @@ Status GeExecutor::SetDynamicImageSize(uint32_t model_id, void *dynamic_input_ad rt_ret = rtMemcpy(reinterpret_cast(reinterpret_cast(dynamic_input_addr) + size), remain_size, &image_width, size, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "memcpy dynamic resolution input data failed!"); + REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed! size:%lu, ret:0x%X, model id:%u", + remain_size, rt_ret, model_id); + GELOGE(rt_ret, "[Call][RtMemcpy] memcpy dynamic resolution input data failed! size:%lu, ret:0x%X, model id:%u", + remain_size, rt_ret, model_id); return RT_ERROR_TO_GE_STATUS(rt_ret); } return SUCCESS; @@ -394,40 +431,48 @@ Status GeExecutor::SetDynamicImageSize(uint32_t model_id, void *dynamic_input_ad Status GeExecutor::SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, uint64_t length, const vector &dynamic_dims) { if (dynamic_input_addr == nullptr) { - GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID, "Dynamic input addr is nullptr!"); + REPORT_INNER_ERROR("E19999", "Param dynamic_input_addr is nullptr, check invalid, model id:%u", model_id); + GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID, + "[Check][Param] Dynamic input addr is nullptr, model id:%u", model_id); return ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID; } vector cur_dynamic_dims; Status ret = GetCurDynamicDims(model_id, dynamic_dims, cur_dynamic_dims); if (ret != SUCCESS) { - GELOGE(ret, "Set cur gear dynamic dims failed"); + GELOGE(ret, "[Get][CurDynamicDims] failed, model id:%u", model_id); return ret; } std::vector> batch_info; int32_t dynamic_type = static_cast(FIXED); ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { - GELOGE(ret, "Get dynamic input info failed."); + REPORT_CALL_ERROR("E19999", "Get dynamic input info failed, model id:%u.", model_id); + GELOGE(ret, "[Get][DynamicBatchInfo] failed, model id:%u.", model_id); return ret; } if (!IsDynmaicDimsSizeMatchModel(cur_dynamic_dims, batch_info)) { - GELOGE(ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID, "The current dynamic input does not match the gear of the model."); + GELOGE(ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID, + "[Check][Param] The current dynamic input does not match the gear of the model, id:%u.", model_id); return ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID; } ret = GraphExecutor::SetDynamicSize(model_id, cur_dynamic_dims, static_cast(DYNAMIC_DIMS)); if (ret != SUCCESS) { - GELOGE(ret, "Set dynamic size failed"); + REPORT_CALL_ERROR("E19999", "Set dynamic size failed, model id:%u", model_id); + GELOGE(ret, "[Set][DynamicSize] failed, model id:%u", model_id); return ret; } size_t dynamic_dim_num = cur_dynamic_dims.size(); uint64_t dynamic_input_size = static_cast(dynamic_dim_num * sizeof(uint32_t)); if (length < dynamic_input_size) { + REPORT_INNER_ERROR("E19999", "input dynamic size [%lu] is less than [%lu], model id:%u", + length, dynamic_input_size, model_id); GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID, - "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size); + "[Check][Param] Dynamic input size [%lu] is less than [%lu], model id:%u", + length, dynamic_input_size, model_id); return ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID; } uint64_t size = sizeof(uint32_t); @@ -440,7 +485,9 @@ Status GeExecutor::SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, u rt_ret = rtMemcpy(reinterpret_cast(reinterpret_cast(dynamic_input_addr) + size * i), length - size * i, &cur_dynamic_dims[i], size, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "memcpy dynamic resolution input data failed!"); + REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, size:%lu, ret:0x%X", (length - size * i), rt_ret); + GELOGE(rt_ret, "[Call][RtMemcpy] memcpy dynamic resolution input data failed! size:%lu, ret:0x%X", + length - size * i, rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); } } @@ -454,14 +501,14 @@ Status GeExecutor::GetCurDynamicDims(uint32_t model_id, const vector & vector output_desc; auto ret = GetModelDescInfo(model_id, input_desc, output_desc); if (ret != ge::SUCCESS) { - GELOGE(ret, "GetModelDescInfo failed."); + GELOGE(ret, "[Get][ModelDescInfo] failed, model id:%u.", model_id); return ret; } vector user_designate_shape_order; vector all_data_dims; ret = GetUserDesignateShapeOrder(model_id, user_designate_shape_order); if (ret != ge::SUCCESS) { - GELOGE(ret, "GetUserDesignateShapeOrder failed."); + GELOGE(ret, "[Call][GetUserDesignateShapeOrder] failed, model id:%u.", model_id); return ret; } for (auto &data_name : user_designate_shape_order) { @@ -475,8 +522,10 @@ Status GeExecutor::GetCurDynamicDims(uint32_t model_id, const vector & } } if (dynamic_dims.size() != all_data_dims.size()){ + REPORT_INNER_ERROR("E19999", "Dynamic input size [%lu] is not equal with all data dims size [%lu]!", + dynamic_dims.size(), all_data_dims.size()); GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID, - "Dynamic input size [%lu] is not equal with all data dims size [%lu]!", + "[Check][Param] Dynamic input size [%lu] is not equal with all data dims size [%lu]!", dynamic_dims.size(), all_data_dims.size()); return ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID; } @@ -484,8 +533,10 @@ Status GeExecutor::GetCurDynamicDims(uint32_t model_id, const vector & if (all_data_dims[i] < 0) { cur_dynamic_dims.push_back(dynamic_dims[i]); } else if (static_cast(all_data_dims[i]) != dynamic_dims[i]) { + REPORT_INNER_ERROR("E19999", "Static dims should be same, index:%zu value:%lu should be %ld", + i, dynamic_dims[i], all_data_dims[i]); GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID, - "Static dims should be same, index: %zu value: %lu should be %ld", + "[Check][Param] Static dims should be same, index:%zu value:%lu should be %ld", i, dynamic_dims[i], all_data_dims[i]); return ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID; } @@ -496,12 +547,14 @@ Status GeExecutor::GetCurDynamicDims(uint32_t model_id, const vector & Status GeExecutor::GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type) { GELOGI("Begin to get current shape"); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized, model id:%u", model_id); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized, model id:%u", model_id); return ACL_ERROR_GE_EXEC_NOT_INIT; } Status ret = GraphExecutor::GetCurShape(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { - GELOGE(ret, "Get current shape failed"); + REPORT_CALL_ERROR("E19999", "Get Cur Shape failed, model id:%u", model_id); + GELOGE(ret, "[Get][CurShape] failed, model id:%u", model_id); return ret; } return SUCCESS; @@ -512,11 +565,14 @@ Status GeExecutor::SetDynamicAippData(uint32_t model_id, void *dynamic_input_add const kAippDynamicPara &aippParms) { GELOGI("Enter to SetDynamicAippData."); if (dynamic_input_addr == nullptr) { - GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID, "Dynamic aipp input addr is nullptr!"); + REPORT_INNER_ERROR("E19999", "Param dynamic_input_addr is nullptr, check invalid, model id:%u", model_id); + GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID, + "[Check][Param] Dynamic aipp input addr is nullptr, model id:%u", model_id); return ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID; } if (aippBatchPara.empty()) { - GELOGE(ACL_ERROR_GE_AIPP_BATCH_EMPTY, "aippBatchPara is empty."); + REPORT_INNER_ERROR("E19999", "Param aippBatchPara is empty, check invalid, model id:%u", model_id); + GELOGE(ACL_ERROR_GE_AIPP_BATCH_EMPTY, "[Check][Param] aippBatchPara is empty, model id:%u", model_id); return ACL_ERROR_GE_AIPP_BATCH_EMPTY; } uint64_t batch_num = aippBatchPara.size(); @@ -527,14 +583,18 @@ Status GeExecutor::SetDynamicAippData(uint32_t model_id, void *dynamic_input_add "batch num is %lu, struct_len is %lu", model_id, length, batch_num, struct_len); if (struct_len > length) { + REPORT_INNER_ERROR("E19999", "input dynamic aipp param len:%lu is larger than aipp_data size:%lu", + struct_len, length); GELOGE(ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID, - "input dynamic aipp param len [%lu] is larger than aipp_data size [%lu]", struct_len, length); + "[Check][Param] input dynamic aipp param len [%lu] is larger than aipp_data size [%lu]", + struct_len, length); return ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID; } // Memcpy real kAippDynamicBatchPara from host to device rtError_t rt_ret = rtMemcpy(dynamic_input_addr, length, &aippParms, real_aippParms_size, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "memcpy real_aippParms_size failed! ret: 0x%X", rt_ret); + REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, size:%lu, ret:0x%X", length, rt_ret); + GELOGE(rt_ret, "[Call][RtMemcpy] memcpy aippParms failed! size:%lu, ret:0x%X", length, rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); } uint64_t remain_len = length - real_aippParms_size; @@ -545,7 +605,8 @@ Status GeExecutor::SetDynamicAippData(uint32_t model_id, void *dynamic_input_add (remain_len - i * sizeof(kAippDynamicBatchPara)), &(aippBatchPara[i]), sizeof(kAippDynamicBatchPara), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "memcpy kAippDynamicBatchPara input data failed! ret: 0x%X", rt_ret); + REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, ret:0x%X", rt_ret); + GELOGE(rt_ret, "[Call][RtMemcpy] memcpy kAippDynamicBatchPara input data failed! ret:0x%X", rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); } } @@ -555,12 +616,14 @@ Status GeExecutor::SetDynamicAippData(uint32_t model_id, void *dynamic_input_add Status GeExecutor::UnloadModel(uint32_t model_id) { GELOGD("unload model %u begin.", model_id); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } Status ret = GraphLoader::DestroyAicpuSessionForInfer(model_id); if (ret != SUCCESS) { - GELOGE(ret, "[GraphLoader] DestroyAicpuSessionForInfer failed. model id: %u", model_id); + REPORT_CALL_ERROR("E19999", "Destroy Aicpu Session For Infer failed, model id:%u", model_id); + GELOGE(ret, "[Destroy][AicpuSession] For Infer failed. model id:%u", model_id); return ret; } @@ -578,7 +641,8 @@ Status GeExecutor::UnloadModel(uint32_t model_id) { } ret = GraphLoader::UnloadModel(model_id); if (ret != SUCCESS) { - GELOGE(ret, "[GraphLoader] DestroyAicpuSessionForInfer failed. model id: %u", model_id); + REPORT_CALL_ERROR("E19999", "unload model failed, model id:%u", model_id); + GELOGE(ret, "[Unload][Model] failed. model id:%u", model_id); return ret; } return SUCCESS; @@ -588,7 +652,8 @@ Status GeExecutor::UnloadModel(uint32_t model_id) { Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector &input_desc, std::vector &output_desc, bool new_model_desc) { if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized, model id:%u", model_id); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized, model id:%u", model_id); return ACL_ERROR_GE_EXEC_NOT_INIT; } @@ -600,20 +665,26 @@ Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector> &batch_info, int32_t &dynamic_type) { if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { - GELOGE(ret, "GetDynamicBatchInfo failed."); + REPORT_CALL_ERROR("E19999", "Get Dynamic BatchInfo failed, model id:%u.", model_id); + GELOGE(ret, "[Get][DynamicBatchInfo] failed, model id:%u.", model_id); return ret; } return SUCCESS; @@ -657,13 +730,15 @@ Status GeExecutor::GetDynamicBatchInfo(uint32_t model_id, std::vector> &batch_info) { GELOGI("Begin to get combined dynamic dims info."); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } Status ret = GraphExecutor::GetCombinedDynamicDims(model_id, batch_info); if (ret != SUCCESS) { - GELOGE(ret, "GetCombinedDynamicDims failed."); + REPORT_CALL_ERROR("E19999", "Get Combined DynamicDims failed, model id:%u.", model_id); + GELOGE(ret, "[Get][CombinedDynamicDims] failed, model id:%u.", model_id); return ret; } @@ -680,13 +755,15 @@ Status GeExecutor::GetCombinedDynamicDims(uint32_t model_id, vector &user_designate_shape_order) { if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } Status ret = GraphExecutor::GetUserDesignateShapeOrder(model_id, user_designate_shape_order); if (ret != SUCCESS) { - GELOGE(ret, "GetUserDesignateShapeOrder failed."); + REPORT_CALL_ERROR("E19999", "GetUserDesignateShapeOrder failed, model id:%u.", model_id); + GELOGE(ret, "[Call][GetUserDesignateShapeOrder] failed, model id:%u.", model_id); return ret; } @@ -704,7 +781,8 @@ Status GeExecutor::GetUserDesignateShapeOrder(uint32_t model_id, vector Status GeExecutor::GetAIPPInfo(uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info) { GELOGI("Begin to GetAIPPInfo."); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "not inited yet!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor not inited yet!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } Status ret = GraphExecutor::GetAippInfo(model_id, index, aipp_info); @@ -719,7 +797,8 @@ Status GeExecutor::GetAIPPInfo(uint32_t model_id, uint32_t index, AippConfigInfo Status GeExecutor::GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index) { GELOGI("Begin to get aipp type."); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "not inited yet!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not inited yet!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } Status ret = GraphExecutor::GetAippType(model_id, index, type, aipp_index); @@ -741,8 +820,10 @@ Status GeExecutor::GetOpAttr(uint32_t model_id, const std::string &op_name, cons } Status ret = GraphExecutor::GetOpAttr(model_id, op_name, attr_name, attr_value); if (ret != SUCCESS) { - GELOGE(ret, "[Get][OpAttr]Get op:%s attr:%s failed.", op_name.c_str(), attr_name.c_str()); - REPORT_CALL_ERROR("E19999", "Get op:%s attr:%s failed.", op_name.c_str(), attr_name.c_str()); + GELOGE(ret, "[Get][OpAttr]Get op:%s attr:%s failed, model id:%u.", + op_name.c_str(), attr_name.c_str(), model_id); + REPORT_CALL_ERROR("E19999", "Get op:%s attr:%s failed, model id:%u", + op_name.c_str(), attr_name.c_str(), model_id); return ret; } return SUCCESS; @@ -750,12 +831,14 @@ Status GeExecutor::GetOpAttr(uint32_t model_id, const std::string &op_name, cons Status GeExecutor::GetModelAttr(uint32_t model_id, std::vector &dynamic_output_shape_info) { if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "not inited yet!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not inited yet!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not inited yet!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } Status ret = GraphExecutor::GetModelAttr(model_id, dynamic_output_shape_info); if (ret != SUCCESS) { - GELOGE(ret, "Get dynamic batch output shape info failed."); + REPORT_CALL_ERROR("E19999", "Get Model Attr failed, model id:%u.", model_id); + GELOGE(ret, "[Get][ModelAttr] failed, model id:%u.", model_id); return ret; } return SUCCESS; @@ -764,7 +847,8 @@ Status GeExecutor::GetModelAttr(uint32_t model_id, std::vector &dyn Status GeExecutor::CommandHandle(const Command &command) { Status ret = GraphLoader::CommandHandle(command); if (ret != SUCCESS) { - GELOGE(ACL_ERROR_GE_COMMAND_HANDLE, "CommandHandle: Command Handle failed."); + REPORT_CALL_ERROR("E19999", "call CommandHandle failed, ret:%u", ret); + GELOGE(ACL_ERROR_GE_COMMAND_HANDLE, "[Call][CommandHandle] failed, ret:%u", ret); return ACL_ERROR_GE_COMMAND_HANDLE; } return SUCCESS; @@ -773,7 +857,8 @@ Status GeExecutor::CommandHandle(const Command &command) { Status GeExecutor::GetMaxUsedMemory(uint32_t model_id, uint32_t &max_size) { GELOGI("Get max used memory begin."); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } @@ -793,14 +878,15 @@ Status GeExecutor::GetMaxUsedMemory(uint32_t model_id, uint32_t &max_size) { Status GeExecutor::LoadDataFromFile(const std::string &path, ModelData &model_data) { GELOGI("Load data from file begin."); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } string filePath = RealPath(path.c_str()); if (filePath.empty()) { GELOGE(ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID, - "File path is invalid. please check your text file '%s'.", path.c_str()); + "[Call][RealPath] File path is invalid. please check your text file '%s'.", path.c_str()); return ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID; } GELOGI("load modelData from file: %s.", path.c_str()); @@ -829,7 +915,8 @@ Status GeExecutor::LoadDataFromFile(const std::string &path, ModelData &model_da Status GeExecutor::LoadModelFromData(uint32_t &model_id, const ModelData &model_data, void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "not inited yet!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not inited yet!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } @@ -850,7 +937,8 @@ Status GeExecutor::LoadModelWithQ(uint32_t &model_id, const ModelData &model_dat const std::vector &output_queue_ids) { GELOGI("Load model with queue begin."); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } return GraphLoader::LoadModelWithQ(model_id, model_data, input_queue_ids, output_queue_ids); @@ -889,7 +977,8 @@ Status GeExecutor::ExecModel(uint32_t model_id, void *stream, const ge::RunModel const std::vector &input_desc, ge::RunModelData &run_output_data, std::vector &output_desc, bool async_mode) { if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } @@ -904,7 +993,8 @@ Status GeExecutor::ExecModel(uint32_t model_id, void *stream, const ge::RunModel int32_t dynamic_type = static_cast(FIXED); Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { - GELOGE(ret, "Get dynamic input info failed."); + REPORT_CALL_ERROR("E19999", "get dynamic batch info failed, model id:%u.", model_id); + GELOGE(ret, "[Get][DynamicBatchInfo] failed, model id:%u.", model_id); return ret; } if (!batch_info.empty()) { @@ -926,14 +1016,16 @@ Status GeExecutor::ExecModel(uint32_t model_id, void *stream, const ge::RunModel Status GeExecutor::GetMemAndWeightSize(const std::string &path, size_t &mem_size, size_t &weight_size) { GELOGI("Get memory and weight size from file begin."); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } ModelData model; Status ret = ge::GraphLoader::LoadDataFromFile(path, 0, model); if ((ret != SUCCESS) || (model.model_data == nullptr)) { - GELOGE(ret, "Load data from file failed. ret = %d", ret); + REPORT_CALL_ERROR("E19999", "load data from file failed, ret = %d", ret); + GELOGE(ret, "[Load][Data] from file failed. ret = %d", ret); return ret; } @@ -958,12 +1050,14 @@ Status GeExecutor::GetMemAndWeightSize(const void *model_data, size_t model_size size_t &weight_size) { GELOGI("Get memory and weight size from data begin."); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } if (model_data == nullptr) { - GELOGE(ACL_ERROR_GE_EXEC_MODEL_ADDR_INVALID, "invalid model data!"); + REPORT_INNER_ERROR("E19999", "param model_data is nullptr, check invalid!"); + GELOGE(ACL_ERROR_GE_EXEC_MODEL_ADDR_INVALID, "[Check][Param] invalid model data!"); return ACL_ERROR_GE_EXEC_MODEL_ADDR_INVALID; } @@ -997,7 +1091,8 @@ Status GeExecutor::LoadDynamicSingleOpV2(const std::string &model_name, const ge Status GeExecutor::ExecuteAsync(SingleOp *executor, const std::vector &inputs, std::vector &outputs) { if (executor == nullptr) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "param is NULL"); + REPORT_INNER_ERROR("E19999", "Param executor is nullptr, check invalid"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] param executor is nullptr"); return ACL_ERROR_GE_EXEC_NOT_INIT; } @@ -1021,7 +1116,8 @@ Status GeExecutor::GetDeviceIdByModelId(uint32_t model_id, uint32_t &device_id) GE_CHECK_NOTNULL(model_manager); auto davinci_model = model_manager->GetModel(model_id); if (davinci_model == nullptr) { - GELOGE(ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, "Model id: %d is invaild or model is not loaded.", model_id); + GELOGE(ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, + "[Get][Model] failed, Model id:%u is invaild or model is not loaded.", model_id); return ACL_ERROR_GE_EXEC_MODEL_ID_INVALID; } @@ -1034,7 +1130,7 @@ Status GeExecutor::GetBatchInfoSize(uint32_t model_id, size_t &shape_count) { int32_t dynamic_type = static_cast(FIXED); Status ret = GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { - GELOGE(ret, "Calc batch info size failed. ret = %d", ret); + GELOGE(ret, "[Get][DynamicBatchInfo] failed. ret = %d, model id:%u", ret, model_id); return ret; } if (batch_info.empty()) { @@ -1048,13 +1144,15 @@ Status GeExecutor::GetBatchInfoSize(uint32_t model_id, size_t &shape_count) { Status GeExecutor::GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info) { GELOGI("Begin to GetOrigInputInfo."); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "not inited yet!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } Status ret = GraphExecutor::GetOrigInputInfo(model_id, index, orig_input_info); if (ret != SUCCESS) { - GELOGE(ret, "GetOrigInputInfo failed."); + REPORT_CALL_ERROR("E19999", "Get Orig Input Info failed, model id:%u.", model_id); + GELOGE(ret, "[Get][OrigInputInfo] failed, model id:%u.", model_id); return ret; } @@ -1067,13 +1165,15 @@ Status GeExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector &output_dims) { GELOGI("Begin to GetAllAippInputOutputDims."); if (!isInit_) { - GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "not inited yet!"); + REPORT_INNER_ERROR("E19999", "GeExecutor has not been initialized!"); + GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "[Check][Param] GeExecutor has not been initialized!"); return ACL_ERROR_GE_EXEC_NOT_INIT; } Status ret = GraphExecutor::GetAllAippInputOutputDims(model_id, index, input_dims, output_dims); if (ret != SUCCESS) { - GELOGE(ret, "GetAllAippInputOutputDims failed."); + REPORT_CALL_ERROR("E19999", "Get All Aipp Input Output Dims failed, model id:%u.", model_id); + GELOGE(ret, "[Get][AllAippInputOutputDims] failed, model id:%u.", model_id); return ret; } @@ -1085,7 +1185,10 @@ Status GeExecutor::GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_ GELOGI("Begin to GetOpDescInfo."); Status ret = GraphExecutor::GetOpDescInfo(device_id, stream_id, task_id, op_desc_info); if (ret != SUCCESS) { - GELOGE(ret, "GetOpDescInfo failed."); + REPORT_CALL_ERROR("E19999", "get opdesc info failed, device_id:%u, stream_id:%u, task_id:%u.", + device_id, stream_id, task_id); + GELOGE(ret, "[Get][OpDescInfo] failed, device_id:%u, stream_id:%u, task_id:%u.", + device_id, stream_id, task_id); return ret; } GELOGI("GetOpDescInfo succ."); @@ -1096,7 +1199,7 @@ Status GeExecutor::SetDump(const DumpConfig &dump_config) { GELOGI("Start to set dump config"); auto ret = DumpManager::GetInstance().SetDumpConf(dump_config); if (ret != SUCCESS) { - GELOGE(ret, "Set dump conf failed"); + GELOGE(ret, "[Set][DumpConf] failed, ret:%d", ret); return ret; } GELOGI("Set dump config successfully"); diff --git a/ge/executor/module.mk b/ge/executor/module.mk index 7a7e2b51..430efa75 100644 --- a/ge/executor/module.mk +++ b/ge/executor/module.mk @@ -63,7 +63,7 @@ local_ge_executor_src_files := \ ../single_op/task/aicpu_task_builder.cc \ ../single_op/task/aicpu_kernel_task_builder.cc \ ../hybrid/node_executor/aicpu/aicpu_ext_info.cc \ - ../graph/common/local_context.cc \ + ../common/local_context.cc \ ../hybrid/common/tensor_value.cc \ ../hybrid/common/npu_memory_allocator.cc \ ../hybrid/executor/rt_callback_manager.cc \ @@ -102,7 +102,7 @@ local_ge_executor_src_files := \ ../hybrid/node_executor/task_context.cc \ ../hybrid/hybrid_davinci_model.cc \ ../ge_local_engine/engine/host_cpu_engine.cc \ - ../graph/common/omg_util.cc \ + ../common/omg_util.cc \ ../graph/manager/host_mem_manager.cc \ ../graph/build/memory/var_mem_assign_util.cc \ ../host_kernels/transpose_kernel.cc \ @@ -144,7 +144,7 @@ local_ge_executor_src_files := \ ../host_kernels/transdata_kernel.cc \ ../host_kernels/unpack_kernel.cc \ ../graph/passes/pass_utils.cc \ - ../graph/common/bcast.cc \ + ../common/bcast.cc \ ../common/fp16_t.cc \ ../common/formats/format_transfers/format_transfer_transpose.cc \ ../common/formats/utils/formats_trans_utils.cc \ diff --git a/ge/executor/proto/dump_task.proto b/ge/executor/proto/dump_task.proto deleted file mode 100644 index a2411ddb..00000000 --- a/ge/executor/proto/dump_task.proto +++ /dev/null @@ -1,113 +0,0 @@ -syntax = "proto3"; -package toolkit.dump; - -enum OutputDataType { - DT_UNDEFINED = 0; - DT_FLOAT = 1; - DT_FLOAT16 = 2; - DT_INT8 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_UINT16 = 6; - DT_INT32 = 7; - DT_INT64 = 8; - DT_UINT32 = 9; - DT_UINT64 = 10; - DT_BOOL = 11; - DT_DOUBLE = 12; - DT_STRING = 13; - DT_DUAL_SUB_INT8 = 14; - DT_DUAL_SUB_UINT8 = 15; - DT_COMPLEX64 = 16; - DT_COMPLEX128 = 17; - DT_QINT8 = 18; - DT_QINT16 = 19; - DT_QINT32 = 20; - DT_QUINT8 = 21; - DT_QUINT16 = 22; - DT_RESOURCE = 23; - DT_STRING_REF = 24; - DT_DUAL = 25; - DT_VARIANT = 26; -} - -enum OutputFormat { - FORMAT_NCHW = 0; - FORMAT_NHWC = 1; - FORMAT_ND = 2; - FORMAT_NC1HWC0 = 3; - FORMAT_FRACTAL_Z = 4; - FORMAT_NC1C0HWPAD = 5; - FORMAT_NHWC1C0 = 6; - FORMAT_FSR_NCHW = 7; - FORMAT_FRACTAL_DECONV = 8; - FORMAT_C1HWNC0 = 9; - FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; - FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; - FORMAT_NC1HWC0_C04 = 12; - FORMAT_FRACTAL_Z_C04 = 13; - FORMAT_CHWN = 14; - FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; - FORMAT_HWCN = 16; - FORMAT_NC1KHKWHWC0 = 17; - FORMAT_BN_WEIGHT = 18; - FORMAT_FILTER_HWCK = 19; - FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; - FORMAT_HASHTABLE_LOOKUP_KEYS = 21; - FORMAT_HASHTABLE_LOOKUP_VALUE = 22; - FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; - FORMAT_HASHTABLE_LOOKUP_HITS=24; - FORMAT_C1HWNCoC0 = 25; - FORMAT_MD = 26; - FORMAT_NDHWC = 27; - FORMAT_FRACTAL_ZZ = 28; - FORMAT_FRACTAL_NZ = 29; - FORMAT_RESERVED = 30; -} - -message OriginalOp { - string name = 1; - uint32 output_index = 2; - OutputDataType data_type = 3; - OutputFormat format = 4; -} - -message Shape { - repeated uint64 dim = 1; -} - -message OpOutput { - OutputDataType data_type = 1; - OutputFormat format = 2; - Shape shape = 3; - OriginalOp original_op = 4; // the original op corresponding to the output - bytes data = 5; - uint64 size = 6; -} - -message OpInput { - OutputDataType data_type = 1; - OutputFormat format = 2; - Shape shape = 3; - bytes data = 4; - uint64 size = 5; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - bytes data = 2; - uint64 size = 3; -} - -message DumpData{ - string version = 1; - uint64 dump_time = 2; - repeated OpOutput output = 3; - repeated OpInput input = 4; - repeated OpBuffer buffer = 5; - string op_name = 6; -} diff --git a/ge/executor/proto/ge_ir.proto b/ge/executor/proto/ge_ir.proto deleted file mode 100644 index c0ef3071..00000000 --- a/ge/executor/proto/ge_ir.proto +++ /dev/null @@ -1,193 +0,0 @@ -syntax = "proto3"; - -package ge.proto; - -enum DataType -{ - DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. - DT_FLOAT = 1; // float type - DT_FLOAT16 = 2; // fp16 type - DT_INT8 = 3; // int8 type - DT_UINT8 = 4; // uint8 type - DT_INT16 = 5; // int16 type - DT_UINT16 = 6; // uint16 type - DT_INT32 = 7; // - DT_INT64 = 8; // int64 type - DT_UINT32 = 9; // unsigned int32 - DT_UINT64 = 10; // unsigned int64 - DT_BOOL = 11; // bool type - DT_DOUBLE = 12; // double type - DT_STRING = 13; // string type - DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ - DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ - DT_COMPLEX64 = 16; // complex64 type - DT_COMPLEX128 = 17; // complex128 type - DT_QINT8 = 18; // qint8 type - DT_QINT16 = 19; // qint16 type - DT_QINT32 = 20; // qint32 type - DT_QUINT8 = 21; // quint8 type - DT_QUINT16 = 22; // quint16 type - DT_RESOURCE = 23; // resource type - DT_STRING_REF = 24; // string_ref type - DT_DUAL = 25; /**< dual output type */ - DT_VARIANT = 26; // variant type - DT_BF16 = 27; // bf16 type - DT_INT4 = 28; // int4 type -} - -message AttrDef -{ - message ListValue - { - enum ListValueType{ - VT_LIST_NONE = 0; - VT_LIST_STRING = 1; - VT_LIST_INT = 2; - VT_LIST_FLOAT = 3; - VT_LIST_BOOL = 4; - VT_LIST_BYTES = 5; - VT_LIST_TENSOR_DESC = 6; - VT_LIST_TENSOR = 7; - VT_LIST_GRAPH = 8; - VT_LIST_NAMED_ATTRS = 9; - VT_LIST_DATA_TYPE = 10; - } - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3; // "list(int)" - repeated float f = 4; // "list(float)" - repeated bool b = 5; // "list(bool)" - repeated bytes bt = 7; - repeated TensorDescriptor td = 8; - repeated TensorDef t = 9; - repeated GraphDef g = 10; - repeated NamedAttrs na = 11; - repeated int64 dt = 12; // list ge::DataType - - ListValueType val_type = 20; - } - - message ListListInt{ - message ListInt{ - repeated int64 list_i = 1; // list int - } - repeated ListInt list_list_i = 1; // list list int - } - - oneof value - { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; // Used to support attr nesting - TensorDescriptor td = 11; // GeTensorDesc type - TensorDef t = 12; // GeTensor type - GraphDef g = 13; // Graph type - ListListInt list_list_int = 14; // List List Int type - int64 dt = 15; // ge::DataType - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs -{ - string name = 1; - map attr = 2; -} - -// Shape / dimension description, using row-major order -message ShapeDef -{ - repeated int64 dim = 1; // Size of each dimension -} - -// Multidimensional data description -message TensorDescriptor -{ - string name = 1; // Optional parameter, tensor name - - DataType dtype = 2; // tensor datatype - ShapeDef shape = 3; // Shape / dimension - string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" - - bool has_out_attr = 9; - int64 size = 10; - int64 weight_size = 11; - bool reuse_input = 12; - bool output_tensor = 13; - string device_type = 14; - bool input_tensor =15; - int64 real_dim_cnt = 16; - int64 reuse_input_index = 17; - int64 data_offset = 18; - int64 cmps_size = 19; - string cmps_tab = 20; - int64 cmps_tab_offset = 21; - - map attr = 5; // Set of extra parameter fields -} - -// GeTensor definition -message TensorDef -{ - TensorDescriptor desc = 1; // Tensor description - bytes data = 2; // Tensor data -} - - -// Operator description -message OpDef -{ - string name = 1; // name - string type = 2; // type - - repeated string input = 5; // input original op name + outgoing index. op_name:index - - map attr = 10; // Set of operator parameter fields - - bool has_out_attr = 20; - int64 id = 21; - int64 stream_id =22; - repeated string input_name = 23; - repeated string src_name = 24; - repeated int64 src_index = 25; - repeated string dst_name = 26; - repeated int64 dst_index = 27; - repeated int64 input_i = 28; - repeated int64 output_i = 29; - repeated int64 workspace = 30; - repeated int64 workspace_bytes = 31; - repeated bool is_input_const = 32; - repeated TensorDescriptor input_desc = 33; - repeated TensorDescriptor output_desc = 34; - repeated string subgraph_name = 35; -} - -// Graph definition -message GraphDef -{ - string name = 1; // name - - repeated string input = 4; // Graph input - repeated string output = 5; // Graph output - - repeated OpDef op = 6; // List of operators - - map attr = 11; // Extended field -} - -// model definition -message ModelDef -{ - string name = 1; // name - uint32 version = 2; // IR Proto verion - string custom_version = 3; // User model version number, passed in by user - - repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef - - map attr = 11; // Extended field -} - diff --git a/ge/executor/proto/insert_op.proto b/ge/executor/proto/insert_op.proto deleted file mode 100644 index 7d708865..00000000 --- a/ge/executor/proto/insert_op.proto +++ /dev/null @@ -1,140 +0,0 @@ -syntax = "proto3"; - -package domi; - -message InsertNewOps { - repeated AippOpParams aipp_op = 1; - repeated MultiShapeOpParams multi_shape_op = 2; -} - -message AippOpParams { - enum InputFormat { - UNDEFINED = 0; - YUV420SP_U8 = 1; - XRGB8888_U8 = 2; - RGB888_U8 = 3; - YUV400_U8 = 4; - NC1HWC0DI_FP16 = 5; - NC1HWC0DI_S8 = 6; - ARGB8888_U8 = 7; - YUYV_U8 = 8; - YUV422SP_U8 = 9; - AYUV444_U8 = 10; - RAW10 = 11; - RAW12 = 12; - RAW16 = 13; - RAW24 = 14; - RGB16 = 15; - RGB20 = 16; - RGB24 = 17; - RGB8_IR = 18; - RGB16_IR = 19; - RGB24_IR = 20; - } - - enum AippMode { - undefined = 0; - static = 1; - dynamic = 2; - } - - // AIPPģʽ־̬AIPPͶ̬AIPP - AippMode aipp_mode = 1; - - // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 - // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 - uint32 related_input_rank = 2; - - // related_input_name is optional and the top name of data node which inserts aipp - string related_input_name = 6; - - // input_edge_idxΪѡΪͣ÷ΧΪ>=0 - // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP - // ֵ <= Dataߵĸ - repeated uint32 input_edge_idx = 3; - - // [Begin] ̬AIPPþ̬AIPPʱЧ - uint32 max_src_image_size = 4; - - // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ - bool support_rotation = 5; - - // [End] ̬AIPP - - - // [Begin] ̬AIPPö̬AIPPʱЧ - InputFormat input_format = 51; - bool csc_switch = 52; - float cpadding_value = 53; - bool rbuv_swap_switch = 54; - bool ax_swap_switch = 55; - bool single_line_mode = 56; - - int32 src_image_size_w = 57; - int32 src_image_size_h = 58; - - bool crop = 59; - int32 load_start_pos_w = 60; - int32 load_start_pos_h = 61; - int32 crop_size_w = 62; - int32 crop_size_h = 63; - - bool resize = 64; - int32 resize_output_w = 65; - int32 resize_output_h = 66; - - bool padding = 67; - int32 left_padding_size = 68; - int32 right_padding_size = 69; - int32 top_padding_size = 70; - int32 bottom_padding_size = 71; - float padding_value = 72; - - int32 mean_chn_0 = 10; - int32 mean_chn_1 = 11; - int32 mean_chn_2 = 12; - int32 mean_chn_3 = 19; - float min_chn_0 = 13; - float min_chn_1 = 14; - float min_chn_2 = 15; - float min_chn_3 = 20; - repeated float var_reci_chn_0 = 16; - repeated float var_reci_chn_1 = 17; - repeated float var_reci_chn_2 = 18; - repeated float var_reci_chn_3 = 21; - - repeated int32 matrix_r0c0 = 30; - repeated int32 matrix_r0c1 = 31; - repeated int32 matrix_r0c2 = 32; - repeated int32 matrix_r1c0 = 33; - repeated int32 matrix_r1c1 = 34; - repeated int32 matrix_r1c2 = 35; - repeated int32 matrix_r2c0 = 36; - repeated int32 matrix_r2c1 = 37; - repeated int32 matrix_r2c2 = 38; - repeated int32 output_bias_0 = 39; - repeated int32 output_bias_1 = 40; - repeated int32 output_bias_2 = 41; - repeated int32 input_bias_0 = 42; - repeated int32 input_bias_1 = 43; - repeated int32 input_bias_2 = 44; - - // [End] ̬AIPP - - // The n number that is used for raw/rgbir data into f16 transformation. - // The transformation equation is x/(2^n). If set to 0, no transform is performed. - uint32 raw_rgbir_to_f16_n = 45; -} - -message MultiShapeOpParams { - enum MultiShapeMode { - batch = 0; //̬batch - resolution = 1; //ֱ̬ʣչ - } - - MultiShapeMode mode = 1; //ģʽ - uint32 related_input_rank = 2; //Ӳ뵽ĸ - - - repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ -} diff --git a/ge/executor/proto/om.proto b/ge/executor/proto/om.proto deleted file mode 100644 index e15e5f80..00000000 --- a/ge/executor/proto/om.proto +++ /dev/null @@ -1,396 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package domi; - -enum TargetType -{ - MINI = 0; - TINY = 1; - LITE = 2; -} - -// offline model -message ModelDef { - string name = 1; - uint32 version = 2; - - uint64 memory_size = 10; - uint32 stream_num = 11; - uint32 event_num = 12; - uint64 weight_size = 13; - uint32 label_num = 15; - repeated OpDef op = 20; - TargetType target_type = 23; - - map attr = 30; -}; - -// operator define -message OpDef { - string name = 1; - string type = 2; - - uint32 id = 3; - uint32 stream_id = 4; - - repeated string input_name = 5; - - repeated string src_name = 8; - repeated int32 src_index = 9; - repeated int64 input = 10; - repeated int64 output = 11; - repeated TensorDescriptor input_desc = 12; - repeated TensorDescriptor output_desc = 13; - repeated WeightDef weights = 14; - repeated string dst_name = 15; - repeated int32 dst_index = 16; - - repeated int64 workspace = 20; - repeated uint32 workspace_bytes = 21; - - repeated string weight_name = 22; - repeated bool is_input_const = 23; - - map attr = 30; - - QuantizeFactorParams quantize_factor = 31; - - oneof op_params { - // start at 100 here - SendOpParams sender_param = 100; - RecvOpParams receiver_param = 200; - ConvolutionOpParams convolution_param = 300; - PoolingOpParams pooling_param = 400; - EltwiseOpParams eltwise_param = 500; - BatchNormOpParams batchnorm_param = 600; - ScaleOpParams scale_param = 700; - FullConnectionOpParams full_connection_param = 800; - SoftmaxOpParams softmax_param = 900; - ActivationOpParams activation_param = 1000; - ReshapeOpParams reshape_param = 1100; - } -}; - -message SendOpParams { - uint32 event_id = 1; -}; - -message RecvOpParams { - uint32 event_id = 1; -}; - -enum QuantizeScaleType -{ - VECTOR_SCALE = 0; - SCALAR_SCALE = 1; -} - -enum QuantizeScaleMode -{ - NORMAL_MODE = 0; - SQRT_MODE = 1; -} - -enum QuantizeAlgorithm -{ - NON_OFFSET_ALGO = 0; - HALF_OFFSET_ALGO = 1; - ALL_OFFSET_ALGO = 2; -} -message QuantizeFactor -{ - QuantizeScaleMode scale_mode = 1; - bytes scale_value = 2; - int64 scale_offset = 3; - bytes offset_data_value = 4; - int64 offset_data_offset = 5; - bytes offset_weight_value = 6; - int64 offset_weight_offset = 7; - bytes offset_pad_value = 8; - int64 offset_pad_offset = 9; -}; - -message QuantizeCalcFactor -{ - bytes offsetw = 1; - int64 offsetw_offset = 2; - bytes offsetd = 3; - int64 offsetd_offset = 4; - bytes scalereq = 5; - int64 scaledreq_offset = 6; - bytes offsetdnext = 7; - int64 offsetdnext_offset = 8; -} - -message QuantizeFactorParams -{ - QuantizeAlgorithm quantize_algo = 1; - QuantizeScaleType scale_type = 2; - QuantizeFactor quantize_param = 3; - QuantizeFactor dequantize_param = 4; - QuantizeFactor requantize_param = 5; - QuantizeCalcFactor quantizecalc_param = 6; -}; - -message ConvolutionOpParams { - int32 mode = 1; - int32 algo = 2; - int32 pad_mode = 3; - uint32 group = 4; - uint32 num_output = 5; - - repeated uint32 pad = 10; - repeated uint32 stride = 11; - repeated uint32 dilation = 12; - repeated uint32 kernel = 13; - - float alpha = 20; - float beta = 21; - - WeightDef filter = 40; - WeightDef bias = 41; - - bool relu_flag = 62; - repeated uint32 adj = 70; - repeated uint32 target_shape = 71; - repeated uint32 before_pad = 72; -}; - -message PoolingOpParams { - int32 mode = 1; - int32 nan_opt = 2; - int32 pad_mode = 3; - bool global_pooling = 4; - - repeated uint32 window = 10; - repeated uint32 pad = 11; - repeated uint32 stride = 12; - bool ceil_mode = 13; - int32 data_mode = 14; - - float alpha = 20; - float beta = 21; - repeated uint32 before_pad = 22; -}; - -message EltwiseOpParams { - int32 mode = 1; - repeated float coeff = 2; - float alpha = 3; - float beta = 4; - repeated WeightDef weight = 5; - bool relu_flag = 6; -}; - -message ActivationOpParams { - int32 mode = 1; - float coef = 2; - float alpha = 3; - float beta = 4; -}; - -message BatchNormOpParams { - int32 mode = 1; - - float alpha = 2; - float beta = 3; - double epsilon = 4;//optinal,[default = 1e-5] - bool use_global_stats = 5; //optinal,by default true,testing mode - float moving_average_fraction = 6; //optinal,[default = .999]; - - WeightDef estimated_mean = 7; - WeightDef estimated_variance = 8; - - WeightDef scale = 9; - WeightDef bias = 10; -}; - -message ScaleOpParams { - WeightDef scale = 1; - WeightDef bias = 2; -}; - -message ReshapeOpParams { - float alpha = 1; - float beta = 2; - ShapeDef shape = 3; - int32 axis = 4; - int32 num_axes = 5; - int32 format = 6; -}; - -message SoftmaxOpParams { - int32 algo = 1; - int32 mode = 2; - float alpha = 3; - float beta = 4; -}; - -message FullConnectionOpParams { - WeightDef filter = 1; - WeightDef bias = 2; - uint32 num_output = 3; - bool relu_flag = 12; -}; - -message FlattenOpParams { - float alpha = 1; - float beta = 2; - int32 start_axis = 3; - int32 end_axis = 4; -} - -message AddLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message MulLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message AddOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message MulOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message SubOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message BiasAddOpParams { - float alpha = 1; - float beta = 2; - - WeightDef bias = 10; -}; - -message MatMulOpParams { - float alpha = 1; - float beta = 2; - bool transposeX = 3; - bool transposeW = 4; - - WeightDef filter = 10; - WeightDef bias = 12; -}; - -message RsqrtOpParams { - float alpha = 1; - float beta = 2; -}; - - -message WeightDef { - int32 format = 1; - int32 data_type = 2; - ShapeDef shape = 3; - bytes data = 4; - int64 data_offset = 5; - uint32 cmps_size = 6; - bytes cmps_tab = 7; - int64 cmps_tab_offset = 10; - CompressInfo cmps_info = 8; - AllOffsetQuantizeInfo alloffset_quantize_info = 11; -} - -message ShapeDef { - repeated int64 dim = 1; -} - -enum DeviceType { - NPU = 0; // In default, we will use NPU. - CPU = 1; // CPU -} - -message AllOffsetQuantizeInfo { - float scale = 1; - int32 offset = 2; -} - -message TensorDescriptor { - int32 format = 1; - int32 data_type = 2; - repeated int64 dim = 3; - uint32 size = 4; - bool reuse_input = 5; - bool output_tensor = 7; - DeviceType device_type = 8; - bool input_tensor = 9; - uint32 real_dim_cnt = 10; - uint32 reuse_input_index = 11; - AllOffsetQuantizeInfo alloffset_quantize_info = 12; -} - -message CompressInfo { - int32 blockRow = 1; // block row - int32 blockCol = 2; // block col - int32 fractalK = 3; // fractal K - int32 fractalN = 4; // fractal N - int32 lastFractalK = 5; // K of last fractal - int32 lastFractalN = 6; // N of last fractal - int32 cubeSize = 7; // cube's length - int32 loadDir = 8; // data load directtiono 0:col load 1:row load -} - -message AttrDef { - message ListValue { - repeated string s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated uint32 u = 6 [packed = true]; // "list(uint)" - repeated bytes bt = 7; - } - - oneof value { - string s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - uint32 u = 6; // "uint32" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs { - string name = 1; - map attr = 2; -} - diff --git a/ge/executor/proto/op_mapping.proto b/ge/executor/proto/op_mapping.proto deleted file mode 100644 index d626eb49..00000000 --- a/ge/executor/proto/op_mapping.proto +++ /dev/null @@ -1,75 +0,0 @@ -syntax = "proto3"; -package toolkit.aicpu.dump; - -message Shape { - repeated uint64 dim = 1; -} - -message Output { - int32 data_type = 1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - string original_name = 5; - int32 original_output_index = 6; - int32 original_output_data_type = 7; - int32 original_output_format = 8; - uint64 size = 9; - Shape origin_shape = 10; -} - -message Input { - int32 data_type =1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - uint64 size = 5; - Shape origin_shape = 6; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - uint64 address = 2; - uint64 size = 3; -} - -message Op { - string op_name = 1; - string op_type = 2; -} - -message Task { - uint32 task_id = 1; - uint32 stream_id = 2; - Op op = 3; - repeated Output output = 4; - bool end_graph = 5; - repeated Input input = 6; - repeated OpBuffer buffer = 7; -} - -message OpMappingInfo { - string dump_path = 1; - oneof model_name_param { - string model_name = 2; - } - oneof model_id_param { - uint32 model_id = 3; - } - oneof step_id { - uint64 step_id_addr = 4; - } - oneof iterations_per_loop { - uint64 iterations_per_loop_addr = 5; - } - oneof loop_cond { - uint64 loop_cond_addr = 6; - } - uint32 flag = 7; // 0x01 load, 0x00 unload - repeated Task task = 8; - string dump_step = 9; -} \ No newline at end of file diff --git a/ge/executor/proto/task.proto b/ge/executor/proto/task.proto deleted file mode 100644 index 0da5631e..00000000 --- a/ge/executor/proto/task.proto +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; - KernelDefWithHandle kernel_with_handle = 40; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; -} - -message KernelDefWithHandle { - KernelContext context = 1; - - uint64 handle = 10; - string dev_func = 11; - uint32 block_dim = 12; - uint32 args_size = 13; - bytes args = 14; - bytes sm_desc = 15; - string original_kernel_key = 16; - string node_info = 17; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - repeated uint32 origin_op_index = 8; -} - - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index a56eaadf..3fd8be1a 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -80,7 +80,7 @@ ANALYZER_SRC_FILES:= \ OMG_HOST_SRC_FILES := \ model/ge_model.cc \ model/ge_root_model.cc \ - graph/common/transop_util.cc \ + common/transop_util.cc \ graph/passes/pass_manager.cc \ graph/passes/resource_pair_add_control_pass.cc \ graph/passes/resource_pair_remove_control_pass.cc \ @@ -115,9 +115,9 @@ OMG_HOST_SRC_FILES := \ graph/passes/mark_graph_unknown_status_pass.cc \ graph/passes/mark_node_unknown_shape_pass.cc \ graph/passes/mark_agnostic_pass.cc \ - graph/common/omg_util.cc \ - graph/common/bcast.cc \ - graph/common/local_context.cc \ + common/omg_util.cc \ + common/bcast.cc \ + common/local_context.cc \ graph/passes/dimension_compute_pass.cc \ graph/passes/dimension_adjust_pass.cc \ graph/passes/get_original_format_pass.cc \ diff --git a/ge/ge_local_engine/CMakeLists.txt b/ge/ge_local_engine/CMakeLists.txt index 3675d333..01a10eaa 100755 --- a/ge/ge_local_engine/CMakeLists.txt +++ b/ge/ge_local_engine/CMakeLists.txt @@ -41,8 +41,6 @@ target_include_directories(ge_local_engine PRIVATE ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### @@ -91,8 +89,6 @@ target_include_directories(atc_ge_local_engine PRIVATE ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### @@ -146,8 +142,6 @@ target_include_directories(ge_local_opskernel_builder PRIVATE ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### @@ -197,8 +191,6 @@ target_include_directories(atc_ge_local_opskernel_builder PRIVATE ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### @@ -254,8 +246,6 @@ target_include_directories(ge_local_opskernel_builder_static PRIVATE ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### diff --git a/ge/ge_local_engine/engine/ge_local_engine.cc b/ge/ge_local_engine/engine/ge_local_engine.cc index ac3e5473..910bb924 100755 --- a/ge/ge_local_engine/engine/ge_local_engine.cc +++ b/ge/ge_local_engine/engine/ge_local_engine.cc @@ -19,10 +19,10 @@ #include #include #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" -#include "common/constant/constant.h" +#include "framework/common/ge_inner_error_codes.h" +#include "ge_local_engine/common/constant/constant.h" #include "common/ge/ge_util.h" -#include "ops_kernel_store/ge_local_ops_kernel_info.h" +#include "ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h" namespace ge { namespace ge_local { diff --git a/ge/ge_local_engine/engine/host_cpu_engine.cc b/ge/ge_local_engine/engine/host_cpu_engine.cc index cd68ae15..d9b67736 100755 --- a/ge/ge_local_engine/engine/host_cpu_engine.cc +++ b/ge/ge_local_engine/engine/host_cpu_engine.cc @@ -13,15 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "host_cpu_engine.h" -#include "graph/common/omg_util.h" +#include "ge_local_engine/engine/host_cpu_engine.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_adapter.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/type_utils.h" #include "register/op_kernel_registry.h" #include "register/host_cpu_context.h" #include "common/ge/ge_util.h" #include "common/ge/plugin_manager.h" -#include "graph/utils/type_utils.h" #include "common/fp16_t.h" #include "common/math/math_util.h" @@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) { } Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr &op_kernel) { - std::string op_type; - auto status = GetOriginalType(node, op_type); - GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status); - + const std::string op_type = NodeUtils::GetNodeType(node); auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); if (kernel == nullptr) { GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); diff --git a/ge/ge_local_engine/engine/host_cpu_engine.h b/ge/ge_local_engine/engine/host_cpu_engine.h index d13fcae1..5d6fa664 100644 --- a/ge/ge_local_engine/engine/host_cpu_engine.h +++ b/ge/ge_local_engine/engine/host_cpu_engine.h @@ -33,7 +33,7 @@ #include #include "framework/common/ge_inner_error_codes.h" #include "graph/node.h" -#include "graph/operator.h" +#include "external/graph/operator.h" #include "external/../register/register.h" namespace ge { diff --git a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_builder.cc b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_builder.cc index 5842fe29..33aa407d 100644 --- a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_builder.cc +++ b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_builder.cc @@ -14,9 +14,9 @@ * limitations under the License. */ -#include "ge_local_ops_kernel_builder.h" +#include "ge_local_engine/ops_kernel_store/ge_local_ops_kernel_builder.h" #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/utils/node_utils.h" diff --git a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc index 504c3f2f..d775309d 100755 --- a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc +++ b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc @@ -16,14 +16,14 @@ #include "ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h" #include -#include "common/constant/constant.h" +#include "ge_local_engine/common/constant/constant.h" #include "common/ge/ge_util.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -#include "op/op_factory.h" +#include "ge_local_engine/ops_kernel_store/op/op_factory.h" #include "proto/task.pb.h" namespace ge { diff --git a/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc b/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc index ee601a99..dc4abfb8 100755 --- a/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc +++ b/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc @@ -16,7 +16,7 @@ #include "ge_local_engine/ops_kernel_store/op/ge_deleted_op.h" #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "ge_local_engine/ops_kernel_store/op/op_factory.h" namespace ge { diff --git a/ge/ge_local_engine/ops_kernel_store/op/no_op.cc b/ge/ge_local_engine/ops_kernel_store/op/no_op.cc index c2104693..45d9da47 100755 --- a/ge/ge_local_engine/ops_kernel_store/op/no_op.cc +++ b/ge/ge_local_engine/ops_kernel_store/op/no_op.cc @@ -16,7 +16,7 @@ #include "ge_local_engine/ops_kernel_store/op/no_op.h" #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "ge_local_engine/ops_kernel_store/op/op_factory.h" namespace ge { diff --git a/ge/ge_local_engine/ops_kernel_store/op/op.cc b/ge/ge_local_engine/ops_kernel_store/op/op.cc index 11229b2c..c2ef0091 100644 --- a/ge/ge_local_engine/ops_kernel_store/op/op.cc +++ b/ge/ge_local_engine/ops_kernel_store/op/op.cc @@ -16,7 +16,7 @@ #include "ge_local_engine/ops_kernel_store/op/op.h" #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/op_desc.h" #include "graph/utils/anchor_utils.h" #include "graph/utils/tensor_utils.h" diff --git a/ge/ge_local_engine/ops_kernel_store/op/op.h b/ge/ge_local_engine/ops_kernel_store/op/op.h index b75a8bed..004723e1 100644 --- a/ge/ge_local_engine/ops_kernel_store/op/op.h +++ b/ge/ge_local_engine/ops_kernel_store/op/op.h @@ -20,7 +20,7 @@ #include #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/node.h" namespace ge { diff --git a/ge/ge_local_engine/ops_kernel_store/op/op_factory.cc b/ge/ge_local_engine/ops_kernel_store/op/op_factory.cc index 2e56b7bb..18f3b7b9 100644 --- a/ge/ge_local_engine/ops_kernel_store/op/op_factory.cc +++ b/ge/ge_local_engine/ops_kernel_store/op/op_factory.cc @@ -16,7 +16,7 @@ #include "ge_local_engine/ops_kernel_store/op/op_factory.h" #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/op_desc.h" namespace ge { diff --git a/ge/ge_local_engine/proto/task.proto b/ge/ge_local_engine/proto/task.proto deleted file mode 100644 index 0da5631e..00000000 --- a/ge/ge_local_engine/proto/task.proto +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; - KernelDefWithHandle kernel_with_handle = 40; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; -} - -message KernelDefWithHandle { - KernelContext context = 1; - - uint64 handle = 10; - string dev_func = 11; - uint32 block_dim = 12; - uint32 args_size = 13; - bytes args = 14; - bytes sm_desc = 15; - string original_kernel_key = 16; - string node_info = 17; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - repeated uint32 origin_op_index = 8; -} - - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/ge/ge_opt_info/ge_opt_info.cc b/ge/ge_opt_info/ge_opt_info.cc new file mode 100644 index 00000000..8c1b84ab --- /dev/null +++ b/ge/ge_opt_info/ge_opt_info.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_opt_info/ge_opt_info.h" + +#include +#include +#include "graph/ge_local_context.h" +#include "ge/ge_api_types.h" +#include "common/debug/ge_log.h" +#include "opt_info.h" + +namespace ge { +Status GeOptInfo::SetOptInfo() { + std::string soc_ver; + graphStatus ret = GetThreadLocalContext().GetOption(SOC_VERSION, soc_ver); + if (ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get soc version failed."); + GELOGE(FAILED, "[Get][SocVersion]Get soc version failed."); + return FAILED; + } + GELOGD("Soc version:%s.", soc_ver.c_str()); + std::map opt_info; + // the first arg does not work at present. + if (gelc::GetOptInfo(gelc::kOffline, soc_ver, opt_info) != gelc::SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get optional information failed, is_offline:%d, soc version:%s", + gelc::kOffline, soc_ver.c_str()); + GELOGE(FAILED, "[Get][OptInfo]Get optional information failed, is_offline:%d, soc version:%s", + gelc::kOffline, soc_ver.c_str()); + return FAILED; + } + // do nothing if get empty information + if (opt_info.empty()) { + GELOGI("Optional information is empty."); + return SUCCESS; + } + std::map graph_options = GetThreadLocalContext().GetAllGraphOptions(); + for (const auto &itr : opt_info) { + graph_options.emplace(itr.first, itr.second); + GELOGI("Get optional information success, key:%s, value:%s.", itr.first.c_str(), itr.second.c_str()); + } + GetThreadLocalContext().SetGraphOption(graph_options); + return SUCCESS; +} +} // namespace ge diff --git a/ge/ge_opt_info/ge_opt_info.h b/ge/ge_opt_info/ge_opt_info.h new file mode 100644 index 00000000..5cc1063a --- /dev/null +++ b/ge/ge_opt_info/ge_opt_info.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_OPT_INFO_GE_OPT_INFO_H_ +#define GE_OPT_INFO_GE_OPT_INFO_H_ + +#include "ge/ge_api_error_codes.h" +#include "register/register_types.h" + +namespace ge { +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeOptInfo { + public: + GeOptInfo() = default; + ~GeOptInfo() = default; + static Status SetOptInfo(); +}; +} // namespace ge + +#endif // GE_OPT_INFO_GE_OPT_INFO_H_ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 8ca8572c..d6462542 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -43,10 +43,10 @@ LIBGE_LOCAL_SRC_FILES := \ graph/build/stream_allocator.cc \ graph/build/stream_graph_optimizer.cc \ graph/build/task_generator.cc \ - graph/common/bcast.cc \ - graph/common/local_context.cc \ - graph/common/omg_util.cc \ - graph/common/transop_util.cc \ + common/bcast.cc \ + common/local_context.cc \ + common/omg_util.cc \ + common/transop_util.cc \ graph/execute/graph_execute.cc \ graph/label/case_label_maker.cc \ graph/label/if_label_maker.cc \ diff --git a/ge/ge_runtime/CMakeLists.txt b/ge/ge_runtime/CMakeLists.txt index 40113285..ffea784b 100644 --- a/ge/ge_runtime/CMakeLists.txt +++ b/ge/ge_runtime/CMakeLists.txt @@ -35,21 +35,13 @@ target_compile_definitions(ge_runtime PRIVATE target_include_directories(ge_runtime PRIVATE ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR} ${GE_CODE_DIR}/ge ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/graph ${GE_CODE_DIR}/inc/external ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common - ${GE_CODE_DIR}/inc/framework/ge_runtime - ${GE_CODE_DIR}/inc/cce ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${METADEF_DIR} ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external/graph ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/graph ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/ge ) diff --git a/ge/ge_runtime/model_runner.cc b/ge/ge_runtime/model_runner.cc index 9961ab4e..9338aae2 100644 --- a/ge/ge_runtime/model_runner.cc +++ b/ge/ge_runtime/model_runner.cc @@ -14,12 +14,12 @@ * limitations under the License. */ -#include "ge_runtime/model_runner.h" -#include "./runtime_model.h" +#include "framework/ge_runtime/model_runner.h" +#include "ge_runtime/runtime_model.h" #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" -#include "ge_runtime/davinci_model.h" +#include "framework/ge_runtime/davinci_model.h" #include "graph/op_desc.h" namespace ge { diff --git a/ge/ge_runtime/output.cc b/ge/ge_runtime/output.cc index 90c33bb4..0be053d5 100644 --- a/ge/ge_runtime/output.cc +++ b/ge/ge_runtime/output.cc @@ -15,8 +15,8 @@ */ #include "ge_runtime/output.h" -#include "common/ge_inner_error_codes.h" -#include "common/util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" namespace ge { diff --git a/ge/ge_runtime/output.h b/ge/ge_runtime/output.h index 1f7f91ee..61fcf94e 100644 --- a/ge/ge_runtime/output.h +++ b/ge/ge_runtime/output.h @@ -19,8 +19,8 @@ #include #include -#include "ge_runtime/davinci_model.h" -#include "common/ge_types.h" +#include "framework/ge_runtime/davinci_model.h" +#include "framework/common/ge_types.h" namespace ge { namespace model_runner { diff --git a/ge/ge_runtime/runtime_model.cc b/ge/ge_runtime/runtime_model.cc index a19fbcaf..490ac25b 100644 --- a/ge/ge_runtime/runtime_model.cc +++ b/ge/ge_runtime/runtime_model.cc @@ -16,16 +16,15 @@ #include "ge_runtime/runtime_model.h" #include -#include "./model_context.h" -#include "./task/task.h" -#include "common/ge_inner_error_codes.h" -#include "common/types.h" -#include "common/util.h" -#include "common/math/math_util.h" +#include "ge_runtime/model_context.h" +#include "ge_runtime/task/task.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/op/op_parser_util.h" -#include "graph/types.h" -#include "task/task_factory.h" +#include "external/graph/types.h" +#include "ge_runtime/task/task_factory.h" #include "ge/common/math/math_util.h" namespace ge { diff --git a/ge/ge_runtime/runtime_model.h b/ge/ge_runtime/runtime_model.h index d0c466d4..429a143f 100644 --- a/ge/ge_runtime/runtime_model.h +++ b/ge/ge_runtime/runtime_model.h @@ -20,8 +20,8 @@ #include #include #include -#include "ge_runtime/davinci_model.h" -#include "common/ge_types.h" +#include "framework/ge_runtime/davinci_model.h" +#include "framework/common/ge_types.h" #include "runtime/base.h" #include "runtime/rt_model.h" diff --git a/ge/ge_runtime/task/hccl_task.cc b/ge/ge_runtime/task/hccl_task.cc index df4bf4c8..bfe0d0f3 100644 --- a/ge/ge_runtime/task/hccl_task.cc +++ b/ge/ge_runtime/task/hccl_task.cc @@ -16,6 +16,7 @@ #include "ge_runtime/task/hccl_task.h" #include +#include "framework/common/util.h" #include "ge_runtime/task/task_factory.h" #include "common/opskernel/ops_kernel_info_store.h" #include "common/opskernel/ge_task_info.h" diff --git a/ge/ge_runtime/task/label_goto_task.cc b/ge/ge_runtime/task/label_goto_task.cc index e9046026..a3b70971 100644 --- a/ge/ge_runtime/task/label_goto_task.cc +++ b/ge/ge_runtime/task/label_goto_task.cc @@ -85,7 +85,7 @@ bool LabelGotoTask::Distribute() { return false; } - rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); + rt_ret = rtLabelListCpy(reinterpret_cast(label_list.data()), label_list.size(), label_info_, label_info_size); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); return false; diff --git a/ge/ge_runtime/task/task.h b/ge/ge_runtime/task/task.h index c255fd22..8170f3ca 100644 --- a/ge/ge_runtime/task/task.h +++ b/ge/ge_runtime/task/task.h @@ -23,7 +23,7 @@ #include #include "runtime/rt_model.h" #include "ge_runtime/model_context.h" -#include "ge_runtime/task_info.h" +#include "framework/ge_runtime/task_info.h" #include "external/runtime/rt_error_codes.h" namespace ge { diff --git a/ge/ge_runtime/task/task_factory.h b/ge/ge_runtime/task/task_factory.h index 670d1fef..f19b7419 100644 --- a/ge/ge_runtime/task/task_factory.h +++ b/ge/ge_runtime/task/task_factory.h @@ -21,9 +21,9 @@ #include #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" -#include "ge_runtime/task_info.h" +#include "framework/ge_runtime/task_info.h" namespace ge { namespace model_runner { diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 575afb35..7c5cb330 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -14,29 +14,30 @@ * limitations under the License. */ -#include "generator/ge_generator.h" +#include "framework/generator/ge_generator.h" #include #include "common/ge/ge_util.h" #include "common/ge/plugin_manager.h" -#include "common/helper/model_helper.h" -#include "common/helper/om_file_helper.h" -#include "common/util.h" +#include "framework/common/helper/model_helper.h" +#include "framework/common/helper/om_file_helper.h" +#include "framework/common/util.h" #include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" -#include "ge/ge_api.h" +#include "external/ge/ge_api.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" #include "graph/manager/graph_manager.h" +#include "graph/manager/graph_var_manager.h" #include "graph/manager/util/rt_context_util.h" #include "graph/operator_factory_impl.h" #include "graph/opsproto_manager.h" #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #include "analyzer/analyzer.h" using std::map; @@ -205,6 +206,7 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const } (void)AttrUtils::SetBool(data_op, "_is_single_op", true); + (void)AttrUtils::SetBool(data_op, ATTR_NAME_IS_ORIGINAL_INPUT, true); GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, REPORT_CALL_ERROR("E19999", "AddInputDesc failed for node:%s", data_op->GetName().c_str()); @@ -674,6 +676,12 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr GELOGD("Current ctx is null."); ctx = nullptr; } + std::function callback = [&]() { + if (ctx != nullptr) { + (void)rtCtxSetCurrent(ctx); + } + }; + GE_MAKE_GUARD(restore, callback); GeRootModelPtr ge_root_model = nullptr; GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); @@ -712,11 +720,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr } return ret; } - - if (ctx != nullptr) { - (void)rtCtxSetCurrent(ctx); - } - return SUCCESS; } @@ -806,7 +809,7 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector return SUCCESS; } -Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) { +Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph) { GE_CHECK_NOTNULL(op_desc); if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) { auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType()); @@ -830,7 +833,11 @@ Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) { } node_op.BreakConnect(); } - auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); + auto comp_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(comp_graph); + auto node = comp_graph->FindNode(op_desc->GetName()); + GE_CHECK_NOTNULL(node); + auto op = OpDescUtils::CreateOperatorFromNode(node); auto ret = op_desc->CallInferFormatFunc(op); if (ret != GRAPH_SUCCESS) { REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail", @@ -877,7 +884,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in Graph graph; GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str()); - GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc)); + GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc, graph)); // 2. check engine type when compile online if (model_file_name == kFileNameSuffix) { @@ -1151,7 +1158,6 @@ Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector if (ret != SUCCESS) { REPORT_CALL_ERROR("E19999", "build graph failed, graph id:%u, ret:%d", graph_id, ret); GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "[Build][Graph] fail, graph id: %u", graph_id); - ret = GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; } RtContextUtil::GetInstance().DestroyRtContexts(session_id); diff --git a/ge/generator/generator_api.cc b/ge/generator/generator_api.cc index b64a9eb3..56a35130 100644 --- a/ge/generator/generator_api.cc +++ b/ge/generator/generator_api.cc @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "generator/generator_api.h" +#include "framework/generator/generator_api.h" #include "common/ge/ge_util.h" -#include "common/util.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" -#include "generator/ge_generator.h" +#include "framework/generator/ge_generator.h" #include "graph/ge_attr_value.h" #include "graph/ge_tensor.h" #include "graph/op_desc.h" diff --git a/ge/graph/build/graph_builder.cc b/ge/graph/build/graph_builder.cc index 8b172e63..e1398d1f 100644 --- a/ge/graph/build/graph_builder.cc +++ b/ge/graph/build/graph_builder.cc @@ -17,18 +17,18 @@ #include "graph/build/graph_builder.h" #include "graph/build/memory/graph_mem_assigner.h" #include "common/ge/ge_util.h" -#include "common/helper/model_helper.h" +#include "framework/common/helper/model_helper.h" #include "graph/build/logical_stream_allocator.h" #include "graph/build/run_context.h" #include "graph/build/stream_graph_optimizer.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/ge_context.h" #include "graph/manager/graph_var_manager.h" #include "graph/passes/mark_same_addr_pass.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #include "graph/ge_context.h" #include "opskernel_manager/ops_kernel_builder_manager.h" #include "graph/utils/op_desc_utils.h" diff --git a/ge/graph/build/graph_builder.h b/ge/graph/build/graph_builder.h index fb9ab6bd..6ed14dae 100644 --- a/ge/graph/build/graph_builder.h +++ b/ge/graph/build/graph_builder.h @@ -22,23 +22,23 @@ #include #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/debug/memory_dumper.h" #include "common/properties_manager.h" -#include "common/string_util.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/string_util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "graph/build/model_builder.h" #include "graph/build/task_generator.h" #include "graph/compute_graph.h" -#include "graph/graph.h" +#include "external/graph/graph.h" #include "graph/manager/graph_manager_utils.h" #include "graph/model.h" #include "graph/node.h" #include "graph/partition/graph_partition.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" namespace ge { class GraphBuilder { diff --git a/ge/graph/build/label_allocator.cc b/ge/graph/build/label_allocator.cc index 32bdd0a3..f2329769 100644 --- a/ge/graph/build/label_allocator.cc +++ b/ge/graph/build/label_allocator.cc @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "label_allocator.h" +#include "graph/build/label_allocator.h" #include "framework/common/types.h" -#include "common/util.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/util.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/label/label_maker.h" @@ -80,12 +80,16 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::setGetParentNode(); if (func_node == nullptr) { - REPORT_INNER_ERROR("E19999", "Parent node not set in node:%s(%s), graph:%s", - func_node->GetName().c_str(), func_node->GetType().c_str(), graph->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "Parent node not set, graph:%s", graph->GetName().c_str()); GELOGE(INTERNAL_ERROR, "[Get][Node] Parent functional node not set: %s.", graph->GetName().c_str()); return false; } + if (func_node->GetOpDesc() != nullptr && func_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { + GELOGD("Graph[%s] is ffts subgraph, skip label allocator.", graph->GetName().c_str()); + return true; + } + ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph(); if (owner_graph == nullptr) { REPORT_INNER_ERROR("E19999", "ComputeGraph owner not set in node:%s(%s), graph:%s", diff --git a/ge/graph/build/logical_stream_allocator.cc b/ge/graph/build/logical_stream_allocator.cc index c74cdf7a..3d6ca74a 100644 --- a/ge/graph/build/logical_stream_allocator.cc +++ b/ge/graph/build/logical_stream_allocator.cc @@ -22,7 +22,7 @@ #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" using std::map; using std::set; @@ -474,6 +474,11 @@ Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vectorGetDirectNode()) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); + if (op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID)) { + op_desc->SetStreamId(kInvalidStream); + GELOGI("Ffts node %s of type %s reassign to invalid stream.", node->GetName().c_str(), node->GetType().c_str()); + continue; + } int64_t stream_id = op_desc->GetStreamId(); if (ops_without_label.find(op_desc) != ops_without_label.end()) { if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) { diff --git a/ge/graph/build/memory/block_mem_assigner.cc b/ge/graph/build/memory/block_mem_assigner.cc index 9b81eae3..7d0db676 100755 --- a/ge/graph/build/memory/block_mem_assigner.cc +++ b/ge/graph/build/memory/block_mem_assigner.cc @@ -24,7 +24,7 @@ #include "graph/buffer.h" #include "graph/ge_attr_value.h" #include "graph/ge_context.h" -#include "graph/types.h" +#include "external/graph/types.h" #include "graph/node.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" @@ -34,9 +34,9 @@ #include "graph/debug/ge_attr_define.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/optimize/common/params.h" -#include "omg/omg_inner_types.h" +#include "framework/omg/omg_inner_types.h" #include "runtime/mem.h" using std::map; diff --git a/ge/graph/build/memory/block_mem_assigner.h b/ge/graph/build/memory/block_mem_assigner.h index 231cce09..651daed5 100755 --- a/ge/graph/build/memory/block_mem_assigner.h +++ b/ge/graph/build/memory/block_mem_assigner.h @@ -24,9 +24,9 @@ #include #include #include -#include "common/ge_inner_error_codes.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "graph/build/memory/mem_assigner.h" #include "graph/compute_graph.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/build/memory/buffer_pool_mem_assigner.cc b/ge/graph/build/memory/buffer_pool_mem_assigner.cc index d66fe038..ca197e02 100644 --- a/ge/graph/build/memory/buffer_pool_mem_assigner.cc +++ b/ge/graph/build/memory/buffer_pool_mem_assigner.cc @@ -15,7 +15,7 @@ */ #include "graph/build/memory/buffer_pool_mem_assigner.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/tensor_utils.h" #include "framework/common/util.h" #include "graph/compute_graph.h" diff --git a/ge/graph/build/memory/graph_mem_assigner.cc b/ge/graph/build/memory/graph_mem_assigner.cc index e086940a..542b6215 100755 --- a/ge/graph/build/memory/graph_mem_assigner.cc +++ b/ge/graph/build/memory/graph_mem_assigner.cc @@ -24,7 +24,7 @@ #include "graph/build/memory/hybrid_mem_assigner.h" #include "graph/build/memory/var_mem_assign_util.h" #include "graph/build/memory/block_mem_assigner.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_attr_value.h" #include "graph/manager/graph_var_manager.h" @@ -275,7 +275,7 @@ Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, map({"size", "item", "maxsize"}), std::vector({std::to_string(total_mem_offset), "featuremap", std::to_string(VarManager::Instance(session_id)->GetGraphMemoryMaxSize())})); - return ge::FAILED; + return ACL_ERROR_GE_MEMORY_ALLOCATION; } return SUCCESS; } diff --git a/ge/graph/build/memory/hybrid_mem_assigner.h b/ge/graph/build/memory/hybrid_mem_assigner.h index 2bdfd5c5..33bb152b 100755 --- a/ge/graph/build/memory/hybrid_mem_assigner.h +++ b/ge/graph/build/memory/hybrid_mem_assigner.h @@ -22,8 +22,8 @@ #include "graph/build/memory/block_mem_assigner.h" #include "graph/compute_graph.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" namespace ge { using BlockMemAssignerPtr = std::shared_ptr; diff --git a/ge/graph/build/memory/mem_assigner.h b/ge/graph/build/memory/mem_assigner.h index 7d0252d9..d607b989 100755 --- a/ge/graph/build/memory/mem_assigner.h +++ b/ge/graph/build/memory/mem_assigner.h @@ -17,8 +17,8 @@ #ifndef GE_GRAPH_BUILD_MEMORY_MEM_ASSIGNER_H_ #define GE_GRAPH_BUILD_MEMORY_MEM_ASSIGNER_H_ -#include "common/ge_inner_error_codes.h" -#include "memory/memory_assigner.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/memory/memory_assigner.h" namespace ge { static const int64_t kInvalidOffset = -1; diff --git a/ge/graph/build/memory/memory_assigner.cc b/ge/graph/build/memory/memory_assigner.cc index 570aae07..41171164 100755 --- a/ge/graph/build/memory/memory_assigner.cc +++ b/ge/graph/build/memory/memory_assigner.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "memory/memory_assigner.h" +#include "framework/memory/memory_assigner.h" #include #include "framework/common/debug/ge_log.h" #include "graph/build/memory/graph_mem_assigner.h" @@ -29,9 +29,10 @@ Status MemoryAssigner::AssignMemory(bool is_loop_graph, map &m } // Reassign memory for special nodes - if (graph_mem_assigner.ReAssignMemory(is_loop_graph, mem_offset) != ge::SUCCESS) { + Status ret = graph_mem_assigner.ReAssignMemory(is_loop_graph, mem_offset); + if (ret != ge::SUCCESS) { GELOGE(ge::FAILED, "[ReAssign][Memory] failed, graph:%s", compute_graph_->GetName().c_str()); - return ge::FAILED; + return ret; } // Assign memory (block and offset) for zero copy nodes diff --git a/ge/graph/build/memory/var_mem_assign_util.cc b/ge/graph/build/memory/var_mem_assign_util.cc index b8138a30..dc7c3b01 100755 --- a/ge/graph/build/memory/var_mem_assign_util.cc +++ b/ge/graph/build/memory/var_mem_assign_util.cc @@ -16,14 +16,14 @@ #include "graph/build/memory/var_mem_assign_util.h" #include -#include "common/types.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" -#include "graph/tensor.h" -#include "graph/types.h" +#include "external/graph/tensor.h" +#include "external/graph/types.h" #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" diff --git a/ge/graph/build/memory/var_mem_assign_util.h b/ge/graph/build/memory/var_mem_assign_util.h index 9528dbdb..26da9111 100644 --- a/ge/graph/build/memory/var_mem_assign_util.h +++ b/ge/graph/build/memory/var_mem_assign_util.h @@ -17,8 +17,8 @@ #ifndef GE_GRAPH_BUILD_MEMORY_VAR_MEM_ASSIGN_UTIL_H_ #define GE_GRAPH_BUILD_MEMORY_VAR_MEM_ASSIGN_UTIL_H_ #include -#include "common/debug/log.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/debug/log.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/utils/node_utils.h" namespace ge { diff --git a/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc index 431e4882..897be1f8 100755 --- a/ge/graph/build/model_builder.cc +++ b/ge/graph/build/model_builder.cc @@ -22,20 +22,19 @@ #include "common/dump/dump_manager.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" -#include "graph/attr_value.h" +#include "external/graph/attr_value.h" #include "graph/buffer.h" #include "graph/build/stream_allocator.h" -#include "graph/common/omg_util.h" -#include "graph/common/ge_call_wrapper.h" -#include "graph/common/local_context.h" +#include "common/omg_util.h" +#include "common/ge_call_wrapper.h" +#include "common/local_context.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_attr_value.h" #include "graph/ge_context.h" -#include "graph/ge_error_codes.h" -#include "graph/manager/graph_mem_allocator.h" +#include "external/graph/ge_error_codes.h" #include "graph/manager/graph_var_manager.h" #include "graph/optimize/common/params.h" -#include "graph/types.h" +#include "external/graph/types.h" #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" @@ -43,8 +42,8 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" -#include "memory/memory_assigner.h" -#include "omg/version.h" +#include "framework/memory/memory_assigner.h" +#include "framework/omg/version.h" #include "register/op_registry.h" #include "graph/passes/set_input_output_offset_pass.h" #include "graph/build/memory/block_mem_assigner.h" @@ -707,7 +706,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) { GE_CHECK_NOTNULL(kernel_buffer.GetData()); std::vector data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); - tbe_kernel = std::make_shared(kernel_name, std::move(data)); + tbe_kernel = MakeShared(kernel_name, std::move(data)); GE_CHECK_NOTNULL(tbe_kernel); GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); diff --git a/ge/graph/build/model_builder.h b/ge/graph/build/model_builder.h index 6f097329..d87976dd 100644 --- a/ge/graph/build/model_builder.h +++ b/ge/graph/build/model_builder.h @@ -23,17 +23,17 @@ #include #include #include -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "common/tbe_kernel_store.h" #include "common/cust_aicpu_kernel_store.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "graph/compute_graph.h" #include "graph/manager/graph_manager_utils.h" #include "graph/model.h" #include "graph/node.h" -#include "model/ge_model.h" -#include "omg/omg_inner_types.h" +#include "common/model/ge_model.h" +#include "framework/omg/omg_inner_types.h" namespace ge { class ModelBuilder { diff --git a/ge/graph/build/run_context.cc b/ge/graph/build/run_context.cc index 05e40b63..e629bddc 100644 --- a/ge/graph/build/run_context.cc +++ b/ge/graph/build/run_context.cc @@ -15,10 +15,10 @@ */ #include "graph/build/run_context.h" -#include "common/util.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { RunContextUtil::~RunContextUtil() { DestroyRtModelResources(); } diff --git a/ge/graph/build/run_context.h b/ge/graph/build/run_context.h index 82f799aa..20ba76d4 100755 --- a/ge/graph/build/run_context.h +++ b/ge/graph/build/run_context.h @@ -18,7 +18,7 @@ #define GE_GRAPH_BUILD_RUN_CONTEXT_H_ #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/opskernel/ops_kernel_info_types.h" #include "framework/common/types.h" #include "graph/compute_graph.h" diff --git a/ge/graph/build/stream_allocator.cc b/ge/graph/build/stream_allocator.cc index dae36b83..987a77f7 100644 --- a/ge/graph/build/stream_allocator.cc +++ b/ge/graph/build/stream_allocator.cc @@ -22,12 +22,12 @@ #include "framework/common/fmk_error_codes.h" #include "framework/common/types.h" #include "graph/build/logical_stream_allocator.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" #include "graph/utils/graph_utils.h" #include "init/gelib.h" -#include "common/string_util.h" +#include "framework/common/string_util.h" #include "common/util/error_manager/error_manager.h" using std::map; @@ -432,7 +432,11 @@ Status StreamAllocator::SetActiveStreamsForSubgraphs() { // Insert the send/recv event id to the graph Status StreamAllocator::InsertSyncEvents() { - for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { + auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) { + return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH); + }; + + for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag(), nullptr, ffts_filter)) { // Take the adjacent points, then judge whether need to insert the event for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { @@ -531,6 +535,11 @@ Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const Status StreamAllocator::InsertEventsForSubgraph() { for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { GE_CHECK_NOTNULL(subgraph); + const auto parent_node = subgraph->GetParentNode(); + if (parent_node != nullptr && parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { + GELOGD("Skip ffts subgraph, parent node is %s.", parent_node->GetName().c_str()); + continue; + } for (const auto &node : subgraph->GetDirectNode()) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); diff --git a/ge/graph/build/stream_graph_optimizer.cc b/ge/graph/build/stream_graph_optimizer.cc index 30142c2b..acf91ad5 100644 --- a/ge/graph/build/stream_graph_optimizer.cc +++ b/ge/graph/build/stream_graph_optimizer.cc @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "stream_graph_optimizer.h" +#include "graph/build/stream_graph_optimizer.h" #include -#include "common/util.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" diff --git a/ge/graph/build/stream_graph_optimizer.h b/ge/graph/build/stream_graph_optimizer.h index d69fa7ba..ec32f7fb 100644 --- a/ge/graph/build/stream_graph_optimizer.h +++ b/ge/graph/build/stream_graph_optimizer.h @@ -18,7 +18,7 @@ #define GE_GRAPH_BUILD_OPTIMIZE_STREAM_GRAPH_H_ #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/opskernel/ops_kernel_info_types.h" #include "framework/common/types.h" #include "graph/compute_graph.h" diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc index 12da803d..abb409c4 100755 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -18,8 +18,8 @@ #include #include #include "common/profiling/profiling_manager.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" @@ -29,10 +29,10 @@ #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "init/gelib.h" #include "graph/ge_local_context.h" -#include "ge/ge_api_types.h" +#include "external/ge/ge_api_types.h" #include "opskernel_manager/ops_kernel_builder_manager.h" using domi::LogTimeStampDef; @@ -50,6 +50,7 @@ const char *const kIsInputVar = "INPUT_IS_VAR"; const char *const kIsOutputVar = "OUTPUT_IS_VAR"; const char *const kProfilingMode = "PROFILING_MODE"; const char *const kIteratorV2 = "IteratorV2"; +const char *const kKernelInfoNameHccl = "ops_kernel_info_hccl"; const uint32_t kProfilingArStep = 2; const uint64_t kProfilingFpStartLogid = 1; const uint64_t kProfilingBpEndLogid = 2; @@ -354,7 +355,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra }; GE_MAKE_GUARD(release, callback); - for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { + auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) { + return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH); + }; + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag(), nullptr, ffts_filter)) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); node_index++; @@ -380,20 +384,16 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); continue; } - if (op_kernel_lib_name.empty()) { - GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); - continue; - } + GE_CHK_BOOL_EXEC_INFO(!op_kernel_lib_name.empty(), continue, + "Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); - if (kernel_info_store == nullptr) { - REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s", - node->GetName().c_str(), node->GetType().c_str(), op_kernel_lib_name.c_str()); - GELOGE(INTERNAL_ERROR, "[Call][GetOpsKernelInfoStore] No ops kernel store or ops kernel builder found. " - "node:%s(%s), op_kernel_lib_name=%s.", name.c_str(), type.c_str(), op_kernel_lib_name.c_str()); - return INTERNAL_ERROR; - } + GE_CHECK_NOTNULL(kernel_info_store); GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(), type.c_str()); + if (node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { + GE_CHK_STATUS_RET(UpdateAnchorStatusForFfts(node), "[Call][UpdateAnchorStatusForFfts] node:%s(%s) failed", + name.c_str(), type.c_str()); + } // Profiling task size_t task_list_size_before = task_def_list.size(); GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); @@ -432,14 +432,15 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra } // Reset stream id to ge stream id, as graph load must use ge stream to reassign stream - void *ops_kernel_info_store_ptr = kernel_info_store.get(); for (size_t idx = task_list_size_before; idx < task_list_size_after; ++idx) { task_def_list[idx].set_stream_id(static_cast(stream_id)); op_name_map[idx] = name; - // Set opsKernelInfoStorePtr and op_index, the two fields be use in DistributeTask and InitTaskInfo TaskDef *task_def_ptr = &task_def_list[idx]; GE_CHECK_NOTNULL(task_def_ptr); - task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast(ops_kernel_info_store_ptr)); + // Set opsKernelInfoStorePtr for hccl which will be use in DistributeTask and InitTaskInfo + if (op_kernel_lib_name == kKernelInfoNameHccl) { + task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast(kernel_info_store.get())); + } } GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %zu task(s).", op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, @@ -571,7 +572,24 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info return ret; } +Status TaskGenerator::UpdateAnchorStatusForFfts(const NodePtr &node) { + GELOGD("Start UpdateAnchorStatusForFfts for %s.", node->GetName().c_str()); + if (!node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { + for (size_t i = 0; i < node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { + auto sub_graph = NodeUtils::GetSubgraph(*node, i); + GE_CHECK_NOTNULL(sub_graph); + GELOGD("Start update anchor status for %s.", sub_graph->GetName().c_str()); + for (auto &ffts_node : sub_graph->GetDirectNode()) { + GE_CHK_STATUS_RET(UpdateAnchorStatus(ffts_node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", + ffts_node->GetName().c_str(), ffts_node->GetType().c_str()); + } + } + } + return SUCCESS; +} + Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) { + GELOGD("Start UpdateAnchorStatus for %s.", node->GetName().c_str()); if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) { REPORT_CALL_ERROR("E19999", "SetAllAnchorStatus fail for op:%s(%s)", node->GetName().c_str(), node->GetType().c_str()); @@ -771,7 +789,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP GELOGI("Start AutoFindBpOpIndex"); NodePtr bp_node = nullptr; uint32_t current_idx = 0; - uint32_t netoutput_idx = 0; for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -789,7 +806,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) { if (bp_node == nullptr) { bp_node = node; - netoutput_idx = current_idx - 1; } } if (graph->GetNeedIteration()) { @@ -814,34 +830,30 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP if (bp_node == nullptr) { GELOGW("not find bp_node."); return SUCCESS; - } else if (bp_node->GetName() == NODE_NAME_NET_OUTPUT) { - profiling_point.bp_index = netoutput_idx; - GELOGI("First bp name %s, idx %u", bp_node->GetName().c_str(), netoutput_idx); - } else { - profiling_point.bp_index = FindLastBpFromBpNode(graph, bp_node); } - return SUCCESS; + return FindLastBpFromBpNode(graph, bp_node, profiling_point.bp_index); } -uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const { - uint32_t last_bp = 0; +Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &target_node, + uint32_t &bp_index) const { + bp_index = 0; + auto target_desc = target_node->GetOpDesc(); + GE_CHECK_NOTNULL(target_desc); OpDescPtr bp_op_desc = nullptr; - for (auto &in_anchor : bp_node->GetAllInDataAnchors()) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) { - continue; - } - auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL(out_node_desc); - if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) { - bp_op_desc = out_node_desc; + for (auto &in_node : target_node->GetInAllNodes()) { + GE_CHECK_NOTNULL(in_node); + auto in_node_desc = in_node->GetOpDesc(); + GE_CHECK_NOTNULL(in_node_desc); + if ((bp_op_desc == nullptr || (in_node_desc->GetId() > bp_op_desc->GetId())) && + (in_node_desc->GetStreamId() == target_desc->GetStreamId())){ + bp_op_desc = in_node_desc; } - GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId()); } if (bp_op_desc == nullptr) { - return last_bp; + GELOGI("Did not find bp node."); + return SUCCESS; } uint32_t current_idx = 0; for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { @@ -849,12 +861,14 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const GE_CHECK_NOTNULL(op_desc); current_idx++; if (op_desc->GetName() == bp_op_desc->GetName()) { - last_bp = current_idx; - GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp); + bp_index = current_idx; + GELOGI("Find bp name %s, idx %u", op_desc->GetName().c_str(), bp_index); break; } } - return last_bp; + GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(), + bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId()); + return SUCCESS; } Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, diff --git a/ge/graph/build/task_generator.h b/ge/graph/build/task_generator.h index 9f12d568..5d204c3c 100755 --- a/ge/graph/build/task_generator.h +++ b/ge/graph/build/task_generator.h @@ -21,7 +21,7 @@ #include #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/opskernel/ops_kernel_info_types.h" #include "framework/common/types.h" #include "graph/compute_graph.h" @@ -80,6 +80,7 @@ class TaskGenerator { Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, std::vector &all_reduce_nodes); private: + Status UpdateAnchorStatusForFfts(const NodePtr &node); Status UpdateAnchorStatus(const NodePtr &node); Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id); @@ -115,7 +116,7 @@ class TaskGenerator { Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const; Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, vector &all_reduce_nodes) const; - uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const; + Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const; Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, ProfilingPoint &profiling_point) const; diff --git a/ge/graph/execute/graph_execute.cc b/ge/graph/execute/graph_execute.cc index 02d7d3ca..03abf91f 100755 --- a/ge/graph/execute/graph_execute.cc +++ b/ge/graph/execute/graph_execute.cc @@ -21,6 +21,7 @@ #include "graph/load/model_manager/model_manager.h" #include "graph/load/model_manager/davinci_model.h" +#include "common/profiling/profiling_manager.h" namespace ge { using Uint32Pair = pair; @@ -31,7 +32,6 @@ GraphExecutor::GraphExecutor() sync_run_mutex_(nullptr), condition_(nullptr), graph_run_listener_(nullptr), - graph_context_(nullptr), last_graph_id_(UINT32_MAX), malloc_flag_(false) {} @@ -79,16 +79,6 @@ Status GraphExecutor::SetCondition(std::mutex *mutex, std::condition_variable *c return SUCCESS; } -Status GraphExecutor::SetGraphContext(GraphContextPtr graph_context_ptr) { - if (graph_context_ptr == nullptr) { - REPORT_INNER_ERROR("E19999", "Check param graph_context_ptr nullptr"); - GELOGE(GE_GRAPH_PARAM_NULLPTR, "[Check][Param] input param graph_context_ptr is nullptr"); - return GE_GRAPH_PARAM_NULLPTR; - } - graph_context_ = graph_context_ptr; - return SUCCESS; -} - Status GraphExecutor::SetDynamicSize(uint32_t model_id, const std::vector &batch_num, int32_t dynamic_type) { auto model_manager = ge::ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); @@ -376,7 +366,11 @@ Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_ro GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[SyncExecute][Model] Error! graph id:%u", graph_id); return GE_GRAPH_SYNC_MODEL_FAILED; } - + ret = ModelSubscribe(graph_id); + if (ret != SUCCESS) { + GELOGE(ret, "[Call][ModelSubscribe] failed, graph_id:%u", graph_id); + return ret; + } return SUCCESS; } @@ -787,4 +781,41 @@ Status GraphExecutor::GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint } return SUCCESS; } + +Status GraphExecutor::GetModelByID(uint32_t model_id, std::shared_ptr &davinci_model) { + auto model_manager = ge::ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + davinci_model = model_manager->GetModel(static_cast(model_id)); + if (davinci_model == nullptr) { + REPORT_INNER_ERROR("E19999", "GetModel from model_manager fail, model_id:%u", model_id); + GELOGE(ge::FAILED, "[Get][Model] failed, Model id:%d is invaild or model is not loaded.", model_id); + return ge::FAILED; + } + return ge::SUCCESS; +} + +Status GraphExecutor::ModelSubscribe(uint32_t graph_id) { + auto &profiling_manager = ProfilingManager::Instance(); + const auto &subcribe_info = profiling_manager.GetSubscribeInfo(); + if (subcribe_info.is_subscribe) { + std::shared_ptr davinci_model = nullptr; + uint32_t model_id = 0; + Status ret = profiling_manager.GetModelIdFromGraph(graph_id, model_id); + if (ret != SUCCESS) { + GELOGE(ret, "[Call][GetModelIdFromGraph] failed, graph_id:%u", graph_id); + return ret; + } + ret = GetModelByID(model_id, davinci_model); + if (ret != SUCCESS) { + GELOGE(ret, "[Call][GetModelByID] failed, model_id:%u", model_id); + return ret; + } + ret = profiling_manager.ProfModelSubscribe(subcribe_info.prof_switch, davinci_model.get()); + if (ret != SUCCESS) { + GELOGE(ret, "[Call][ProfModelSubscribe] failed"); + return ret; + } + } + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/execute/graph_execute.h b/ge/graph/execute/graph_execute.h index aa791c9b..56e322f1 100755 --- a/ge/graph/execute/graph_execute.h +++ b/ge/graph/execute/graph_execute.h @@ -24,20 +24,21 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/debug/memory_dumper.h" -#include "common/ge_types.h" +#include "framework/common/ge_types.h" #include "common/properties_manager.h" -#include "common/string_util.h" -#include "common/types.h" -#include "common/util.h" -#include "ge/ge_api_types.h" +#include "framework/common/string_util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" +#include "external/ge/ge_api_types.h" #include "graph/compute_graph.h" #include "graph/manager/graph_context.h" #include "graph/manager/graph_manager_utils.h" #include "graph/model.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" +#include "graph/load/model_manager/davinci_model.h" namespace ge { class GraphExecutor { @@ -60,8 +61,6 @@ class GraphExecutor { Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr listener); - Status SetGraphContext(GraphContextPtr graph_context_ptr); - static Status SetDynamicSize(uint32_t model_id, const std::vector &batch_num, int32_t dynamic_type); void SetTrainFlag(bool is_train_graph); @@ -150,6 +149,10 @@ class GraphExecutor { static Status SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model, const RunAsyncCallback &callback); + Status ModelSubscribe(uint32_t graph_id); + + Status GetModelByID(uint32_t model_id, std::shared_ptr &davinci_model); + bool init_flag_; bool train_graph_flag_; @@ -160,8 +163,6 @@ class GraphExecutor { // Run graph asynchronous call back listener std::shared_ptr graph_run_listener_; - GraphContextPtr graph_context_; - std::vector outputs_desc_; GraphId last_graph_id_; diff --git a/ge/graph/execute/model_executor.cc b/ge/graph/execute/model_executor.cc new file mode 100644 index 00000000..993ba8c3 --- /dev/null +++ b/ge/graph/execute/model_executor.cc @@ -0,0 +1,565 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/execute/model_executor.h" + +#include "graph/ge_context.h" +#include "graph/debug/ge_attr_define.h" +#include "common/ge_call_wrapper.h" +#include "common/local_context.h" +#include "graph/manager/graph_var_manager.h" +#include "graph/utils/tensor_adapter.h" +#include "graph/load/graph_loader.h" +#include "graph/load/model_manager/model_manager.h" +#include "common/math/math_util.h" +#include "common/formats/utils/formats_trans_utils.h" + +namespace { +constexpr int32_t kBase = 10; +constexpr uint8_t kNeverLoaded = 0; +} + +namespace ge { +/// +/// @ingroup ge +/// @brief graph executor init +/// @param [in] options user config params +/// @return Status result of function +/// +Status ModelExecutor::Initialize(const map &options, uint64_t session_id) { + graph_run_listener_ = MakeShared(sync_run_mutex_, condition_); + if (graph_run_listener_ == nullptr) { + REPORT_CALL_ERROR("E19999", "New GraphModelListener fail"); + GELOGE(MEMALLOC_FAILED, "[New][GraphModelListener] failed"); + return MEMALLOC_FAILED; + } + + const auto model_manager = ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + Status status = model_manager->EnableExceptionDump(options); + if (status != SUCCESS) { + return status; + } + + session_id_ = session_id; + train_graph_flag_ = ParseTrainGraphFlag(); + thread_run_flag_.store(true); + run_thread_ = std::thread(&ModelExecutor::RunThread, this); + + init_flag_ = true; + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief graph executor finalize +/// @return Status result of function +/// +Status ModelExecutor::Finalize() { + if (!init_flag_) { + GELOGW("ModelExecutor has not been initialized."); + return SUCCESS; + } + + StopQueue(); + if (run_thread_.joinable()) { + run_thread_.join(); + } + + if (graph_executor_.FreeExecuteMemory() != SUCCESS) { + GELOGW("Graph executor FreeExecuteMemory failed, resources may not be released correctly."); + } + + ModelManager::GetInstance()->DestroyAicpuSession(session_id_); + return SUCCESS; +} + +// OPTION_GRAPH_RUN_MODE is supposed to be a session-level option, but it used to be set to global-level in the past. +// If can not parse from session, it can parse from global by GetContext(). +bool ModelExecutor::ParseTrainGraphFlag() { + string run_mode; + if (GetContext().GetOption(OPTION_GRAPH_RUN_MODE, run_mode) == SUCCESS && !run_mode.empty()) { + if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, kBase)) >= TRAIN) { + GELOGI("Graph train flag set."); + return true; + } + } + return false; +} + +void ModelExecutor::AddGraphNode(GraphId graph_id, const GraphNodePtr &graph_node) { + std::lock_guard lock(mutex_); + graph_nodes_.emplace(graph_id, graph_node); +} + +void ModelExecutor::RemoveGraphNode(GraphId graph_id) { + std::lock_guard lock(mutex_); + graph_nodes_.erase(graph_id); +} + +/// +/// @ingroup ge +/// @brief Load mode for graph. +/// @param [in] GeRootModel: root model of graph compiled. +/// @param [in] GraphNode: node of graph. +/// @return Status result of function +/// +Status ModelExecutor::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { + GE_CHECK_NOTNULL(graph_node); + if (ge_root_model == nullptr) { + return SUCCESS; + } + + UpdateLocalOmeContext(graph_node); + return graph_node->IsAsync() ? ModelLoadAsync(ge_root_model, graph_node) : ModelLoadSync(ge_root_model, graph_node); +} + +/// +/// @ingroup ge +/// @brief Unload mode for graph. +/// @param [in] GeRootModel: root model of graph compiled. +/// @param [in] graph_id: graph identifier. +/// @return Status result of function +/// +Status ModelExecutor::UnloadGraph(const GeRootModelPtr &ge_root_model, uint32_t graph_id) { + GE_CHECK_NOTNULL(ge_root_model); + rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + GELOGW("[GraphExecutor] rtSetDevice failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), graph_id); + return FAILED; + } + + RemoveGraphNode(graph_id); + Status ret = UnloadModel(ge_root_model, graph_id); + if (ret != SUCCESS) { + GELOGW("[GraphExecutor] unload model failed, graph_id=%u.", graph_id); + } + rt_ret = rtDeviceReset(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + GELOGW("[GraphExecutor] rtDeviceReset failed, graphId=%u.", graph_id); + } + + return ret; +} + +Status ModelExecutor::UnloadModel(const GeRootModelPtr &ge_root_model, uint32_t graph_id) { + GE_CHECK_NOTNULL(ge_root_model); + for (size_t i = 0; i < ge_root_model->GetAllModelId().size(); ++i) { + uint32_t model_id = ge_root_model->GetAllModelId()[i]; + GELOGI("Unload model %u.", model_id); + Status ret = GraphLoader::UnloadModel(model_id); + if (ret != SUCCESS) { + GELOGE(ret, "[GraphExecutor] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); + return ret; + } + } + return SUCCESS; +} + +void ModelExecutor::StopQueue() { + thread_run_flag_.store(false); + run_args_q_.Stop(); +} + +void ModelExecutor::ReturnError(RunAsyncCallback callback, Status ret, const string &log) { + StopQueue(); + GELOGE(ret, "%s.", log.c_str()); + std::vector outputs; + if (callback != nullptr) { + callback(ret, outputs); + } +} + +void ModelExecutor::UpdateLocalOmeContext(const GraphNodePtr &graph_node) { + std::lock_guard lock(mutex_); + SetLocalOmeContext(graph_node->GetOmeContext()); +} + +/// +/// @ingroup ge +/// @brief Push model execution params to queue. +/// @param [in] RunArgs of for model execution. +/// @return Status result of function +/// +Status ModelExecutor::PushGraph(const RunArgs &args) { + return run_args_q_.Push(args) ? SUCCESS : FAILED; +} + +void ModelExecutor::RunThread() { + ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); + if (mmSetCurrentThreadName("GE_Run") != EN_OK) { + GELOGW("Set thread name failed."); + } + + RunArgs args; + while (thread_run_flag_) { + if (!run_args_q_.Pop(args)) { + continue; + } + + GELOGI("[RunThread] A new loop start, graph_id:%u.", args.graph_id); + ErrorManager::GetInstance().SetErrorContext(args.error_context); + GetContext().SetSessionId(args.session_id); + GetThreadLocalContext() = args.context; + UpdateLocalOmeContext(args.graph_node); + + // parse inputs.dims to vector> dynamic_dims + Status ret = ParseInputsDims(args.input_tensor); + if (ret != SUCCESS) { + ReturnError(args.callback, ret, "ParseInputsDims failed, thread exit."); + args.graph_node->Unlock(); + return; + } + + args.graph_node->UpdateLoadFlag(); + if (!args.graph_node->GetLoadFlag()) { + ErrorManager::GetInstance().SetStage(error_message::kModelLoad, error_message::kModelLoad); + args.ge_root_model->SetTrainFlag(train_graph_flag_); + ret = ModelLoadAsync(args.ge_root_model, args.graph_node); + if (ret != SUCCESS || args.ge_root_model == nullptr) { + StopQueue(); + ReturnError(args.callback, ret, "LoadGraphAsync failed, thread exit."); + args.graph_node->Unlock(); + return; + } + // control the times of graph loading in multi-thread scenario + args.graph_node->DecreaseLoadCount(); + args.graph_node->IncreaseLoadRecord(); + + args.graph_node->SetLoadFlag(true); + GELOGI("LoadGraph[%u], model[%u] success and set LoadFlag to true.", args.graph_node->GetGraphId(), + args.ge_root_model->GetModelId()); + } + + ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); + if (train_graph_flag_) { + graph_executor_.SetTrainFlag(train_graph_flag_); + } + + ret = graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeRootModel(), + args.input_tensor, args.callback); + args.graph_node->SetRunFlag(false); + if (ret != SUCCESS) { + ReturnError(args.callback, ret, "ExecuteGraphAsync failed, thread exit."); + args.graph_node->Unlock(); + return; + } + args.graph_node->Unlock(); + GELOGI("[GraphExecutor] Run graph async success, graph_id=%u.", args.graph_id); + } +} + +/// +/// @ingroup ge +/// @brief Run graph for synchronize model. +/// @param [in] graph_node: node of graph. +/// @param [in] graph_id: graph identifier. +/// @param [in] inputs: input data for the graph running. +/// @param [out] outputs: output data of the graph running +/// @return Status result of function +/// +Status ModelExecutor::RunGraph(const GraphNodePtr &graph_node, GraphId graph_id, + const std::vector &inputs, std::vector &outputs) { + Status ret = graph_executor_.SetCondition(&sync_run_mutex_, &condition_, graph_run_listener_); + if (ret != SUCCESS) { + GELOGE(GE_GRAPH_RUNGRAPH_FAILED, "[Set][Condition] failed, graph_id = %u.", graph_id); + graph_node->SetRunFlag(false); + return GE_GRAPH_RUNGRAPH_FAILED; + } + + if (train_graph_flag_) { + graph_executor_.SetTrainFlag(train_graph_flag_); + } + ret = graph_executor_.ExecuteGraph(graph_id, graph_node->GetGeRootModel(), inputs, outputs); + + graph_node->SetRunFlag(false); + if (ret != SUCCESS) { + GELOGE(ret, "[Execute][Graph] failed, graph_id = %u.", graph_id); + return ret; + } + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Run graph for NN synchronize model. +/// @param [in] graph_node: node of graph. +/// @param [in] graph_id: graph identifier. +/// @param [in] stream: Stream for model running. +/// @param [in] inputs: input data for the graph running. +/// @param [out] outputs: output data of the graph running +/// @return Status result of function +/// +Status ModelExecutor::RunGraphWithStream(const GraphNodePtr &graph_node, GraphId graph_id, rtStream_t stream, + const std::vector &inputs, std::vector &outputs) { + auto ret = graph_executor_.SetCondition(&sync_run_mutex_, &condition_, graph_run_listener_); + if (ret != SUCCESS) { + GELOGE(GE_GRAPH_RUNGRAPH_FAILED, "[Set][Condition] failed, graph id = %u, stream = %p.", graph_id, stream); + graph_node->SetRunFlag(false); + return GE_GRAPH_RUNGRAPH_FAILED; + } + + ret = graph_executor_.ExecuteGraphWithStream(graph_id, stream, graph_node->GetGeRootModel(), inputs, outputs); + graph_node->SetRunFlag(false); + graph_node->SetIsSpecificStream(false); + if (ret != SUCCESS) { + GELOGE(ret, "[Execute][Graph] With Stream failed, graph id = %u, stream = %p.", graph_id, stream); + return ret; + } + GELOGI("[Run][GraphWithStreamAsync] run graph success, graph id = %u, stream = %p.", graph_id, stream); + return SUCCESS; +} + +Status ModelExecutor::ModelLoadSync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { + ge_root_model->SetIsSpecificStream(graph_node->IsSpecificStream()); + return ModelLoad(ge_root_model, graph_node, graph_run_listener_); +} + +Status ModelExecutor::ModelLoadAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { + auto listener = MakeShared(); + GE_CHECK_NOTNULL(listener); + return ModelLoad(ge_root_model, graph_node, listener); +} + +Status ModelExecutor::ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node, + const std::shared_ptr &listener) { + ge_root_model->SetTrainFlag(train_graph_flag_); + bool is_unknown_shape = false; + GE_CHK_STATUS_RET(ge_root_model->CheckIsUnknownShape(is_unknown_shape)); + if (!is_unknown_shape) { + if (getenv(kEnvGeuseStaticMemory) != nullptr) { + GELOGI("[LoadGraph] GE_USE_STATIC_MEMORY is seted."); + } else { + auto root_graph = ge_root_model->GetRootGraph(); + GE_CHECK_NOTNULL(root_graph); + auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); + GeModelPtr ge_model = name_to_model[root_graph->GetName()]; + GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); + } + } + GE_TIMESTAMP_START(LoadModelOnline); + uint32_t model_id = INVALID_MODEL_ID; + Status ret = GraphLoader::LoadModelOnline(model_id, ge_root_model, listener); + GE_TIMESTAMP_EVENT_END(LoadModelOnline, "GraphLoader::LoadModelOnline"); + if (ret != SUCCESS) { + GELOGE(ret, "[Load][ModelOnline] Failed, model_id:%u", model_id); + graph_node->SetRunFlag(false); + return ret; + } + graph_node->SetLoadFlag(true); + ge_root_model->SetModelId(model_id); + graph_node->SetGeRootModel(ge_root_model); + AddGraphNode(graph_node->GetGraphId(), graph_node); + return SUCCESS; +} + +void ModelExecutor::ReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node, + const std::vector &model_ids, uint32_t graph_id, uint64_t session_id) { + rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u", GetContext().DeviceId()); + GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id=%u.", GetContext().DeviceId()); + return; + } + for (auto model_id : model_ids) { + uint64_t max_memory_size = 0; + Status result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size); + if (result != SUCCESS) { + continue; + } + GELOGI("try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id, max_memory_size); + if (model_ids.size() > 1) { + result = ge_model->GetSessionId(model_id, session_id); + if (result != SUCCESS) { + GELOGW("[GraphExecutor:] get session failed when dynamic memory, modelId=%u, graphId=%u.", model_id, + graph_id); + continue; + } + } + result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0); + if (result != SUCCESS) { + GELOGW("[GraphExecutor:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, + graph_id); + } + result = GraphLoader::UnloadModel(model_id); + if (result != SUCCESS) { + GELOGW("[GraphExecutor:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); + } + GELOGI("UnloadGraph[%u], model[%u] success.", graph_id, model_id); + } + graph_node->SetLoadFlag(false); + // Allow model to be loaded agagin without adding graph again + graph_node->SetLoadCount(graph_node->GetLoadRecord()); + graph_node->SetLoadRecord(kNeverLoaded); + GeRootModelPtr ge_root_model = graph_node->GetGeRootModel(); + if (ge_root_model == nullptr) { + GELOGW("ge_root_model is null, graph_id:%u", graph_id); + return; + } + ge_root_model->ClearAllModelId(); + rt_ret = rtDeviceReset(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u", GetContext().DeviceId()); + GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u.", GetContext().DeviceId()); + return; + } +} + +Status ModelExecutor::CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) { + GELOGI("graph_id[%u]", graph_node->GetGraphId()); + int64_t free_memory = 0; + Status result = GraphLoader::GetMemoryInfo(free_memory); + if (result != SUCCESS) { + return result; + } + + int64_t value = 0; + int64_t memory_size = AttrUtils::GetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, value) ? value : 0; + int64_t weight_size = AttrUtils::GetInt(ge_model, ATTR_MODEL_WEIGHT_SIZE, value) ? value : 0; + int64_t session_id = AttrUtils::GetInt(ge_model, MODEL_ATTR_SESSION_ID, value) ? value : 0; + + GELOGI("Graph[%u] need memory_size[%ld], weight_size[%ld], Device[%u] free_memory_size[%ld]", + graph_node->GetGraphId(), memory_size, weight_size, GetContext().DeviceId(), free_memory); + if (CheckInt64AddOverflow(memory_size, weight_size) != SUCCESS) { + REPORT_INNER_ERROR("E19999", "memory_size:%ld and weight_size:%ld will overflow after add, check invalid", + memory_size, weight_size); + GELOGE(INTERNAL_ERROR, "[Check][Param] memory_size:%ld and weight_size:%ld will overflow after add", + memory_size, weight_size); + return INTERNAL_ERROR; + } + if (free_memory >= (memory_size + weight_size)) { + return SUCCESS; + } + + std::lock_guard lock(mutex_); + for (const auto &it : graph_nodes_) { + auto graph_id = it.second->GetGraphId(); + auto model = it.second->GetGeRootModel(); + if (model == nullptr) { + continue; + } + auto model_id = model->GetModelId(); + auto model_ids = model->GetAllModelId(); + // unload model not release + bool is_unknown_shape = false; + GE_CHK_STATUS_RET(model->CheckIsUnknownShape(is_unknown_shape)); + if (is_unknown_shape) { + GELOGD("model_id[%u] graph_id[%u] is unknown model, not release memory", model_id, graph_id); + continue; + } + // not loaded,no need unload + if (!it.second->GetLoadFlag()) { + GELOGI("CheckAndReleaseMemory graph[%u] has not been loaded.", graph_id); + continue; + } + ReleaseMemory(ge_model, it.second, model_ids, graph_id, static_cast(session_id)); + } + + return SUCCESS; +} + +void ModelExecutor::ParseInputsDimsForData(const std::vector &input_tensor) { + GELOGD("Start parse input dims from data."); + for (size_t i = 0; i < input_tensor.size(); ++i) { + const TensorDesc &tensor_desc = input_tensor[i].GetTensorDesc(); + const Shape &shape = tensor_desc.GetShape(); + const auto &shape_dims = shape.GetDims(); + GELOGD("Input tensor dims is %s.", formats::JoinToString(shape_dims).c_str()); + GetLocalOmeContext().user_real_input_dims.emplace_back(shape_dims); + } +} + +Status ModelExecutor::ParseInputsDimsForGetNextNoSinkAndData(const vector &dynamic_nodes, + const std::vector &input_tensor) { + GELOGD("Start parse inputs dims when coexist data and getnext sink."); + for (size_t i = 0; i < dynamic_nodes.size(); ++i) { + auto op_desc = dynamic_nodes.at(i)->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + GeAttrValue::INT index = 0; + if (!(AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, index))) { + REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) fail", ATTR_NAME_INDEX.c_str(), + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + GELOGE(PARAM_INVALID, "[Get][Attr] %s from op:%s(%s) fail", ATTR_NAME_INDEX.c_str(), + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + return PARAM_INVALID; + } + if (static_cast(index) > input_tensor.size()) { + REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s) value:%ld > param input_tensor.size:%zu, " + "check invalid", ATTR_NAME_INDEX.c_str(), + op_desc->GetName().c_str(), op_desc->GetType().c_str(), + index, input_tensor.size()); + GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in op:%s(%s) value:%ld > param input_tensor.size:%zu", + ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), + index, input_tensor.size()); + return PARAM_INVALID; + } + + const TensorDesc &tensor_desc = input_tensor[i].GetTensorDesc(); + const Shape &shape = tensor_desc.GetShape(); + const auto &shape_dims = shape.GetDims(); + GELOGI("Shape dims of %zu data is %s.", index, formats::JoinToString(shape_dims).c_str()); + GetLocalOmeContext().user_real_input_dims.emplace_back(std::move(shape_dims)); + } + return SUCCESS; +} + +Status ModelExecutor::ParseInputsDims(const std::vector &input_tensor) { + GELOGI("Start parse input dims of %zu input tensor.", input_tensor.size()); + GetLocalOmeContext().user_real_input_dims.clear(); + if (GetLocalOmeContext().dynamic_node_type.empty()) { + return SUCCESS; + } + + const vector &data_nodes = GetLocalOmeContext().data_nodes; + const vector &getnext_nosink_nodes = GetLocalOmeContext().getnext_nosink_nodes; + GELOGD("Data nodes count is %zu, getnext nosink nodes count is %zu.", data_nodes.size(), + getnext_nosink_nodes.size()); + if (GetLocalOmeContext().dynamic_node_type == DATA) { + if (getnext_nosink_nodes.empty()) { + // just data or data+getnext_sink + ParseInputsDimsForData(input_tensor); + } else { + // data+getnext_nosink, but only need to get shape_dims of data + if (ParseInputsDimsForGetNextNoSinkAndData(data_nodes, input_tensor) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Parse][Dims] from data failed, when data coexist with getnext nosink."); + return PARAM_INVALID; + } + } + } else { + if (getnext_nosink_nodes.empty()) { + // just getnext_sink or getnext_sink+data, need to get shape_dims from aicpu op + GELOGI("Need to get dims from aicpu op: GETDYNAMICDIMS."); + return SUCCESS; + } else { + if (data_nodes.empty()) { + // just getnext_nosink + ParseInputsDimsForData(input_tensor); + } else { + // getnext_nosink + data, but only need to get shape_dims of getnext_nosink + if (ParseInputsDimsForGetNextNoSinkAndData(getnext_nosink_nodes, input_tensor) != SUCCESS) { + GELOGE(PARAM_INVALID, "[Parse][Dims] from getnext nosink failed, when data coexist with getnext nosink"); + return PARAM_INVALID; + } + } + } + } + + GELOGI("Parse %zu inputs dims success.", GetLocalOmeContext().user_real_input_dims.size()); + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/execute/model_executor.h b/ge/graph/execute/model_executor.h new file mode 100644 index 00000000..f11441e9 --- /dev/null +++ b/ge/graph/execute/model_executor.h @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GE_GRAPH_EXECUTE_MODEL_EXECUTOR_H +#define GE_GRAPH_EXECUTE_MODEL_EXECUTOR_H + +#include + +#include "common/executor.h" +#include "graph/execute/graph_execute.h" + +namespace ge { +class ModelExecutor : public Executor { + public: + /// + /// @ingroup ge + /// @brief graph executor init + /// @param [in] options user config params + /// @return Status result of function + /// + Status Initialize(const map &options, uint64_t session_id); + + /// + /// @ingroup ge + /// @brief graph executor finalize + /// @return Status result of function + /// + Status Finalize(); + + /// + /// @ingroup ge + /// @brief Load mode for graph. + /// @param [in] GeRootModel: root model of graph compiled. + /// @param [in] GraphNode: node of graph. + /// @return Status result of function + /// + Status LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); + + /// + /// @ingroup ge + /// @brief Unload mode for graph. + /// @param [in] GeRootModel: root model of graph compiled. + /// @param [in] graph_id: graph identifier. + /// @return Status result of function + /// + Status UnloadGraph(const GeRootModelPtr &ge_root_model, uint32_t graph_id); + + /// + /// @ingroup ge + /// @brief Push model execution params to queue. + /// @param [in] RunArgs of for model execution. + /// @return Status result of function + /// + Status PushGraph(const RunArgs &args); + + /// + /// @ingroup ge + /// @brief Run graph for synchronize model. + /// @param [in] graph_node: node of graph. + /// @param [in] graph_id: graph identifier. + /// @param [in] inputs: input data for the graph running. + /// @param [out] outputs: output data of the graph running + /// @return Status result of function + /// + Status RunGraph(const GraphNodePtr &graph_node, GraphId graph_id, + const std::vector &inputs, std::vector &outputs); + + /// + /// @ingroup ge + /// @brief Run graph for NN synchronize model. + /// @param [in] graph_node: node of graph. + /// @param [in] graph_id: graph identifier. + /// @param [in] stream: Stream for model running. + /// @param [in] inputs: input data for the graph running. + /// @param [out] outputs: output data of the graph running + /// @return Status result of function + /// + Status RunGraphWithStream(const GraphNodePtr &graph_node, GraphId graph_id, rtStream_t stream, + const std::vector &inputs, std::vector &outputs); + + private: + bool ParseTrainGraphFlag(); + + void AddGraphNode(GraphId graph_id, const GraphNodePtr &graph_node); + void RemoveGraphNode(GraphId graph_id); + + Status ModelLoadSync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); + Status ModelLoadAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); + Status ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node, + const std::shared_ptr &listener); + + Status UnloadModel(const GeRootModelPtr &ge_root_model, uint32_t graph_id); + + void ReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node, const std::vector &model_ids, + uint32_t graph_id, uint64_t session_id); + Status CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); + + void UpdateLocalOmeContext(const GraphNodePtr &graph_node); + + void RunThread(); + void StopQueue(); + void ReturnError(RunAsyncCallback callback, Status ret, const string &log); + + void ParseInputsDimsForData(const std::vector &input_tensor); + Status ParseInputsDimsForGetNextNoSinkAndData(const vector &dynamic_nodes, + const std::vector &input_tensor); + Status ParseInputsDims(const std::vector &input_tensor); + + bool init_flag_{false}; + bool train_graph_flag_{false}; + uint64_t session_id_{0}; + GraphExecutor graph_executor_; + + std::mutex mutex_; + std::map graph_nodes_; + + std::thread run_thread_; + std::atomic_bool thread_run_flag_{false}; + BlockingQueue run_args_q_; + + // for run graph synchronous return + std::mutex sync_run_mutex_; + std::condition_variable condition_; + // run graph synchronization call back listener + std::shared_ptr graph_run_listener_; +}; +} +#endif // GE_GRAPH_EXECUTE_MODEL_EXECUTOR_H \ No newline at end of file diff --git a/ge/graph/label/case_label_maker.cc b/ge/graph/label/case_label_maker.cc index 3fdb1783..88d698d1 100644 --- a/ge/graph/label/case_label_maker.cc +++ b/ge/graph/label/case_label_maker.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "case_label_maker.h" +#include "graph/label/case_label_maker.h" -#include "common/util.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/util.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "framework/common/op/ge_op_utils.h" #include "graph/debug/ge_attr_define.h" diff --git a/ge/graph/label/if_label_maker.cc b/ge/graph/label/if_label_maker.cc index 72b33015..df911e70 100644 --- a/ge/graph/label/if_label_maker.cc +++ b/ge/graph/label/if_label_maker.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "if_label_maker.h" +#include "graph/label/if_label_maker.h" -#include "common/util.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/util.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "framework/common/op/ge_op_utils.h" #include "graph/debug/ge_attr_define.h" diff --git a/ge/graph/label/label_maker.cc b/ge/graph/label/label_maker.cc index 638cbbae..47eeda86 100644 --- a/ge/graph/label/label_maker.cc +++ b/ge/graph/label/label_maker.cc @@ -16,8 +16,8 @@ #include "graph/label/label_maker.h" -#include "common/util.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/util.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "framework/common/op/ge_op_utils.h" #include "graph/debug/ge_attr_define.h" diff --git a/ge/graph/label/partitioned_call_label_maker.cc b/ge/graph/label/partitioned_call_label_maker.cc index 7b4bcbd8..ec8b8c89 100644 --- a/ge/graph/label/partitioned_call_label_maker.cc +++ b/ge/graph/label/partitioned_call_label_maker.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "partitioned_call_label_maker.h" +#include "graph/label/partitioned_call_label_maker.h" -#include "common/util.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/util.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/label/while_label_maker.cc b/ge/graph/label/while_label_maker.cc index cd6b3743..7e6b8a98 100644 --- a/ge/graph/label/while_label_maker.cc +++ b/ge/graph/label/while_label_maker.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "while_label_maker.h" +#include "graph/label/while_label_maker.h" -#include "common/util.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/util.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "framework/common/op/ge_op_utils.h" #include "graph/debug/ge_attr_define.h" diff --git a/ge/graph/label/while_label_maker.h b/ge/graph/label/while_label_maker.h index 6c30475b..1561b860 100644 --- a/ge/graph/label/while_label_maker.h +++ b/ge/graph/label/while_label_maker.h @@ -19,57 +19,59 @@ #include "graph/node.h" #include "graph/label/label_maker.h" -/******************************************************************************* - +------------+ - | Node | - +------------+ - | Node | - +------------+ - | While | - +------------+ - +-----------+ - | Node | +------------+ - +-----------+ | LabelSet |\ - | Node | +------------+ \ - +-----------+ |StreamActive| \ - | Node | +------------+ A - +-----------+ | c | | - | While | +------------+ | - +-----------+ | o | | - | Node | +------------+ | - +-----------+ | n | | - | Node | +------------+ | - +-----------+ | d | | - | Node | +------------+ | - +-----------+ /|SwitchByIdx | | - / +------------+ | - ====> / | - | \ +------------+ | - | \|LabelSet(1) | | - | +------------+ | - | |StreamActive| | - | +------------+ | - +-----------+ +-----------+ | | b | | - | c | | b | | +------------+ | - +-----------+ +-----------+ | | o | | - | o | | o | | +------------+ | - +-----------+ +-----------+ | | d | | - | n | | d | | +------------+ | - +-----------+ +-----------+ | | y | / - | d | | y | V +------------+ / - +-----------+ +-----------+ \ | LabelGoto |/ - \ +------------+ - \|LabelSet(0) | - +------------+ - - +------------+ - | Node | - +------------+ - | Node | - +------------+ - | Node | - +------------+ -*******************************************************************************/ +/*********************************************************************************************************************** + +------------+ Step0: DavinciModel::InitNodes + | Node | + +------------+ rtLabelCreateExV2 + | Node | + +------------+ + | Node | + +------------+ + | While | + +------------+ + +-----------+ Step1: TaskInfo::Init + | Node | +------------+ + +-----------+ | LabelSet(0)|\ LabelSetTaskInfo --> id=0 + | Node | +------------+ \ + +-----------+ |StreamActive| \ If active_stream_list empty, not task. + | Node | +------------+ A + +-----------+ | c | | + | While | +------------+ | + +-----------+ | o | | + | Node | +------------+ | + +-----------+ | n | | + | Node | +------------+ | + +-----------+ | d | | + | Node | +------------+ | + +-----------+ /|SwitchByIdx | | LabelSwitchByIndexTaskInfo --> rtLabelListCpy({1,2}) + / +------------+ | + ====> / | + | \ +------------+ | + | \| LabelSet(1)| | LabelSetTaskInfo --> id=1 + | +------------+ | + | |StreamActive| | If active_stream_list empty, not task. + | +------------+ | + +-----------+ +-----------+ | | b | | + | c | | b | | +------------+ | + +-----------+ +-----------+ | | o | | + | o | | o | | +------------+ | + +-----------+ +-----------+ | | d | | + | n | | d | | +------------+ | + +-----------+ +-----------+ | | y | / + | d | | y | V +------------+ / + +-----------+ +-----------+ \ | LabelGoto |/ LabelGotoExTaskInfo --> GetLabelGotoAddr(id=0) + \ +------------+ + \| LabelSet(2)| LabelSetTaskInfo --> id=2 + +------------+ + Step2: TaskInfo::Distribute + +------------+ + | Node | LabelSetTaskInfo --> rtLabelSet + +------------+ LabelSwitchByIndexTaskInfo --> rtLabelSwitchByIndex + | Node | LabelSetTaskInfo --> rtLabelSet + +------------+ LabelGotoExTaskInfo --> rtLabelSwitchByIndex + | Node | LabelSetTaskInfo --> rtLabelSet + +------------+ +***********************************************************************************************************************/ namespace ge { class WhileOpLabelMaker : public LabelMaker { public: diff --git a/ge/graph/load/graph_loader.cc b/ge/graph/load/graph_loader.cc index 94b90d69..b2a61106 100755 --- a/ge/graph/load/graph_loader.cc +++ b/ge/graph/load/graph_loader.cc @@ -19,7 +19,7 @@ #include #include -#include "common/helper/model_helper.h" +#include "framework/common/helper/model_helper.h" #include "common/model_parser/model_parser.h" #include "graph/ge_context.h" #include "graph/load/model_manager/model_manager.h" diff --git a/ge/graph/load/graph_loader.h b/ge/graph/load/graph_loader.h index e11af749..f6324c98 100755 --- a/ge/graph/load/graph_loader.h +++ b/ge/graph/load/graph_loader.h @@ -21,9 +21,9 @@ #include #include -#include "common/debug/log.h" -#include "common/fmk_types.h" -#include "common/ge_types.h" +#include "framework/common/debug/log.h" +#include "framework/common/fmk_types.h" +#include "framework/common/ge_types.h" #include "graph/compute_graph.h" #include "graph/manager/graph_manager_utils.h" #include "graph/model.h" diff --git a/ge/graph/load/model_manager/aipp_utils.cc b/ge/graph/load/model_manager/aipp_utils.cc index 8a18c421..a9f885f8 100755 --- a/ge/graph/load/model_manager/aipp_utils.cc +++ b/ge/graph/load/model_manager/aipp_utils.cc @@ -18,8 +18,8 @@ #include -#include "common/debug/log.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/debug/log.h" +#include "framework/common/op/ge_op_utils.h" #include "framework/common/util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/attr_utils.h" diff --git a/ge/graph/load/model_manager/aipp_utils.h b/ge/graph/load/model_manager/aipp_utils.h index 78107f3e..237eeced 100755 --- a/ge/graph/load/model_manager/aipp_utils.h +++ b/ge/graph/load/model_manager/aipp_utils.h @@ -19,8 +19,8 @@ #include -#include "common/ge_inner_error_codes.h" -#include "common/ge_types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/ge_types.h" #include "graph/op_desc.h" #include "proto/insert_op.pb.h" diff --git a/ge/graph/load/model_manager/cpu_queue_schedule.cc b/ge/graph/load/model_manager/cpu_queue_schedule.cc index 9821aa73..0ec80b34 100644 --- a/ge/graph/load/model_manager/cpu_queue_schedule.cc +++ b/ge/graph/load/model_manager/cpu_queue_schedule.cc @@ -15,8 +15,8 @@ */ #include "graph/load/model_manager/cpu_queue_schedule.h" -#include "common/debug/ge_log.h" -#include "common/debug/log.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" namespace { const uint32_t kCoreDim = 1; // for rtCpuKernelLaunch diff --git a/ge/graph/load/model_manager/cpu_queue_schedule.h b/ge/graph/load/model_manager/cpu_queue_schedule.h index 8dc44538..d3c8915e 100644 --- a/ge/graph/load/model_manager/cpu_queue_schedule.h +++ b/ge/graph/load/model_manager/cpu_queue_schedule.h @@ -19,7 +19,7 @@ #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/load/model_manager/task_info/task_info.h" #include "graph/load/model_manager/zero_copy_offset.h" #include "runtime/kernel.h" diff --git a/ge/graph/load/model_manager/data_dumper.cc b/ge/graph/load/model_manager/data_dumper.cc index c96b3885..7b5d9df9 100644 --- a/ge/graph/load/model_manager/data_dumper.cc +++ b/ge/graph/load/model_manager/data_dumper.cc @@ -24,7 +24,7 @@ #include "common/debug/memory_dumper.h" #include "common/properties_manager.h" -#include "common/util.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" #include "graph/anchor.h" diff --git a/ge/graph/load/model_manager/data_dumper.h b/ge/graph/load/model_manager/data_dumper.h index d1714950..2851e63a 100755 --- a/ge/graph/load/model_manager/data_dumper.h +++ b/ge/graph/load/model_manager/data_dumper.h @@ -29,7 +29,7 @@ #include "proto/ge_ir.pb.h" #include "proto/op_mapping.pb.h" #include "runtime/mem.h" -#include "task_info/task_info.h" +#include "graph/load/model_manager/task_info/task_info.h" #include "framework/common/ge_types.h" #include "runtime/base.h" diff --git a/ge/graph/load/model_manager/data_inputer.cc b/ge/graph/load/model_manager/data_inputer.cc index d286b9b4..d68e95aa 100755 --- a/ge/graph/load/model_manager/data_inputer.cc +++ b/ge/graph/load/model_manager/data_inputer.cc @@ -18,9 +18,9 @@ #include -#include "common/debug/log.h" -#include "common/scope_guard.h" -#include "common/types.h" +#include "framework/common/debug/log.h" +#include "framework/common/scope_guard.h" +#include "framework/common/types.h" namespace ge { domi::Status InputDataWrapper::Init(const InputData &input, const OutputData &output) { diff --git a/ge/graph/load/model_manager/data_inputer.h b/ge/graph/load/model_manager/data_inputer.h index b8d145d4..28b6fb26 100755 --- a/ge/graph/load/model_manager/data_inputer.h +++ b/ge/graph/load/model_manager/data_inputer.h @@ -22,8 +22,8 @@ #include #include "common/blocking_queue.h" -#include "common/ge_types.h" -#include "common/types.h" +#include "framework/common/ge_types.h" +#include "framework/common/types.h" namespace ge { /// diff --git a/ge/graph/load/model_manager/davinci_model.cc b/ge/graph/load/model_manager/davinci_model.cc index 5b67c205..495ec28e 100755 --- a/ge/graph/load/model_manager/davinci_model.cc +++ b/ge/graph/load/model_manager/davinci_model.cc @@ -21,22 +21,22 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/formats/formats.h" #include "common/formats/utils/formats_trans_utils.h" #include "common/math/math_util.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "common/profiling/profiling_manager.h" #include "common/properties_manager.h" -#include "common/scope_guard.h" +#include "framework/common/scope_guard.h" #include "common/thread_pool.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" -#include "graph/graph.h" +#include "external/graph/graph.h" #include "graph/load/model_manager/cpu_queue_schedule.h" #include "graph/load/model_manager/model_manager.h" #include "graph/load/model_manager/tbe_handle_store.h" @@ -57,11 +57,12 @@ #include "runtime/rt_model.h" #include "runtime/stream.h" #include "securec.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "common/formats/utils/formats_trans_utils.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/build/memory/block_mem_assigner.h" #include "graph/manager/session_scope_mem_allocator.h" +#include "framework/omg/omg_inner_types.h" // create std::thread, catch exceptions using try/catch #define CREATE_STD_THREAD(thread_id, func, args) \ @@ -99,6 +100,9 @@ const uint32_t kEndOfSequenceNew = 507005; const int32_t kModelAbortNormal = 0x0704000e; const int32_t kModelAbortNormalNew = 507024; const uint32_t kInteval = 2; +const uint32_t kFftsTbeHandleElementSize = 2; +const uint32_t kNonTailBlock = 0; +const uint32_t kTailBlock = 1; const char *const kModelName = "model_name"; const char *const kModeleId = "model_id"; const char *const kLoadStartTime = "load_start_time"; @@ -116,14 +120,15 @@ const char *const kWorkSpaceSize = "workspace_size"; const char *const kTotalSize = "total_size"; const char *const kTaskCount = "task_count"; const char *const kTaskId = "task_id"; -const char* const kRequestId = "request_id"; -const char* const kThreadId = "thread_id"; -const char* const kInputBeginTime = "input_begin_time"; -const char* const kInputEndTime = "input_end_time"; -const char* const kInferBeginTime = "infer_begin_time"; -const char* const kInferEndTime = "infer_end_time"; -const char* const kOutputBeginTime = "output_start_time"; -const char* const kOutputEndTime = "output_end_time"; +const char *const kRequestId = "request_id"; +const char *const kThreadId = "thread_id"; +const char *const kInputBeginTime = "input_begin_time"; +const char *const kInputEndTime = "input_end_time"; +const char *const kInferBeginTime = "infer_begin_time"; +const char *const kInferEndTime = "infer_end_time"; +const char *const kOutputBeginTime = "output_start_time"; +const char *const kOutputEndTime = "output_end_time"; +const char *const kStubFuncName = "_register_stub_func"; const uint32_t kStringHeadElems = 2; const uint32_t kPlacementHostData = 0; const size_t kAlignment = 64; @@ -233,6 +238,12 @@ DavinciModel::~DavinciModel() { GE_LOGW_IF(rtEventDestroy(event_list_[i]) != RT_ERROR_NONE, "Destroy event failed, index: %zu", i); } + for (const auto &it : stream_2_event_) { + if (rtEventDestroy(it.second) != RT_ERROR_NONE) { + GELOGW("Destroy event failed"); + } + } + FreeWeightsMem(); FreeFeatureMapMem(); @@ -383,8 +394,8 @@ Status DavinciModel::InitWeightMem(void *dev_ptr, void *weight_ptr, size_t weigh Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { if (is_feature_map_mem_has_inited_) { - REPORT_INNER_ERROR("E19999", "Call InitFeatureMapMem more than once, model_id:%u, check invalid", model_id_); - GELOGE(PARAM_INVALID, "[Check][Param] call InitFeatureMapMem more than once, model_id:%u", model_id_); + REPORT_INNER_ERROR("E19999", "InitFeatureMapMem is called more than once, model_id:%u, check invalid", model_id_); + GELOGE(PARAM_INVALID, "[Check][Param] InitFeatureMapMem is called more than once, model_id:%u", model_id_); return PARAM_INVALID; } is_feature_map_mem_has_inited_ = true; @@ -452,8 +463,7 @@ Status DavinciModel::InitVariableMem() { void DavinciModel::InitRuntimeParams() { int64_t value = 0; - bool ret; - ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_MEMORY_SIZE, value); + bool ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_MEMORY_SIZE, value); runtime_param_.mem_size = ret ? (uint64_t)value : 0; ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_WEIGHT_SIZE, value); runtime_param_.weight_size = ret ? (uint64_t)value : 0; @@ -760,8 +770,16 @@ void DavinciModel::SaveSpecifyAttrValues(const OpDescPtr &op_desc) { } Status DavinciModel::ReportProfilingData() { - ProfilingManager::Instance().ReportProfilingData(model_id_, GetTaskDescInfo()); - GE_CHK_STATUS(SinkModelProfile(), "[Sink][ModelProfile] failed, model_id:%u.", model_id_); + bool is_train = domi::GetContext().train_flag; + auto model_id = model_id_; + auto &profiling_manager = ProfilingManager::Instance(); + auto graph_id = runtime_param_.graph_id; + if (is_train) { + GELOGD("Replace model_id:%u with graph_id:%u, when training.", model_id, graph_id); + model_id = graph_id; + } + profiling_manager.ReportProfilingData(model_id, GetTaskDescInfo()); + GE_CHK_STATUS(SinkModelProfile(), "[Sink][ModelProfile] failed, model_id:%u.", model_id); return SUCCESS; } @@ -902,10 +920,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { SetLabelForDynamic(node); auto it = op_desc_handle.find(op_desc->GetType()); if (it != op_desc_handle.end()) { - if ((this->*it->second)(op_desc) != SUCCESS) { - GELOGE(PARAM_INVALID, "[Init][Node] failed, Name:%s", op_desc->GetName().c_str()); - return PARAM_INVALID; - } + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((this->*it->second)(op_desc) != SUCCESS, return PARAM_INVALID, + "[Init][Node] failed, Name:%s", op_desc->GetName().c_str()); continue; } @@ -935,7 +951,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { GE_TIMESTAMP_RESTART(InitTbeHandle); if (IsTbeTask(op_desc)) { - Status status = InitTbeHandle(op_desc); + Status status = + op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID) ? InitTbeHandleWithFfts(op_desc) : InitTbeHandle(op_desc); if (status != SUCCESS) { GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str()); return status; @@ -980,7 +997,7 @@ Status DavinciModel::InitDataOp(const ComputeGraphPtr &graph, const NodePtr &nod // op_desc Checked by Init: Data, valid. auto op_desc = node->GetOpDesc(); if (node->GetOwnerComputeGraph() != graph) { - GELOGI("Skip subgraph Data node: %s.", op_desc->GetName().c_str()); + GELOGI("Skip Data node: %s in subgraph.", op_desc->GetName().c_str()); return SUCCESS; } @@ -1153,7 +1170,6 @@ Status DavinciModel::InitNetOutput(const ComputeGraphPtr &graph, const NodePtr & } size_t num = output_data_info_.size(); - bool fusion_flag = false; size_t input_count = input_size_list.size(); is_getnext_sink_dynamic_ = false; @@ -1163,6 +1179,7 @@ Status DavinciModel::InitNetOutput(const ComputeGraphPtr &graph, const NodePtr & } for (size_t idx = 0; idx < input_count; ++idx) { ZeroCopyOffset zero_copy_offset; + bool fusion_flag = false; Status ret = zero_copy_offset.InitOutputDataInfo(input_size_list, virtual_addr_list, op_desc, idx, fusion_flag); GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(PARAM_INVALID, "[Init][DataInfo] of input_info %s failed.", op_desc->GetName().c_str()); @@ -1192,7 +1209,7 @@ Status DavinciModel::InitRealSizeAndShapeInfo(const ComputeGraphPtr &compute_gra GELOGD("No need to get size and shape of netoutput in subgraph."); return SUCCESS; } - GELOGD("Start init real size and shape info of %s.", node->GetName().c_str()); + GELOGD("Start to initialize real size and shape info of %s.", node->GetName().c_str()); GetAllGearsInfo(node); if (is_getnext_sink_dynamic_) { GE_IF_BOOL_EXEC(GetGetDynamicDimsNodeInfo(node) != SUCCESS, @@ -1235,7 +1252,7 @@ void DavinciModel::GetAllGearsInfo(const NodePtr &node) { } if (!gear_info.empty()) { all_gears_info_.emplace_back(gear_info); - GELOGD("Init all gears info from %s, gaer info is %s", node->GetName().c_str(), + GELOGD("Init all gears info from %s, gear info is %s", node->GetName().c_str(), formats::JoinToString(gear_info).c_str()); } } @@ -1315,7 +1332,7 @@ Status DavinciModel::GetGearAndRealOutSizeInfo(const ComputeGraphPtr &graph, con Status DavinciModel::GetRealOutputSizeOfCase(const ComputeGraphPtr &graph, size_t input_index, const NodePtr &case_node) { - GELOGD("Start get output size of %s, which is %zu input to netoutput", case_node->GetName().c_str(), input_index); + GELOGD("Start to get output size of %s, which is %zu input to netoutput", case_node->GetName().c_str(), input_index); const auto &func_desc = case_node->GetOpDesc(); GE_CHECK_NOTNULL(func_desc); std::map, int64_t> gear_and_real_out_size_info; @@ -1477,6 +1494,11 @@ Status DavinciModel::GetLabelGotoAddr(uint32_t label_index, rtMemType_t mem_type return SUCCESS; } +void DavinciModel::SetGlobalStep(void *global_step, uint64_t global_step_size) { + global_step_addr_ = global_step; + global_step_size_ = global_step_size; +} + /// @ingroup ge /// @brief LabelSet Op Initialize. /// @param [in] op_desc: LabelSet Op descriptor. @@ -1539,14 +1561,16 @@ Status DavinciModel::InitLabelSet(const OpDescPtr &op_desc) { } Status DavinciModel::InitVariable(const OpDescPtr &op_desc, map &variable_by_name) { - if (op_desc->GetName() == NODE_NAME_GLOBAL_STEP) { - const auto output_sizes = ModelUtils::GetOutputSize(op_desc); - if (!output_sizes.empty()) { - global_step_size_ = output_sizes[0]; - } - const auto output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc); - if (!output_addrs.empty()) { - global_step_addr_ = output_addrs[0]; + if (!known_node_) { + if (op_desc->GetName() == NODE_NAME_GLOBAL_STEP) { + const auto output_sizes = ModelUtils::GetOutputSize(op_desc); + if (!output_sizes.empty()) { + global_step_size_ = output_sizes[0]; + } + const auto output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc); + if (!output_addrs.empty()) { + global_step_addr_ = output_addrs[0]; + } } } @@ -2217,10 +2241,10 @@ void DavinciModel::CreateOutput(uint32_t index, const OpDescPtr &op_desc, InputO dims[i] = shape.GetDim(i); } } else { // FOR FORMAT_NHWC or FORMAT_NCHW - dims[0] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_N : NCHW_DIM_N); // 0: first dim - dims[1] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_C : NCHW_DIM_C); // 1: second dim - dims[2] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_H : NCHW_DIM_H); // 2: third dim - dims[3] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_W : NCHW_DIM_W); // 3: forth dim + dims[0] = shape.GetDim((format == FORMAT_NHWC) ? NHWC_DIM_N : NCHW_DIM_N); // 0: first dim + dims[1] = shape.GetDim((format == FORMAT_NHWC) ? NHWC_DIM_C : NCHW_DIM_C); // 1: second dim + dims[2] = shape.GetDim((format == FORMAT_NHWC) ? NHWC_DIM_H : NCHW_DIM_H); // 2: third dim + dims[3] = shape.GetDim((format == FORMAT_NHWC) ? NHWC_DIM_W : NCHW_DIM_W); // 3: forth dim } output.shape_info.num = dims[0]; // 0: first dim output.shape_info.channel = dims[1]; // 1: second dim @@ -2731,7 +2755,7 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b } if (!has_output_node_) { - GELOGW("Output tensor list is empty, model id: %u", model_id_); + GELOGW("The tensor list of output is empty, model id: %u", model_id_); GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "[Call][OnComputeDone] failed, model_id:%u, data_id:%u.", model_id_, data_id); return INTERNAL_ERROR; @@ -3061,7 +3085,7 @@ Status DavinciModel::CreateKnownZeroCopyMap(const vector &inputs, const GELOGI("output %zu, v addr %p, r addr %p, p addr %p", i, addr_list[i], addr, outputs[i]); } - GELOGI("success, known input data info size: %zu, known output data info size: %zu", + GELOGI("create map for zero copy success, known input data info size: %zu, known output data info size: %zu", known_input_data_info_.size(), known_output_data_info_.size()); return SUCCESS; } @@ -3096,12 +3120,12 @@ Status DavinciModel::UpdateKnownZeroCopyAddr(vector &total_io_addrs, boo total_io_addrs[i] = known_output_data_info_.at(total_io_addrs[i]); } } - GELOGI("success, total io addrs size: %zu", total_io_addrs.size()); + GELOGI("update known zero copy addr success, total io addrs size: %zu", total_io_addrs.size()); return SUCCESS; } Status DavinciModel::UpdateKnownNodeArgs(const vector &inputs, const vector &outputs) { - GELOGI("DavinciModel::UpdateKnownNodeArgs in"); + GELOGI("DavinciModel::UpdateKnownNodeArgs begin"); GE_CHK_STATUS_RET(CreateKnownZeroCopyMap(inputs, outputs), "[Call][CreateKnownZeroCopyMap] failed, model_id:%u.", model_id_); total_io_addrs_.clear(); @@ -3463,11 +3487,11 @@ bool DavinciModel::CheckUserAndModelSize(const int64_t &size, const int64_t &op_ } // The input and model input size can not be exactly equal because user input is not definite. if ((size + kDataMemAlignSizeCompare) < op_size) { - REPORT_INNER_ERROR("E19999", "%s size:%ld from user add align:%u < input_op_size:%ld in model, model_id:%u, " + REPORT_INNER_ERROR("E19999", "%s size:%ld from user add align:%u < op_size:%ld in model, model_id:%u, " "check invalid", input_or_output.c_str(), size, kDataMemAlignSizeCompare, op_size, model_id_); GELOGE(ACL_ERROR_GE_PARAM_INVALID, - "[Check][Param] %s size:%ld from user add align:%u < input_op_size:%ld in model, model_id:%u", + "[Check][Param] %s size:%ld from user add align:%u < op_size:%ld in model, model_id:%u", input_or_output.c_str(), size, kDataMemAlignSizeCompare, op_size, model_id_); return false; } @@ -3673,6 +3697,7 @@ Status DavinciModel::InitConstant(const OpDescPtr &op_desc) { elem_num = 1; } uint64_t *buff = reinterpret_cast(tensor->MutableData().data()); + GE_CHECK_NOTNULL(buff); if (ge::CheckInt64Uint32MulOverflow(elem_num, kBytes * kStringHeadElems) != SUCCESS) { GELOGE(FAILED, "[Call][CheckInt64Uint32MulOverflow] Shape size:%ld is invalid", elem_num); return FAILED; @@ -3700,6 +3725,7 @@ Status DavinciModel::InitConstant(const OpDescPtr &op_desc) { /// @return Status /// Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { + string bin_file = op_desc->GetName(); auto kernel = ge_model_->GetTBEKernelStore().FindKernel(op_desc->GetName()); auto tbe_kernel = (kernel != nullptr) ? kernel : op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); if (tbe_kernel == nullptr) { @@ -3708,12 +3734,61 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str()); return INTERNAL_ERROR; } + GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file, tbe_kernel, false), "Function register of bin file: %s failed", + bin_file.c_str()); + return SUCCESS; +} - std::string session_graph_model_id; - GetUniqueId(op_desc, session_graph_model_id); - const char *bin_file_key = GetRegisterStub(op_desc->GetName(), session_graph_model_id); // from set, always valid. - TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); +Status DavinciModel::InitTbeHandleWithFfts(const OpDescPtr &op_desc) { + std::vector tbe_kernel; + tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_THREAD_TBE_KERNEL, tbe_kernel); + GELOGD("Kernel bin ptr vec size is %zu.", tbe_kernel.size()); + if (tbe_kernel.size() != kFftsTbeHandleElementSize) { + REPORT_INNER_ERROR("E19999", "Get tbe_kernel for op:%s(%s) fail, model_id:%u", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); + GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file, size is %zu when ffts", + op_desc->GetName().c_str(), tbe_kernel.size()); + return INTERNAL_ERROR; + } + if (tbe_kernel[0] == nullptr || tbe_kernel[1] == nullptr) { + REPORT_INNER_ERROR("E19999", "Tbe kernel for op:%s is nullptr.", op_desc->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: tvm bin file of %s is nullptr when ffts.", op_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + vector bin_file_keys; + (void)AttrUtils::GetListStr(op_desc, kStubFuncName, bin_file_keys); + if (bin_file_keys.size() != kFftsTbeHandleElementSize) { + REPORT_INNER_ERROR("E19999", "Get bin_file for op:%s(%s) fail.", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find bin file keys, size is %zu when ffts", + op_desc->GetName().c_str(), bin_file_keys.size()); + return INTERNAL_ERROR; + } + GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kNonTailBlock], tbe_kernel[kNonTailBlock], true, + kNonTailBlock), + "Function register of first bin file %s failed.", bin_file_keys[kNonTailBlock].c_str()); + GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kTailBlock], tbe_kernel[kTailBlock], true, kTailBlock), + "Function register of second bin file %s failed.", bin_file_keys[kTailBlock].c_str()); + return SUCCESS; +} +Status DavinciModel::FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, + bool is_ffts, size_t thread_index) { + if (thread_index > 1) { + GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Thread index: %zu should less than 1.", thread_index); + return INTERNAL_ERROR; + } + const char *bin_file_key; + if (is_ffts) { + bin_file_key = GetRegisterStub(bin_file, ""); + GELOGI("Node:%s inherit func name:%s directly.", op_desc->GetName().c_str(), bin_file_key); + } else { + std::string session_graph_model_id; + GetUniqueId(op_desc, session_graph_model_id); + bin_file_key = GetRegisterStub(bin_file, session_graph_model_id); // from set, always valid. + } + + TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); std::lock_guard lock(tvm_bin_mutex_); if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) { void *bin_handle = nullptr; @@ -3721,59 +3796,115 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key); rtDevBinary_t binary; - std::string json_string; - GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_MAGIC, json_string), - GELOGD("Get original type of session_graph_id.")); - if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICPU") { - binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICPU; - } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF") { - binary.magic = RT_DEV_BINARY_MAGIC_ELF; - } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") { - binary.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; - } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") { - binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE; - } else { - REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", - TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); - GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", - TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); - return PARAM_INVALID; - } - + GE_CHK_STATUS_RET(InitBinaryMagic(op_desc, is_ffts, thread_index, binary), "Init binary magic of %s failed.", + op_desc->GetName().c_str()); binary.version = 0; binary.data = tbe_kernel->GetBinData(); binary.length = tbe_kernel->GetBinDataSize(); - GELOGD("TBE: binary.length: %lu", binary.length); GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle)); - std::string meta_data; - GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_METADATA, meta_data), - GELOGI("Get original type of json_string")); - GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str()); - GE_IF_BOOL_EXEC(!meta_data.empty(), GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str()))); - + GE_CHK_STATUS_RET(InitMetaData(op_desc, is_ffts, thread_index, bin_handle), "Init tvm meta data of %s failed.", + op_desc->GetName().c_str()); kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel); } else { GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key); kernel_store.ReferTBEHandle(bin_file_key); } - std::string kernel_name; - GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name), - GELOGD("Get original type of kernel_name")); + GE_CHK_STATUS_RET(InitKernelName(op_desc, is_ffts, thread_index, kernel_name), "Init kernel name of %s failed.", + op_desc->GetName().c_str()); GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0)); used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1. return SUCCESS; } - // Kernel registed, Increase used num in store. StoreTbeHandle(bin_file_key); return SUCCESS; } +Status DavinciModel::InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, + rtDevBinary_t &binary) { + string json_string; + const string &tvm_magic = is_ffts ? TVM_ATTR_NAME_THREAD_MAGIC : TVM_ATTR_NAME_MAGIC; + const static std::map binary_magics = { + {"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU}, + {"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF}, + {"RT_DEV_BINARY_MAGIC_ELF_AIVEC", RT_DEV_BINARY_MAGIC_ELF_AIVEC}, + {"RT_DEV_BINARY_MAGIC_ELF_AICUBE", RT_DEV_BINARY_MAGIC_ELF_AICUBE} + }; + if (is_ffts) { + vector json_list; + (void)AttrUtils::GetListStr(op_desc, tvm_magic, json_list); + if (json_list.size() != kFftsTbeHandleElementSize) { + GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Attr is %s, thread index is %zu, json list size is %zu.", + tvm_magic.c_str(), thread_index, json_list.size()); + return INTERNAL_ERROR; + } + json_string = json_list[thread_index]; + } else { + (void)AttrUtils::GetStr(op_desc, tvm_magic, json_string); + } + auto iter = binary_magics.find(json_string); + if (iter == binary_magics.end()) { + REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", + tvm_magic.c_str(), json_string.c_str(), op_desc->GetName().c_str(), + op_desc->GetType().c_str(), model_id_); + GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", + TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), + op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); + return PARAM_INVALID; + } + binary.magic = iter->second; + return SUCCESS; +} + +Status DavinciModel::InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle) { + string meta_data; + const string &tvm_metadata = is_ffts ? TVM_ATTR_NAME_THREAD_METADATA : TVM_ATTR_NAME_METADATA; + if (is_ffts) { + vector meta_data_list; + (void)AttrUtils::GetListStr(op_desc, tvm_metadata, meta_data_list); + if (meta_data_list.size() != kFftsTbeHandleElementSize) { + GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, meta data list size is %zu.", + tvm_metadata.c_str(), thread_index, meta_data_list.size()); + return INTERNAL_ERROR; + } + meta_data = meta_data_list[thread_index]; + } else { + (void)AttrUtils::GetStr(op_desc, tvm_metadata, meta_data); + } + GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str()); + if (!meta_data.empty()) { + GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str())); + } + return SUCCESS; +} + +Status DavinciModel::InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name) { + if (is_ffts) { + // delete prefix, eg: *sgt_graph_nodes*/loss_scale/gradient/fp32_vals/Mean_grad/Tile + vector kernel_name_list; + auto pos = op_desc->GetName().find("/"); + if (pos == std::string::npos) { + GELOGE(INTERNAL_ERROR, "[Check][Param] failed, subgraph node name: %s.", op_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + string attr_kernel_name = op_desc->GetName().substr(pos + 1) + "_thread_kernelname"; + (void)AttrUtils::GetListStr(op_desc, attr_kernel_name, kernel_name_list); + if (kernel_name_list.size() != kFftsTbeHandleElementSize) { + GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, kernel name list size is %zu.", + attr_kernel_name.c_str(), thread_index, kernel_name_list.size()); + return INTERNAL_ERROR; + } + kernel_name = kernel_name_list[thread_index]; + } else { + string attr_kernel_name = op_desc->GetName() + "_kernelname"; + (void)AttrUtils::GetStr(op_desc, attr_kernel_name, kernel_name); + } + return SUCCESS; +} + void DavinciModel::StoreTbeHandle(const std::string &handle_key) { // Online mode FE may call rtFunctionRegister. TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); @@ -4256,7 +4387,7 @@ void DavinciModel::SetDataDumperArgs(const ComputeGraphPtr &graph, const map void *{ @@ -4523,4 +4654,50 @@ Status DavinciModel::GetTotalMemSizeExcludeZeroCopy(int64_t &total_useful_size) total_useful_size = runtime_param_.mem_size - runtime_param_.zero_copy_size; return SUCCESS; } + +Status DavinciModel::GetEventIdForBlockingAicpuOp(const OpDescPtr &op_desc, rtStream_t stream, uint32_t &event_id) { + GELOGI("Get event id for aicpu blocking op:%s", op_desc->GetName().c_str()); + auto it = stream_2_event_.find(stream); + if (it != stream_2_event_.end()) { + auto rt_ret = rtGetEventID(it->second, &event_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetEventID failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetEventID] failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + } else { + rtEvent_t rt_event = nullptr; + auto rt_ret = rtEventCreateWithFlag(&rt_event, RT_EVENT_WITH_FLAG); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtEventCreateWithFlag failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + GELOGE(RT_FAILED, "[Call][rtEventCreateWithFlag] failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + rt_ret = rtGetEventID(rt_event, &event_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetEventID failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetEventID] failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + stream_2_event_.emplace(stream, rt_event); + } + return SUCCESS; +} + +Status DavinciModel::GetEventByStream(const rtStream_t &stream, rtEvent_t &rt_event) { + auto it = stream_2_event_.find(stream); + if (it == stream_2_event_.end()) { + REPORT_INNER_ERROR("E19999", "Get event failed"); + GELOGE(FAILED, "[Get][Event] Get event failed"); + return FAILED; + } + rt_event = it->second; + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/load/model_manager/davinci_model.h b/ge/graph/load/model_manager/davinci_model.h index 819a2ea2..76b0beef 100755 --- a/ge/graph/load/model_manager/davinci_model.h +++ b/ge/graph/load/model_manager/davinci_model.h @@ -24,14 +24,14 @@ #include #include -#include "common/ge_types.h" -#include "common/helper/model_helper.h" -#include "common/helper/om_file_helper.h" +#include "framework/common/ge_types.h" +#include "framework/common/helper/model_helper.h" +#include "framework/common/helper/om_file_helper.h" #include "common/opskernel/ge_task_info.h" #include "common/properties_manager.h" #include "common/dump/exception_dumper.h" #include "common/dump/opdebug_register.h" -#include "common/types.h" +#include "framework/common/types.h" #include "framework/common/util.h" #include "graph/debug/ge_attr_define.h" #include "graph/load/model_manager/aipp_utils.h" @@ -43,13 +43,13 @@ #include "graph/model.h" #include "graph/node.h" #include "graph/op_desc.h" -#include "graph/operator.h" +#include "external/graph/operator.h" #include "graph/utils/attr_utils.h" #include "graph/utils/tensor_utils.h" #include "mmpa/mmpa_api.h" #include "proto/task.pb.h" -#include "task_info/task_info.h" -#include "graph/common/local_context.h" +#include "graph/load/model_manager/task_info/task_info.h" +#include "common/local_context.h" using std::mutex; using std::thread; @@ -300,6 +300,7 @@ class DavinciModel { return op_list_.at(index); } + void SetGlobalStep(void *global_step, uint64_t global_step_size); void *GetGlobalStep() const { return global_step_addr_; } // get task info for profiling @@ -498,10 +499,6 @@ class DavinciModel { return exception_dumper_.DumpExceptionInfo(exception_infos); } - void SetKnownShapeGlobalStep(void *global_step) { - known_shape_global_step_ = global_step; - } - void DumperShrink() { data_dumper_.DumpShrink(); } @@ -585,6 +582,10 @@ class DavinciModel { void SetRunningFlag(bool flag) { running_flg_ = flag; } Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback); + // for blocking aicpu op + Status GetEventByStream(const rtStream_t &stream, rtEvent_t &rt_event); + Status GetEventIdForBlockingAicpuOp(const OpDescPtr &op_desc, rtStream_t stream, uint32_t &event_id); + private: // memory address of weights uint8_t *weights_mem_base_; @@ -771,6 +772,12 @@ class DavinciModel { /// @return Status /// Status InitTbeHandle(const OpDescPtr &op_desc); + Status InitTbeHandleWithFfts(const OpDescPtr &op_desc); + Status FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, bool is_ffts, + size_t thread_index = 0); + Status InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, rtDevBinary_t &binary); + Status InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle); + Status InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name); void StoreTbeHandle(const string &handle_key); void CleanTbeHandle(); @@ -1102,11 +1109,10 @@ class DavinciModel { vector output_descs_; vector output_formats_; - // known shape node for dump - void *known_shape_global_step_; - // op name to attrs mapping std::map>> op_name_to_attrs_; + + std::map stream_2_event_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ diff --git a/ge/graph/load/model_manager/model_manager.cc b/ge/graph/load/model_manager/model_manager.cc index 6a563d2f..d0d88e66 100755 --- a/ge/graph/load/model_manager/model_manager.cc +++ b/ge/graph/load/model_manager/model_manager.cc @@ -21,11 +21,11 @@ #include "aicpu/aicpu_schedule/aicpu_op_type_list.h" #include "common/model_parser/model_parser.h" #include "common/dump/dump_manager.h" -#include "common/l2_cache_optimize.h" +#include "framework/common/l2_cache_optimize.h" #include "common/profiling/profiling_manager.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/load/model_manager/davinci_model.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" #include "common/formats/utils/formats_trans_utils.h" namespace ge { @@ -368,7 +368,17 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrGetRuntimeParam().graph_id; + if (subcribe_info.graph_id == graph_id) { + profiling_manager.SetGraphIdToModelMap(graph_id, model_id); + } + else { + GELOGW("graph_id:%u is not in subcribe info.", graph_id); + } + } return ret; } @@ -513,8 +523,7 @@ Status ModelManager::GetCurDynamicDims(const vector> &user_real_ } GELOGD("Cur dynamic dims is %s.", formats::JoinToString(cur_dynamic_dims).c_str()); bool cur_dynamic_dims_valid = false; - std::vector shape_strs = ge::StringUtils::Split(GetLocalOmgContext().dynamic_dims, ';'); - for (auto dynamic_dim : shape_strs) { + for (auto dynamic_dim : GetLocalOmeContext().dynamic_shape_dims) { if (dynamic_dim == formats::JoinToString(cur_dynamic_dims)) { cur_dynamic_dims_valid = true; break; @@ -556,10 +565,10 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector cur_dynamic_dims; - if (!GetLocalOmgContext().user_real_input_dims.empty()) { - if (GetCurDynamicDims(GetLocalOmgContext().user_real_input_dims, GetLocalOmgContext().user_input_dims, + if (!GetLocalOmeContext().user_real_input_dims.empty()) { + if (GetCurDynamicDims(GetLocalOmeContext().user_real_input_dims, GetLocalOmeContext().user_input_dims, cur_dynamic_dims) != SUCCESS) { GELOGE(INTERNAL_ERROR, "[Get][CurDynamicDims] [Train_Dynamic] Failed to Parse real_dynamic_dims."); return INTERNAL_ERROR; @@ -570,6 +579,7 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector(cur_dynamic_dims.size() * sizeof(int32_t)); GE_CHK_BOOL_EXEC(memcpy_s(data.data, length, cur_dynamic_dims.data(), length) == EOK, REPORT_CALL_ERROR("E19999", "memcpy data failed, size:%u", length); + delete[] reinterpret_cast(data.data); return INTERNAL_ERROR, "[Memcpy][Data] failed, size:%u.", length); data.length = length; input_data.blobs.push_back(data); @@ -758,12 +768,15 @@ Status ModelManager::HandleProfModelUnsubscribeCommand(const Command &command) { if (ret != SUCCESS) { return ret; } - - if (ProfilingManager::Instance().ProfModelUnsubscribe(static_cast(davinci_model.get())) != SUCCESS) { + auto &profiling_manager = ProfilingManager::Instance(); + if (profiling_manager.ProfModelUnsubscribe(static_cast(davinci_model.get())) != SUCCESS) { GELOGE(FAILED, "[Handle][ProfModelUnsubscribe] failed."); return FAILED; } - + auto is_subscribe = profiling_manager.GetSubscribeInfo().is_subscribe; + if (is_subscribe) { + profiling_manager.CleanSubscribeInfo(); + } return SUCCESS; } @@ -1378,7 +1391,9 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); std::lock_guard lock(cust_aicpu_mutex_); - if (cust_aicpu_so_.size() == 0) return SUCCESS; + if (cust_aicpu_so_.empty()) { + return SUCCESS; + } // get current context rtContext_t rt_cur_ctx = nullptr; auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); @@ -1394,9 +1409,19 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { return SUCCESS; } + rtStream_t stream = nullptr; vector allocated_mem; + std::function callback = [&]() { + for (auto mem : allocated_mem) { + GE_CHK_RT(rtFree(mem)); + } + if (stream != nullptr) { + GE_CHK_RT(rtStreamDestroy(stream)); + } + }; + GE_MAKE_GUARD(release, callback); + rtError_t status; - rtStream_t stream = nullptr; vector v_cust_so; void *args = nullptr; @@ -1471,13 +1496,6 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { GELOGE(RT_FAILED, "[Call][RtStreamSynchronize] fail, ret = 0x%X", status); return RT_ERROR_TO_GE_STATUS(status); } - std::function callback = [&]() { - for (auto mem : allocated_mem) { - GE_CHK_RT(rtFree(mem)); - } - GE_CHK_RT(rtStreamDestroy(stream)); - }; - GE_MAKE_GUARD(release, callback); GELOGI("Cpu kernel launch task success."); return SUCCESS; } @@ -1786,7 +1804,8 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector &aicpu_op std::vector op_name; op_name.clear(); op_name.resize(kOpNameMaxSize); - GE_CHK_RT(rtMemcpy(op_name.data(), aicpu_info.opLen, reinterpret_cast(aicpu_info.opType), + GE_CHK_RT(rtMemcpy(op_name.data(), aicpu_info.opLen, + reinterpret_cast(static_cast(aicpu_info.opType)), aicpu_info.opLen, RT_MEMCPY_DEVICE_TO_HOST)); std::string kernel_type = (static_cast(aicpu_info.kernelsType) == TF_KERNEL) ? "TF_KERNEL" : "CPU_KERNEL"; @@ -1820,5 +1839,4 @@ Status ModelManager::CheckAicpuOpList(GeModelPtr ge_model) { "[Call][LaunchKernelCheckAicpuOp] failed."); return SUCCESS; } - } // namespace ge diff --git a/ge/graph/load/model_manager/model_manager.h b/ge/graph/load/model_manager/model_manager.h index e35bb7aa..6389d6db 100755 --- a/ge/graph/load/model_manager/model_manager.h +++ b/ge/graph/load/model_manager/model_manager.h @@ -17,7 +17,7 @@ #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_MODEL_MANAGER_H_ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_MODEL_MANAGER_H_ -#include +#include #include #include #include @@ -26,13 +26,13 @@ #include #include #include "cce/aicpu_engine_struct.h" -#include "common/ge_inner_error_codes.h" -#include "common/ge_types.h" -#include "common/helper/model_helper.h" -#include "common/helper/om_file_helper.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/ge_types.h" +#include "framework/common/helper/model_helper.h" +#include "framework/common/helper/om_file_helper.h" #include "common/properties_manager.h" -#include "common/types.h" -#include "ge/ge_api_types.h" +#include "framework/common/types.h" +#include "external/ge/ge_api_types.h" #include "graph/ge_context.h" #include "graph/model.h" #include "hybrid/hybrid_davinci_model.h" diff --git a/ge/graph/load/model_manager/model_utils.cc b/ge/graph/load/model_manager/model_utils.cc index 224a3331..a31837ca 100755 --- a/ge/graph/load/model_manager/model_utils.cc +++ b/ge/graph/load/model_manager/model_utils.cc @@ -16,11 +16,11 @@ #include "graph/load/model_manager/model_utils.h" #include -#include "common/debug/log.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/debug/log.h" +#include "framework/common/op/ge_op_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/manager/graph_var_manager.h" -#include "graph/types.h" +#include "external/graph/types.h" #include "graph/build/memory/block_mem_assigner.h" #include "common/math/math_util.h" diff --git a/ge/graph/load/model_manager/model_utils.h b/ge/graph/load/model_manager/model_utils.h index 8ce1b060..0eadc7a8 100755 --- a/ge/graph/load/model_manager/model_utils.h +++ b/ge/graph/load/model_manager/model_utils.h @@ -19,8 +19,8 @@ #include -#include "common/ge_inner_error_codes.h" -#include "common/types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" #include "graph/load/model_manager/task_info/task_info.h" #include "graph/op_desc.h" #include "graph/utils/tensor_adapter.h" diff --git a/ge/graph/load/model_manager/task_info/ffts_task_info.cc b/ge/graph/load/model_manager/task_info/ffts_task_info.cc new file mode 100644 index 00000000..e311ccac --- /dev/null +++ b/ge/graph/load/model_manager/task_info/ffts_task_info.cc @@ -0,0 +1,393 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/load/model_manager/task_info/ffts_task_info.h" + +#include + +#include "graph/load/model_manager/davinci_model.h" + +namespace { +constexpr uint32_t kAddrLen = sizeof(void *); +} +namespace ge { +FftsTaskInfo::~FftsTaskInfo() { + GE_FREE_RT_LOG(args_); +} + +Status FftsTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GELOGI("FftsTaskInfo Init Start."); + GE_CHECK_NOTNULL(davinci_model); + davinci_model_ = davinci_model; + GE_CHK_STATUS_RET_NOLOG(SetStream(task_def.stream_id(), davinci_model_->GetStreamList())); + + const domi::FftsTaskDef &ffts_task_def = task_def.ffts_task(); + OpDescPtr op_desc = davinci_model_->GetOpByIndex(ffts_task_def.op_index()); + GE_CHECK_NOTNULL(op_desc); + + if ((ffts_task_def.sub_task_size() > static_cast(RT_FFTS_MAX_SUB_TASK_NUM)) || + (ffts_task_def.ticket_cache_size() > static_cast(RT_FFTS_MAX_TICKET_CACHE_NUM))) { + GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Node: %s, sub task desc size: %d, ticket cache size: %d", + op_desc->GetName().c_str(), ffts_task_def.sub_task_size(), ffts_task_def.ticket_cache_size()); + return INTERNAL_ERROR; + } + args_size_ = kAddrLen * ffts_task_def.addr_size(); + GE_CHK_RT_RET(rtMalloc(&args_, args_size_, RT_MEMORY_HBM)); + InitFftsDescInfo(ffts_task_def.ffts_desc(), sub_task_info_.fftsDesc); + + sub_task_info_.fftsType = static_cast(ffts_task_def.ffts_type()); + sub_task_info_.subTaskNum = ffts_task_def.sub_task_size(); + for (int idx = 0; idx < ffts_task_def.sub_task_size(); ++idx) { + GE_CHK_STATUS_RET_NOLOG(InitSubTaskInfo(ffts_task_def.sub_task(idx), sub_task_info_.subTask[idx])); + } + + sub_task_info_.tickCacheNum = ffts_task_def.ticket_cache_size(); + for (int idx = 0; idx < ffts_task_def.ticket_cache_size(); ++idx) { + GE_CHK_STATUS_RET_NOLOG(InitTicketCache(ffts_task_def.ticket_cache(idx), sub_task_info_.ticketCache[idx])); + } + + size_t data_size = kAddrLen * io_addrs_.size(); + GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs_.data(), data_size, RT_MEMCPY_HOST_TO_DEVICE)); + GELOGI("FftsTaskInfo::Init Success. Node: %s, input/output size: %zu", op_desc->GetName().c_str(), io_addrs_.size()); + return SUCCESS; +} + +void FftsTaskInfo::InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc) { + ffts_desc.tm = static_cast(ffts_desc_def.tm()); + ffts_desc.di = static_cast(ffts_desc_def.di()); + ffts_desc.dw = static_cast(ffts_desc_def.dw()); + ffts_desc.df = static_cast(ffts_desc_def.df()); + ffts_desc.dataSplitUnit = static_cast(ffts_desc_def.data_split_unit()); + ffts_desc.prefetchOstNum = static_cast(ffts_desc_def.prefetch_ost_num()); + ffts_desc.cacheMaintainOstNum = static_cast(ffts_desc_def.cache_maintain_ost_num()); + ffts_desc.aicPrefetchUpper = static_cast(ffts_desc_def.aic_prefetch_upper()); + ffts_desc.aicPrefetchLower = static_cast(ffts_desc_def.aic_prefetch_lower()); + ffts_desc.aivPrefetchUpper = static_cast(ffts_desc_def.aiv_prefetch_upper()); + ffts_desc.aivPrefetchLower = static_cast(ffts_desc_def.aiv_prefetch_lower()); +} + +Status FftsTaskInfo::InitSubTaskInfo(const domi::FftsSubTaskDef &sub_task_def, rtFftsSubTaskInfo_t &sub_task_desc) { + if ((sub_task_def.dst_tick_cache_id_size() > static_cast(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) || + (sub_task_def.src_tick_cache_id_size() > static_cast(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) { + GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, dst tick cache id size: %d, src tick cache id size: %d", + sub_task_def.dst_tick_cache_id_size(), sub_task_def.src_tick_cache_id_size()); + return FAILED; + } + + if (sub_task_def.has_auto_thread_aic_aiv() == sub_task_def.has_manual_thread_aic_aiv()) { + GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, auto thread aic/aiv: %d, manual thread aic/aiv: %d", + sub_task_def.has_auto_thread_aic_aiv(), sub_task_def.has_manual_thread_aic_aiv()); + return FAILED; + } + + thread_dim_ = sub_task_def.thread_dim(); + GE_CHK_BOOL_RET_STATUS(thread_dim_ != 0, FAILED, "[Get][thread_dim] failed, Invalid thread dim: %u!", thread_dim_); + + sub_task_desc.subTaskType = static_cast(sub_task_def.sub_task_type()); + sub_task_desc.threadDim = sub_task_def.thread_dim(); + + sub_task_desc.dstTickCacheVldBitmap = sub_task_def.dst_tick_cache_vld_bitmap(); + sub_task_desc.srcTickCacheVldBitmap = sub_task_def.src_tick_cache_vld_bitmap(); + sub_task_desc.srcDataOutOfSubGraphBitmap = sub_task_def.src_data_out_of_subgraph_bitmap(); + + for (int idx = 0; idx < sub_task_def.dst_tick_cache_id_size(); ++idx) { + sub_task_desc.dstTickCacheID[idx] = sub_task_def.dst_tick_cache_id(idx); + } + + for (int idx = 0; idx < sub_task_def.src_tick_cache_id_size(); ++idx) { + sub_task_desc.srcTickCacheID[idx] = sub_task_def.src_tick_cache_id(idx); + } + + if (sub_task_def.has_auto_thread_aic_aiv()) { + GE_CHK_STATUS_RET_NOLOG(InitAutoAicAiv(sub_task_def.auto_thread_aic_aiv(), sub_task_desc.custom.autoThreadAicAiv)); + } + + if (sub_task_def.has_manual_thread_aic_aiv()) { + GE_CHK_STATUS_RET_NOLOG( + InitManualAicAiv(sub_task_def.manual_thread_aic_aiv(), sub_task_desc.custom.manualThreadAicAiv)); + } + + if (sub_task_def.has_manual_thread_nop()) { + GE_CHK_STATUS_RET_NOLOG(InitManualNop(sub_task_def.manual_thread_nop(), sub_task_desc.custom.manualThreadNop)); + } + + return SUCCESS; +} + +Status FftsTaskInfo::InitTicketCache(const domi::TicketCacheDef &ticket_cache_def, rtTicketCache_t &ticket_cache) { + if (ticket_cache_def.has_auto_thread_cache() == ticket_cache_def.has_manual_thread_cache()) { + GELOGE(FAILED, "[Check][Param] Invalid TicketCacheDef, has auto thread cache: %d, has manual thread cache: %d", + ticket_cache_def.has_auto_thread_cache(), ticket_cache_def.has_manual_thread_cache()); + return FAILED; + } + + ticket_cache.cacheOption = static_cast(ticket_cache_def.cache_option()); + ticket_cache.ticketCacheWindow = ticket_cache_def.ticket_cache_window(); + + if (ticket_cache_def.has_auto_thread_cache()) { + InitAutoCacheInfo(ticket_cache_def.auto_thread_cache(), ticket_cache.custom.autoThreadCache); + } + if (ticket_cache_def.has_manual_thread_cache()) { + GE_CHK_STATUS_RET_NOLOG( + InitManualCacheInfo(ticket_cache_def.manual_thread_cache(), ticket_cache.custom.manualThreadCache)); + } + + return SUCCESS; +} + +// task_addr = {0,200,700,1000,2000, 3500} +// task_addr_offset = {20,40,2,100,200} +template +Status FftsTaskInfo::InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim, + uint32_t addr_count) { + for (uint32_t i = 0; i < addr_count; ++i) { + uintptr_t logic_addr = aic_aiv_def.task_addr(i) + thread_dim * aic_aiv_def.task_addr_offset(i); + uint8_t *io_addr = nullptr; + if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress]GetRtAddress failed."); + return INTERNAL_ERROR; + } + GELOGD("aic_aiv_def task base addr is %ld, offset is %ld, thread is %d, logic addrs is 0x%lx, io addr is %p", + aic_aiv_def.task_addr(i), aic_aiv_def.task_addr_offset(i), thread_dim, logic_addr, io_addr); + io_addrs_.emplace_back(io_addr); + } + return SUCCESS; +} + +Status FftsTaskInfo::InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv) { + if (aic_aiv_def.src_prefetch_size() > static_cast(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) { + GELOGE(FAILED, "[Check][Param] Invalid AutoThreadAicAivInfo, prefetch size: %d", aic_aiv_def.src_prefetch_size()); + return FAILED; + } + + aic_aiv.taskParamAddr = reinterpret_cast(args_) + kAddrLen * io_addrs_.size(); + GELOGD("AutoThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr); + const auto &rts_param = davinci_model_->GetRuntimeParam(); + for (uint32_t i = 0; i < thread_dim_ - 1; ++i) { + GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i, + static_cast(aic_aiv_def.task_addr_offset_size()))); + } + GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count())); + int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size(); + for (int k = 0; k < last_thread_workspace_size; ++k) { + uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k); + uint8_t *io_addr = nullptr; + GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr)); + GELOGD("logic addr is 0x%lx, io addr is %p.", logic_addr, io_addr); + io_addrs_.emplace_back(io_addr); + } + + aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset(); + GELOGD("args_: %p, io_addrs size: %zu, task param offset: %u.", args_, io_addrs_.size(), aic_aiv.taskParamOffset); + aic_aiv.satMode = aic_aiv_def.sat_mode(); + aic_aiv.scheduleMode = aic_aiv_def.schedule_mode(); + aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt(); + + aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap(); + aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap(); + + aic_aiv.tailBlkDim = aic_aiv_def.tail_blk_dim(); + aic_aiv.nonTailBlkDim = aic_aiv_def.non_tail_blk_dim(); + + aic_aiv.nonTailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.non_tail_task_func_stub(), ""); + aic_aiv.tailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.tail_task_func_stub(), ""); + + GELOGI("Set func name[%s][%s] succ.", aic_aiv.nonTailTaskFuncStub, aic_aiv.tailTaskFuncStub); + for (int idx = 0; idx < aic_aiv_def.src_prefetch_size(); ++idx) { + InitAutoPrefetch(aic_aiv_def.src_prefetch(idx), aic_aiv.srcPrefetch[idx]); + } + + return SUCCESS; +} + +void FftsTaskInfo::InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache) { + cache.dataAddr = cache_def.data_addr(); + cache.dataAddrOffset = cache_def.data_addr_offset(); + cache.nonTailDataLen = cache_def.non_tail_data_len(); + cache.tailDataLen = cache_def.tail_data_len(); + cache.ticketCacheRefCnt = cache_def.ticket_cache_ref_cnt(); +} + +void FftsTaskInfo::InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch) { + prefetch.dataAddr = prefetch_def.data_addr(); + prefetch.dataAddrOffset = prefetch_def.data_addr_offset(); + prefetch.nonTailDataLen = prefetch_def.non_tail_data_len(); + prefetch.tailDataLen = prefetch_def.tail_data_len(); +} + +Status FftsTaskInfo::InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def, + rtManualThreadAicAivInfo_t &aic_aiv) { + if ((aic_aiv_def.thread_prefetch_dmu_idx_size() > static_cast(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || + (aic_aiv_def.thread_blk_dim_size() > static_cast(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || + (aic_aiv_def.thread_task_func_stub_size() > static_cast(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || + (aic_aiv_def.src_dep_tbl_size() > static_cast(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) { + GELOGE(FAILED, "[Check][Param] Invalid ManualThreadAicAivInfo, thread prefetch dmu desc size: %d, " + "thread blk dim size: %d, thread task func stub size: %d, src dep tbl size: %d", + aic_aiv_def.thread_prefetch_dmu_idx_size(), aic_aiv_def.thread_blk_dim_size(), + aic_aiv_def.thread_task_func_stub_size(), aic_aiv_def.src_dep_tbl_size()); + return FAILED; + } + aic_aiv.taskParamAddr = reinterpret_cast(args_) + kAddrLen * io_addrs_.size(); + GELOGD("ManualThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr); + const auto &rts_param = davinci_model_->GetRuntimeParam(); + for (uint32_t i = 0; i < thread_dim_ - 1; ++i) { + GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i, + static_cast(aic_aiv_def.task_addr_offset_size()))); + } + GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count())); + int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size(); + for (int k = 0; k < last_thread_workspace_size; ++k) { + uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k); + uint8_t *io_addr = nullptr; + GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr)); + io_addrs_.emplace_back(io_addr); + } + aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset(); + + aic_aiv.satMode = aic_aiv_def.sat_mode(); + aic_aiv.scheduleMode = aic_aiv_def.schedule_mode(); + aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt(); + + aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap(); // 8 bit bitmap 1 0 1 0 + aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap(); // 8 bit bitmap 1 0 1 0 + aic_aiv.prefetchOnceDmuNum = aic_aiv_def.prefetch_once_dmu_num(); + + for (int idx = 0; idx < aic_aiv_def.thread_prefetch_dmu_idx_size(); ++idx) { + aic_aiv.threadPrefetchDmuIdx[idx] = aic_aiv_def.thread_prefetch_dmu_idx(idx); + } + for (int idx = 0; idx < aic_aiv_def.thread_blk_dim_size(); ++idx) { + aic_aiv.threadBlkDim[idx] = aic_aiv_def.thread_blk_dim(idx); + } + for (int idx = 0; idx < aic_aiv_def.thread_task_func_stub_size(); ++idx) { + aic_aiv.threadTaskFuncStub[idx] = aic_aiv_def.thread_task_func_stub(idx).c_str(); + } + + InitManualDmuInfo(aic_aiv_def, aic_aiv.prefetchList); + for (int idx = 0; idx < aic_aiv_def.src_dep_tbl_size(); ++idx) { + GE_CHK_STATUS_RET_NOLOG(InitManualDependency(aic_aiv_def.src_dep_tbl(idx), aic_aiv.srcDepTbl[idx])); + } + + return SUCCESS; +} + +Status FftsTaskInfo::InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def, + rtManualThreadCacheInfo_t &cache_info) { + if ((cache_def.slice_dmu_idx_size() > static_cast(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || + (cache_def.ticket_cache_ref_cnt_tbl_size() > static_cast(RT_FFTS_MAX_MANUAL_THREAD_NUM))) { + GELOGE(FAILED, "[Check][Param] Invalid ManualThreadCacheInfo slice dum desc index %d, ticket cache ref cnt %d", + cache_def.slice_dmu_idx_size(), cache_def.ticket_cache_ref_cnt_tbl_size()); + return FAILED; + } + + InitManualDmuInfo(cache_def, cache_info.dmuList); + for (int idx = 0; idx < cache_def.slice_dmu_idx_size(); ++idx) { + cache_info.sliceDmuIdx[idx] = cache_def.slice_dmu_idx(idx); + } + + for (int idx = 0; idx < cache_def.ticket_cache_ref_cnt_tbl_size(); ++idx) { + cache_info.ticketCacheRefCntTbl[idx] = cache_def.ticket_cache_ref_cnt_tbl(idx); + } + + return SUCCESS; +} + +Status FftsTaskInfo::InitManualDependency(const domi::ManualThreadDependencyDef &dependency_def, + rtManualThreadDependency_t &dependency) { + if (dependency_def.dependency_size() > static_cast(RT_FFTS_MANUAL_SRC_DEPEND_TBL_LEN)) { + GELOGE(FAILED, "[Check][Param] Invalid ManualThreadDependency size: %d", dependency_def.dependency_size()); + return FAILED; + } + + for (int idx = 0; idx < dependency_def.dependency_size(); ++idx) { + dependency.dependency[idx] = dependency_def.dependency(idx); + } + + return SUCCESS; +} + +Status FftsTaskInfo::InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop_info) { + if (nop_def.src_dep_tbl_size() > static_cast(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) { + GELOGE(FAILED, "[Check][Param] Invalid ManualThreadNopInfo, src dep tbl size: %d", nop_def.src_dep_tbl_size()); + return FAILED; + } + + for (int idx = 0; idx < nop_def.src_dep_tbl_size(); ++idx) { + GE_CHK_STATUS_RET_NOLOG(InitManualDependency(nop_def.src_dep_tbl(idx), nop_info.srcDepTbl[idx])); + } + + return SUCCESS; +} + +void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu) { + if (aic_aiv_def.prefetch_list().empty()) { + return; + } + + std::vector buffer(sizeof(rtManualThreadDmuInfo_t) * aic_aiv_def.prefetch_list_size()); + dmu = reinterpret_cast(buffer.data()); + for (int idx = 0; idx < aic_aiv_def.prefetch_list_size(); ++idx) { + InitManualDmuInfo(aic_aiv_def.prefetch_list(idx), dmu[idx]); + } +} + +void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu) { + if (cache_def.dmu_list().empty()) { + return; + } + + std::vector buffer(sizeof(rtManualThreadDmuInfo_t) * cache_def.dmu_list_size()); + dmu = reinterpret_cast(buffer.data()); + for (int idx = 0; idx < cache_def.dmu_list_size(); ++idx) { + InitManualDmuInfo(cache_def.dmu_list(idx), dmu[idx]); + } +} + +void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu) { + dmu.dataAddr = dmu_def.data_addr(); + dmu.numOuter = dmu_def.num_outer(); + dmu.numInner = dmu_def.num_inner(); + dmu.strideOuter = dmu_def.stride_outer(); + dmu.lenInner = dmu_def.len_inner(); + dmu.strideInner = dmu_def.stride_inner(); +} + +Status FftsTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + return SUCCESS; +} + +Status FftsTaskInfo::UpdateArgs() { + GE_CHECK_NOTNULL(davinci_model_); + std::vector io_addrs = io_addrs_; + davinci_model_->UpdateKnownZeroCopyAddr(io_addrs); + auto addr_size = kAddrLen * io_addrs.size(); + GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs.data(), addr_size, RT_MEMCPY_HOST_TO_DEVICE)); + return SUCCESS; +} + +Status FftsTaskInfo::Distribute() { + GELOGI("FftsTaskInfo Distribute Start."); + rtError_t rt_ret = rtFftsTaskLaunch(&sub_task_info_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "[Check][RT_ret] Call rtFftsTaskLaunch failed, ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + + GELOGI("FftsTaskInfo Distribute Success."); + return SUCCESS; +} + +REGISTER_TASK_INFO(RT_MODEL_TASK_FFTS_TASK, FftsTaskInfo); +} // namespace ge diff --git a/ge/graph/load/model_manager/task_info/ffts_task_info.h b/ge/graph/load/model_manager/task_info/ffts_task_info.h new file mode 100644 index 00000000..ffc286f9 --- /dev/null +++ b/ge/graph/load/model_manager/task_info/ffts_task_info.h @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ +#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ + +#include "graph/load/model_manager/task_info/task_info.h" +#include "graph/op_desc.h" + +namespace ge { +class FftsTaskInfo : public TaskInfo { + public: + FftsTaskInfo() = default; + ~FftsTaskInfo() override; + + Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + + Status Distribute() override; + + Status UpdateArgs() override; + + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + + private: + void InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc); + Status InitSubTaskInfo(const domi::FftsSubTaskDef &task_def, rtFftsSubTaskInfo_t &task); + Status InitTicketCache(const domi::TicketCacheDef &cache_def, rtTicketCache_t &cache); + + Status InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv); + void InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache); + void InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch); + + Status InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadAicAivInfo_t &aic_aiv); + Status InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadCacheInfo_t &cache); + Status InitManualDependency(const domi::ManualThreadDependencyDef &depend_def, rtManualThreadDependency_t &depend); + Status InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop); + + void InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu); + void InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu); + void InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu); + + template + Status InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim, uint32_t addr_count); + + DavinciModel *davinci_model_{nullptr}; + rtFftsTaskInfo_t sub_task_info_; + std::vector io_addrs_; + uint32_t thread_dim_{0}; + void *args_{nullptr}; // runtime args memory + uint32_t args_size_{0}; // runtime args memory length +}; +} // namespace ge +#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ diff --git a/ge/graph/load/model_manager/task_info/hccl_task_info.cc b/ge/graph/load/model_manager/task_info/hccl_task_info.cc index c3c5c8b7..a3cef836 100644 --- a/ge/graph/load/model_manager/task_info/hccl_task_info.cc +++ b/ge/graph/load/model_manager/task_info/hccl_task_info.cc @@ -329,7 +329,7 @@ void HcclTaskInfo::GetPrivateDefByTaskDef(const domi::TaskDef &task) { // Get privateDef and opsKernelStorePtr from taskDef and save them in taskInfo GELOGI("get custom info in modelTaskDef."); ops_kernel_store_ = nullptr; - void *ops_kernel_store_name_temp = reinterpret_cast(task.ops_kernel_store_ptr()); + void *ops_kernel_store_name_temp = reinterpret_cast(static_cast(task.ops_kernel_store_ptr())); if (ops_kernel_store_name_temp != nullptr) { ops_kernel_store_ = std::move(ops_kernel_store_name_temp); std::string private_def_temp = task.private_def(); diff --git a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc index 356919f6..fe9cd0cc 100644 --- a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc +++ b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc @@ -23,11 +23,11 @@ #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/fmk_error_codes.h" -#include "graph/attr_value.h" +#include "external/graph/attr_value.h" #include "graph/load/model_manager/davinci_model.h" #include "graph/load/model_manager/model_manager.h" -#include "hybrid/node_executor/aicpu/aicpu_ext_info.h" #include "framework/common/debug/log.h" +#include "runtime/rt.h" namespace { const char *const kAicpuAllshape = "_AllShape"; @@ -43,7 +43,7 @@ Status KernelExTaskInfo::InitTaskExtInfo(const std::string &ext_info, const OpDe UnknowShapeOpType unknown_type = static_cast(unknown_shape_type_val); uint32_t num_inputs = op_desc->GetInputsSize(); uint32_t num_outputs = op_desc->GetOutputsSize(); - std::unique_ptr ext_handle( + std::shared_ptr ext_handle( new(std::nothrow) ::ge::hybrid::AicpuExtInfoHandler(op_desc->GetName(), num_inputs, num_outputs, @@ -76,6 +76,16 @@ Status KernelExTaskInfo::InitTaskExtInfo(const std::string &ext_info, const OpDe } } } + + AttrUtils::GetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, is_blocking_aicpu_op_); + GELOGD("Get op:%s attribute(is_blocking_op), value:%d", op_desc->GetName().c_str(), is_blocking_aicpu_op_); + + if (UpdateEventIdForAicpuBlockingOp(op_desc, ext_handle) != SUCCESS) { + GELOGE(FAILED, "[Call][UpdateEventIdForAicpuBlockingOp] failed for op:%s(%s)", + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + return FAILED; + } + auto rt_ret = rtMalloc(&ext_info_addr_, ext_handle->GetExtInfoLen(), RT_MEMORY_HBM); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, REPORT_CALL_ERROR("E19999", "Call rtMalloc failed, size:%zu, ret:0x%X", ext_info.size(), rt_ret); @@ -106,6 +116,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin // 1. Copy context from kernelExDef.private to workspace uint32_t op_index = kernel_ex_def.op_index(); OpDescPtr op_desc = davinci_model_->GetOpByIndex(op_index); + op_desc_ = op_desc; if (op_desc == nullptr) { REPORT_INNER_ERROR("E19999", "Can't get op_desc from davinci_model by index:%u", op_index); GELOGE(INTERNAL_ERROR, "[Get][Op] by index failed, index:%u is out of range!", op_index); @@ -420,9 +431,9 @@ Status KernelExTaskInfo::Distribute() { // xxxxxxxx xxxxxxxx xxxxxxxx xx10xxxx: HOST_ONLY // xxxxxxxx xxxxxxxx xxxxxxxx xx11xxxx: HOST_FIRST if (topic_type_flag_ > 0) { - dump_flag_ = dump_flag_ | topic_type_flag_; + dump_flag_ = dump_flag_ | static_cast(topic_type_flag_); } - rtError_t rt_ret = rtKernelLaunchEx(kernel_buf_, kernel_buf_size_, dump_flag_, stream_); + rtError_t rt_ret = rtKernelLaunchFwk(op_desc_->GetName().c_str(), kernel_buf_, kernel_buf_size_, dump_flag_, stream_); if (rt_ret != RT_ERROR_NONE) { REPORT_CALL_ERROR("E19999", "Call rtKernelLaunchEx failed, ret:0x%X", rt_ret); GELOGE(RT_FAILED, "[Call][RtKernelLaunchEx] failed, ret:0x%X", rt_ret); @@ -447,6 +458,101 @@ Status KernelExTaskInfo::Distribute() { stream_id_ = stream_id; GELOGI("KernelExTaskInfo Distribute Success. task id: %u, stream id: %u", task_id_, stream_id_); + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp() != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } + return SUCCESS; +} + +Status KernelExTaskInfo::CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support) { + int32_t device_id = 0; + auto rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDevice failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDevice] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + int32_t value = 0; + rt_ret = rtGetDeviceCapability(device_id, FEATURE_TYPE_BLOCKING_OPERATOR, RT_MODULE_TYPE_AICPU, &value); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDeviceCapability failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDeviceCapability] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + if (value != RT_AICPU_BLOCKING_OP_NOT_SUPPORT && value != RT_AICPU_BLOCKING_OP_SUPPORT) { + REPORT_INNER_ERROR("E19999", "Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + GELOGE(FAILED, "[Check][Value] Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + return FAILED; + } + is_support = (value == RT_AICPU_BLOCKING_OP_SUPPORT ? true : false); + return SUCCESS; +} + +Status KernelExTaskInfo::UpdateEventIdForAicpuBlockingOp(const OpDescPtr &op_desc, + std::shared_ptr &ext_handle) { + if (is_blocking_aicpu_op_) { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("Device not support blocking aicpu op process"); + return SUCCESS; + } + uint32_t event_id = 0; + if (davinci_model_->GetEventIdForBlockingAicpuOp(op_desc, stream_, event_id) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get event id failed for op:%s(%s).", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + GELOGE(FAILED, "[Get][EventId] Get event id failed for op:%s(%s)", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + return FAILED; + } + if (ext_handle->UpdateEventId(event_id) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Update event id failed for op:%s(%s).", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + GELOGE(FAILED, "[Update][EventId] Update event id failed for op:%s(%s)", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + return FAILED; + } + GELOGI("Update event_id=%u success", event_id); + } + return SUCCESS; +} + +Status KernelExTaskInfo::DistributeWaitTaskForAicpuBlockingOp() { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("Device not support blocking aicpu op process."); + return SUCCESS; + } + GELOGD("Distribute wait task begin"); + rtEvent_t rt_event = nullptr; + if (davinci_model_->GetEventByStream(stream_, rt_event) != SUCCESS) { + GELOGE(FAILED, "[Call][GetEventByStream] Call GetEventByStream failed"); + return FAILED; + } + auto rt_ret = rtStreamWaitEvent(stream_, rt_event); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtStreamWaitEvent failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + rt_ret = rtEventReset(rt_event, stream_); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtEventReset failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } return SUCCESS; } diff --git a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h index 1b77b715..eb411576 100644 --- a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h +++ b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h @@ -19,6 +19,7 @@ #include "graph/load/model_manager/task_info/task_info.h" #include "graph/op_desc.h" +#include "hybrid/node_executor/aicpu/aicpu_ext_info.h" namespace ge { class KernelExTaskInfo : public TaskInfo { @@ -65,11 +66,18 @@ class KernelExTaskInfo : public TaskInfo { void InitDumpArgs(void *addr, const OpDescPtr &op_desc); Status InitTaskExtInfo(const std::string &ext_info, const OpDescPtr &op_desc); + // for blocking aicpu op + Status DistributeWaitTaskForAicpuBlockingOp(); + Status CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support); + Status UpdateEventIdForAicpuBlockingOp(const OpDescPtr &op_desc, + std::shared_ptr &ext_handle); + uint32_t task_id_; uint32_t stream_id_; uint32_t dump_flag_; uint32_t kernel_buf_size_; DavinciModel *davinci_model_; + OpDescPtr op_desc_; void *kernel_buf_; void *input_output_addr_; void *ext_info_addr_; @@ -78,6 +86,7 @@ class KernelExTaskInfo : public TaskInfo { uint32_t args_offset_ = 0; int64_t fixed_addr_offset_ = 0; int32_t topic_type_flag_ = -1; + bool is_blocking_aicpu_op_ = false; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_KERNEL_EX_TASK_INFO_H_ diff --git a/ge/graph/load/model_manager/task_info/kernel_task_info.cc b/ge/graph/load/model_manager/task_info/kernel_task_info.cc index 919a56cd..6bbfe58e 100755 --- a/ge/graph/load/model_manager/task_info/kernel_task_info.cc +++ b/ge/graph/load/model_manager/task_info/kernel_task_info.cc @@ -28,11 +28,10 @@ #include "graph/load/model_manager/davinci_model.h" #include "graph/load/model_manager/model_manager.h" #include "graph/load/model_manager/model_utils.h" -#include "runtime/kernel.h" -#include "super_kernel/super_kernel.h" -#include "super_kernel/super_kernel_factory.h" +#include "runtime/rt.h" +#include "graph/load/model_manager/task_info/super_kernel/super_kernel.h" +#include "graph/load/model_manager/task_info/super_kernel/super_kernel_factory.h" #include "cce/aicpu_engine_struct.h" -#include "hybrid/node_executor/aicpu/aicpu_ext_info.h" #include "framework/common/debug/log.h" namespace { @@ -436,13 +435,14 @@ Status KernelTaskInfo::Distribute() { // xxxxxxxx xxxxxxxx xxxxxxxx xx01xxxx: DEVICE_FIRST // xxxxxxxx xxxxxxxx xxxxxxxx xx10xxxx: HOST_ONLY // xxxxxxxx xxxxxxxx xxxxxxxx xx11xxxx: HOST_FIRST - dump_flag_ = dump_flag_ | topic_type_flag_; + dump_flag_ = dump_flag_ | static_cast(topic_type_flag_); } GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_); // blockDim is reserved parameter, set to 1 - rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name_.c_str()), - reinterpret_cast(kernel_name_.c_str()), 1, args_, args_size_, - nullptr, stream_, dump_flag_); + std::string op_name = op_desc_->GetName(); + rtKernelLaunchNames_t launch_name = {so_name_.c_str(), kernel_name_.c_str(), op_name.c_str()}; + rt_ret = rtAicpuKernelLaunchWithFlag(&launch_name, 1, args_, args_size_, + nullptr, stream_, dump_flag_); call_save_dump_ = true; } else { /* default: not skt launch */ @@ -473,6 +473,12 @@ Status KernelTaskInfo::Distribute() { } // set for task_id_ UpdateTaskId(); + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp() != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } GELOGD( "KernelTaskInfo Distribute Success. sktenable:%d taskid:%d sktid:%d stubfunc_name:%s stubfunc:%p " "blockdim:%d stream:%p", @@ -481,6 +487,91 @@ Status KernelTaskInfo::Distribute() { return SUCCESS; } +Status KernelTaskInfo::CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support) { + int32_t device_id = 0; + auto rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDevice failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDevice] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + int32_t value = 0; + rt_ret = rtGetDeviceCapability(device_id, FEATURE_TYPE_BLOCKING_OPERATOR, RT_MODULE_TYPE_AICPU, &value); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDeviceCapability failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDeviceCapability] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + if (value != RT_AICPU_BLOCKING_OP_NOT_SUPPORT && value != RT_AICPU_BLOCKING_OP_SUPPORT) { + REPORT_INNER_ERROR("E19999", "Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + GELOGE(FAILED, "[Check][Value] Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + return FAILED; + } + is_support = (value == RT_AICPU_BLOCKING_OP_SUPPORT ? true : false); + return SUCCESS; +} + +Status KernelTaskInfo::UpdateEventIdForAicpuBlockingOp(std::shared_ptr &ext_handle) { + if (is_blocking_aicpu_op_) { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("Device not support blocking aicpu op process"); + return SUCCESS; + } + uint32_t event_id = 0; + if (davinci_model_->GetEventIdForBlockingAicpuOp(op_desc_, stream_, event_id) != SUCCESS) { + GELOGE(FAILED, "[Get][EventId] Get event id failed for op:%s(%s)", op_desc_->GetName().c_str(), + op_desc_->GetType().c_str()); + return FAILED; + } + if (ext_handle->UpdateEventId(event_id) != SUCCESS) { + GELOGE(FAILED, "[Update][EventId] Update event id failed for op:%s(%s)", op_desc_->GetName().c_str(), + op_desc_->GetType().c_str()); + return FAILED; + } + GELOGI("Update event_id=%u success", event_id); + } + return SUCCESS; +} + +Status KernelTaskInfo::DistributeWaitTaskForAicpuBlockingOp() { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("device not support blocking aicpu op process."); + return SUCCESS; + } + GELOGD("Distribute wait task begin"); + rtEvent_t rt_event = nullptr; + if (davinci_model_->GetEventByStream(stream_, rt_event) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Call GetEventByStream failed"); + GELOGE(FAILED, "[Call][GetEventByStream] Call GetEventByStream failed"); + return FAILED; + } + auto rt_ret = rtStreamWaitEvent(stream_, rt_event); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtStreamWaitEvent failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + rt_ret = rtEventReset(rt_event, stream_); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtEventReset failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + return SUCCESS; +} + void KernelTaskInfo::SetIoAddrs(const OpDescPtr &op_desc) { const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); @@ -645,6 +736,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne GE_CHECK_NOTNULL(op_desc); args_addr = std::unique_ptr(new (std::nothrow) uint8_t[args_size_]); + GE_CHECK_NOTNULL(args_addr); errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); if (sec_ret != EOK) { REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret); @@ -1000,6 +1092,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k // copy args to new host memory args_addr = std::unique_ptr(new (std::nothrow) uint8_t[args_size_]); + GE_CHECK_NOTNULL(args_addr); GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); if (sec_ret != EOK) { @@ -1106,7 +1199,7 @@ Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) { UnknowShapeOpType unknown_type = static_cast(unknown_shape_type_val); uint32_t num_inputs = op_desc_->GetInputsSize(); uint32_t num_outputs = op_desc_->GetOutputsSize(); - std::unique_ptr ext_handle( + std::shared_ptr ext_handle( new(std::nothrow) ::ge::hybrid::AicpuExtInfoHandler(op_desc_->GetName(), num_inputs, num_outputs, @@ -1142,6 +1235,16 @@ Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) { j, op_desc_->GetName().c_str()); } } + + AttrUtils::GetBool(op_desc_, ATTR_NAME_IS_BLOCKING_OP, is_blocking_aicpu_op_); + GELOGD("Get op:%s attribute(is_blocking_op), value:%d", op_desc_->GetName().c_str(), is_blocking_aicpu_op_); + + if (UpdateEventIdForAicpuBlockingOp(ext_handle) != SUCCESS) { + GELOGE(FAILED, "[Call][UpdateEventIdForAicpuBlockingOp] failed for op:%s(%s)", + op_desc_->GetName().c_str(), op_desc_->GetType().c_str()); + return FAILED; + } + auto rt_ret = rtMalloc(&aicpu_ext_info_addr_, ext_handle->GetExtInfoLen(), RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { REPORT_CALL_ERROR("E19999", "Call rtMalloc failed for op:%s(%s), size:%zu, ret:0x%X", diff --git a/ge/graph/load/model_manager/task_info/kernel_task_info.h b/ge/graph/load/model_manager/task_info/kernel_task_info.h index d9dd30bb..59a91aee 100644 --- a/ge/graph/load/model_manager/task_info/kernel_task_info.h +++ b/ge/graph/load/model_manager/task_info/kernel_task_info.h @@ -24,6 +24,8 @@ #include "graph/load/model_manager/task_info/task_info.h" #include "graph/op_desc.h" +#include "hybrid/node_executor/aicpu/aicpu_ext_info.h" + namespace ge { class KernelTaskInfo : public TaskInfo { public: @@ -148,6 +150,11 @@ class KernelTaskInfo : public TaskInfo { bool DoubleCallSKTSaveCheck(); void SetArgs(); + // for blocking aicpu op + Status DistributeWaitTaskForAicpuBlockingOp(); + Status CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support); + Status UpdateEventIdForAicpuBlockingOp(std::shared_ptr &ext_handle); + void *stub_func_; void *args_; void *sm_desc_; @@ -187,6 +194,7 @@ class KernelTaskInfo : public TaskInfo { uint32_t skt_dump_flag_ = RT_KERNEL_DEFAULT; void *superkernel_device_args_addr_ = nullptr; void *superkernel_dev_nav_table_ = nullptr; + bool is_blocking_aicpu_op_ = false; struct AICPUCustomInfo { void *input_descs = nullptr; diff --git a/ge/graph/load/model_manager/task_info/memcpy_async_task_info.h b/ge/graph/load/model_manager/task_info/memcpy_async_task_info.h index 728305ff..4ae03967 100755 --- a/ge/graph/load/model_manager/task_info/memcpy_async_task_info.h +++ b/ge/graph/load/model_manager/task_info/memcpy_async_task_info.h @@ -47,7 +47,7 @@ class MemcpyAsyncTaskInfo : public TaskInfo { uint64_t count_; uint32_t kind_; vector io_addrs_; - int64_t fixed_addr_offset_; + int64_t fixed_addr_offset_ = 0; DavinciModel *davinci_model_ = nullptr; uint32_t args_offset_ = 0; }; diff --git a/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc b/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc index 44aac465..b5db845d 100644 --- a/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc +++ b/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "super_kernel.h" +#include "graph/load/model_manager/task_info/super_kernel/super_kernel.h" #include "framework/common/debug/ge_log.h" namespace ge { diff --git a/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc b/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc index 07dc5d19..d1f53cc4 100644 --- a/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc +++ b/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "super_kernel_factory.h" +#include "graph/load/model_manager/task_info/super_kernel/super_kernel_factory.h" #include "framework/common/debug/ge_log.h" namespace ge { diff --git a/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.h b/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.h index c5058b6a..741d1c13 100644 --- a/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.h +++ b/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.h @@ -18,7 +18,7 @@ #define SUPER_KERNEL_FACTORY_H #include -#include "super_kernel.h" +#include "graph/load/model_manager/task_info/super_kernel/super_kernel.h" #include "framework/common/debug/log.h" namespace ge { diff --git a/ge/graph/load/model_manager/tbe_handle_store.cc b/ge/graph/load/model_manager/tbe_handle_store.cc index 36207aa2..d20b1bbf 100755 --- a/ge/graph/load/model_manager/tbe_handle_store.cc +++ b/ge/graph/load/model_manager/tbe_handle_store.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "tbe_handle_store.h" +#include "graph/load/model_manager/tbe_handle_store.h" #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "runtime/kernel.h" diff --git a/ge/graph/load/model_manager/tbe_handle_store.h b/ge/graph/load/model_manager/tbe_handle_store.h index 6c3ad750..ba934fc3 100644 --- a/ge/graph/load/model_manager/tbe_handle_store.h +++ b/ge/graph/load/model_manager/tbe_handle_store.h @@ -25,7 +25,7 @@ #include #include -#include "common/fmk_types.h" +#include "framework/common/fmk_types.h" #include "graph/op_kernel_bin.h" namespace ge { diff --git a/ge/graph/load/model_manager/zero_copy_offset.cc b/ge/graph/load/model_manager/zero_copy_offset.cc index 4a57a899..2a0423c7 100644 --- a/ge/graph/load/model_manager/zero_copy_offset.cc +++ b/ge/graph/load/model_manager/zero_copy_offset.cc @@ -62,7 +62,8 @@ Status ZeroCopyOffset::InitInputDataInfo(int64_t output_size, void *virtual_addr for (size_t index = 0; index < zero_copy_basic_offset_.size(); ++index) { if (zero_copy_basic_offset_.at(index) == virtual_addr_offset) { out_count++; - uint64_t out_offset = reinterpret_cast(virtual_addr) + zero_copy_relative_offset_.at(index); + uint64_t out_offset = static_cast(reinterpret_cast(virtual_addr)) + + zero_copy_relative_offset_.at(index); data_info_.emplace_back(output_size, reinterpret_cast(static_cast(out_offset))); relative_offset_.emplace_back(zero_copy_relative_offset_.at(index)); GELOGI("[ZCPY] virtual_addr: %p has been l2-fusion to %lu, need copy data_size is %ld.", basic_addr_, @@ -117,7 +118,8 @@ Status ZeroCopyOffset::InitOutputDataInfo(const vector &input_size_list for (size_t index = 0; index < zero_copy_basic_offset_.size(); ++index) { if (zero_copy_basic_offset_.at(index) == virtual_addr_offset) { in_count++; - uint64_t in_offset = reinterpret_cast(virtual_addr_list[idx]) + zero_copy_relative_offset_.at(index); + uint64_t in_offset = static_cast(reinterpret_cast(virtual_addr_list[idx])) + + zero_copy_relative_offset_.at(index); int64_t real_data_size = ModelUtils::GetInputSize(op_desc).at(idx); data_info_.emplace_back(real_data_size, reinterpret_cast(static_cast(in_offset))); relative_offset_.emplace_back(zero_copy_relative_offset_.at(index)); diff --git a/ge/graph/load/model_manager/zero_copy_offset.h b/ge/graph/load/model_manager/zero_copy_offset.h index 2dea5666..f3dd07a8 100644 --- a/ge/graph/load/model_manager/zero_copy_offset.h +++ b/ge/graph/load/model_manager/zero_copy_offset.h @@ -29,7 +29,7 @@ #include "graph/utils/attr_utils.h" #include "graph/utils/tensor_utils.h" #include "runtime/mem.h" -#include "task_info/task_info.h" +#include "graph/load/model_manager/task_info/task_info.h" using std::map; using std::set; diff --git a/ge/graph/load/model_manager/zero_copy_task.cc b/ge/graph/load/model_manager/zero_copy_task.cc index 4957f8ea..85be6d7b 100755 --- a/ge/graph/load/model_manager/zero_copy_task.cc +++ b/ge/graph/load/model_manager/zero_copy_task.cc @@ -19,7 +19,7 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" #include "graph/load/model_manager/model_utils.h" -#include "common/ge_compiler_options.h" +#include "framework/common/ge_compiler_options.h" namespace ge { ZeroCopyTask::ZeroCopyTask(const string &name, uint8_t *args, size_t size) diff --git a/ge/graph/manager/graph_caching_allocator.cc b/ge/graph/manager/graph_caching_allocator.cc index 82bfbda9..7b316fc3 100644 --- a/ge/graph/manager/graph_caching_allocator.cc +++ b/ge/graph/manager/graph_caching_allocator.cc @@ -20,7 +20,6 @@ #include #include -#include "framework/common/debug/ge_log.h" #include "graph/manager/graph_mem_manager.h" namespace ge { @@ -94,7 +93,8 @@ void IncreaseCount(std::map &count, size_t size) { } } -CachingAllocator::CachingAllocator(rtMemType_t memory_type) : memory_type_(memory_type), memory_allocator_(nullptr) { +CachingAllocator::CachingAllocator(rtMemType_t memory_type) + : memory_type_(memory_type), memory_allocator_(nullptr), called_malloc_counts_(0), called_free_counts_(0) { for (uint32_t i = 0; i < kNumBins; i++) { free_block_bins_[i] = nullptr; } @@ -121,6 +121,8 @@ Status CachingAllocator::Initialize(uint32_t device_id) { if (memory_allocator_ == nullptr) { return ACL_ERROR_GE_INTERNAL_ERROR; } + called_malloc_counts_ = 0; + called_free_counts_ = 0; return ge::SUCCESS; } @@ -133,6 +135,7 @@ void CachingAllocator::Finalize(uint32_t device_id) { uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device_id) { GELOGI("Start malloc pool memory, size = %zu, device id = %u", size, device_id); + called_malloc_counts_++; size = GetBlockSize(size); uint8_t *ptr = nullptr; Block *block = FindFreeBlock(size, org_ptr, device_id); @@ -156,6 +159,7 @@ uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device Status CachingAllocator::Free(uint8_t *ptr, uint32_t device_id) { GELOGI("Free device id = %u", device_id); + called_free_counts_++; if (ptr == nullptr) { REPORT_INNER_ERROR("E19999", "Param ptr is nullptr, device_id:%u, check invalid", device_id); GELOGE(PARAM_INVALID, "[Check][Param] Invalid memory pointer, device_id:%u", device_id); @@ -283,6 +287,7 @@ Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { if (memory_addr == nullptr) { GELOGE(ge::FAILED, "[Malloc][Memory] failed, no enough memory for size = %zu, device_id = %u", memory_size, device_id); + PrintStatics(DLOG_ERROR); return ge::FAILED; } GELOGT(TRACE_RUNNING, "Try to free cached memory size:%zu and malloc memory size:%zu success.", @@ -385,14 +390,14 @@ void CachingAllocator::FreeBlockBins() { } void PrintCount(std::map &count, const std::string &name, size_t total_size, size_t total_count) { - GELOGI("%6s total[size:%10zu count:%10zu].", name.c_str(), total_size, total_count); + GEEVENT("%6s total[size:%11zu count:%11zu].", name.c_str(), total_size, total_count); for (auto &it : count) { - GELOGI(" |- block[size:%10zu count:%10zu].", it.first, it.second); + GEEVENT(" |- block[size:%11zu count:%11zu].", it.first, it.second); } } -void CachingAllocator::PrintStatics() { - if (!IsLogEnable(GE_MODULE_NAME, DLOG_INFO)) { +void CachingAllocator::PrintStatics(int32_t level) { + if (!IsLogEnable(GE_MODULE_NAME, level)) { return; } size_t total_using_size = 0; @@ -435,6 +440,7 @@ void CachingAllocator::PrintStatics() { } } while (0); + GEEVENT("Called counts[malloc:%11zu free:%11zu].", called_malloc_counts_.load(), called_free_counts_.load()); PrintCount(malloc_block_stat, "Malloc", total_malloc_size, total_malloc_count); PrintCount(using_block_stat, "Using", total_using_size, total_using_count); PrintCount(free_block_stat, "Free", total_free_size, total_free_count); diff --git a/ge/graph/manager/graph_caching_allocator.h b/ge/graph/manager/graph_caching_allocator.h index 2db00ff2..d00858f3 100644 --- a/ge/graph/manager/graph_caching_allocator.h +++ b/ge/graph/manager/graph_caching_allocator.h @@ -27,6 +27,7 @@ #include #include +#include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/node.h" #include "graph/manager/block_memory.h" @@ -192,9 +193,10 @@ class CachingAllocator { /// /// @ingroup ge_graph /// @brief print the memory info in pool + /// @param [in] log level /// @return void /// - void PrintStatics(); + void PrintStatics(int32_t level = DLOG_INFO); private: rtMemType_t memory_type_; @@ -213,6 +215,12 @@ class CachingAllocator { // malloced memorys from device std::map malloced_memory_; + + //user call Malloc total counts + std::atomic called_malloc_counts_; + + //user call Free total counts + std::atomic called_free_counts_; }; } // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index bf04ed58..fa140bfe 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -27,10 +27,11 @@ #include "common/math/math_util.h" #include "common/thread_pool.h" #include "common/dump/dump_manager.h" +#include "ge_opt_info/ge_opt_info.h" #include "analyzer/analyzer.h" -#include "graph/common/ge_call_wrapper.h" -#include "graph/common/local_context.h" -#include "graph/common/transop_util.h" +#include "common/ge_call_wrapper.h" +#include "common/local_context.h" +#include "common/transop_util.h" #include "graph/ge_context.h" #include "graph/ge_global_options.h" #include "graph/manager/util/rt_context_util.h" @@ -102,12 +103,13 @@ #include "inc/pass_manager.h" #include "init/gelib.h" #include "ir_build/option_utils.h" -#include "graph/common/local_context.h" -#include "graph/common/omg_util.h" +#include "common/local_context.h" +#include "common/omg_util.h" #include "common/formats/utils/formats_trans_utils.h" #include "register/custom_pass_helper.h" #include "external/graph/types.h" #include "common/util/error_manager/error_manager.h" +#include "common/profiling/profiling_manager.h" namespace { const char *const kSummary = "Summary"; @@ -122,13 +124,12 @@ const char *const kVectorEngine = "VectorEngine"; const char *const kAIcoreEngine = "AIcoreEngine"; const int32_t kDynamicDimsTypeIsGetNext = 0; const int32_t kDynamicDimsTypeIsData = 1; +const int32_t kBase = 10; const char *const kGetNextName = "IteratorV2"; const uint32_t kInitGraphCount = 1; const uint32_t kNotAdded = 0; const uint32_t kStartAdd = 1; const uint32_t kDoneAdded = 2; -const uint32_t kNeverLoaded = 0; -const size_t kAlignment = 64; bool IsTailingOptimization() { string is_tailing_optimization_option; @@ -162,26 +163,12 @@ ge::Status CheckFpCeilingMode() { } // namespace namespace ge { -GraphManager::GraphManager() - : thread_run_flag_(false), - graph_run_listener_(nullptr), - init_flag_(false) { -} - -Status GraphManager::Initialize(const std::map &options) { +Status GraphManager::Initialize(const std::map &options, Executor *executor) { ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); if (init_flag_) { GELOGW("[Initialize] GraphManager already initialized."); return SUCCESS; } - - // malloc - graph_run_listener_ = MakeShared(sync_run_mutex_, condition_); - if (graph_run_listener_ == nullptr) { - REPORT_CALL_ERROR("E19999", "New GraphModelListener fail"); - GELOGE(MEMALLOC_FAILED, "[New][GraphModelListener] failed"); - return MEMALLOC_FAILED; - } // graph context graph_context_ = MakeShared(); if (graph_context_ == nullptr) { @@ -209,31 +196,18 @@ Status GraphManager::Initialize(const std::map &options) { return ret; } - graph_map_.clear(); - cache_helper_map_.clear(); - graph_id_to_add_graph_cond_.clear(); - graph_count_.clear(); + executor_ = executor; init_flag_ = true; thread_run_flag_ = true; - prerun_thread_ = std::thread(GraphManager::PreRunThread, this); - run_thread_ = std::thread(GraphManager::RunThread, this); + prerun_thread_ = std::thread(&GraphManager::PreRunThread, this); return SUCCESS; } Status GraphManager::UnloadModel(GeRootModelPtr ge_root_model, uint32_t graph_id) { - Status ret = SUCCESS; - for (size_t i = 0; i < ge_root_model->GetAllModelId().size(); ++i) { - uint32_t model_id = ge_root_model->GetAllModelId()[i]; - GELOGI("Unload model %u.", model_id); - ret = GraphLoader::UnloadModel(model_id); - if (ret != SUCCESS) { - GELOGW("[GraphManager] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); - return ret; - } - } - return ret; + GE_CHECK_NOTNULL(executor_); + return executor_->UnloadGraph(ge_root_model, graph_id); } Status GraphManager::Finalize() { @@ -242,23 +216,13 @@ Status GraphManager::Finalize() { return SUCCESS; } - if (graph_executor_.FreeExecuteMemory() != SUCCESS) { - GELOGW("Graph executor FreeExecuteMemory failed, resources may not be released correctly."); - } - - StopQueue(this); - + StopQueue(); if (prerun_thread_.joinable()) { prerun_thread_.join(); } - if (run_thread_.joinable()) { - run_thread_.join(); - } // check graph whether running or not Status unload_model_ret = SUCCESS; - Status ret; - rtError_t rt_ret; for (auto iter = graph_map_.begin(); iter != graph_map_.end(); ++iter) { GraphNodePtr graph_node = iter->second; if (graph_node->GetRunFlag()) { @@ -269,22 +233,10 @@ Status GraphManager::Finalize() { // unload model auto ge_root_model = graph_node->GetGeRootModel(); if (ge_root_model != nullptr && ge_root_model->GetModelId() != INVALID_MODEL_ID && graph_node->GetLoadFlag()) { - rt_ret = rtSetDevice(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - GELOGW("[GraphManager] rtSetDevice failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), iter->first); - unload_model_ret = FAILED; - continue; - } - ret = UnloadModel(ge_root_model, iter->first); + Status ret = UnloadModel(ge_root_model, iter->first); if (ret != SUCCESS) { - GELOGW("[GraphManager] unload model failed, graph_id=%u.", iter->first); unload_model_ret = ret; - } - rt_ret = rtDeviceReset(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - GELOGW("[GraphManager] rtDeviceReset failed, graphId=%u.", iter->first); - unload_model_ret = FAILED; - continue; + GELOGW("[GraphManager] unload model failed, graph_id=%u.", iter->first); } } @@ -296,7 +248,6 @@ Status GraphManager::Finalize() { Analyzer::GetInstance()->DestroyGraphJsonObject(session_id, graph_id); } graph_map_.clear(); - cache_helper_map_.clear(); graph_count_.clear(); // graph context @@ -511,6 +462,9 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, const std::map &options, const OmgContext &omg_context) { IncreaseGraphCount(graph_id); + auto device_id = GetContext().DeviceId(); + GELOGD("Device id is %u", device_id); + ProfilingManager::Instance().SetGraphIdToDeviceMap(graph_id, device_id); // validation for adding graphs of same graph_id in multi-thread secenario // 1.previous thread owns same graph_id has finished the AddGraph procession if (GetAddGraphCondition(graph_id) == kDoneAdded) { @@ -949,7 +903,7 @@ Status GraphManager::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint rtError_t rt_ret = rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtCtxCreate faileded, session_id:%lu, graph_id:%u, mode:%d", + REPORT_CALL_ERROR("E19999", "Call rtCtxCreate failed, session_id:%lu, graph_id:%u, mode:%d", session_id, graph_id, mode); GELOGE(FAILED, "[Call][RtCtxCreate] faileded, session_id:%lu, graph_id:%u, mode:%d", session_id, graph_id, mode); return FAILED; @@ -1001,6 +955,12 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorGetGraphId(), compute_graph, ge_model); - if (save_ret != SUCCESS) { - GELOGW("Fail to save cache."); - } GEEVENT("[GEPERFTRACE] GE PreRun End"); return SUCCESS; } @@ -1102,24 +1055,16 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: graph_node->GetGraphId()); return PARAM_INVALID; } - GeModelPtr ge_model = nullptr; - // check need incre build. - ret = IncreBuild(graph_node, ge_model); + + ret = PreRun(graph_node, inputs, ge_root_model, session_id); + // release rts generate context + RtContextUtil::GetInstance().DestroyRtContexts(session_id, graph_node->GetGraphId()); if (ret != SUCCESS) { - ret = PreRun(graph_node, inputs, ge_root_model, session_id); - // release rts generate context - RtContextUtil::GetInstance().DestroyRtContexts(session_id, graph_node->GetGraphId()); - if (ret != SUCCESS) { - GELOGE(ret, "[Call][PreRun] Failed, graph_id:%u, session_id:%lu.", graph_node->GetGraphId(), session_id); - return ret; - } - } - ErrorManager::GetInstance().SetStage(error_message::kModelLoad, error_message::kModelLoad); - if (!graph_node->IsAsync()) { - ret = LoadGraph(ge_root_model, graph_node); - } else { - ret = LoadGraphAsync(ge_root_model, graph_node); + GELOGE(ret, "[Call][PreRun] Failed, graph_id:%u, session_id:%lu.", graph_node->GetGraphId(), session_id); + return ret; } + + ret = LoadGraph(ge_root_model, graph_node); if (ret != SUCCESS) { GELOGE(ret, "[Load][Graph] Failed, graph_id:%u.", graph_node->GetGraphId()); return ret; @@ -1127,13 +1072,8 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: graph_node->SetBuildFlag(true); var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); } else if (!graph_node->GetLoadFlag()) { - ErrorManager::GetInstance().SetStage(error_message::kModelLoad, error_message::kModelLoad); GeRootModelPtr ge_root_model_ptr = graph_node->GetGeRootModel(); - if (!graph_node->IsAsync()) { - ret = LoadGraph(ge_root_model_ptr, graph_node); - } else { - ret = LoadGraphAsync(ge_root_model_ptr, graph_node); - } + ret = LoadGraph(ge_root_model, graph_node); if (ret != SUCCESS) { GELOGE(ret, "[Load][Graph] Failed, graph_id:%u.", graph_node->GetGraphId()); return ret; @@ -1141,168 +1081,28 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: } return ret; } + Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { GELOGI("[LoadGraph] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); - if (options_.run_graph_flag && ge_root_model != nullptr) { - ge_root_model->SetTrainFlag(GetTrainFlag()); - // synchronization run graph with model - std::shared_ptr model_listener = GetModelListener(); - ModelIdInfo model_id_info; - bool is_unknown_shape = false; - GE_CHK_STATUS_RET(ge_root_model->CheckIsUnknownShape(is_unknown_shape)); - if (!is_unknown_shape) { - if (getenv(kEnvGeuseStaticMemory) != nullptr) { - GELOGI("[LoadGraph] GE_USE_STATIC_MEMORY is seted."); - } else { - auto root_graph = ge_root_model->GetRootGraph(); - GE_CHECK_NOTNULL(root_graph); - auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); - GeModelPtr ge_model = name_to_model[root_graph->GetName()]; - GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); - } - } - ge_root_model->SetIsSpecificStream(graph_node->IsSpecificStream()); - GE_TIMESTAMP_START(LoadGraph); - Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, model_listener); - GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraph"); - if (ret != SUCCESS) { - GELOGE(ret, "[Load][Model] failed, ret:%d", ret); - graph_node->SetRunFlag(false); - return ret; - } - graph_node->SetLoadFlag(true); - ge_root_model->SetModelId(model_id_info.model_id); - graph_node->SetGeRootModel(ge_root_model); - } - return SUCCESS; -} - -Status GraphManager::LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, - GeModelPtr &ge_model) { - auto graph_id = graph_node->GetGraphId(); - auto ret = cache_helper->LoadOmModelFromCache(ge_model); - if (ret != SUCCESS) { - GELOGW("Fail to load om model from cache."); - if (cache_helper->ClearCache(graph_id) != SUCCESS) { - GELOGW("Fail to clear cache of graph %u.", graph_id); - } - return FAILED; - } - ret = cache_helper->RecoverVarManagerFromCache(); - if (ret != SUCCESS) { - GELOGW("Fail to recover VarManager from cache."); - if (cache_helper->ClearCache(graph_id) != SUCCESS) { - GELOGW("Fail to clear cache of graph %u.", graph_id); - } - return FAILED; - } - ComputeGraphPtr compute_graph_in_model = GraphUtils::GetComputeGraph(ge_model->GetGraph()); - if (compute_graph_in_model == nullptr) { - GELOGW("Error occurred when get compute graph from om, abandon."); - return FAILED; - } else { - graph_node->SetComputeGraph(compute_graph_in_model); - graph_node->SetGeModel(ge_model); - GELOGI("Load model and graph form cache om file."); - } - return SUCCESS; -} - -Status GraphManager::SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper) { - auto ret = cache_helper->SaveCacheInfoToCache(); - if (ret != SUCCESS) { - GELOGW("Fail to save cache info of graph[%d] to cache.", graph_id); - return FAILED; - } - ret = cache_helper->SaveVarManagerToCache(true); - if (ret != SUCCESS) { - GELOGW("Fail to save var manager to cache."); - cache_helper->ClearCache(graph_id); - return FAILED; - } - GELOGI("Cache files have been saved."); - return SUCCESS; -} - -Status GraphManager::SaveCacheAfterBuild(uint32_t graph_id, ge::ComputeGraphPtr graph, GeModelPtr &ge_model) { - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if ((instance_ptr == nullptr) || !instance_ptr->InitFlag()) { - GELOGW("GELib not initialized."); - return FAILED; + if (!options_.run_graph_flag) { + return SUCCESS; } - if (instance_ptr->IsIncreBuild()) { - std::lock_guard lock(member_mutex_); - auto iter = cache_helper_map_.find(graph_id); - if (iter == cache_helper_map_.end()) { - GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); - return FAILED; - } else { - ModelCacheHelperPtr cache_helper = iter->second; - auto ret = cache_helper->RefreshComputeGraph(graph); - if (ret != SUCCESS) { - cache_helper->ClearCache(graph_id); - GELOGW("Fail to refresh cache helper's compute graph"); - return FAILED; - } - ret = cache_helper->SaveVarManagerToCache(false); - if (ret != SUCCESS) { - cache_helper->ClearCache(graph_id); - GELOGW("Fail to save VarManager to cache"); - return FAILED; - } - ret = cache_helper->SaveOmModelToCache(ge_model); - if (ret != SUCCESS) { - cache_helper->ClearCache(graph_id); - GELOGW("Fail to save om model to cache"); - return FAILED; - } - } - } - return SUCCESS; + ErrorManager::GetInstance().SetStage(error_message::kModelLoad, error_message::kModelLoad); + GE_CHECK_NOTNULL(executor_); + return executor_->LoadGraph(ge_root_model, graph_node); } Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, const std::vector &inputs, std::vector &outputs) { - Status ret = graph_executor_.SetCondition(&sync_run_mutex_, &condition_, graph_run_listener_); - if (ret != SUCCESS) { - GELOGE(GE_GRAPH_RUNGRAPH_FAILED, "[Set][Condition] failed, graph_id = %u.", graph_id); - graph_node->SetRunFlag(false); - return GE_GRAPH_RUNGRAPH_FAILED; - } - - if (GetTrainFlag()) { - GE_CHK_STATUS_RET(graph_executor_.SetGraphContext(GetGraphContext())); - graph_executor_.SetTrainFlag(options_.train_graph_flag); - } - ret = graph_executor_.ExecuteGraph(graph_id, graph_node->GetGeRootModel(), inputs, outputs); - - graph_node->SetRunFlag(false); - if (ret != SUCCESS) { - GELOGE(ret, "[Execute][Graph] failed, graph_id = %u.", graph_id); - return ret; - } - return SUCCESS; + GE_CHECK_NOTNULL(executor_); + return executor_->RunGraph(graph_node, graph_id, inputs, outputs); } Status GraphManager::InnerRunGraphWithStream(GraphNodePtr &graph_node, const GraphId &graph_id, rtStream_t stream, const std::vector &inputs, std::vector &outputs) { - auto ret = graph_executor_.SetCondition(&sync_run_mutex_, &condition_, graph_run_listener_); - if (ret != SUCCESS) { - GELOGE(GE_GRAPH_RUNGRAPH_FAILED, "[Set][Condition] failed, graph id = %u, stream = %p.", graph_id, stream); - graph_node->SetRunFlag(false); - return GE_GRAPH_RUNGRAPH_FAILED; - } - - ret = graph_executor_.ExecuteGraphWithStream(graph_id, stream, graph_node->GetGeRootModel(), inputs, outputs); - graph_node->SetRunFlag(false); - graph_node->SetIsSpecificStream(false); - if (ret != SUCCESS) { - GELOGE(ret, "[Execute][Graph] With Stream failed, graph id = %u, stream = %p.", graph_id, stream); - return ret; - } - GELOGI("[Run][GraphWithStreamAsync] run graph success, graph id = %u, stream = %p.", graph_id, stream); - return SUCCESS; + GE_CHECK_NOTNULL(executor_); + return executor_->RunGraphWithStream(graph_node, graph_id, stream, inputs, outputs); } Status GraphManager::RunGraphWithStreamAsync(const GraphId &graph_id, rtStream_t stream, uint64_t session_id, @@ -1343,8 +1143,6 @@ Status GraphManager::RunGraphWithStreamAsync(const GraphId &graph_id, rtStream_t graph_node->SetIsSpecificStream(true); ComputeGraphPtr compute_graph_tmp = GraphUtils::GetComputeGraph(*(graph_node->GetGraph())); - // when set incre build, add cache helper map - AddModelCacheHelperToMap(graph_id, session_id, compute_graph_tmp); if (options_.local_fmk_op_flag) { GetCompilerStages(graph_id).optimizer.TranFrameOp(compute_graph_tmp); } @@ -1403,9 +1201,6 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vectorSetRunFlag(false); if (ret != SUCCESS) { - GELOGE(GE_GRAPH_PRERUN_FAILED, "[Call][StartForRunGraph] failed! graph_id:%u.", graph_id); - return GE_GRAPH_PRERUN_FAILED; + GELOGE(ret, "[Call][StartForRunGraph] failed! graph_id:%u.", graph_id); + return ret; } GELOGI("[BuildGraph] build graph success, graph_id=%u.", graph_id); @@ -1622,16 +1417,6 @@ Status GraphManager::SaveParams(ge::GeModel &model, const std::string &type, con return SUCCESS; } -void GraphManager::RemoveModelCacheHelper(const GraphId &graph_id) { - std::lock_guard lock(member_mutex_); - auto iter = cache_helper_map_.find(graph_id); - if (iter != cache_helper_map_.end()) { - cache_helper_map_.erase(iter); - } else { - GELOGW("[GraphManager] cache helper does not exist, graph_id = %u", graph_id); - } -} - bool GraphManager::CheckModelLoad(const GeRootModelPtr &ge_root_model, bool load_flag) { return ((ge_root_model != nullptr) && (ge_root_model->GetModelId() != INVALID_MODEL_ID) && load_flag); } @@ -1657,38 +1442,17 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { std::lock_guard lock(unload_model_mutex_); - Status middle_ret; - rtError_t rt_ret; var_acc_ctrl_.RemoveGraph(graph_id); RemoveGraphNode(graph_id); - RemoveModelCacheHelper(graph_id); - auto ge_root_model = graph_node->GetGeRootModel(); if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) { - rt_ret = rtSetDevice(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, graph_id:%u", - GetContext().DeviceId(), graph_id); - GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, modelId=%u, graphId=%u.", ge_root_model->GetModelId(), - graph_id); - return FAILED; - } - // same graph may be added for several times, different models were created separately, - // unload them respectively. - middle_ret = UnloadModel(ge_root_model, graph_id); + Status middle_ret = UnloadModel(ge_root_model, graph_id); if (middle_ret != SUCCESS) { REPORT_INNER_ERROR("E19999", "UnloadModel for graph:%u failed, check invalid", graph_id); GELOGE(middle_ret, "[Unload][Model] model failed, graph_id=%u.", graph_id); ret = middle_ret; } - rt_ret = rtDeviceReset(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, graph_id:%u", - GetContext().DeviceId(), graph_id); - GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, graph_id:%u", GetContext().DeviceId(), graph_id); - ret = FAILED; - } } RemoveCompilerStages(graph_id); @@ -1788,7 +1552,7 @@ Status GraphManager::ParseOptions(const std::map &opti return GE_GRAPH_OPTIONS_INVALID); // ge.graphType - ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); + ret = ParseTrainGraphFlag(options_.train_graph_flag); GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid"); return GE_GRAPH_OPTIONS_INVALID); @@ -1833,19 +1597,18 @@ Status GraphManager::ParseOptions(const std::map &opti return SUCCESS; } -Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag) { - std::shared_ptr ge_instance_ptr = ge::GELib::GetInstance(); - if (ge_instance_ptr == nullptr) { - GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); - train_flag = false; - } else if (!ge_instance_ptr->isTrainMode()) { - train_flag = false; - } else { // ge_instance_ptr->isTrainMode() is true - train_flag = true; - if (!run_flag) { - GELOGW("Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); +// OPTION_GRAPH_RUN_MODE is supposed to be a session-level option, but it used to be set to global-level in the past. +// If can not parse from session, it can parse from global by GetContext(). +Status GraphManager::ParseTrainGraphFlag(bool &train_flag) { + train_flag = false; + string run_mode; + if (GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == SUCCESS && !run_mode.empty()) { + if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, kBase)) >= TRAIN) { + train_flag = true; } } + domi::GetContext().train_flag = train_flag; + GELOGI("Is train flag: %d.", train_flag); return SUCCESS; } @@ -2114,8 +1877,6 @@ Status GraphManager::SummaryHandle(const GraphId &graph_id, std::vector &outputs) { GELOGI("[GraphManager] CheckpointHandle, outputsSize=%zu.", outputs.size()); - std::vector outputs_desc = graph_executor_.GetOutputsDesc(); - GELOGI("[GraphManager] CheckpointHandle, outputsDescSize=%zu.", outputs_desc.size()); std::map save_results; NodePtr netoutput = nullptr; @@ -2780,160 +2541,6 @@ void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_gr } } -Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { - GELOGI("[LoadGraphAsync] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); - if (options_.run_graph_flag && ge_root_model != nullptr) { - ge_root_model->SetTrainFlag(GetTrainFlag()); - // synchronization run graph with model - ModelIdInfo model_id_info; - bool is_unknown_shape = false; - GE_CHK_STATUS_RET(ge_root_model->CheckIsUnknownShape(is_unknown_shape)); - if (!is_unknown_shape) { - if (getenv(kEnvGeuseStaticMemory) != nullptr) { - GELOGI("[LoadGraphAsync] GE_USE_STATIC_MEMORY is seted."); - } else { - auto root_graph = ge_root_model->GetRootGraph(); - GE_CHECK_NOTNULL(root_graph); - auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); - GeModelPtr ge_model = name_to_model[root_graph->GetName()]; - GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); - } - } - GE_TIMESTAMP_START(LoadGraph); - auto listener = MakeShared(); - GE_CHECK_NOTNULL(listener); - Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, listener); - GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraphAsync"); - if (ret != SUCCESS) { - GELOGE(ret, "[Load][ModelOnline] Failed, model_id:%u", model_id_info.model_id); - graph_node->SetRunFlag(false); - return ret; - } - graph_node->SetLoadFlag(true); - ge_root_model->SetModelId(model_id_info.model_id); - graph_node->SetGeRootModel(ge_root_model); - } - return SUCCESS; -} - -void GraphManager::ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node, - const std::vector &model_ids, uint32_t graph_id, uint64_t session_id) { - rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u", GetContext().DeviceId()); - GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id=%u.", GetContext().DeviceId()); - return; - } - for (auto model_id : model_ids) { - uint64_t max_memory_size = 0; - Status result = GraphLoader::GetMaxUsedMemory(model_id, max_memory_size); - if (result != SUCCESS) { - continue; - } - GELOGI("CheckAndReleaseMemory try to UnloadGraph[%u], model[%u] which MaxUsedMemory[%lu].", graph_id, model_id, - max_memory_size); - if (model_ids.size() > 1) { - result = ge_model->GetSessionId(model_id, session_id); - if (result != SUCCESS) { - GELOGW("[GraphManager:] get session failed when dynamic memory, modelId=%u, graphId=%u.", model_id, - graph_id); - continue; - } - } - result = GraphLoader::DestroyAicpuKernel(session_id, model_id, 0); - if (result != SUCCESS) { - GELOGW("[GraphManager:] destroy aicpu kernel failed when dynamic memory, modelId=%u, graphId=%u.", model_id, - graph_id); - } - result = GraphLoader::UnloadModel(model_id); - if (result != SUCCESS) { - GELOGW("[GraphManager:] unload model failed, modelId=%u, graphId=%u.", model_id, graph_id); - } - GELOGI("CheckAndReleaseMemory UnloadGraph[%u], model[%u] success.", graph_id, model_id); - } - graph_node->SetLoadFlag(false); - // Allow model to be loaded agagin without adding graph again - graph_node->SetLoadCount(graph_node->GetLoadRecord()); - graph_node->SetLoadRecord(kNeverLoaded); - GeRootModelPtr ge_root_model = graph_node->GetGeRootModel(); - if (ge_root_model == nullptr) { - GELOGW("ge_root_model is null, graph_id:%u", graph_id); - return; - } - ge_root_model->ClearAllModelId(); - rt_ret = rtDeviceReset(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u", GetContext().DeviceId()); - GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u.", GetContext().DeviceId()); - return; - } -} - -Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) { - GELOGI("CheckAndReleaseMemory graph_id[%u]", graph_node->GetGraphId()); - int64_t value = 0; - bool ret = ge::AttrUtils::GetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, value); - int64_t memory_size = ret ? value : 0; - ret = ge::AttrUtils::GetInt(ge_model, ATTR_MODEL_WEIGHT_SIZE, value); - int64_t weight_size = ret ? value : 0; - ret = ge::AttrUtils::GetInt(ge_model, MODEL_ATTR_SESSION_ID, value); - uint64_t session_id = ret ? value : 0; - - int64_t free_memory = 0; - Status result = GraphLoader::GetMemoryInfo(free_memory); - if (result != SUCCESS) { - return result; - } - - GELOGI( - "CheckAndReleaseMemory Graph[%u] need memory_size[%ld], weight_size[%ld]," - " Device[%u] free_memory_size[%ld]", - graph_node->GetGraphId(), memory_size, weight_size, GetContext().DeviceId(), free_memory); - if (ge::CheckInt64AddOverflow(memory_size, weight_size) != SUCCESS) { - REPORT_INNER_ERROR("E19999", "memory_size:%ld and weight_size:%ld will overflow after add, check invalid", - memory_size, weight_size); - GELOGE(INTERNAL_ERROR, "[Check][Param] memory_size:%ld and weight_size:%ld will overflow after add", - memory_size, weight_size); - return INTERNAL_ERROR; - } - if (free_memory >= (memory_size + weight_size)) { - return SUCCESS; - } - - std::lock_guard lock(unload_model_mutex_); - - std::map graph_map; - { - std::lock_guard lock(member_mutex_); - graph_map = graph_map_; - } - - for (auto &it : graph_map) { - auto graph_id = it.second->GetGraphId(); - auto model = it.second->GetGeRootModel(); - if (model == nullptr) { - continue; - } - auto model_id = model->GetModelId(); - auto model_ids = model->GetAllModelId(); - // unload model not release - bool is_unknown_shape = false; - GE_CHK_STATUS_RET(model->CheckIsUnknownShape(is_unknown_shape)); - if (is_unknown_shape) { - GELOGD("model_id[%u] graph_id[%u] is unknown model, not release memory", model_id, graph_id); - continue; - } - // not loaded,no need unload - if (!it.second->GetLoadFlag()) { - GELOGI("CheckAndReleaseMemory graph[%u] has not been loaded.", graph_id); - continue; - } - ReleaseMemory(ge_model, it.second, model_ids, graph_id, session_id); - } - - return SUCCESS; -} - Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, const SubGraphInfoPtr &sub_graph_info_ptr, const std::string &root_graph_name, @@ -3008,135 +2615,76 @@ Status GraphManager::RunGraphAsync(const GraphId &graph_id, const std::vector instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) { - std::lock_guard lock(member_mutex_); - auto iter = cache_helper_map_.find(graph_id); - if (iter == cache_helper_map_.end()) { - ModelCacheHelperPtr cache_helper = MakeShared(session_id, graph_id, compute_graph); - if (cache_helper != nullptr) { - cache_helper_map_.emplace(std::make_pair(graph_id, cache_helper)); - } else { - GELOGW("Cache helper make shared failed, graph_id = %u.", graph_id); - } - } - } -} - -ModelCacheHelperPtr GraphManager::FindModelCacheHelper(GraphId graph_id) { - std::lock_guard lock(member_mutex_); - auto iter = cache_helper_map_.find(graph_id); - if (iter != cache_helper_map_.end()) { - return iter->second; - } - - return nullptr; -} - -Status GraphManager::IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model) { - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->IsIncreBuild()) { - return FAILED; - } - const uint32_t graph_id = graph_node->GetGraphId(); - ModelCacheHelperPtr cache_helper = FindModelCacheHelper(graph_id); - if (cache_helper == nullptr) { - GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); - return FAILED; - } - if (cache_helper->IsModelCacheHit()) { - GEEVENT("Model cache hit."); - Status ret = LoadFromCache(graph_node, cache_helper, ge_model); - if (ret == SUCCESS) { - return SUCCESS; - } else { - GELOGW("Error occurred when load from cache, abandon."); - } - } else { - GEEVENT("Model cache miss."); - } - if (SaveCacheBeforeBuild(graph_node->GetGraphId(), cache_helper) != SUCCESS) { - GELOGW("Error occurred when save cache."); - } - return FAILED; -} - -Status GraphManager::CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args, +Status GraphManager::CheckIncreBuildAndPreRun(const PreRunArgs &args, GraphNodePtr &graph_node, GeRootModelPtr &ge_root_model) { - if (!graph_manager->IsGraphNeedBuild(graph_node)) { + if (!IsGraphNeedBuild(graph_node)) { ge_root_model = graph_node->GetGeRootModel(); return SUCCESS; } if (graph_node->GetBuildFlag()) { - ReturnError(graph_manager, args.callback, PARAM_INVALID, + ReturnError(args.callback, PARAM_INVALID, "The graph " + std::to_string(graph_node->GetGraphId()) + " need to re-build, you should remove it" " from GE first, then AddGraph again and rebuild it."); return PARAM_INVALID; } // check need incre build. - GeModelPtr ge_model = nullptr; - if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { - std::vector ge_inputs; - for (const auto &item: args.input_tensor) { - ge_inputs.emplace_back(TensorAdapter::AsGeTensor(item)); - } - Status ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); - // release rts generate context - RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); - if (ret != SUCCESS) { - ReturnError(graph_manager, args.callback, ret, "PreRun Failed."); - return ret; - } + std::vector ge_inputs; + for (const auto &item: args.input_tensor) { + ge_inputs.emplace_back(TensorAdapter::AsGeTensor(item)); + } + Status ret = PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); + // release rts generate context + RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); + if (ret != SUCCESS) { + ReturnError(args.callback, ret, "PreRun Failed."); + return ret; } + graph_node->SetBuildFlag(true); - graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); + var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); return SUCCESS; } -void GraphManager::PreRunThread(GraphManager *graph_manager) { +void GraphManager::PreRunThread() { if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { GELOGW("Set thread name failed."); } PreRunArgs args; - while (graph_manager->thread_run_flag_) { - bool pop_status = graph_manager->prerun_args_q_.Pop(args); - if (!pop_status) { + while (thread_run_flag_) { + if (!prerun_args_q_.Pop(args)) { continue; } GELOGI("[PreRunThread] A new loop start, graph_id:%u.", args.graph_id); - ErrorManager::GetInstance().SetErrorContext(args.error_context); ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); GetContext().SetSessionId(args.session_id); GetThreadLocalContext() = args.context; - graph_manager->UpdateLocalOmgContext(args.graph_id); + UpdateLocalOmgContext(args.graph_id); // find graph GraphNodePtr graph_node = nullptr; - Status ret = graph_manager->GetGraphNode(args.graph_id, graph_node); + Status ret = GetGraphNode(args.graph_id, graph_node); if (ret != SUCCESS) { - ReturnError(graph_manager, args.callback, GE_GRAPH_GRAPH_NODE_NULL, + ReturnError(args.callback, GE_GRAPH_GRAPH_NODE_NULL, "[RunGraph] graph not exist, graph_id=" + std::to_string(args.graph_id)); return; } // more than one graph owns same graph_id uint32_t count = 0; - if (graph_manager->GetGraphCount(args.graph_id, count) != SUCCESS) { + if (GetGraphCount(args.graph_id, count) != SUCCESS) { GELOGE(INTERNAL_ERROR, "[Get][GraphCount] failed, graph id:%u.", args.graph_id); return; } // Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency if (count > 1 && graph_node->GetBuildFlag()) { - graph_node->Lock(); GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); // In online inference concurrency senario, graph_node is allowed to be locked for 'count' times graph_node->SetSemSize(count); - graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, + graph_node->Lock(); + PushGraph(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); continue; @@ -3145,7 +2693,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { graph_node->Lock(); if (graph_node->GetRunFlag()) { - ReturnError(graph_manager, args.callback, GE_GRAPH_ALREADY_RUNNING, + ReturnError(args.callback, GE_GRAPH_ALREADY_RUNNING, "[RunGraph] graph already running, graph id=" + std::to_string(args.graph_id)); graph_node->Unlock(); return; @@ -3156,25 +2704,21 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { ComputeGraphPtr compute_graph_tmp = GraphUtils::GetComputeGraph(*(graph_node->GetGraph())); if (compute_graph_tmp == nullptr) { - ReturnError(graph_manager, args.callback, GE_GRAPH_GRAPH_NODE_NULL, + ReturnError(args.callback, GE_GRAPH_GRAPH_NODE_NULL, "[RunGraph] compute_graph_tmp is NULL, graph id = %u."); graph_node->Unlock(); return; } - // when set incre build, save cache helper. - graph_manager->AddModelCacheHelperToMap(args.graph_id, args.session_id, compute_graph_tmp); - - std::vector ge_models; - if (graph_manager->options_.local_fmk_op_flag) { - graph_manager->GetCompilerStages(graph_node->GetGraphId()).optimizer.TranFrameOp(compute_graph_tmp); + if (options_.local_fmk_op_flag) { + GetCompilerStages(graph_node->GetGraphId()).optimizer.TranFrameOp(compute_graph_tmp); } // it will not execute graph preprocess, optimize, parition, build if the graph has built successful. GELOGI("Start for run graph async."); GeRootModelPtr ge_root_model = nullptr; - ret = CheckIncreBuildAndPreRun(graph_manager, args, graph_node, ge_root_model); + ret = CheckIncreBuildAndPreRun(args, graph_node, ge_root_model); if (ret != SUCCESS) { graph_node->SetRunFlag(false); if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) { @@ -3187,250 +2731,49 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { continue; } } - graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, + + PushGraph(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, args.input_tensor, ge_root_model, GetThreadLocalContext(), args.callback })); GELOGI("[PreRunThread] Loop end."); } } -void GraphManager::ParseInputsDimsForData(const std::vector &input_tensor) { - GELOGD("Start parse input dims from data."); - for (size_t i = 0; i < input_tensor.size(); ++i) { - const TensorDesc &tensor_desc = input_tensor[i].GetTensorDesc(); - const Shape &shape = tensor_desc.GetShape(); - const auto &shape_dims = shape.GetDims(); - GELOGD("Input tensor dims is %s.", formats::JoinToString(shape_dims).c_str()); - GetLocalOmgContext().user_real_input_dims.emplace_back(shape_dims); - } -} - -Status GraphManager::ParseInputsDimsForGetNexNosinkAndData(const vector &dynamic_nodes, - const std::vector &input_tensor) { - GELOGD("Start parse inputs dims when coexist data and getnext sink."); - for (size_t i = 0; i < dynamic_nodes.size(); ++i) { - auto op_desc = dynamic_nodes.at(i)->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - GeAttrValue::INT index = 0; - if (!(AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, index))) { - REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) fail", ATTR_NAME_INDEX.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GELOGE(PARAM_INVALID, "[Get][Attr] %s from op:%s(%s) fail", ATTR_NAME_INDEX.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return PARAM_INVALID; - } - if (static_cast(index) > input_tensor.size()) { - REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s) value:%ld > param input_tensor.size:%zu, " - "check invalid", ATTR_NAME_INDEX.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str(), - index, input_tensor.size()); - GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in op:%s(%s) value:%ld > param input_tensor.size:%zu", - ATTR_NAME_INDEX.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), - index, input_tensor.size()); - return PARAM_INVALID; - } - - const TensorDesc &tensor_desc = input_tensor[i].GetTensorDesc(); - const Shape &shape = tensor_desc.GetShape(); - const auto &shape_dims = shape.GetDims(); - GELOGI("Shape dims of %zu data is %s.", index, formats::JoinToString(shape_dims).c_str()); - GetLocalOmgContext().user_real_input_dims.emplace_back(std::move(shape_dims)); +void GraphManager::PushGraph(const RunArgs &args) { + if (executor_ == nullptr) { + GELOGW("Just compile model, not support execute."); + return; } - return SUCCESS; -} -Status GraphManager::ParseInputsDims(const std::vector &input_tensor) { - GELOGI("Start parse input dims of %zu input tensor.", input_tensor.size()); - GetLocalOmgContext().user_real_input_dims.clear(); - if (!GetLocalOmgContext().dynamic_node_type.empty()) { - vector data_nodes; - vector getnext_nosink_nodes; - data_nodes = GetLocalOmgContext().data_nodes; - getnext_nosink_nodes = GetLocalOmgContext().getnext_nosink_nodes; - GELOGD("Data nodes count is %zu, getnext nosink nodes count is %zu.", data_nodes.size(), - getnext_nosink_nodes.size()); - if (GetLocalOmgContext().dynamic_node_type == DATA) { - if (getnext_nosink_nodes.empty()) { - // just data or data+getnext_sink - ParseInputsDimsForData(input_tensor); - } else { - // data+getnext_nosink, but only need to get shape_dims of data - if (ParseInputsDimsForGetNexNosinkAndData(data_nodes, input_tensor) != SUCCESS) { - GELOGE(PARAM_INVALID, "[Parse][Dims] from data failed, when data coexist with getnext nosink."); - return PARAM_INVALID; - } - } - } else { - if (getnext_nosink_nodes.empty()) { - // just getnext_sink or getnext_sink+data, need to get shape_dims from aicpu op - GELOGI("Need to get dims from aicpu op: GETDYNAMICDIMS."); - return SUCCESS; - } else { - if (data_nodes.empty()) { - // just getnext_nosink - ParseInputsDimsForData(input_tensor); - } else { - // getnext_nosink + data, but only need to get shape_dims of getnext_nosink - if (ParseInputsDimsForGetNexNosinkAndData(getnext_nosink_nodes, input_tensor) != SUCCESS) { - GELOGE(PARAM_INVALID, "[Parse][Dims] from getnext nosink failed, when data coexist with getnext nosink"); - return PARAM_INVALID; - } - } - } - } - } - GELOGI("Parse %zu inputs dims success.", GetLocalOmgContext().user_real_input_dims.size()); - return SUCCESS; + (void)executor_->PushGraph(args); } -void GraphManager::RunThread(GraphManager *graph_manager) { - ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); - if (prctl(PR_SET_NAME, ("GE_Run")) != 0) { - GELOGW("Set thread name failed."); - } - - RunArgs args; - while (graph_manager->thread_run_flag_) { - bool pop_status = graph_manager->run_args_q_.Pop(args); - if (!pop_status) { - continue; - } +void GraphManager::SetRunContext(const GraphNodePtr &graph_node) { + OmeContext ome_context; + ome_context.need_multi_batch = GetLocalOmgContext().need_multi_batch; + ome_context.dynamic_node_type = GetLocalOmgContext().dynamic_node_type; + ome_context.dynamic_shape_dims = StringUtils::Split(GetLocalOmgContext().dynamic_dims, ';'); + ome_context.user_input_dims = GetLocalOmgContext().user_input_dims; - GELOGI("[RunThread] A new loop start, graph_id:%u.", args.graph_id); + ome_context.data_nodes = GetLocalOmgContext().data_nodes; + ome_context.getnext_nosink_nodes = GetLocalOmgContext().getnext_nosink_nodes; - ErrorManager::GetInstance().SetErrorContext(args.error_context); - GetContext().SetSessionId(args.session_id); - GetThreadLocalContext() = args.context; - graph_manager->UpdateLocalOmgContext(args.graph_id); - - Status ret; - // parse inputs.dims to vector> dynamic_dims - ret = graph_manager->ParseInputsDims(args.input_tensor); - if (ret != SUCCESS) { - ReturnError(graph_manager, args.callback, ret, "ParseInputsDims failed, thread exit."); - args.graph_node->Unlock(); - return; - } + ome_context.user_real_input_dims = GetLocalOmgContext().user_real_input_dims; - args.graph_node->UpdateLoadFlag(); - if (!args.graph_node->GetLoadFlag()) { - ErrorManager::GetInstance().SetStage(error_message::kModelLoad, error_message::kModelLoad); - args.ge_root_model->SetTrainFlag(graph_manager->GetTrainFlag()); - ret = graph_manager->LoadGraphAsync(args.ge_root_model, args.graph_node); - if (ret != SUCCESS || args.ge_root_model == nullptr) { - StopQueue(graph_manager); - ReturnError(graph_manager, args.callback, ret, "LoadGraphAsync failed, thread exit."); - args.graph_node->Unlock(); - return; - } - // control the times of graph loading in multi-thread scenario - args.graph_node->DecreaseLoadCount(); - args.graph_node->IncreaseLoadRecord(); - - args.graph_node->SetLoadFlag(true); - GELOGI("LoadGraph[%u], model[%u] success and set LoadFlag to true.", args.graph_node->GetGraphId(), - args.ge_root_model->GetModelId()); - } - - ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); - if (graph_manager->GetTrainFlag()) { - ret = graph_manager->graph_executor_.SetGraphContext(graph_manager->GetGraphContext()); - if (ret != SUCCESS) { - GELOGW("[GraphManager] SetGraphContext failed, graph_id=%u.", args.graph_id); - } - graph_manager->graph_executor_.SetTrainFlag(graph_manager->options_.train_graph_flag); - } - - ret = graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeRootModel(), - args.input_tensor, args.callback); - args.graph_node->SetRunFlag(false); - if (ret != SUCCESS) { - ReturnError(graph_manager, args.callback, ret, "ExecuteGraphAsync failed, thread exit."); - args.graph_node->Unlock(); - return; - } - args.graph_node->Unlock(); - GELOGI("[GraphManager] Run graph async success, graph_id=%u.", args.graph_id); - } + graph_node->SetOmeContext(ome_context); } -void GraphManager::StopQueue(GraphManager *graph_manager) { - if (graph_manager == nullptr) { - return; - } - - graph_manager->thread_run_flag_.store(false); - graph_manager->prerun_args_q_.Stop(); - graph_manager->run_args_q_.Stop(); +void GraphManager::StopQueue() { + thread_run_flag_.store(false); + prerun_args_q_.Stop(); } -void GraphManager::ReturnError(GraphManager *graph_manager, RunAsyncCallback callback, Status ret, const string &log) { - if (graph_manager == nullptr) { - return; - } - StopQueue(graph_manager); +void GraphManager::ReturnError(RunAsyncCallback callback, Status ret, const string &log) { + StopQueue(); GELOGE(ret, "%s.", log.c_str()); std::vector outputs; - callback(ret, outputs); -} - -void GraphManager::ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_node, RunAsyncCallback callback, - Status ret, const string &log) { - std::vector outputs; - auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph()); - if (graph_manager == nullptr || compute_graph == nullptr) { - REPORT_INNER_ERROR("E19999", "Param graph_manager or compute_graph in graph_node is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] compute graph or graph manager is nullptr"); - callback(GRAPH_FAILED, outputs); - return; - } - - for (const auto &node : compute_graph->GetAllNodes()) { - if (node->GetType() != "NetOutput") { - continue; - } - for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); i++) { - auto input_desc = node->GetOpDesc()->MutableInputDesc(i); - GeShape ge_shape(input_desc->GetShape().GetDims()); - GeTensorDesc ge_tensor_desc; - ge_tensor_desc.SetShape(ge_shape); - GeTensor ge_tensor(ge_tensor_desc); - int64_t len = 1; - if (input_desc->GetShape().GetDims() != std::vector({})) { - len = input_desc->GetShape().GetShapeSize(); - } - if (len < 0) { - REPORT_INNER_ERROR("E19999", "InputIndex:%zu ShapeSize:%ld of op:%s(%s) < 0, unknown shape is not support, " - "check invalid", i, len, - node->GetName().c_str(), node->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] InputIndex:%zu ShapeSize:%ld of op:%s(%s) < 0, " - "unknown shape is not support", i, len, node->GetName().c_str(), node->GetType().c_str()); - callback(GRAPH_FAILED, outputs); - return; - } else if (len == 0) { - GELOGI("getted shape size is 0.Do process as empty tensor!"); - len = 1; - } - auto length = GetSizeInBytes(len, input_desc->GetDataType()); - auto aligned_ptr = MakeShared(length, kAlignment); - if (aligned_ptr == nullptr) { - REPORT_CALL_ERROR("E19999", "New AlignedPtr failed, len:%ld", length); - GELOGE(GRAPH_FAILED, "[Create][AlignedPtr] failed, len:%ld", length); - return; - } - ge_tensor.SetData(aligned_ptr, length); - ge::Tensor tensor = TensorAdapter::AsTensor(ge_tensor); - // To avoid global step too small and can not stop, totally set a bigger value - auto ptr = aligned_ptr->MutableGet(); - for (int64_t i = 0; i < length; i++) { - ptr[i] = 0x7F; // here stands for a positive max value - } - outputs.emplace_back(std::move(tensor)); - } + if (callback != nullptr) { + callback(ret, outputs); } - callback(SUCCESS, outputs); - return; } bool GraphManager::IsGraphNeedRebuild(uint32_t graph_id) { @@ -3643,6 +2986,7 @@ Status GraphManager::Build(const GraphNodePtr &graph_node, ComputeGraphPtr &comp GraphUtils::DumpGEGraph(compute_graph, "Build", is_always_dump); GraphUtils::DumpGEGraphToOnnx(*compute_graph, "Build"); + SetRunContext(graph_node); graph_node->SetGeRootModel(ge_root_model); return SUCCESS; } diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index 945a5e5d..e7cd88a9 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -26,26 +26,24 @@ #include #include "common/blocking_queue.h" -#include "common/ge_inner_error_codes.h" -#include "common/helper/model_cache_helper.h" +#include "framework/common/ge_inner_error_codes.h" #include "external/graph/types.h" -#include "ge/ge_api_types.h" +#include "external/ge/ge_api_types.h" #include "graph/build/graph_builder.h" -#include "graph/execute/graph_execute.h" #include "graph/ge_local_context.h" -#include "graph/load/graph_loader.h" #include "graph/manager/graph_manager_utils.h" #include "graph/manager/util/variable_accelerate_ctrl.h" #include "graph/optimize/graph_optimize.h" #include "graph/partition/graph_partition.h" #include "graph/preprocess/graph_preprocess.h" #include "graph/tuning_utils.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" +#include "common/executor.h" namespace ge { class GraphManager { public: - GraphManager(); + GraphManager() = default; ~GraphManager() = default; /// @@ -54,7 +52,7 @@ class GraphManager { /// @param [in] options user config params /// @return Status result of function /// - Status Initialize(const std::map &options); + Status Initialize(const std::map &options, Executor *executor = nullptr); /// /// @ingroup ge_graph @@ -113,7 +111,7 @@ class GraphManager { /// @param [out] outputs output data /// @return Status result of function /// - Status RunGraphWithStreamAsync(const GraphId &graph_id, rtStream_t stream, uint64_t session_id, + Status RunGraphWithStreamAsync(const GraphId &graph_id, rtStream_t stream, uint64_t session_id, const std::vector &inputs, std::vector &outputs); /// @@ -227,34 +225,18 @@ class GraphManager { RunAsyncCallback callback; }; - struct RunArgs { - GraphNodePtr graph_node; - GraphId graph_id; - uint64_t session_id; - struct error_message::Context error_context; - std::vector input_tensor; - GeRootModelPtr ge_root_model; - GEThreadLocalContext context; - RunAsyncCallback callback; - }; - void AddGraphNode(GraphId graph_id, const GraphNodePtr &graph_node); void RemoveGraphNode(GraphId graph_id); bool HasGraphNode(GraphId graph_id); Status GetGraphNode(const GraphId &graph_id, GraphNodePtr &out); - std::shared_ptr GetModelListener() const { return graph_run_listener_; } - static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, const SubGraphInfoPtr &sub_graph_info_ptr, const std::string &root_graph_name, uint64_t session_id, const struct error_message::Context &error_context, const GEThreadLocalContext &ge_context); - Status ParseInputsDims(const std::vector &input_tensor); - void ParseInputsDimsForData(const std::vector &input_tensor); - Status ParseInputsDimsForGetNexNosinkAndData(const vector &dynamic_nodes, - const std::vector &input_tensor); + Status RunCustomPass(const GraphNodePtr &graph_node); Status PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, GeRootModelPtr &ge_root_model, uint64_t session_id = INVALID_SESSION_ID); @@ -292,7 +274,7 @@ class GraphManager { static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); - static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag); + static Status ParseTrainGraphFlag(bool &train_flag); static bool IsPerfLevelInvalid(int32_t perf_level); @@ -350,30 +332,18 @@ class GraphManager { Status SubexpressionMigration(ComputeGraphPtr &compute_graph); - Status LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); - - Status CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); - bool CheckModelLoad(const GeRootModelPtr &ge_model, bool load_flag); Status LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); bool IsGraphNeedBuild(const GraphNodePtr &graph_node); - Status LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, GeModelPtr &ge_model); - Status SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper); - Status SaveCacheAfterBuild(uint32_t graph_id, ComputeGraphPtr graph, GeModelPtr &ge_model); - void AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, ComputeGraphPtr &compute_graph); - Status IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model); - void RemoveModelCacheHelper(const GraphId &graph_id); - ModelCacheHelperPtr FindModelCacheHelper(GraphId graph_id); - - static void PreRunThread(GraphManager *graph_manager); - static void RunThread(GraphManager *graph_manager); - static void StopQueue(GraphManager *graph_manager); - static void ReturnError(GraphManager *graph_manager, RunAsyncCallback callback, Status ret, const string &log); - static void ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_node, RunAsyncCallback callback, - Status ret, const string &log); + void SetRunContext(const GraphNodePtr &graph_node); + void PushGraph(const RunArgs &args); + + void PreRunThread(); + void StopQueue(); + void ReturnError(RunAsyncCallback callback, Status ret, const string &log); void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph); @@ -409,11 +379,7 @@ class GraphManager { CompilerStages &GetCompilerStages(GraphId graph_id); void RemoveCompilerStages(GraphId graph_id); - static Status CheckIncreBuildAndPreRun(GraphManager *graph_manager, const PreRunArgs &args, GraphNodePtr &graph_node, - GeRootModelPtr &ge_root_model); - - void ReleaseMemory(const GeModelPtr &ge_model, GraphNodePtr &graph_node, const std::vector &model_ids, - uint32_t graph_id, uint64_t session_id); + Status CheckIncreBuildAndPreRun(const PreRunArgs &args, GraphNodePtr &graph_node, GeRootModelPtr &ge_root_model); Status CheckRepeatAdd(uint32_t graph_id, bool &is_added); @@ -431,34 +397,24 @@ class GraphManager { static Status CheckGraphAdded(const GraphId &graph_id, const Graph &graph); - std::atomic_bool thread_run_flag_; + std::atomic_bool thread_run_flag_{false}; BlockingQueue prerun_args_q_{}; - BlockingQueue run_args_q_{}; std::thread prerun_thread_; - std::thread run_thread_; ComputeGraphPtr compute_graph_; std::map graph_map_; - std::map cache_helper_map_; - - // for run graph synchronous return - std::mutex sync_run_mutex_; - std::condition_variable condition_; - // run graph synchronization call back listener - std::shared_ptr graph_run_listener_; // summary and checkpoint callback function list for ME, key is summary or checkpoint std::map &)>> me_callback_map_; std::map &)>> callback_map_; - bool init_flag_; - + bool init_flag_{false}; GraphManagerOptions options_; GraphContextPtr graph_context_ = nullptr; map omg_contexts_; map compiler_stages_; - GraphExecutor graph_executor_; + Executor *executor_{nullptr}; VarAccelerateCtrl var_acc_ctrl_; diff --git a/ge/graph/manager/graph_manager_utils.cc b/ge/graph/manager/graph_manager_utils.cc index a70b15a6..225a748a 100644 --- a/ge/graph/manager/graph_manager_utils.cc +++ b/ge/graph/manager/graph_manager_utils.cc @@ -21,12 +21,12 @@ #include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" -#include "common/string_util.h" +#include "framework/common/string_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/compute_graph.h" #include "graph/op_desc.h" #include "graph/optimize/common/params.h" -#include "omg/omg_inner_types.h" +#include "framework/omg/omg_inner_types.h" #include "runtime/mem.h" namespace ge { @@ -70,45 +70,9 @@ void GraphNode::IncreaseLoadCount() { ++load_count_; } -SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr), malloc_flag_(false) {} +SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr) {} SubGraphInfo::~SubGraphInfo() { - if (malloc_flag_) { - for (auto &buffer_addr : buffer_addr_) { - if (buffer_addr == nullptr) { - continue; - } - rtError_t rt_ret; - rt_ret = rtFreeHost(buffer_addr); - buffer_addr = nullptr; - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "[Call][RtFreeHost] subgraph free buffer failed, modelId = %u", - model_id_info_.model_id); - } - } - } -} - -Status SubGraphInfo::FreeInOutBuffer() { - if (malloc_flag_) { - for (auto iter = buffer_addr_.begin(); iter != buffer_addr_.end(); ++iter) { - rtError_t rt_ret; - rt_ret = rtFreeHost(*iter); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtFreeHost fail, ret:%d", rt_ret); - GELOGE(rt_ret, "[Call][RtFreeHost] subgraph free buffer failed, modelId = %u", model_id_info_.model_id); - buffer_addr_.erase(buffer_addr_.begin(), iter); - return GE_GRAPH_FREE_FAILED; - } - } - buffer_addr_.clear(); - - malloc_flag_ = false; - return SUCCESS; - } else { - GELOGI("[GraphManager] not malloc buffer, modelId = %u", model_id_info_.model_id); - return SUCCESS; - } } GraphModelListener::GraphModelListener(std::mutex &mutex, std::condition_variable &cond) diff --git a/ge/graph/manager/graph_manager_utils.h b/ge/graph/manager/graph_manager_utils.h index d38b4321..e17d9046 100644 --- a/ge/graph/manager/graph_manager_utils.h +++ b/ge/graph/manager/graph_manager_utils.h @@ -27,17 +27,18 @@ #include #include "common/blocking_queue.h" -#include "common/ge_types.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/ge_types.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/compute_graph.h" -#include "graph/graph.h" +#include "common/local_context.h" +#include "external/graph/graph.h" #include "graph/model.h" -#include "model/ge_model.h" -#include "model/ge_root_model.h" -#include "register/register_fmk_types.h" +#include "common/model/ge_model.h" +#include "common/model/ge_root_model.h" +#include "external/register/register_fmk_types.h" #include "external/ge/ge_api_types.h" namespace ge { @@ -85,8 +86,6 @@ class SubGraphInfo { void SetGeModelPtr(const GeModelPtr &ge_model_ptr) { ge_model_ptr_ = ge_model_ptr; } bool GeModelIsValid() const { return ge_model_ptr_ != nullptr; } - Status FreeInOutBuffer(); - void SetOutputContext(const std::string &output) { output_names_ = output; } std::string GetOutputContext() const { return output_names_; } @@ -106,10 +105,7 @@ class SubGraphInfo { std::vector output_flag_; ModelIdInfo model_id_info_; GeModelPtr ge_model_ptr_; - bool malloc_flag_; - std::vector buffer_addr_; std::string output_names_; - std::vector buffer_size_; std::string stream_label_; std::unordered_map end_to_pld_; std::unordered_map pld_to_end_; @@ -154,6 +150,9 @@ class GraphNode { bool GetRunFlag() const { return run_flag_; } void SetRunFlag(bool flag) { run_flag_ = flag; } + void SetOmeContext(const OmeContext &context) { context_ = context; } + OmeContext &GetOmeContext() { return context_; } + bool IsAsync() const { return async_; } void SetAsync(bool flag) { async_ = flag; } @@ -196,6 +195,8 @@ class GraphNode { bool run_flag_; std::vector subgraph_ptr_list_; + OmeContext context_; + GraphPtr graph_; ComputeGraphPtr compute_graph_; bool build_flag_; diff --git a/ge/graph/manager/graph_var_manager.cc b/ge/graph/manager/graph_var_manager.cc index ced8465f..ce5b335e 100755 --- a/ge/graph/manager/graph_var_manager.cc +++ b/ge/graph/manager/graph_var_manager.cc @@ -20,6 +20,7 @@ #include "graph/manager/graph_mem_manager.h" #include "graph/manager/trans_var_data_utils.h" #include "graph/utils/type_utils.h" +#include "graph/ge_context.h" using std::map; using std::string; @@ -194,35 +195,6 @@ ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_na return SUCCESS; } -ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { - GE_CHECK_NOTNULL(base_ptr); - GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str()); - - VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; - uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset; - - return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr, - var_broadcast_info.input_size, session_id_); -} - -ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { - GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str()); - - VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name]; - // subgraph base_ptr could be nullptr, task it as base 0 - uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset; - - return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name, - var_tensor_desc, session_id_); -} - -ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { - return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr); -} - bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; } rtMemType_t VarResource::GetVarMemType(const int64_t &offset) { @@ -457,10 +429,6 @@ ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTenso return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type); } -void VarManager::GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map) { - var_resource_->GetAllVarAddrMgr(var_addr_mgr_map); -} - int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { std::lock_guard lock(mutex_); MemResource *mem_resource = nullptr; @@ -481,36 +449,6 @@ int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { return mem_resource->GetVarMemSize(); } -Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { - std::lock_guard lock(mutex_); - MemResource *mem_resource = nullptr; - auto iter = mem_resource_map_.find(memory_type); - if (iter == mem_resource_map_.end()) { - mem_resource = MemResource::BuildMemResourceFromType(memory_type); - if (mem_resource == nullptr) { - REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu", - memory_type, session_id_); - GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu", - memory_type, session_id_); - return ge::INTERNAL_ERROR; - } else { - mem_resource_map_[memory_type] = mem_resource; - } - } else { - mem_resource = iter->second; - } - - if (mem_resource == nullptr) { - REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu", - memory_type, session_id_); - GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%u, session_id:%lu", - memory_type, session_id_); - return FAILED; - } - mem_resource->UpdateVarMemSize(mem_size); - return SUCCESS; -} - ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, rtMemType_t memory_type) { std::lock_guard lock(mutex_); @@ -638,16 +576,6 @@ bool VarManager::IsVarExist(const std::string &var_name) { return var_resource_->IsVarExist(var_name); } -ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, - uint8_t *base_ptr) { - std::lock_guard lock(mutex_); - if (var_resource_ == nullptr) { - GELOGW("VarManager has not been init."); - return ge::INTERNAL_ERROR; - } - return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr); -} - ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) { std::lock_guard lock(mutex_); GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str()); @@ -676,16 +604,6 @@ ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastIn return SUCCESS; } -ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) { - std::lock_guard lock(mutex_); - - if (var_resource_ == nullptr) { - GELOGW("VarManager has not been init."); - return ge::INTERNAL_ERROR; - } - return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info); -} - ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) { std::lock_guard lock(mutex_); GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str()); @@ -701,16 +619,6 @@ ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPt return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc)); } -ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) { - std::lock_guard lock(mutex_); - if (var_resource_ == nullptr) { - GELOGW("VarManager has not been init."); - return ge::INTERNAL_ERROR; - } - return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr); -} - bool VarManager::IsVarAddr(const int64_t &offset) { std::lock_guard lock(mutex_); if (var_resource_ == nullptr) { @@ -816,25 +724,52 @@ Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &grap return var_resource_->GetChangedGraphId(var_name, graph_id); } +Status VarManager::GetTotalMemorySize(size_t &total_mem_size) { + rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", + GetContext().DeviceId(), rt_ret); + GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); + return RT_FAILED; + } + size_t free_mem = 0; + rt_ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free_mem, &total_mem_size); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtMemGetInfo] failed, ret:0x%X", rt_ret); + return RT_FAILED; + } + rt_ret = rtDeviceReset(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", + GetContext().DeviceId(), rt_ret); + GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); + return RT_FAILED; + } + return SUCCESS; +} + Status VarManager::SetMemoryMallocSize(const map &options) { - auto it = options.find(GRAPH_MEMORY_MAX_SIZE); - if (it == options.end()) { - graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize; - } else { - string graph_memory_manager_malloc_max_size = it->second; + size_t total_mem_size = 0; + GE_CHK_STATUS_RET_NOLOG(VarManager::GetTotalMemorySize(total_mem_size)); + GEEVENT("Total memory size is %zu", total_mem_size); + + graph_mem_max_size_ = floor(total_mem_size * kGraphMemoryManagerMallocRatio); + var_mem_max_size_ = floor(total_mem_size * kVarMemoryManagerMallocRatio); + + auto it1 = options.find(GRAPH_MEMORY_MAX_SIZE); + if (it1 != options.end()) { + string graph_memory_manager_malloc_max_size = it1->second; ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_); if (ret != SUCCESS) { GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); return ge::GE_GRAPH_OPTIONS_INVALID; } - GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_); } - it = options.find(VARIABLE_MEMORY_MAX_SIZE); - if (it == options.end()) { - var_mem_max_size_ = kMemoryVarManagerMallocSize; - } else { - string memory_var_manager_malloc_size = it->second; + auto it2 = options.find(VARIABLE_MEMORY_MAX_SIZE); + if (it2 != options.end()) { + string memory_var_manager_malloc_size = it2->second; ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_); if (ret != SUCCESS) { GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); @@ -842,6 +777,8 @@ Status VarManager::SetMemoryMallocSize(const map &options) { } } + GEEVENT("The graph_mem_max_size is %zu and the var_mem_max_size is %zu", graph_mem_max_size_, var_mem_max_size_); + var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer; if (var_mem_logic_base_ > kMaxMemorySize) { REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid", diff --git a/ge/graph/manager/graph_var_manager.h b/ge/graph/manager/graph_var_manager.h index 0da12f9c..f0e3b89b 100755 --- a/ge/graph/manager/graph_var_manager.h +++ b/ge/graph/manager/graph_var_manager.h @@ -30,7 +30,7 @@ #include "framework/common/l2_cache_optimize.h" #include "graph/ge_tensor.h" #include "graph/op_desc.h" -#include "graph/tensor.h" +#include "external/graph/tensor.h" #include "runtime/mem.h" namespace ge { @@ -43,6 +43,8 @@ const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL; const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY"; const uint64_t kSessionMemAlignSize = 512; const size_t kSessionMemAlignUnit = 2; +const double kGraphMemoryManagerMallocRatio = 26.0 / 32.0; +const double kVarMemoryManagerMallocRatio = 5.0 / 32.0; enum MemStatus { NORMAL = 0, @@ -118,15 +120,6 @@ class VarResource { ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); - ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); - - ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, - const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr); - - ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, - uint8_t *base_ptr); - Status SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) { if (var_to_trans_road_.find(var_name) != var_to_trans_road_.end()) { GELOGW("Var name: %s has already set.", var_name.c_str()); @@ -230,20 +223,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, rtMemType_t &memory_type); - void GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map); - ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); - ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, - uint8_t *base_ptr); - ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); - ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); - - ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc, - uint8_t *base_ptr); - ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); @@ -286,8 +269,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { int64_t GetVarMemSize(rtMemType_t memory_type); - Status UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size); - bool IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); bool IsVarExist(const std::string &var_name); @@ -316,6 +297,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { mutable std::recursive_mutex mutex_; Status ParseMemoryMallocSize(std::string &memory_size, size_t &my_size); + Status GetTotalMemorySize(size_t &total_mem_size); }; class VarManagerPool { diff --git a/ge/graph/manager/host_mem_manager.h b/ge/graph/manager/host_mem_manager.h index 84d5aebe..6ff19edb 100644 --- a/ge/graph/manager/host_mem_manager.h +++ b/ge/graph/manager/host_mem_manager.h @@ -32,7 +32,7 @@ #include "framework/common/l2_cache_optimize.h" #include "graph/ge_tensor.h" #include "graph/op_desc.h" -#include "graph/tensor.h" +#include "external/graph/tensor.h" #include "runtime/mem.h" namespace ge { diff --git a/ge/graph/manager/model_manager/event_manager.cc b/ge/graph/manager/model_manager/event_manager.cc deleted file mode 100644 index 339e9894..00000000 --- a/ge/graph/manager/model_manager/event_manager.cc +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/manager/model_manager/event_manager.h" - -#define RETURN_IF_COND_NOT_MET(condition, ...) \ - do { \ - if (!(condition)) { \ - GELOGE(FAILED, __VA_ARGS__); \ - return; \ - } \ - } while (0); - -namespace ge { -Status EventManager::Init(size_t event_num) { - if (this->inited_) { - return SUCCESS; - } - - rtEvent_t event = nullptr; - current_idx_ = 0; - for (size_t i = 0; i < event_num; ++i) { - GE_CHK_RT_RET(rtEventCreate(&event)); - this->event_list_.push_back(event); - } - - this->inited_ = true; - - return SUCCESS; -} - -void EventManager::Release() noexcept { - for (size_t i = 0; i < this->event_list_.size(); ++i) { - rtError_t rt_ret = rtEventDestroy(this->event_list_[i]); - RETURN_IF_COND_NOT_MET(rt_ret == RT_ERROR_NONE, "[Destroy][Event] failed, idx is %zu, ret is 0x%x.", i, rt_ret); - } - this->event_list_.clear(); - - this->inited_ = false; -} - -Status EventManager::EventRecord(size_t event_idx, rtStream_t stream) { - GE_CHK_BOOL_RET_STATUS_NOLOG(this->inited_, INTERNAL_ERROR); - - GE_CHK_BOOL_RET_STATUS_NOLOG(event_idx < this->event_list_.size(), PARAM_INVALID); - - GE_CHK_RT_RET(rtEventRecord(this->event_list_[event_idx], stream)); - - current_idx_ = static_cast(event_idx); - return SUCCESS; -} - -Status EventManager::EventElapsedTime(size_t start_event_idx, size_t stop_event_idx, float &time) { - GE_CHK_BOOL_RET_STATUS_NOLOG(this->inited_, INTERNAL_ERROR); - - GE_CHK_BOOL_RET_STATUS_NOLOG(start_event_idx < this->event_list_.size() && - stop_event_idx < this->event_list_.size() && start_event_idx <= stop_event_idx, - PARAM_INVALID); - - GE_CHK_RT_RET(rtEventElapsedTime(&time, this->event_list_[start_event_idx], this->event_list_[stop_event_idx])); - - return SUCCESS; -} - -Status EventManager::GetEvent(uint32_t index, rtEvent_t &event) { - GE_CHK_BOOL_RET_STATUS_NOLOG(index < this->event_list_.size(), PARAM_INVALID); - event = this->event_list_[index]; - return SUCCESS; -} -} // namespace ge diff --git a/ge/graph/manager/model_manager/event_manager.h b/ge/graph/manager/model_manager/event_manager.h deleted file mode 100644 index a7464e0c..00000000 --- a/ge/graph/manager/model_manager/event_manager.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ -#define GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ - - -#include - -#include "common/fmk_error_codes.h" -#include "common/fmk_types.h" -#include "common/util.h" -#include "runtime/event.h" - -namespace ge { -class EventManager { - public: - /// - /// @ingroup domi_ome - /// @brief constructor - /// - EventManager() : inited_(false), current_idx_(0) {} - /// - /// @ingroup domi_ome - /// @brief destructor - /// - ~EventManager() { this->Release(); } - - /// - /// @ingroup domi_ome - /// @brief init and create event list - /// @param [in] event_num event number created - /// @return exec result - /// - Status Init(size_t event_num); - - /// - /// @ingroup domi_ome - /// @brief event record - /// @param [in] event_idx event index - /// @param [in] stream related stream - /// @return exec result - /// - Status EventRecord(size_t event_idx, rtStream_t stream); - - /// - /// @ingroup domi_ome - /// @brief time between start and end in ms - /// @param [in] start_event_idx start event index - /// @param [in] stop_event_idx stop event index - /// @param [out] time - /// @return exec result - /// - Status EventElapsedTime(size_t start_event_idx, size_t stop_event_idx, float &time); - - /// - /// @ingroup domi_ome - /// @brief current event index - /// @return - /// - uint32_t CurrentIdx() const { return current_idx_; } - - /// - /// @ingroup domi_ome - /// @brief get event at specific loc - /// @param [in] index event index - /// @return - /// - Status GetEvent(uint32_t index, rtEvent_t &event); - - /// - /// @ingroup domi_ome - /// @brief release event list - /// @param [in] - /// @return - /// - void Release() noexcept; - - private: - std::vector event_list_; - bool inited_; - uint32_t current_idx_; -}; // EventManager -} // namespace ge -#endif // GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ diff --git a/ge/graph/manager/trans_var_data_utils.cc b/ge/graph/manager/trans_var_data_utils.cc index 621eba79..2e6ce454 100644 --- a/ge/graph/manager/trans_var_data_utils.cc +++ b/ge/graph/manager/trans_var_data_utils.cc @@ -16,14 +16,14 @@ #include "graph/manager/trans_var_data_utils.h" -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/debug/memory_dumper.h" #include "common/formats/formats.h" #include "common/formats/utils/formats_trans_utils.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "graph/manager/graph_var_manager.h" -#include "graph/types.h" +#include "external/graph/types.h" #include "graph/utils/type_utils.h" #include "common/thread_pool.h" #include @@ -415,72 +415,6 @@ Status CopyTensorFromSrcVarNode(const NodePtr &var_src, return SUCCESS; } } // namespace -Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, - uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) { - GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "[Check][Param] dst addr is nullptr."); - uint8_t *src_host_addr = nullptr; - int64_t src_addr_size = 0; - GE_MAKE_GUARD_RTMEM(src_host_addr); - GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id)); - - GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size); - GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED, - "[Check][Param] src_addr_size:%ld not equal to dst_addr_size:%ld", - src_addr_size, dst_addr_size); - - GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE)); - return SUCCESS; -} - -Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, - const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { - GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "[Check][Param] src addr is nullptr. "); - uint8_t *host_addr = nullptr; - GE_MAKE_GUARD_RTMEM(host_addr); - GE_CHK_RT_RET(rtMallocHost(reinterpret_cast(&host_addr), src_addr_size)); - GE_CHK_RT_RET(rtMemcpy(host_addr, src_addr_size, src_addr, src_addr_size, RT_MEMCPY_DEVICE_TO_HOST)); - - GE_CHK_STATUS_RET( - SyncTensorToDevice(var_name, reinterpret_cast(host_addr), src_addr_size, dst_tensor_desc, session_id)); - - return SUCCESS; -} - -Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, - uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) { - GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "[Get][Size] from TensorDesc failed"); - - uint8_t *src_addr = nullptr; - GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr)); - uint8_t *mem_addr = - src_addr - - static_cast(static_cast(VarManager::Instance(session_id)->GetVarMemLogicBase())) + - static_cast( - reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); - GE_CHK_RT_RET(rtMallocHost(reinterpret_cast(host_addr), src_tensor_size)); - - GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); - - GELOGI("SyncTensorToHost var_name %s, src_tensor_size %ld", var_name.c_str(), src_tensor_size); - return SUCCESS; -} - -Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, - const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { - uint8_t *dst_addr = nullptr; - GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr)); - uint8_t *mem_addr = - dst_addr - - static_cast(static_cast(VarManager::Instance(session_id)->GetVarMemLogicBase())) + - static_cast( - reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); - GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE)); - - GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size); - - return SUCCESS; -} - Status TransVarDataUtils::TransAllVarData(const vector &variable_nodes, uint64_t session_id, rtContext_t context, diff --git a/ge/graph/manager/trans_var_data_utils.h b/ge/graph/manager/trans_var_data_utils.h index 95ebd09a..f5a89a50 100755 --- a/ge/graph/manager/trans_var_data_utils.h +++ b/ge/graph/manager/trans_var_data_utils.h @@ -24,16 +24,10 @@ #include "graph/utils/tensor_utils.h" #include "graph/node.h" #include "runtime/context.h" -#include "graph_var_manager.h" namespace ge { class TransVarDataUtils { public: - static ge::Status SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, - uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id_); - static ge::Status SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, - const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); - static ge::Status TransAllVarData(const std::vector &variable_nodes, uint64_t session_id, rtContext_t context, @@ -41,12 +35,6 @@ class TransVarDataUtils { uint32_t thread_num = 16); static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id); - - private: - static ge::Status SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, - uint8_t **host_addr, int64_t &addr_size, uint64_t session_id_); - static ge::Status SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size, - const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); }; } // namespace ge diff --git a/ge/graph/manager/util/debug.h b/ge/graph/manager/util/debug.h index e1b13caf..02cacb72 100755 --- a/ge/graph/manager/util/debug.h +++ b/ge/graph/manager/util/debug.h @@ -33,10 +33,10 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/debug/memory_dumper.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "mmpa/mmpa_api.h" #include "proto/om.pb.h" #include "runtime/kernel.h" diff --git a/ge/graph/manager/util/hcom_util.cc b/ge/graph/manager/util/hcom_util.cc index 2da19cc9..021a458e 100644 --- a/ge/graph/manager/util/hcom_util.cc +++ b/ge/graph/manager/util/hcom_util.cc @@ -16,10 +16,10 @@ #include "graph/manager/util/hcom_util.h" -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/math/math_util.h" -#include "common/op/attr_value_util.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/attr_value_util.h" +#include "framework/common/op/ge_op_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" @@ -109,8 +109,7 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, HcclDataType GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetInputDescPtr(i), input_size), "[Get][Size] from TensorDesc failed, op:%s, input index:%zu", op_desc->GetName().c_str(), i); // dynamic shape hccl op get size from output tensor desc - if (op_desc->HasAttr(ATTR_NAME_IS_UNKNOWN_SHAPE)) { - GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i)); + if (op_desc->HasAttr(ATTR_NAME_IS_UNKNOWN_SHAPE) && (op_desc->GetOutputDescPtr(i) != nullptr)) { GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), input_size), "[Get][Size] from TensorDesc failed, op:%s, input index:%zu", op_desc->GetName().c_str(), i); } diff --git a/ge/graph/manager/util/hcom_util.h b/ge/graph/manager/util/hcom_util.h index f80ced35..96ef92bf 100644 --- a/ge/graph/manager/util/hcom_util.h +++ b/ge/graph/manager/util/hcom_util.h @@ -21,11 +21,11 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/opskernel/ge_task_info.h" -#include "common/string_util.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/string_util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "graph/op_desc.h" #include "hccl/hcom.h" #include "proto/task.pb.h" diff --git a/ge/graph/optimize/common/params.h b/ge/graph/optimize/common/params.h index d5b66b8f..fbe58c6b 100644 --- a/ge/graph/optimize/common/params.h +++ b/ge/graph/optimize/common/params.h @@ -20,7 +20,7 @@ #include #include "common/singleton.h" -#include "common/types.h" +#include "framework/common/types.h" namespace ge { class Params : public Singleton { diff --git a/ge/graph/optimize/graph_optimize.cc b/ge/graph/optimize/graph_optimize.cc index 835e257b..a321ed43 100644 --- a/ge/graph/optimize/graph_optimize.cc +++ b/ge/graph/optimize/graph_optimize.cc @@ -17,7 +17,7 @@ #include "graph/optimize/graph_optimize.h" #include "graph/ge_context.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/passes/dimension_adjust_pass.h" #include "inc/pass_manager.h" #include "init/gelib.h" @@ -336,10 +336,8 @@ Status GraphOptimize::OptimizeAfterStage1(ComputeGraphPtr &compute_graph) { GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str()); continue; } -#ifndef ONLY_COMPILE_OPEN_SRC GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str()); ret = (iter->second)->OptimizeAfterStage1(*compute_graph); -#endif if (ret != SUCCESS) { REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, " "graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str()); diff --git a/ge/graph/optimize/graph_optimize.h b/ge/graph/optimize/graph_optimize.h index 702b7e33..a3d359b6 100755 --- a/ge/graph/optimize/graph_optimize.h +++ b/ge/graph/optimize/graph_optimize.h @@ -25,13 +25,13 @@ #include #include -#include "common/ge_inner_error_codes.h" -#include "common/ge_types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/ge_types.h" #include "common/optimizer/graph_optimizer.h" #include "graph/compute_graph.h" #include "graph/manager/graph_context.h" #include "graph/manager/graph_manager_utils.h" -#include "omg/omg_inner_types.h" +#include "framework/omg/omg_inner_types.h" namespace ge { using ComputeGraphPtr = std::shared_ptr; diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index 7e7ab908..2edb1828 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -17,7 +17,7 @@ #include #include "common/ge/ge_util.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/optimize/graph_optimize.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/optimize/summary_optimize.cc b/ge/graph/optimize/summary_optimize.cc index d3c02d3e..08a27c4e 100644 --- a/ge/graph/optimize/summary_optimize.cc +++ b/ge/graph/optimize/summary_optimize.cc @@ -21,7 +21,7 @@ #include "graph/optimize/graph_optimize.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" -#include "omg/omg_inner_types.h" +#include "framework/omg/omg_inner_types.h" namespace { const char *const kSummary = "Summary"; diff --git a/ge/graph/partition/dynamic_shape_partition.cc b/ge/graph/partition/dynamic_shape_partition.cc index 055b2aa4..cd98b6c5 100755 --- a/ge/graph/partition/dynamic_shape_partition.cc +++ b/ge/graph/partition/dynamic_shape_partition.cc @@ -31,7 +31,7 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #define REQUIRE(cond, ...) \ do { \ @@ -284,9 +284,6 @@ Status DynamicShapePartitioner::InitClusters() { auto cluster = MakeShared(rank++, type, node, this); REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed."); node_2_cluster_[node] = cluster; - if (cluster->IsUnknownShape()) { - ordered_cluster_.push_back(cluster); - } int64_t group_index = -1; if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { @@ -306,7 +303,7 @@ Status DynamicShapePartitioner::InitClusters() { return SUCCESS; } -Status DynamicShapePartitioner::TopologicalSortClusters() { +Status DynamicShapePartitioner::TopologicalSortClusters(const OrderedFilter &ordered_filter) { ordered_cluster_.clear(); // BFS topological sort clusters for known shape cluster std::queue ready_clusters; @@ -331,7 +328,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { auto cluster = ready_clusters.front(); ready_clusters.pop(); cluster->UpdateRank(rank++); - if (cluster->IsKnownShape() || cluster->IsInputNode()) { + if (ordered_filter == nullptr || ordered_filter(cluster)) { ordered_cluster_.push_back(cluster); } for (const auto &out_cluster : cluster->Outputs()) { @@ -364,6 +361,7 @@ static std::string ToString(const std::vector &clusters) { } void DynamicShapePartitioner::MergeClustersControlFlow() { + std::unordered_set all_merged_clusters; for (const auto &item : control_clusters_) { const auto &control_cluster = item.second; auto rit = control_cluster.rbegin(); @@ -373,12 +371,21 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { } const auto &cluster = *rit; + if (all_merged_clusters.count(cluster) > 0) { + continue; + } + for (++rit; rit != control_cluster.rend(); ++rit) { const auto &cluster_from = *rit; + if (all_merged_clusters.count(cluster_from) > 0) { + continue; + } + auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), ToString(merged_clusters).c_str()); for (const auto &merged_cluster : merged_clusters) { + all_merged_clusters.emplace(merged_cluster); for (const auto &node : merged_cluster->Nodes()) { node_2_cluster_[node] = cluster; } @@ -459,9 +466,19 @@ void DynamicShapePartitioner::MergeClustersInputData() { } Status DynamicShapePartitioner::MergeClusters() { + const auto filter_known = [](const ClusterPtr &cluster) { + return cluster->IsKnownShape() || cluster->IsInputNode(); + }; + const auto filter_unknown = [](const ClusterPtr &cluster) { + return cluster->IsUnknownShape(); + }; + MergeClustersControlFlow(); + REQUIRE_SUCCESS(TopologicalSortClusters(filter_unknown), + "[TopologicalSort][Clusters] after merge control flow clusters failed."); MergeClustersUnknownShape(); - REQUIRE_SUCCESS(TopologicalSortClusters(), "[TopologicalSort][Clusters] after merge unknown shape clusters failed."); + REQUIRE_SUCCESS(TopologicalSortClusters(filter_known), + "[TopologicalSort][Clusters] after merge unknown shape clusters failed."); MergeClustersKnownShape(); MergeClustersInputData(); return SUCCESS; @@ -703,7 +720,12 @@ void Cluster::Merge(ClusterPtr other) { if (other->min_ < min_) { min_ = other->min_; } -}; + + if (!IsUnknownShape() && other->IsUnknownShape()) { + type_ = UNKNOWN_SHAPE; + } +} + bool Cluster::TryMerge(ClusterPtr other) { std::queue forward_reached; forward_reached.push(other); diff --git a/ge/graph/partition/dynamic_shape_partition.h b/ge/graph/partition/dynamic_shape_partition.h index a17c4e4b..0eb282a2 100644 --- a/ge/graph/partition/dynamic_shape_partition.h +++ b/ge/graph/partition/dynamic_shape_partition.h @@ -21,7 +21,7 @@ #include #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/compute_graph.h" namespace ge { @@ -111,6 +111,8 @@ class DynamicShapePartitioner { Status Partition(); + using OrderedFilter = std::function &cluster)>; + private: Status PartitionImpl(); // Collect nodes that satisfy the unknowshape rules: @@ -138,7 +140,7 @@ class DynamicShapePartitioner { // Merge clusters step3 void MergeClustersInputData(); // Topological sort clusters after merge unknown shape clusters. - Status TopologicalSortClusters(); + Status TopologicalSortClusters(const OrderedFilter &ordered_filter); // Deduplicate merged clusters void PruneUniqueClusters(); // Establish the input-output anchors for each partition of the cluster and record links to other clusters @@ -161,7 +163,7 @@ class DynamicShapePartitioner { ge::ComputeGraphPtr root_graph_; // The original graph to partition std::unordered_map> node_2_cluster_; // Record nodes and the cluster it belongs to // V1 control flow cluster, need merge to one Graph. - std::unordered_map>> control_clusters_; + std::map>> control_clusters_; // topological sorted clusters, this field will change with the splitting. // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters diff --git a/ge/graph/partition/engine_place.cc b/ge/graph/partition/engine_place.cc index 93cc3e61..8639f015 100755 --- a/ge/graph/partition/engine_place.cc +++ b/ge/graph/partition/engine_place.cc @@ -22,7 +22,7 @@ #include #include -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "common/util/error_manager/error_manager.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" diff --git a/ge/graph/partition/engine_place.h b/ge/graph/partition/engine_place.h index 5dc3e6a0..125babb6 100755 --- a/ge/graph/partition/engine_place.h +++ b/ge/graph/partition/engine_place.h @@ -20,7 +20,7 @@ #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/compute_graph.h" namespace ge { diff --git a/ge/graph/partition/graph_partition.cc b/ge/graph/partition/graph_partition.cc index c3f9480d..86c9f1fd 100755 --- a/ge/graph/partition/graph_partition.cc +++ b/ge/graph/partition/graph_partition.cc @@ -24,11 +24,11 @@ #include "analyzer/analyzer.h" #include "common/ge/ge_util.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_manager_utils.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" @@ -179,6 +179,7 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); } GE_CHECK_NOTNULL(original_compute_graph); + output_merged_compute_graph->SetName(original_compute_graph->GetName()); // partition sub graph for (const auto &sub_graph : original_compute_graph->GetAllSubgraphs()) { ComputeGraphPtr merged_sub_graph = nullptr; @@ -188,8 +189,16 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); continue; } + // this means subgraph added in optimize subgraph and without partitions, so just add to root graph + if (merged_sub_graph == sub_graph) { + GELOGI("Just add subgraph %s (parent node is %s) to root graph %s.", sub_graph->GetName().c_str(), + sub_graph->GetParentNode()->GetName().c_str(), output_merged_compute_graph->GetName().c_str()); + sub_graph->SetParentGraph(sub_graph->GetParentNode()->GetOwnerComputeGraph()); + GE_IF_BOOL_EXEC(output_merged_compute_graph->AddSubgraph(sub_graph->GetName(), merged_sub_graph) != SUCCESS, + return FAILED;) + continue; + } // add sub graph - output_merged_compute_graph->SetName(original_compute_graph->GetName()); merged_sub_graph->SetName(sub_graph->GetName()); merged_sub_graph->SetInputSize(sub_graph->GetInputSize()); merged_sub_graph->SetOutputSize(sub_graph->GetOutputSize()); @@ -245,12 +254,9 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co } if ((graph_2_graph_partition_info_.find(original_compute_graph) == graph_2_graph_partition_info_.end()) || (graph_2_subgraph_list_.find(original_compute_graph) == graph_2_subgraph_list_.end())) { - REPORT_INNER_ERROR("E19999", "original_compute_graph:%s is not find in graph_2_graph_partition_info_.", - original_compute_graph->GetName().c_str()); - GELOGE(GE_GRAPH_NULL_INPUT, - "[Check][Param] original_compute_graph:%s is not find in graph_2_graph_partition_info_.", - original_compute_graph->GetName().c_str()); - return FAILED; + GELOGW("[GraphPartition]: compute_graph has not found, just return original."); + output_merged_compute_graph = original_compute_graph; + return SUCCESS; } GraphPartitionInfo &subgraph_info = graph_2_graph_partition_info_[original_compute_graph]; const auto &sub_graph_list = graph_2_subgraph_list_[original_compute_graph]; @@ -708,6 +714,7 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vectorGetName()); + (void)sub_graph->SetExtAttr("part_src_graph", compute_graph); GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), compute_graph->GetName().c_str()); GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]); diff --git a/ge/graph/partition/graph_partition.h b/ge/graph/partition/graph_partition.h index f34c67e6..6c21fabe 100644 --- a/ge/graph/partition/graph_partition.h +++ b/ge/graph/partition/graph_partition.h @@ -28,7 +28,7 @@ #include #include "graph/compute_graph.h" #include "graph/manager/graph_manager_utils.h" -#include "graph/operator_reg.h" +#include "external/graph/operator_reg.h" #include "graph/partition/engine_place.h" namespace ge { @@ -70,6 +70,8 @@ class GraphPartitioner { // Return all subgraphs const Graph2SubGraphInfoList &GetSubGraphMap(); + const Graph2InputNodesSubGraphInfo &GetSubGraphInfoMap() {return graph_2_input_subgraph_; } + private: Status MergeSubGraph(ge::ComputeGraphPtr &output_merged_compute_graph, const ge::ComputeGraphPtr &original_compute_graph); diff --git a/ge/graph/partition/stage_partition.cc b/ge/graph/partition/stage_partition.cc index 309e24c4..68b4209f 100644 --- a/ge/graph/partition/stage_partition.cc +++ b/ge/graph/partition/stage_partition.cc @@ -21,8 +21,8 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" -#include "common/util.h" -#include "common/types.h" +#include "framework/common/util.h" +#include "framework/common/types.h" namespace ge { namespace { diff --git a/ge/graph/partition/stage_partition.h b/ge/graph/partition/stage_partition.h index bac00e6b..99aac2b9 100644 --- a/ge/graph/partition/stage_partition.h +++ b/ge/graph/partition/stage_partition.h @@ -21,7 +21,7 @@ #include #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/compute_graph.h" namespace ge { diff --git a/ge/graph/passes/addn_pass.h b/ge/graph/passes/addn_pass.h index 373d1842..075ff9fc 100644 --- a/ge/graph/passes/addn_pass.h +++ b/ge/graph/passes/addn_pass.h @@ -17,10 +17,10 @@ #ifndef GE_GRAPH_PASSES_ADDN_PASS_H_ #define GE_GRAPH_PASSES_ADDN_PASS_H_ -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "framework/common/types.h" -#include "graph/graph.h" +#include "external/graph/graph.h" #include "graph/passes/base_pass.h" #include "graph/passes/pass_utils.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/passes/aicpu_constant_folding_pass.cc b/ge/graph/passes/aicpu_constant_folding_pass.cc index 8fdb51a1..d33d4db2 100644 --- a/ge/graph/passes/aicpu_constant_folding_pass.cc +++ b/ge/graph/passes/aicpu_constant_folding_pass.cc @@ -19,9 +19,9 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/ge/ge_util.h" -#include "common/types.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/attr_utils.h" diff --git a/ge/graph/passes/atomic_addr_clean_pass.cc b/ge/graph/passes/atomic_addr_clean_pass.cc index 9f202c77..13700e2e 100755 --- a/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/ge/graph/passes/atomic_addr_clean_pass.cc @@ -22,9 +22,9 @@ #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/node_utils.h" #include "init/gelib.h" diff --git a/ge/graph/passes/atomic_addr_clean_pass.h b/ge/graph/passes/atomic_addr_clean_pass.h index 0d0b8fff..30162359 100755 --- a/ge/graph/passes/atomic_addr_clean_pass.h +++ b/ge/graph/passes/atomic_addr_clean_pass.h @@ -19,7 +19,7 @@ #include -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/attach_stream_label_pass.cc b/ge/graph/passes/attach_stream_label_pass.cc index d5b28ec7..71d74500 100644 --- a/ge/graph/passes/attach_stream_label_pass.cc +++ b/ge/graph/passes/attach_stream_label_pass.cc @@ -15,8 +15,8 @@ */ #include "graph/passes/attach_stream_label_pass.h" -#include "ge/ge_api_types.h" -#include "graph/common/omg_util.h" +#include "external/ge/ge_api_types.h" +#include "common/omg_util.h" using std::string; diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index 165e7e81..8b4a8b88 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -1,374 +1,475 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/base_pass.h" - -#include -#include - -#include "common/debug/log.h" -#include "framework/common/debug/ge_log.h" -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" - -namespace ge { -namespace { -constexpr int kMaxRePassTimes = 10000; -constexpr size_t kMaxOneInNodes = 1000; -// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later -constexpr int kMaxRecursiveDepth = 20; -struct DuringPassNodeSets { - std::unordered_set nodes_seen; - std::unordered_set nodes_deleted; - std::unordered_set nodes_re_pass; - std::unordered_set nodes_re_pass_immediately; - std::unordered_set nodes_last; - std::unordered_set nodes_suspend; - std::unordered_set nodes_resume; -}; - -void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &input_edge_nodes, - std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { - nodes_last.clear(); - for (auto &node : graph->GetDirectNode()) { - if (node == nullptr) { - continue; - } - size_t in_nums = node->GetInNodes().size(); - if (in_nums == 0) { - input_edge_nodes.push_back(node); - nodes_seen.insert(node.get()); - } else if (in_nums > kMaxOneInNodes) { - nodes_last.insert(node); - } - } -} - -bool IsAllInNodesAlive(const Node::Vistor &nodes, const std::unordered_set &nodes_suspend) { - return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; }); -} - -void AddNextIterNodes(const Node::Vistor &nodes, std::deque &nodes_to_pass, - DuringPassNodeSets &during_pass_node_set) { - auto &nodes_seen = during_pass_node_set.nodes_seen; - const auto &nodes_last = during_pass_node_set.nodes_last; - const auto &nodes_suspend = during_pass_node_set.nodes_suspend; - for (auto &node : nodes) { - if (node == nullptr) { - continue; - } - if (nodes_last.count(node) != 0) { - continue; - } - if (nodes_suspend.count(node) > 0) { - GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str()); - continue; - } - - bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend); - bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); - if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) { - nodes_to_pass.push_back(node); - } - } -} - -void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque &nodes) { - for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { - GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); - nodes.push_front(node); - } - during_pass_node_set.nodes_re_pass_immediately.clear(); -} - -void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque &nodes) { - for (auto &node : during_pass_node_set.nodes_resume) { - const auto &it = during_pass_node_set.nodes_suspend.find(node); - if (it != during_pass_node_set.nodes_suspend.end()) { - during_pass_node_set.nodes_suspend.erase(node); - GELOGD("The node %s resumed by pass.", node->GetName().c_str()); - nodes.push_back(node); - } else { - GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str()); - } - } - during_pass_node_set.nodes_resume.clear(); -} - -void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, - const std::unordered_set &nodes_suspend, - const std::unordered_set &nodes_resume) { - for (const auto &node : nodes_suspend) { - GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); - during_pass_node_set.nodes_suspend.emplace(node); - } - - for (const auto &node : nodes_resume) { - GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); - during_pass_node_set.nodes_resume.emplace(node); - } -} - -void PushToRePassIfSeen(NodePtr &node, const std::pair &name_to_pass, - std::unordered_set &nodes_seen, const std::unordered_set &nodes_to_re_pass, - std::unordered_set &nodes_re_pass) { - for (const auto &node_to_re_pass : nodes_to_re_pass) { - if (node_to_re_pass == nullptr) { - GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), - node->GetName().c_str(), node->GetType().c_str()); - continue; - } - if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { - GELOGD("The node %s will be re-pass.", node_to_re_pass->GetName().c_str()); - nodes_re_pass.insert(node_to_re_pass); - } else { - GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); - } - } -} - -Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNodeSets &during_pass_node_set) { - if (node == nullptr) { - REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); - GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); - return FAILED; - } - GELOGD("Begin to run pass for node %s", node->GetName().c_str()); - for (const auto &name_to_pass : names_to_passes) { - if (name_to_pass.second == nullptr) { - GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); - continue; - } - - GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); - name_to_pass.second->init(); - auto result = name_to_pass.second->Run(node); - if (result != SUCCESS) { - REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", - name_to_pass.first.c_str(), node->GetName().c_str(), result); - GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result " - "%u, the passes will be terminated immediately.", - name_to_pass.first.c_str(), node->GetName().c_str(), result); - return result; - } - - const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); - PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, - during_pass_node_set.nodes_re_pass); - - const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); - PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, - during_pass_node_set.nodes_re_pass_immediately); - - PushToSuspendNodes(during_pass_node_set, name_to_pass.first, - name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume()); - - const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); - during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); - if (nodes_deleted_by_pass.count(node) > 0) { - GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), - name_to_pass.first.c_str()); - break; - } - } - - return SUCCESS; -} - -void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) { - for (auto &name_to_pass : names_to_pass) { - name_to_pass.second->SetOption(option, ""); - } -} - -void ClearOption(NamesToPass names_to_pass) { - for (auto &name_to_pass : names_to_pass) { - name_to_pass.second->ClearOptions(); - } -} - -bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) { - if (node == nullptr) { - GELOGW("node is null"); - return false; - } - if (during_pass_node_set.nodes_deleted.count(node) > 0) { - GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); - return false; - } - if (during_pass_node_set.nodes_suspend.count(node) > 0) { - GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", - node->GetName().c_str()); - return false; - } - - return true; -} -} // namespace - -Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map) { - if (node == nullptr) { - REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); - GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); - return FAILED; - } - GELOGI("Prepare to isolate and delete node, name:%s, type:%s.", node->GetName().c_str(), - node->GetType().c_str()); - ComputeGraphPtr graph = node->GetOwnerComputeGraph(); - if (graph == nullptr) { - REPORT_INNER_ERROR("E19999", "The owner graph of node:%s must not be null.", node->GetName().c_str()); - GELOGE(FAILED, "[Get][OwnerComputeGraph] failed, The owner graph of node:%s must not be null.", - node->GetName().c_str()); - return FAILED; - } - - AddRePassNodesWithInOut(node); - - if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str()); - GELOGE(FAILED, "[Isolate][Node] %s failed.", node->GetName().c_str()); - return FAILED; - } - - if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) { - REPORT_CALL_ERROR("E19999", "call RemoveNodeWithoutRelink for node:%s failed.", node->GetName().c_str()); - GELOGE(FAILED, "[Call][RemoveNodeWithoutRelink] for node:%s failed.", node->GetName().c_str()); - return FAILED; - } - - AddNodeDeleted(node); - return SUCCESS; -} - -Status GEPass::Run(const NamesToPass &names_to_passes) { - if (graph_ == nullptr) { - REPORT_INNER_ERROR("E19999", "graph_ is nullptr, check invalid."); - GELOGE(INTERNAL_ERROR, "[Check][Param] The graph is nullptr"); - return INTERNAL_ERROR; - } - if (names_to_passes.empty()) { - GELOGW("No passes input, the GEPass will do nothing"); - return INTERNAL_ERROR; - } - - if (depth_ > kMaxRecursiveDepth) { - GELOGE(PARAM_INVALID, - "[Check][Param] The pass for root graph %s will be terminated because too many nesting" - " levels(%d) of subgraphs, last subgraph is %s", - root_graph_->GetName().c_str(), depth_, graph_->GetName().c_str()); - return PARAM_INVALID; - } - - return RunPassesOneGraph(names_to_passes); -} - -Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { - GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); - std::deque nodes; - DuringPassNodeSets during_pass_node_set; - GetAllNodesNoInputEdge(graph_, nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); - GELOGD("Start points count %zu", nodes.size()); - int re_pass_times = 0; - - do { - for (auto &node : during_pass_node_set.nodes_re_pass) { - nodes.push_back(node); - during_pass_node_set.nodes_seen.insert(node.get()); - } - during_pass_node_set.nodes_re_pass.clear(); - - while (!nodes.empty()) { - NodePtr node = nodes.front(); - nodes.pop_front(); - - (void)during_pass_node_set.nodes_re_pass.erase(node); - if (!CheckNode(node, during_pass_node_set)) { - continue; - } - AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); - - auto ret = RunPasses(node, names_to_passes, during_pass_node_set); - if (ret != SUCCESS) { - GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", - node->GetName().c_str(), node->GetType().c_str(), ret); - return ret; - } - - bool has_sub_graph = false; - ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); - if (ret != SUCCESS) { - GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); - return ret; - } - - if (has_sub_graph) { - GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); - SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); - ret = RunPasses(node, names_to_passes, during_pass_node_set); - if (ret != SUCCESS) { - GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", - node->GetName().c_str(), node->GetType().c_str(), ret); - return ret; - } - - // There is only one option scene, so set and clear options around the `RunPasses` func. - // if there are more than one scene to set options, the `ClearOption` function - // should be called each time at the begin of the iteration - ClearOption(names_to_passes); - } - - AddRepassNodes(during_pass_node_set, nodes); - AddResumeNodes(during_pass_node_set, nodes); - } - - for (auto &node : during_pass_node_set.nodes_last) { - bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen); - if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) { - nodes.push_back(node); - } - } - during_pass_node_set.nodes_last.clear(); - } while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); - - if (re_pass_times == kMaxRePassTimes) { - GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); - } - GELOGD("All passes runs end"); - - return SUCCESS; -} -Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { - auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); - has_sub_graph = false; - for (const auto &name : sub_graph_names) { - auto graph = root_graph_->GetSubgraph(name); - if (graph == nullptr) { - GELOGW("Can not find the sub graph %s from node %s, the pass-process will skip it", - name.c_str(), node->GetName().c_str()); - continue; - } - has_sub_graph = true; - GELOGI("Begin to run passes on the sub graph %s of node %s", name.c_str(), node->GetName().c_str()); - GEPass pass(graph, root_graph_, depth_ + 1); - auto ret = pass.Run(names_to_passes); - if (ret != SUCCESS) { - GELOGE(ret, "[Run][Passes] for sub graph:%s from node:%s failed", name.c_str(), node->GetName().c_str()); - return ret; - } - } - return SUCCESS; -} -} // namespace ge +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/base_pass.h" + +#include +#include + +#include "common/debug/log.h" +#include "graph/utils/graph_utils.h" + +namespace ge { +namespace { +constexpr int kMaxRePassTimes = 10000; +constexpr size_t kMaxOneInNodes = 1000; +// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later +constexpr int kMaxRecursiveDepth = 20; + +void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, + GEPass::GraphLevelState &g_state) { + for (auto &node : graph->GetDirectNode()) { + if (node == nullptr) { + continue; + } + size_t in_nums = node->GetInNodes().size(); + if (in_nums == 0) { + g_state.AddNodeToQueueIfNotSeen(node); + } else if (in_nums > kMaxOneInNodes) { + g_state.nodes_last.insert(node); + } + } +} + +bool AnyNodesIn(const Node::Vistor &nodes, const std::unordered_set &nodes_set) { + return std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { + return nodes_set.count(n) > 0; + }); +} + + +bool IsNodeReadyToQueue(const NodePtr &node, GEPass::GraphLevelState &g_state) { + if (node == nullptr) { + GELOGW("node is null"); + return false; + } + if (g_state.nodes_deleted.count(node) > 0) { + GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); + return false; + } + + if (g_state.nodes_last.count(node) != 0) { + return false; + } + + // all in_node seen && all in_node not suspend + if (!node->IsAllInNodesSeen(g_state.nodes_seen)) { + return false; + } + + if (g_state.nodes_suspend.count(node) > 0) { + GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", + node->GetName().c_str()); + return false; + } + + if (AnyNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) { + GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", + node->GetName().c_str()); + return false; + } + return true; +} + +void AddNextIterNodes(const NodePtr &cur_node, + std::unordered_set &out_nodes_before_pass, + GEPass::GraphLevelState &g_state) { + for (auto &node : cur_node->GetOutNodes()) { + if (node == nullptr) { + continue; + } + if(out_nodes_before_pass.erase(node) == 0) { + // after pass node, new output node come up + GELOGI("New output node %s come up after pass %s.", + node->GetName().c_str(), cur_node->GetName().c_str()); + } + + // all in_node seen && all in_node not suspend + if (IsNodeReadyToQueue(node, g_state)) { + g_state.AddNodeToQueueIfNotSeen(node); + } + } + + // + for (const auto &node : out_nodes_before_pass) { + // A-->B-->C if B was + // unlink edge may happend, add these node to queue if needed + if (node->GetInAllNodes().empty() && IsNodeReadyToQueue(node, g_state)) { + GELOGI("Node %s may lost from cur node, add to queue if not seen.", + node->GetName().c_str(), cur_node->GetName().c_str()); + g_state.AddNodeToQueueIfNotSeen(node); + } + } +} + +void AddImmediateRepassNodesToQueue(NodePtr &cur_node, + std::unordered_map re_pass_imm_nodes_to_pass_names, + GEPass::GraphLevelState &g_state) { + for (const auto &node_2_pass_names : re_pass_imm_nodes_to_pass_names) { + auto imme_repass_node = node_2_pass_names.first; + if (imme_repass_node == nullptr) { + GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s", + node_2_pass_names.second.c_str(), + cur_node->GetName().c_str(), cur_node->GetType().c_str()); + continue; + } + if (g_state.nodes_passed.count(imme_repass_node) > 0) { + GELOGD("The node %s specified by pass %s has been passed, it will repass immediately", + imme_repass_node->GetName().c_str(), node_2_pass_names.second.c_str()); + g_state.AddNodeToQueueFront(imme_repass_node); + continue; + } + GELOGW("The node %s specified by pass %s has un-passed, it will not repass immediately", + node_2_pass_names.first->GetName().c_str(), node_2_pass_names.second.c_str()); + } +} + +void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) { + for (auto &node : g_state.nodes_last) { + if (node->IsAllInNodesSeen(g_state.nodes_seen)) { + g_state.AddNodeToQueueIfNotSeen(node); + } + } + g_state.nodes_last.clear(); +} + +void AddResumeNodesToQueue(const std::unordered_map resume_node_2_pass_names, + GEPass::GraphLevelState &g_state) { + // Now base pass doesnt record the order of suspend & resume, so we dont know which one come first in a node pass. + // Here if one node pass suspend and resume a node ,consider it resume that node. + // Better way to record the order, and here suspend or resume in order. + for (const auto &node_2_pass_names : resume_node_2_pass_names) { + auto node = node_2_pass_names.first; + if (g_state.nodes_suspend.erase(node) > 0) { + if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) { + g_state.nodes.push_back(node); + GELOGD("Node %s has been resumed by pass %s, and add to pass queue", + node->GetName().c_str(), node_2_pass_names.second.c_str()); + } + } + } +} + +void PushToRePassIfSeen(NodePtr &node, const std::pair &name_to_pass, + std::unordered_set &nodes_seen, const std::vector &nodes_to_re_pass, + GEPass::RepassLevelState &rp_state) { + for (const auto &node_to_re_pass : nodes_to_re_pass) { + if (node_to_re_pass == nullptr) { + GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), + node->GetName().c_str(), node->GetType().c_str()); + continue; + } + if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { + if (rp_state.AddNodeToRepass(node_to_re_pass)) { + GELOGD("The node %s will be re-pass.", node_to_re_pass->GetName().c_str()); + continue; + } + GELOGD("Node %s has been added to repass queue, no need to add again.", node_to_re_pass->GetName().c_str()); + } else { + GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); + } + } +} + +void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) { + for (auto &name_to_pass : names_to_pass) { + name_to_pass.second->SetOption(option, ""); + } +} + +void ClearOption(NamesToPass names_to_pass) { + for (auto &name_to_pass : names_to_pass) { + name_to_pass.second->ClearOptions(); + } +} +} // namespace + +Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map, + bool is_repass_io_immediately) { + if (node == nullptr) { + REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); + GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); + return FAILED; + } + GELOGI("Prepare to isolate and delete node, name:%s, type:%s.", node->GetName().c_str(), + node->GetType().c_str()); + ComputeGraphPtr graph = node->GetOwnerComputeGraph(); + if (graph == nullptr) { + REPORT_INNER_ERROR("E19999", "The owner graph of node:%s must not be null.", node->GetName().c_str()); + GELOGE(FAILED, "[Get][OwnerComputeGraph] failed, The owner graph of node:%s must not be null.", + node->GetName().c_str()); + return FAILED; + } + + is_repass_io_immediately ? AddImmediateRePassNodesWithInOut(node) : AddRePassNodesWithInOut(node); + + if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str()); + GELOGE(FAILED, "[Isolate][Node] %s failed.", node->GetName().c_str()); + return FAILED; + } + + if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "call RemoveNodeWithoutRelink for node:%s failed.", node->GetName().c_str()); + GELOGE(FAILED, "[Call][RemoveNodeWithoutRelink] for node:%s failed.", node->GetName().c_str()); + return FAILED; + } + + AddNodeDeleted(node); + return SUCCESS; +} + +Status GEPass::Run(const NamesToPass &names_to_passes) { + if (graph_ == nullptr) { + REPORT_INNER_ERROR("E19999", "graph_ is nullptr, check invalid."); + GELOGE(INTERNAL_ERROR, "[Check][Param] The graph is nullptr"); + return INTERNAL_ERROR; + } + if (names_to_passes.empty()) { + GELOGW("No passes input, the GEPass will do nothing"); + return INTERNAL_ERROR; + } + for (const auto &name_to_pass : names_to_passes) { + if (name_to_pass.second == nullptr) { + GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s)", name_to_pass.first.c_str()); + return INTERNAL_ERROR; + } + } + + if (depth_ > kMaxRecursiveDepth) { + GELOGE(PARAM_INVALID, + "[Check][Param] The pass for root graph %s will be terminated because too many nesting" + " levels(%d) of subgraphs, last subgraph is %s", + root_graph_->GetName().c_str(), depth_, graph_->GetName().c_str()); + return PARAM_INVALID; + } + + return RunPassesOneGraph(names_to_passes); + // todo debug mode is on, find first node in topo order which is not passed. and give a warning +} + +void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names_to_pass) { + for (auto &name_to_pass : names_to_pass) { + name_to_pass.second->OnStartPassGraph(graph); + } +} + +Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) { + std::unordered_map resume_nodes_to_pass_names; + for (auto &name_to_pass : names_to_passes) { + name_to_pass.second->init(); + auto ret = name_to_pass.second->OnSuspendNodesLeaked(); + if (ret != SUCCESS) { + GELOGE(ret, "Internal error with OnSuspendNodesLeaked on pass %s.", name_to_pass.first.c_str()); + return ret; + } + for (const auto &resume_node : name_to_pass.second->GetNodesResume()){ + resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); + } + } + AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); + return SUCCESS; +} + +Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { + GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); + NotifyPassGraphStart(graph_, names_to_passes); + GraphLevelState g_state; + g_state.re_pass_times = 0; + GetAllNodesNoInputEdge(graph_, g_state); + GELOGD("Start points count %zu", g_state.nodes.size()); + + do { + if (!g_state.nodes_suspend.empty()) { + auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state); + if (ret != SUCCESS) { + // log inside upper function + return ret; + } + if (g_state.nodes.empty()) { + GELOGE(INTERNAL_ERROR, "There are some suspended nodes leaked and no pass resume them."); + return INTERNAL_ERROR; + } + } + auto ret = RunPassesGraphRepass(names_to_passes, g_state); + if (ret != SUCCESS) { + return ret; + } + } while (!g_state.nodes_suspend.empty()); + + return SUCCESS; +} + + +Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state) { + RepassLevelState rp_state; + do { + for (auto &node : rp_state.nodes_re_pass) { + if (rp_state.nodes_re_pass_set.count(node) > 0) { + GELOGD("Add node %s to queue for re-pass", node->GetName().c_str()); + g_state.AddNodeToQueue(node); + } + } + rp_state.ClearRepass(); + + while (!g_state.nodes.empty()) { + auto node = g_state.PopFront(); + if (g_state.nodes_deleted.count(node) > 0) { + GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); + continue; + } + rp_state.EraseNodeFromRepass(node); + g_state.nodes_seen.insert(node.get()); + + // collect out nodes before pass + std::unordered_set out_nodes_before_pass; + for (const auto &out_node : node->GetOutNodes()) { + out_nodes_before_pass.insert(out_node); + } + auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; + } + AddNextIterNodes(node, out_nodes_before_pass, g_state); + + } + AddLastNodesToQueue(g_state); + } while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes); + + if (g_state.re_pass_times == kMaxRePassTimes) { + GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); + } + GELOGD("All passes runs end"); + return SUCCESS; +} + +Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { + auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); + has_sub_graph = false; + for (const auto &name : sub_graph_names) { + auto graph = root_graph_->GetSubgraph(name); + if (graph == nullptr) { + GELOGW("Can not find the sub graph %s from node %s, the pass-process will skip it", + name.c_str(), node->GetName().c_str()); + continue; + } + has_sub_graph = true; + GELOGI("Begin to run passes on the sub graph %s of node %s", name.c_str(), node->GetName().c_str()); + GEPass pass(graph, root_graph_, depth_ + 1); + auto ret = pass.Run(names_to_passes); + if (ret != SUCCESS) { + GELOGE(ret, "[Run][Passes] for sub graph:%s from node:%s failed", name.c_str(), node->GetName().c_str()); + return ret; + } + } + return SUCCESS; +} + +Status GEPass::RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, + GraphLevelState &g_state, RepassLevelState &rp_state) { + auto ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; + } + + bool has_sub_graph = false; + ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); + return ret; + } + + if (has_sub_graph) { + GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); + SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); + ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; + } + + // There is only one option scene, so set and clear options around the `RunPasses` func. + // if there are more than one scene to set options, the `ClearOption` function + // should be called each time at the begin of the iteration + ClearOption(names_to_passes); + } + return SUCCESS; +} + +Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, + RepassLevelState &rp_state) { + if (node == nullptr) { + REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); + GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); + return FAILED; + } + GELOGD("Begin to run pass for node %s", node->GetName().c_str()); + for (const auto &name_to_pass : names_to_passes) { + GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); + name_to_pass.second->init(); + auto result = name_to_pass.second->Run(node); + if (result != SUCCESS) { + REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", + name_to_pass.first.c_str(), node->GetName().c_str(), result); + GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result " + "%u, the passes will be terminated immediately.", + name_to_pass.first.c_str(), node->GetName().c_str(), result); + return result; + } + if (name_to_pass.second->GetNodesDeleted().count(node) > 0) { + GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), + name_to_pass.first.c_str()); + break; + } + } + + g_state.nodes_passed.insert(node); + + std::unordered_map re_pass_imm_nodes_to_pass_names; + std::unordered_map resume_nodes_to_pass_names; + // if muti psss repass one same node, it will add to queue many times, so collect and duplicate + for (const auto &name_to_pass : names_to_passes) { + PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen, + name_to_pass.second->GetNodesNeedRePass(), + rp_state); + // collect imm_node && resume_node among these passes + for (const auto &imm_node : name_to_pass.second->GetNodesNeedRePassImmediately()){ + re_pass_imm_nodes_to_pass_names[imm_node].append(name_to_pass.first + ","); + } + for (const auto &resume_node : name_to_pass.second->GetNodesResume()){ + resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); + } + + for (const auto &suspend_node : name_to_pass.second->GetNodesSuspend()) { + GELOGD("The iteration suspend of node %s has been set by pass %s", suspend_node->GetName().c_str(), + name_to_pass.first.c_str()); + g_state.nodes_suspend.insert(suspend_node); + } + const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); + g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); + } + + AddImmediateRepassNodesToQueue(node, re_pass_imm_nodes_to_pass_names, g_state); + AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); + + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/base_pass.h b/ge/graph/passes/base_pass.h index d0f125b2..093e2dce 100644 --- a/ge/graph/passes/base_pass.h +++ b/ge/graph/passes/base_pass.h @@ -22,7 +22,6 @@ #include #include #include - #include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "graph/compute_graph.h" @@ -40,6 +39,7 @@ enum NodePassOption { }; class BaseNodePass { + // todo comments public: /// /// Optimize on one node. the function can add nodes to the graph, change @@ -51,7 +51,7 @@ class BaseNodePass { virtual ~BaseNodePass() = default; - const std::unordered_set &GetNodesNeedRePass() { return nodes_need_re_pass_; } + const std::vector &GetNodesNeedRePass() { return nodes_need_re_pass_; } const std::unordered_set &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } @@ -61,23 +61,32 @@ class BaseNodePass { const std::unordered_set &GetNodesResume() { return nodes_resume_; } + virtual Status OnSuspendNodesLeaked() { return SUCCESS; } + void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } void ClearOptions() { options_.clear(); } void init() { nodes_need_re_pass_.clear(); - nodes_deleted_.clear(); nodes_need_re_pass_immediately_.clear(); + nodes_deleted_.clear(); nodes_suspend_.clear(); nodes_resume_.clear(); } + virtual void OnStartPassGraph(const ComputeGraphPtr &graph) { + current_graph_name_ = graph->GetName(); + } + protected: - Status IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map); + const string &GetCurrentGraphName() const { + return current_graph_name_; + } + Status IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map, bool is_repass_io_immediately = false); - Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list &io_map) { - return IsolateAndDeleteNode(node, std::vector(io_map)); + Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list &io_map, bool is_repass_io_immediately = false) { + return IsolateAndDeleteNode(node, std::vector(io_map), is_repass_io_immediately); } /// @@ -86,7 +95,7 @@ class BaseNodePass { /// optimized by other passes, call this function. /// @param node /// - void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); } + void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.emplace_back(node); } /// /// Add a node to be optimized immediately again. If you add a new node to the graph, or @@ -101,14 +110,30 @@ class BaseNodePass { /// @param node /// void AddRePassNodesWithInOut(const NodePtr &node) { + auto in_nodes = node->GetInNodes(); + for (auto &in_node : in_nodes) { + AddRePassNode(in_node); + } AddRePassNode(node); auto out_nodes = node->GetOutNodes(); for (auto &out_node : out_nodes) { AddRePassNode(out_node); } + } + + /// + /// Add a node and it's input/output data nodes to be optimized immediately again. + /// @param node + /// + void AddImmediateRePassNodesWithInOut(const NodePtr &node) { auto in_nodes = node->GetInNodes(); for (auto &in_node : in_nodes) { - AddRePassNode(in_node); + AddImmediateRePassNode(in_node); + } + AddImmediateRePassNode(node); + auto out_nodes = node->GetOutNodes(); + for (auto &out_node : out_nodes) { + AddImmediateRePassNode(out_node); } } @@ -123,34 +148,27 @@ class BaseNodePass { void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } /// - /// If you suspend a node from the graph, especially following node. The remain - /// iterate passes will stop process on the suspend node(if it can be + /// If you postpone a node from the graph, especially following node. The remain + /// iterate passes will stop process on the postpone node(if it can be /// reached by edge connections) till the last one. Obviously it is a waste of - /// time. You can add the suspend nodes by calling this function, to stop the + /// time. You can add the postpone nodes by calling this function, to stop the /// next iterations. /// @param node /// void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } - /// - /// If you resume a node from the graph, especially following node. The remain - /// iterate passes will continue process on the resume node(if it can be - /// reached by edge connections) till the last one. - /// You can add the resume nodes by calling this function, to resume the - /// next iterations. - /// @param node - /// void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } private: - std::unordered_set nodes_need_re_pass_; + std::vector nodes_need_re_pass_; std::unordered_set nodes_need_re_pass_immediately_; std::unordered_set nodes_deleted_; std::unordered_set nodes_suspend_; std::unordered_set nodes_resume_; std::map options_; + std::string current_graph_name_; }; using NamesToPass = std::vector>; @@ -160,12 +178,75 @@ class GEPass { explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {} virtual ~GEPass() = default; Status Run(const NamesToPass &names_to_passes); + /* +* todo +* OneGraph: nodes_deleted, nodes_seen, nodes_passed, nodes_suspended +* RePass: nodes_re_pass +* GraphOneTime: nodes_last +* NodeOneTime: nodes_re_pass_immediately, nodes_resume +*/ + struct GraphLevelState { + std::unordered_set nodes_deleted; + std::unordered_set nodes_seen; + std::unordered_set nodes_passed; + std::unordered_set nodes_suspend; + std::unordered_set nodes_last; + std::deque nodes; + int re_pass_times; + + void AddNodeToQueueFront(NodePtr node) { + nodes_seen.insert(node.get()); + nodes.emplace_front(std::move(node)); + } + + void AddNodeToQueue(NodePtr node) { + nodes_seen.insert(node.get()); + nodes.emplace_back(std::move(node)); + } + void AddNodeToQueueIfNotSeen(NodePtr node) { + if (nodes_seen.insert(node.get()).second) { + nodes.emplace_back(std::move(node)); + } + } + NodePtr PopFront() { + NodePtr node = nodes.front(); + nodes.pop_front(); + return node; + } + }; + struct RepassLevelState { + std::vector nodes_re_pass; + std::unordered_set nodes_re_pass_set; + bool AddNodeToRepass(NodePtr node) { + if (!nodes_re_pass_set.insert(node).second) { + return false; + } + nodes_re_pass.emplace_back(node); + return true; + } + void EraseNodeFromRepass(NodePtr node) { + nodes_re_pass_set.erase(node); + } + void ClearRepass() { + nodes_re_pass_set.clear(); + nodes_re_pass.clear(); + } + }; + struct GraphOneTimeLevelState { + std::unordered_set nodes_last; + }; private: GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth) : graph_(graph), root_graph_(root_graph), depth_(depth) {} + Status RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, + GraphLevelState &g_state, RepassLevelState &rp_state); + Status RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state); Status RunPassesOneGraph(const NamesToPass &names_to_passes); Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph); + Status RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, + RepassLevelState &rp_state); + Status HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state); ComputeGraphPtr graph_; ComputeGraphPtr root_graph_; int depth_; diff --git a/ge/graph/passes/bitcast_pass.h b/ge/graph/passes/bitcast_pass.h index 34acaf57..60990dea 100644 --- a/ge/graph/passes/bitcast_pass.h +++ b/ge/graph/passes/bitcast_pass.h @@ -17,10 +17,10 @@ #ifndef GE_GRAPH_PASSES_BITCAST_PASS_H_ #define GE_GRAPH_PASSES_BITCAST_PASS_H_ -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "framework/common/types.h" -#include "graph/graph.h" +#include "external/graph/graph.h" #include "graph/op_desc.h" #include "graph/passes/base_pass.h" #include "graph/passes/pass_utils.h" diff --git a/ge/graph/passes/buffer_pool_memory_pass.cc b/ge/graph/passes/buffer_pool_memory_pass.cc index 8a64da59..deb25325 100644 --- a/ge/graph/passes/buffer_pool_memory_pass.cc +++ b/ge/graph/passes/buffer_pool_memory_pass.cc @@ -18,7 +18,7 @@ #include #include -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/op_desc_utils.h" diff --git a/ge/graph/passes/buffer_pool_memory_pass.h b/ge/graph/passes/buffer_pool_memory_pass.h index e3d1c159..89fc5363 100644 --- a/ge/graph/passes/buffer_pool_memory_pass.h +++ b/ge/graph/passes/buffer_pool_memory_pass.h @@ -18,7 +18,7 @@ #define GE_GRAPH_PASSES_BUFFER_POOL_MEMORY_PASS_H_ #include -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/cast_remove_pass.cc b/ge/graph/passes/cast_remove_pass.cc index 564b311d..1e1f4eb4 100644 --- a/ge/graph/passes/cast_remove_pass.cc +++ b/ge/graph/passes/cast_remove_pass.cc @@ -18,7 +18,7 @@ #include #include #include "framework/common/debug/ge_log.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" diff --git a/ge/graph/passes/cast_translate_pass.cc b/ge/graph/passes/cast_translate_pass.cc index d49424c8..704faeda 100644 --- a/ge/graph/passes/cast_translate_pass.cc +++ b/ge/graph/passes/cast_translate_pass.cc @@ -22,7 +22,7 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/passes/pass_utils.h" #include "graph/utils/node_utils.h" diff --git a/ge/graph/passes/common_subexpression_elimination_pass.cc b/ge/graph/passes/common_subexpression_elimination_pass.cc index 852ed98a..c41a5cf5 100644 --- a/ge/graph/passes/common_subexpression_elimination_pass.cc +++ b/ge/graph/passes/common_subexpression_elimination_pass.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "common_subexpression_elimination_pass.h" +#include "graph/passes/common_subexpression_elimination_pass.h" #include #include diff --git a/ge/graph/passes/common_subexpression_elimination_pass.h b/ge/graph/passes/common_subexpression_elimination_pass.h index 83bfbace..b182f8b9 100644 --- a/ge/graph/passes/common_subexpression_elimination_pass.h +++ b/ge/graph/passes/common_subexpression_elimination_pass.h @@ -16,7 +16,7 @@ #ifndef GE_COMMON_SUBEXPRESSION_ELIMINATION_H_ #define GE_COMMON_SUBEXPRESSION_ELIMINATION_H_ -#include "graph/types.h" +#include "external/graph/types.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/compile_nodes_pass.cc b/ge/graph/passes/compile_nodes_pass.cc index d0dcec16..c5976f11 100755 --- a/ge/graph/passes/compile_nodes_pass.cc +++ b/ge/graph/passes/compile_nodes_pass.cc @@ -19,10 +19,10 @@ #include #include "common/ge/ge_util.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/op_desc.h" using domi::ImplyType; diff --git a/ge/graph/passes/cond_pass.cc b/ge/graph/passes/cond_pass.cc index 116e4f89..47a75cd8 100644 --- a/ge/graph/passes/cond_pass.cc +++ b/ge/graph/passes/cond_pass.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "graph/passes/cond_pass.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" #include "graph/utils/node_utils.h" diff --git a/ge/graph/passes/cond_remove_pass.cc b/ge/graph/passes/cond_remove_pass.cc index 478858a9..91e44458 100644 --- a/ge/graph/passes/cond_remove_pass.cc +++ b/ge/graph/passes/cond_remove_pass.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "graph/passes/cond_remove_pass.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/graph/passes/constant_folding_pass.cc b/ge/graph/passes/constant_folding_pass.cc index 25fe26da..53b14fd5 100644 --- a/ge/graph/passes/constant_folding_pass.cc +++ b/ge/graph/passes/constant_folding_pass.cc @@ -17,20 +17,26 @@ #include "graph/passes/constant_folding_pass.h" #include -#include "graph/operator_factory.h" +#include "external/graph/operator_factory.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" +#include "ge_local_engine/engine/host_cpu_engine.h" #include "init/gelib.h" namespace ge { const int64_t kStartCallNum = 1; const std::string kKernelLibName = "aicpu_tf_kernel"; -// tf_kernel.json opsFlag config const std::string kOpsFlagClose = "0"; -Status RunOpKernelWithCheck(NodePtr &node, - const vector &inputs, - std::vector &outputs) { +const map> &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const { + return statistic_of_ge_constant_folding_; +} +const map> &ConstantFoldingPass::GetOpConstantFoldingPerfStatistic() const { + return statistic_of_op_constant_folding_; +} + +Status ConstantFoldingPass::RunOpKernelWithCheck(NodePtr &node, const vector &inputs, + std::vector &outputs) { std::shared_ptr instance_ptr = ge::GELib::GetInstance(); if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Check][Param] GE is not initialized or is finalized."); @@ -47,15 +53,13 @@ Status RunOpKernelWithCheck(NodePtr &node, if (ops_flag == kOpsFlagClose) { return UNSUPPORTED; } - return FoldingPass::RunOpKernel(node, inputs, outputs); + return RunOpKernel(node, inputs, outputs); } -const map> &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const { - return statistic_of_ge_constant_folding_; -} - -const map> &ConstantFoldingPass::GetOpConstantFoldingPerfStatistic() const { - return statistic_of_op_constant_folding_; +Status ConstantFoldingPass::RunOpKernel(NodePtr &node, + const vector &inputs, + std::vector &outputs) { + return HostCpuEngine::GetInstance().Run(node, inputs, outputs); } Status ConstantFoldingPass::Run(ge::NodePtr &node) { diff --git a/ge/graph/passes/constant_folding_pass.h b/ge/graph/passes/constant_folding_pass.h index 703e6edd..7de48a17 100644 --- a/ge/graph/passes/constant_folding_pass.h +++ b/ge/graph/passes/constant_folding_pass.h @@ -28,6 +28,11 @@ class ConstantFoldingPass : public FoldingPass { Status Run(ge::NodePtr &node) override; const std::map> &GetGeConstantFoldingPerfStatistic() const; const std::map> &GetOpConstantFoldingPerfStatistic() const; + + static Status RunOpKernel(NodePtr &node, const vector &inputs, vector &outputs); + static Status RunOpKernelWithCheck(NodePtr &node, const vector &inputs, + std::vector &outputs); + private: std::map> statistic_of_op_constant_folding_; std::map> statistic_of_ge_constant_folding_; diff --git a/ge/graph/passes/constant_fuse_same_pass.h b/ge/graph/passes/constant_fuse_same_pass.h index 3ff2d6b7..a7326c32 100755 --- a/ge/graph/passes/constant_fuse_same_pass.h +++ b/ge/graph/passes/constant_fuse_same_pass.h @@ -22,7 +22,7 @@ #include #include #include "graph/aligned_ptr.h" -#include "graph/types.h" +#include "external/graph/types.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/control_trigger_pass.cc b/ge/graph/passes/control_trigger_pass.cc index 85505dc5..d81edefd 100644 --- a/ge/graph/passes/control_trigger_pass.cc +++ b/ge/graph/passes/control_trigger_pass.cc @@ -17,7 +17,7 @@ #include "graph/passes/control_trigger_pass.h" #include #include "common/ge/ge_util.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/type_utils.h" namespace ge { diff --git a/ge/graph/passes/data_pass.h b/ge/graph/passes/data_pass.h index 519ae046..6f841139 100644 --- a/ge/graph/passes/data_pass.h +++ b/ge/graph/passes/data_pass.h @@ -17,7 +17,7 @@ #ifndef GE_GRAPH_PASSES_DATA_PASS_H_ #define GE_GRAPH_PASSES_DATA_PASS_H_ -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/dimension_adjust_pass.h b/ge/graph/passes/dimension_adjust_pass.h index 7766f140..cba283ed 100755 --- a/ge/graph/passes/dimension_adjust_pass.h +++ b/ge/graph/passes/dimension_adjust_pass.h @@ -17,11 +17,11 @@ #ifndef GE_GRAPH_PASSES_DIMENSION_ADJUST_PASS_H_ #define GE_GRAPH_PASSES_DIMENSION_ADJUST_PASS_H_ -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" -#include "common/types.h" -#include "graph/common/omg_util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" +#include "common/omg_util.h" #include "graph/passes/base_pass.h" #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/passes/dimension_compute_pass.cc b/ge/graph/passes/dimension_compute_pass.cc index 350faf71..a24a6bd4 100755 --- a/ge/graph/passes/dimension_compute_pass.cc +++ b/ge/graph/passes/dimension_compute_pass.cc @@ -20,7 +20,7 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/utils/attr_utils.h" diff --git a/ge/graph/passes/end_of_sequence_add_control_pass.h b/ge/graph/passes/end_of_sequence_add_control_pass.h index dcc65848..32ee0b25 100644 --- a/ge/graph/passes/end_of_sequence_add_control_pass.h +++ b/ge/graph/passes/end_of_sequence_add_control_pass.h @@ -17,7 +17,7 @@ #ifndef GE_GRAPH_PASSES_END_OF_SEQUENCE_ADD_CONTROL_EDGE_PASS_H_ #define GE_GRAPH_PASSES_END_OF_SEQUENCE_ADD_CONTROL_EDGE_PASS_H_ -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/flow_ctrl_pass.cc b/ge/graph/passes/flow_ctrl_pass.cc index 87896dc3..e75a4592 100755 --- a/ge/graph/passes/flow_ctrl_pass.cc +++ b/ge/graph/passes/flow_ctrl_pass.cc @@ -22,7 +22,7 @@ #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "common/ge/ge_util.h" #include "graph/manager/graph_var_manager.h" #include "graph/passes/pass_utils.h" diff --git a/ge/graph/passes/flow_ctrl_pass.h b/ge/graph/passes/flow_ctrl_pass.h index 74f3cce0..cf1af97a 100755 --- a/ge/graph/passes/flow_ctrl_pass.h +++ b/ge/graph/passes/flow_ctrl_pass.h @@ -20,7 +20,7 @@ #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/folding_pass.cc b/ge/graph/passes/folding_pass.cc index c0a0f2a2..819c3b40 100755 --- a/ge/graph/passes/folding_pass.cc +++ b/ge/graph/passes/folding_pass.cc @@ -28,8 +28,6 @@ #include "inc/kernel.h" #include "inc/kernel_factory.h" #include "graph/debug/ge_attr_define.h" -#include "ge_local_engine/engine/host_cpu_engine.h" - namespace ge { namespace folding_pass { @@ -123,12 +121,6 @@ NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tens } } // namespace -Status FoldingPass::RunOpKernel(NodePtr &node, - const vector &inputs, - std::vector &outputs) { - return HostCpuEngine::GetInstance().Run(node, inputs, outputs); -} - Status FoldingPass::Folding(NodePtr &node, vector &outputs) { GE_CHECK_NOTNULL(node); GELOGD("begin folding node:%s", node->GetName().c_str()); diff --git a/ge/graph/passes/folding_pass.h b/ge/graph/passes/folding_pass.h index 745cffd7..c461ff5c 100755 --- a/ge/graph/passes/folding_pass.h +++ b/ge/graph/passes/folding_pass.h @@ -34,8 +34,6 @@ bool IsNoNeedConstantFolding(const NodePtr &node); using IndexsToAnchors = std::map>; class FoldingPass : public BaseNodePass { - public: - static Status RunOpKernel(NodePtr &node, const vector &inputs, vector &outputs); protected: Status Folding(NodePtr &node, vector &outputs); private: diff --git a/ge/graph/passes/for_pass.cc b/ge/graph/passes/for_pass.cc index 7d09f370..260e6ea0 100644 --- a/ge/graph/passes/for_pass.cc +++ b/ge/graph/passes/for_pass.cc @@ -16,7 +16,7 @@ #include "graph/passes/for_pass.h" #include "common/ge/ge_util.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "framework/common/ge_inner_error_codes.h" diff --git a/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc b/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc index ec7b2388..280afb6f 100644 --- a/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc +++ b/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc @@ -21,7 +21,7 @@ #include #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" #include "graph/utils/node_utils.h" diff --git a/ge/graph/passes/fuse_data_nodes_with_common_input_pass.h b/ge/graph/passes/fuse_data_nodes_with_common_input_pass.h index 9ff6ab89..33543ded 100755 --- a/ge/graph/passes/fuse_data_nodes_with_common_input_pass.h +++ b/ge/graph/passes/fuse_data_nodes_with_common_input_pass.h @@ -20,7 +20,7 @@ #include #include #include -#include "graph/types.h" +#include "external/graph/types.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/get_original_format_pass.cc b/ge/graph/passes/get_original_format_pass.cc index 670cd50c..0da4c5cc 100644 --- a/ge/graph/passes/get_original_format_pass.cc +++ b/ge/graph/passes/get_original_format_pass.cc @@ -18,14 +18,14 @@ #include -#include "common/debug/log.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/debug/log.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/omg/omg_inner_types.h" #include "graph/utils/attr_utils.h" #include "graph/utils/op_desc_utils.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" using domi::DOMI_TENSOR_NCHW; using domi::DOMI_TENSOR_NHWC; diff --git a/ge/graph/passes/global_step_insert_pass.cc b/ge/graph/passes/global_step_insert_pass.cc index f27641fc..ada4e12a 100755 --- a/ge/graph/passes/global_step_insert_pass.cc +++ b/ge/graph/passes/global_step_insert_pass.cc @@ -24,14 +24,9 @@ #include "framework/common/util.h" #include "graph/debug/ge_attr_define.h" #include "common/ge/ge_util.h" -#include "graph/manager/graph_var_manager.h" #include "graph/passes/pass_utils.h" #include "graph/ge_context.h" -namespace { -const char *const kFlagOff = "0"; -} // namespace - namespace ge { NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, const string &node_type, @@ -80,13 +75,6 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, } Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { - // run_flag off means offline, no need insert global step node which type is variable - std::string run_flag; - if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == kFlagOff) { - GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), - compute_graph->GetName().c_str()); - return SUCCESS; - } NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT); if (output_node == nullptr) { GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); diff --git a/ge/graph/passes/global_step_insert_pass.h b/ge/graph/passes/global_step_insert_pass.h index da83e93a..16be3d4a 100755 --- a/ge/graph/passes/global_step_insert_pass.h +++ b/ge/graph/passes/global_step_insert_pass.h @@ -20,7 +20,7 @@ #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/guarantee_const_pass.cc b/ge/graph/passes/guarantee_const_pass.cc index 1d369f38..06bc821c 100644 --- a/ge/graph/passes/guarantee_const_pass.cc +++ b/ge/graph/passes/guarantee_const_pass.cc @@ -19,9 +19,9 @@ #include #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" -#include "common/types.h" -#include "graph/common/omg_util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" +#include "common/omg_util.h" #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/graph/passes/hccl_continuous_memcpy_pass.cc b/ge/graph/passes/hccl_continuous_memcpy_pass.cc index 56cbb005..7f4597b3 100644 --- a/ge/graph/passes/hccl_continuous_memcpy_pass.cc +++ b/ge/graph/passes/hccl_continuous_memcpy_pass.cc @@ -18,9 +18,9 @@ #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" #include "framework/common/types.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/passes/hccl_continuous_memcpy_pass.h b/ge/graph/passes/hccl_continuous_memcpy_pass.h index 5fbb6fd0..d710531d 100644 --- a/ge/graph/passes/hccl_continuous_memcpy_pass.h +++ b/ge/graph/passes/hccl_continuous_memcpy_pass.h @@ -20,7 +20,7 @@ #include #include -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/hccl_group_pass.cc b/ge/graph/passes/hccl_group_pass.cc index bbfd9b56..35baade6 100644 --- a/ge/graph/passes/hccl_group_pass.cc +++ b/ge/graph/passes/hccl_group_pass.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "hccl_group_pass.h" +#include "graph/passes/hccl_group_pass.h" #include #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" diff --git a/ge/graph/passes/hccl_memcpy_pass.cc b/ge/graph/passes/hccl_memcpy_pass.cc index d56ee342..2d83bf51 100755 --- a/ge/graph/passes/hccl_memcpy_pass.cc +++ b/ge/graph/passes/hccl_memcpy_pass.cc @@ -18,9 +18,9 @@ #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" #include "framework/common/types.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/passes/hccl_memcpy_pass.h b/ge/graph/passes/hccl_memcpy_pass.h index b75b27d1..e6ee519b 100755 --- a/ge/graph/passes/hccl_memcpy_pass.h +++ b/ge/graph/passes/hccl_memcpy_pass.h @@ -20,7 +20,7 @@ #include #include -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/hccl_tailing_optimization_pass.cc b/ge/graph/passes/hccl_tailing_optimization_pass.cc index e1e2f276..fe606067 100644 --- a/ge/graph/passes/hccl_tailing_optimization_pass.cc +++ b/ge/graph/passes/hccl_tailing_optimization_pass.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hccl_tailing_optimization_pass.h" -#include "graph/common/transop_util.h" +#include "graph/passes/hccl_tailing_optimization_pass.h" +#include "common/transop_util.h" namespace ge { Status HcclTailingOptimizationPass::Run(ComputeGraphPtr graph) { diff --git a/ge/graph/passes/identity_pass.cc b/ge/graph/passes/identity_pass.cc index f0653983..0a346bb1 100755 --- a/ge/graph/passes/identity_pass.cc +++ b/ge/graph/passes/identity_pass.cc @@ -19,7 +19,7 @@ #include #include #include "framework/common/debug/ge_log.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/node_utils.h" #include "graph/utils/attr_utils.h" #include "graph/debug/ge_attr_define.h" diff --git a/ge/graph/passes/infer_base_pass.cc b/ge/graph/passes/infer_base_pass.cc new file mode 100644 index 00000000..636cf2ab --- /dev/null +++ b/ge/graph/passes/infer_base_pass.cc @@ -0,0 +1,388 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "infer_base_pass.h" +#include "common/ge/ge_util.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" + +namespace ge { +namespace { +graphStatus FindValidSubgraphNetoutput(const ConstNodePtr &node, const ComputeGraphPtr &sub_graph, NodePtr &netoutput) { + auto sub_nodes = sub_graph->GetDirectNode(); + for (size_t i = sub_nodes.size(); i > 0; --i) { + auto sub_node = sub_nodes.at(i - 1); + if (sub_node->GetType() == NETOUTPUT) { + if (sub_node == nullptr) { + REPORT_INNER_ERROR("E19999", "NetOutput node is null in subgraph %s, parent node %s.", + sub_graph->GetName().c_str(), node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] NetOutput node is null on sub graph %s, parent node %s", + sub_graph->GetName().c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + auto sub_node_opdesc = sub_node->GetOpDesc(); + if (sub_node_opdesc == nullptr) { + REPORT_INNER_ERROR("E19999", "Invalid NetOutput node in subgraph %s, parent node %s, no OpDesc on it", + sub_graph->GetName().c_str(), node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", + sub_graph->GetName().c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + + netoutput = sub_node; + return GRAPH_SUCCESS; + } + } + + REPORT_INNER_ERROR("E19999", "Can not find the NetOutput node in subgraph %s, parent node %s", + sub_graph->GetName().c_str(), node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] Can not find the NetOutput node in subgraph %s, parent node %s", + sub_graph->GetName().c_str(), node->GetName().c_str()); + return GRAPH_FAILED; +} +} // namespace + +Status InferBasePass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + + bool need_infer = NeedInfer(node); + if (!need_infer) { + GELOGD("Node %s does not need to infer.", node->GetName().c_str()); + return SUCCESS; + } + + std::set changed_nodes; + auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Infer and update for node %s failed! ret: %u", node->GetName().c_str(), ret); + return GRAPH_FAILED; + } + + AddChangedNodesImmediateRepass(changed_nodes); + return SUCCESS; +} + +bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; } +void InferBasePass::AddChangedNodesImmediateRepass(const std::set &changed_nodes) { +// need passed_nodes set to solve the problem that multi-input operators do repass in advance. +// when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. + for (const auto &node_ele : changed_nodes) { + AddImmediateRePassNode(node_ele); + } +} + +graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set &changed_nodes) { + graphStatus ret; + if (ContainsSubgraph(node)) { + if (before_subgraph) { + ret = UpdateTensorDescToSubgraphData(node); + } else { + ret = UpdateTensorDescToParentNodeOutput(node); + } + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Update tensor desc failed between parent node %s and subgraphs. ret: %u", node->GetName().c_str(), + ret); + return ret; + } + } + + PrintInOutTensors(node, "before_infer"); + ret = Infer(node); + PrintInOutTensors(node, "after_infer"); + if (ret == GRAPH_NODE_NEED_REPASS) { + // if a node need re_pass, it is not necessary to update peer node input. + changed_nodes.insert(node); + return GRAPH_SUCCESS; + } else if (ret != GRAPH_SUCCESS && ret != GRAPH_NOT_CHANGED) { + GELOGE(ret, "Infer failed for node %s, ret: %u", node->GetName().c_str(), ret); + return ret; + } + + ret = UpdateTensorDescToPeerInputs(node, changed_nodes); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Node %s updates tensor desc to peer input nodes failed! ret: %u", node->GetName().c_str(), ret); + } + GELOGD("Node %s infer and update succeeded .", node->GetName().c_str()); + return ret; +} + +bool InferBasePass::ContainsSubgraph(const NodePtr &node) { + auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); + return !sub_graph_names.empty(); +} + +graphStatus InferBasePass::UpdateTensorDescToPeerInputs(NodePtr &node, std::set &changed_nodes) { + auto op_desc = node->GetOpDesc(); + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { + auto peer_anchor_opdesc = peer_anchor->GetOwnerNode()->GetOpDesc(); + if (peer_anchor_opdesc == nullptr) { + continue; + } + auto peer_input_desc = peer_anchor_opdesc->MutableInputDesc(peer_anchor->GetIdx()); + if (peer_input_desc == nullptr) { + continue; + } + + bool changed = false; + auto ret = UpdateTensorDesc(output_tensor, peer_input_desc, changed); + if (ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Update peer input desc failed, node %s.", node->GetName().c_str()); + GELOGE(ret, "Update peer input desc failed, node %s.", node->GetName().c_str()); + return ret; + } + if (changed) { + changed_nodes.insert(peer_anchor->GetOwnerNode()); + GELOGD("Node %s update peer node succeeded, peer node %s is changed.", node->GetName().c_str(), + peer_anchor->GetOwnerNode()->GetName().c_str()); + } + } + } + return GRAPH_SUCCESS; +} + +std::vector InferBasePass::GetCurNodeSubgraphs(const NodePtr &node) { + std::vector cur_node_subgraph; + auto op_desc = node->GetOpDesc(); + auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + return cur_node_subgraph; + } + + auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + for (const auto &name : sub_graph_names) { + if (name.empty()) { + GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); + continue; + } + auto sub_graph = root_graph->GetSubgraph(name); + if (sub_graph == nullptr) { + GELOGW("The subgrpah %s for node %s is null.", name.c_str(), node->GetName().c_str()); + continue; + } + cur_node_subgraph.emplace_back(sub_graph); + } + return cur_node_subgraph; +} + +graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node) { + auto op_desc = node->GetOpDesc(); + for (const auto &sub_graph : GetCurNodeSubgraphs(node)) { + for (const auto &node_sub : sub_graph->GetDirectNode()) { + if (node_sub->GetType() != DATA) { + continue; + } + + auto data_opdesc = node_sub->GetOpDesc(); + if (data_opdesc == nullptr) { + REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc", + sub_graph->GetName().c_str(), node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", + sub_graph->GetName().c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + int ref_i; + if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute", + sub_graph->GetName().c_str(), node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", + sub_graph->GetName().c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + GELOGD("Subgraph Data node ref_index is %d, parent node is %s.", ref_i, node->GetName().c_str()); + + // In multi-batch, data shape of subgraph is different, no need to refresh. + if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { + GELOGD("While updating subgraph data node, ignore node %s which is created by multi-dims", + data_opdesc->GetName().c_str()); + continue; + } + auto input_desc = op_desc->MutableInputDesc(ref_i); + if (input_desc == nullptr) { + REPORT_INNER_ERROR("E19999", + "The ref index(%d) on the data %s on the sub graph %s " + "parent node %s are incompatible, inputs num %u", + ref_i, node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str(), + node->GetAllInDataAnchorsSize()); + GELOGE(GRAPH_FAILED, + "[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s " + "parent node %s are incompatible, inputs num %u", + ref_i, node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str(), + node->GetAllInDataAnchorsSize()); + return GRAPH_FAILED; + } + GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), + node->GetName().c_str()); + + bool has_tensor_desc_changed = false; + auto data_input_td = data_opdesc->MutableInputDesc(0); + auto ret = UpdateTensorDesc(input_desc, data_input_td, has_tensor_desc_changed); + if (ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s", + node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", + node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str()); + return ret; + } + + auto data_output_td = data_opdesc->MutableOutputDesc(0); + ret = UpdateTensorDesc(input_desc, data_output_td, has_tensor_desc_changed); + if (ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s", + node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed", + node_sub->GetName().c_str(), sub_graph->GetName().c_str(), node->GetName().c_str()); + return ret; + } + GELOGD("Parent node %s update subgraph data %s input and output succeed.", node->GetName().c_str(), + data_opdesc->GetName().c_str()); + } + } + return GRAPH_SUCCESS; +} + +graphStatus InferBasePass::UpdateTensorDescToParentNodeOutput(NodePtr &node) { + std::vector> ref_out_tensors(node->GetAllOutDataAnchorsSize()); + + for (const auto &sub_graph : GetCurNodeSubgraphs(node)) { + NodePtr netoutput; + auto ret = FindValidSubgraphNetoutput(node, sub_graph, netoutput); + if (ret != GRAPH_SUCCESS) { + return ret; + } + + auto netoutput_opdesc = netoutput->GetOpDesc(); + for (auto &netoutput_in_anchor : netoutput->GetAllInDataAnchors()) { + auto netoutput_in_desc = netoutput_opdesc->MutableInputDesc(netoutput_in_anchor->GetIdx()); + if (netoutput_in_desc == nullptr) { + REPORT_INNER_ERROR("E19999", + "Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", + sub_graph->GetName().c_str(), node->GetName().c_str(), netoutput_in_anchor->GetIdx()); + GELOGE(GRAPH_FAILED, + "[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", + sub_graph->GetName().c_str(), node->GetName().c_str(), netoutput_in_anchor->GetIdx()); + return GRAPH_FAILED; + } + GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", netoutput_in_anchor->GetIdx(), + netoutput_in_desc->GetShape().GetDimNum()); + int ref_i; + if (!AttrUtils::GetInt(netoutput_in_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. + continue; + } + GELOGI("Parent node index of edge desc is %d", ref_i); + if (ref_i < 0 || static_cast(ref_i) >= node->GetAllOutDataAnchorsSize()) { + REPORT_INNER_ERROR("E19999", + "Invalid ref_index %d of parent node %s, ref_index should less than %u.", ref_i, + node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); + GELOGE(GRAPH_FAILED, + "[Get][Ref_index] Invalid ref_index %d of parent node %s, ref_index should less than %u.", ref_i, + node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); + return GRAPH_FAILED; + } + ref_out_tensors[ref_i].emplace_back(netoutput_in_desc); + } + } + + return UpdateParentNodeContainsSubgraphs(node, ref_out_tensors); +} + +graphStatus InferBasePass::UpdateParentNodeContainsSubgraphs( + NodePtr &node, const std::vector> &ref_out_tensors) { + for (size_t i = 0; i < ref_out_tensors.size(); i++) { + if (ref_out_tensors[i].empty()) { + REPORT_CALL_ERROR("E19999", "Parent node %s ref_index %zu subgraph output tensor list is empty.", + node->GetName().c_str(), i); + GELOGE(GRAPH_FAILED, "[Param][check] Parent node %s ref_index %zu subgraph output tensor list is empty.", + node->GetName().c_str(), i); + return GRAPH_FAILED; + } + auto node_op_desc = node->GetOpDesc(); + auto node_output_td = node_op_desc->MutableOutputDesc(i); + if (node_output_td == nullptr) { + REPORT_CALL_ERROR("E19999", "Node %s output %zu tensor desc is null.", node->GetName().c_str(), i); + GELOGE(GRAPH_FAILED, "[Param][check] Node %s output %zu tensor desc is null.", node->GetName().c_str(), i); + return GRAPH_FAILED; + } + + graphStatus ret; + if (node_op_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { + ret = UpdateOutputFromSubgraphsForMultiDims(ref_out_tensors[i], node_output_td); + } else { + ret = UpdateOutputFromSubgraphs(ref_out_tensors[i], node_output_td); + } + if (ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Node %s update output %zu tensor desc failed. ret: %u", node->GetName().c_str(), i, + ret); + GELOGE(GRAPH_FAILED, "[Param][check] Node %s update output %zu tensor desc failed. ret: %u", + node->GetName().c_str(), i, ret); + return ret; + } + GELOGD("Parent node %s successfully updated the output tensors from subgraphs.", node->GetName().c_str()); + } + return GRAPH_SUCCESS; +} + +void InferBasePass::PrintInOutTensors(const NodePtr &node, const std::string &phase) { + if (!IsLogEnable(GE, DLOG_DEBUG)) { + return; + } + if (node == nullptr) { + REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); + GELOGE(GRAPH_FAILED, "[Check][Param] node is null"); + return; + } + ge::OpDescPtr op_desc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "Node has no opdesc, check invalid"); + GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return ); + std::stringstream ss; + ss << "{"; + int32_t in_idx = 0; + for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { + if (input_desc == nullptr) { + in_idx++; + continue; + } + if (in_idx > 0) { + ss << " "; + } + ss << "input_" << in_idx << " tensor: "; + ss << SerialTensorInfo(input_desc); + in_idx++; + } + int32_t out_idx = 0; + for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { + if (output_desc == nullptr) { + out_idx++; + continue; + } + ss << " "; + ss << "output_" << out_idx << " tensor: "; + ss << SerialTensorInfo(output_desc); + out_idx++; + } + ss << "}"; + GELOGD("Infer tensor dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str()); +} +} // namespace ge diff --git a/ge/graph/passes/infer_base_pass.h b/ge/graph/passes/infer_base_pass.h new file mode 100644 index 00000000..3900b5db --- /dev/null +++ b/ge/graph/passes/infer_base_pass.h @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GE_GRAPH_PASSES_INFER_BASE_PASS_H_ +#define GE_GRAPH_PASSES_INFER_BASE_PASS_H_ + +#include "graph/passes/base_pass.h" + +namespace ge { +class InferBasePass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; + graphStatus InferAndUpdate(NodePtr &node, bool before_subgraph, std::set &changed_nodes); + void PrintInOutTensors(const NodePtr &node, const std::string &phase); + + protected: + virtual std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const = 0; + virtual bool NeedInfer(const NodePtr &node) const; + virtual graphStatus Infer(NodePtr &node) = 0; + + /** + * Update the output TensorDesc by src TensorDesc. This will be called when updating peer node input desc. + * @param src, input TensorDesc + * @param dst, output TensorDesc to be updated + * @return + */ + virtual graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) = 0; + + /** + * Update the output TensorDesc for nodes which contain subgraphs. + * In dynamic multi-dims/batch/images size scene, the update process maybe different, + * in which case, the `InferBasePass` will call method `UpdateOutputFromSubgraphsForMultiDims` instead. + * @param src, input TensorDesc from NetOutput nodes in all subgraphs + * @param dst, output TensorDesc to be updated + * @return + */ + virtual graphStatus UpdateOutputFromSubgraphs(const std::vector &src, + GeTensorDescPtr &dst) = 0; + virtual graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) = 0; + + private: + void AddChangedNodesImmediateRepass(const std::set &changed_nodes); + bool ContainsSubgraph(const NodePtr &node); + std::vector GetCurNodeSubgraphs(const NodePtr &node); + graphStatus UpdateTensorDescToSubgraphData(NodePtr &node); + graphStatus UpdateTensorDescToParentNodeOutput(NodePtr &node); + graphStatus UpdateParentNodeContainsSubgraphs(NodePtr &node, + const std::vector> &ref_out_tensors); + graphStatus UpdateTensorDescToPeerInputs(NodePtr &node, std::set &changed_nodes); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_INFER_BASE_PASS_H_ diff --git a/ge/graph/passes/infer_value_range_pass.cc b/ge/graph/passes/infer_value_range_pass.cc new file mode 100644 index 00000000..c183a599 --- /dev/null +++ b/ge/graph/passes/infer_value_range_pass.cc @@ -0,0 +1,537 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/infer_value_range_pass.h" +#include "common/formats/utils/formats_trans_utils.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/operator_factory_impl.h" +#include "graph/passes/constant_folding_pass.h" +#include "graph/utils/type_utils.h" +#include "common/ge/ge_util.h" + +using std::unique_ptr; +namespace ge { +namespace { +#define GET_DATA_BY_DTYPE(DTYPE, TYPE) \ + case (DTYPE): \ + ConstructValueRange(lower_boundary_tensor, upper_boundary_tensor, output_tensor_value_range); \ + break; + +void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { + std::vector> shape_range; + (void)desc->GetShapeRange(shape_range); + desc_str += formats::RangeToString(shape_range); + shape_range.clear(); + (void)desc->GetOriginShapeRange(shape_range); + desc_str += ","; + desc_str += formats::RangeToString(shape_range); + shape_range.clear(); +} + +Status RunCpuKernelForValueRange(NodePtr &node, const vector &inputs, + std::vector &outputs) { + // RunOpKernelWithCheck, RunOpKernel for test + auto ret = ConstantFoldingPass::RunOpKernel(node, inputs, outputs); + if (ret != SUCCESS) { + auto op_kernel = folding_pass::GetKernelByType(node); + if (op_kernel == nullptr) { + GELOGW("Calculate value range failed, no op kernel for node %s type %s", node->GetName().c_str(), + node->GetType().c_str()); + return NOT_CHANGED; + } + + ret = op_kernel->Compute(node->GetOpDesc(), inputs, outputs); + if (ret != SUCCESS) { + GELOGW("Calculate value range failed, node %s run cpu kernel failed.", node->GetName().c_str()); + return NOT_CHANGED; + } + } + GELOGI("Node %s type %s, run cpu kernel success.", node->GetName().c_str(), node->GetType().c_str()); + return SUCCESS; +} +} // namespace + +graphStatus InferValueRangePass::Infer(NodePtr &node) { + auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType()); + + // Use registered func to calculate value range + if (!infer_value_range_param.use_cpu_kernel) { + if (infer_value_range_param.infer_value_func == nullptr) { + GELOGW("The registered func of node %s to infer value range is nullptr.", node->GetName().c_str()); + return GRAPH_NOT_CHANGED; + } + Operator op = OpDescUtils::CreateOperatorFromNode(node); + auto ret = node->GetOpDesc()->CallInferValueRangeFunc(op); + if (ret != GRAPH_SUCCESS) { + GELOGW("Node %s call infer value range func failed, ret: %u.", node->GetName().c_str(), ret); + return GRAPH_NOT_CHANGED; + } + GELOGD("Node %s infer value range func succeed by registered func.", node->GetName().c_str()); + return GRAPH_SUCCESS; + } + + // Deal with scenes with unknown value range + bool has_unknown_value_range = false; + bool has_zero_in_value_range = false; + CheckInputValueRange(node, has_unknown_value_range, has_zero_in_value_range); + if (has_unknown_value_range) { + if (has_zero_in_value_range) { + // When there is zero in input value range, it is unreasonable to always set output value range {1:-1}. + GELOGW("Node %s has -1 and 0 in value range, skip setting value range.", node->GetName().c_str()); + return GRAPH_NOT_CHANGED; + } + GELOGI("Node %s has unknown value range in input tensors, set value range {1:-1}, and skip cpu kernel.", + node->GetName().c_str()); + return GenerateWorstValueRange(node); + } + + // Use CPU kernel func to calculate value range + auto ret = ConstructInputAndInferValueRange(node); + if (ret != GRAPH_SUCCESS) { + GELOGW("Use CPU kernel to calculate value range failed. node: %s, ret: %u", node->GetName().c_str(), ret); + return GRAPH_NOT_CHANGED; + } + GELOGD("Node %s infer value range func succeed by running cpu kernel.", node->GetName().c_str()); + return GRAPH_SUCCESS; +} + +std::string InferValueRangePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const { + std::stringstream ss; + ss << "["; + ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),"; + string range_str; + SerialShapeRange(tensor_desc, range_str); + ss << "(shape_range:" << range_str << "),"; + std::vector> value_range; + (void)tensor_desc->GetValueRange(value_range); + string value_range_str = formats::RangeToString(value_range); + ss << "(value_range:" << value_range_str << ")]"; + return ss.str(); +} + +bool InferValueRangePass::NeedInfer(const NodePtr &node) const { + auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType()); + if (!infer_value_range_param.is_initialized) { + GELOGD("Node %s does not register func to infer value range, skip infer_value_range_pass.", + node->GetName().c_str()); + return false; + } + + if (infer_value_range_param.when_call == INPUT_IS_DYNAMIC) { + // Only do infer for node that all inputs are dynamic, such as shape + if (InputIsDynamic(node)) { + return true; + } + GELOGD("Node %s register func to infer value range and when_call is INPUT_IS_DYNAMIC, but check input failed.", + node->GetName().c_str()); + } else if (infer_value_range_param.when_call == INPUT_HAS_VALUE_RANGE) { + // Only do infer for node that all inputs have value_range or node type of inputs is constant/const + if (InputIsConstOrHasValueRange(node)) { + return true; + } + GELOGD("Node %s register func to infer value range and when_call is INPUT_HAS_VALUE_RANGE, but check input failed.", + node->GetName().c_str()); + } + GELOGD("Node %s does not need to infer value range, skip infer_value_range_pass.", node->GetName().c_str()); + return false; +} + +bool InferValueRangePass::InputIsDynamic(const NodePtr &node) const{ + bool input_is_dynamic = false; + auto cur_op_desc = node->GetOpDesc(); + for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) { + auto dims = input_desc->GetShape().GetDims(); + for (auto dim : dims) { + if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { + input_is_dynamic = true; + break; + } + } + } + return input_is_dynamic; +} + +bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) const { + bool input_is_const_or_has_value_range = true; + auto cur_op_desc = node->GetOpDesc(); + auto in_data_anchors = node->GetAllInDataAnchors(); + for (size_t i = 0; i < in_data_anchors.size(); ++i) { + auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + continue; + } + auto peer_node = peer_out_anchor->GetOwnerNode(); + if (peer_node == nullptr || peer_node->GetOpDesc() == nullptr) { + continue; + } + if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) { + continue; + } + + const auto &input_desc = cur_op_desc->GetInputDesc(i); + std::vector> value_range; + (void)input_desc.GetValueRange(value_range); + if (value_range.empty()) { + GELOGD("Node %s input %zu does not have value range, skip infer_value_range_pass for current node.", + node->GetName().c_str(), i); + input_is_const_or_has_value_range = false; + break; + } + } + return input_is_const_or_has_value_range; +} + +void InferValueRangePass::CheckInputValueRange(const NodePtr &node, bool &has_unknown_value_range, + bool &has_zero_in_value_range) const { + has_unknown_value_range = false; + has_zero_in_value_range = false; + auto cur_op_desc = node->GetOpDesc(); + for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) { + std::vector> input_desc_value_range; + input_desc->GetValueRange(input_desc_value_range); + if (!input_desc_value_range.empty()) { + for (const auto &range : input_desc_value_range) { + if (range.first == 0 || range.second == 0) { + GELOGD("Node %s input tensors have zero in value range %s.", node->GetName().c_str(), + formats::RangeToString(input_desc_value_range).c_str()); + has_zero_in_value_range = true; + } + if (range.first == -1 || range.second == -1) { + GELOGD("Node %s input tensors have unknown value range, value range is %s.", node->GetName().c_str(), + formats::RangeToString(input_desc_value_range).c_str()); + has_unknown_value_range = true; + } + } + } + } +} + +graphStatus InferValueRangePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { + if (src == nullptr || dst == nullptr) { + REPORT_CALL_ERROR("E19999", "While updating tensor desc, input desc is null."); + GELOGE(GRAPH_FAILED, "[Param][check] While updating tensor desc, input desc is null."); + return GRAPH_FAILED; + } + + changed = false; + std::vector> src_value_range; + std::vector> dst_value_range; + (void)src->GetValueRange(src_value_range); + (void)dst->GetValueRange(dst_value_range); + if (src_value_range != dst_value_range) { + GELOGD("While updating tensor desc, value range has been changed, src value range: %s, dst value range: %s.", + formats::RangeToString(src_value_range).c_str(), formats::RangeToString(dst_value_range).c_str()); + changed = true; + } + + dst->SetValueRange(src_value_range); + return GRAPH_SUCCESS; +} + +graphStatus InferValueRangePass::UpdateOutputFromSubgraphs(const std::vector &src, + GeTensorDescPtr &dst) { + std::vector> ref_out_tensor_value_range; + auto ref_out_tensor = src.at(0); + (void)ref_out_tensor->GetValueRange(ref_out_tensor_value_range); + for (auto &ref_tensor : src) { + std::vector> ref_tensor_value_range; + (void)ref_tensor->GetValueRange(ref_tensor_value_range); + + if (ref_tensor_value_range.size() != ref_out_tensor_value_range.size()) { + GELOGD("Update TensorDesc %s failed, rank of value ranges %s and %s are not the same, skip value range refresh.", + dst->GetName().c_str(), formats::RangeToString(ref_out_tensor_value_range).c_str(), + formats::RangeToString(ref_tensor_value_range).c_str()); + return GRAPH_SUCCESS; + } + + for (size_t j = 0; j < ref_out_tensor_value_range.size(); j++) { + if ((ref_out_tensor_value_range.at(j).first != ref_tensor_value_range.at(j).first) || + (ref_out_tensor_value_range.at(j).second != ref_tensor_value_range.at(j).second)) { + ref_out_tensor_value_range[j] = std::make_pair(1, -1); + } + } + } + GELOGD("While updating output desc from subgraphs, set parent node desc value range %s.", + formats::RangeToString(ref_out_tensor_value_range).c_str()); + dst->SetValueRange(ref_out_tensor_value_range); + return GRAPH_SUCCESS; +} + +graphStatus InferValueRangePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) { + REPORT_INNER_ERROR("E19999", + "Update TensorDesc %s failed. In dynamic multi-dims size scene, there should be no value range.", + dst->GetName().c_str()); + GELOGE(GRAPH_FAILED, + "[Update][TensorDesc] %s failed. In dynamic multi-dims size scene, there should be no value range.", + dst->GetName().c_str()); + return GRAPH_FAILED; +} + +graphStatus InferValueRangePass::GenerateWorstValueRange(NodePtr &node) { + GELOGI("Node %s does not run cpu kernel, because input value range has -1.", node->GetName().c_str()); + OpDescPtr op_desc = node->GetOpDesc(); + for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { + auto output_desc = op_desc->MutableOutputDesc(i); + if (output_desc == nullptr) { + continue; + } + auto output_i_shape = output_desc->GetShape(); + auto output_i_shape_size = output_i_shape.GetShapeSize(); + if (output_i_shape_size < 0) { + GELOGD("Node %s output shape is unknown, cannot infer value range, shape is %s.", node->GetName().c_str(), + formats::ShapeToString(output_i_shape).c_str()); + return GRAPH_NOT_CHANGED; + } + + std::vector> output_i_value_range(output_i_shape_size, {1, -1}); + if (output_i_shape.IsScalar()) { + output_i_value_range.emplace_back(1, -1); + } + output_desc->SetValueRange(output_i_value_range); + GELOGD("Node %s output %zu shape is %s, the generated worst value range is %s.", node->GetName().c_str(), i, + formats::ShapeToString(output_i_shape).c_str(), formats::RangeToString(output_i_value_range).c_str()); + } + return GRAPH_SUCCESS; +} + +template +graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, + GeTensorPtr &output_ptr) { + std::vector> value_range; + (void)tensor_desc.GetValueRange(value_range); + size_t value_range_data_num = value_range.size(); + auto tensor_shape = tensor_desc.GetShape(); + bool value_range_and_tensor_shape_matched = true; + if (tensor_shape.IsScalar()){ + // scalar tensor has only one value_range pair + if (value_range_data_num != 1) { + value_range_and_tensor_shape_matched = false; + } + } else { + // normal tensor, value_range size is equal to tensor shape size. + if (static_cast(value_range_data_num) != tensor_shape.GetShapeSize()) { + value_range_and_tensor_shape_matched = false; + } + } + if (!value_range_and_tensor_shape_matched) { + GELOGW("Input %s value range and tensor shape do not match. Value range size is %zu, tensor shape is %s.", + tensor_desc.GetName().c_str(), value_range_data_num, formats::ShapeToString(tensor_shape).c_str()); + return GRAPH_PARAM_INVALID; + } + + unique_ptr buf(new (std::nothrow) T[value_range_data_num]()); + if (buf == nullptr) { + REPORT_INNER_ERROR("E19999", "New buf failed"); + GELOGE(MEMALLOC_FAILED, "New buf failed"); + return GRAPH_FAILED; + } + for (size_t j = 0; j < value_range_data_num; ++j) { + auto value_range_j = use_floor_value ? value_range[j].first : value_range[j].second; + buf[j] = static_cast(value_range_j); + } + + if (output_ptr->SetData(reinterpret_cast(buf.get()), value_range_data_num * sizeof(T)) != GRAPH_SUCCESS) { + GELOGW("Set data failed while constructing value range input tensor."); + return GRAPH_NOT_CHANGED; + } + return GRAPH_SUCCESS; +} + +graphStatus InferValueRangePass::ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value, + GeTensorPtr &output_ptr) { + graphStatus ret = GRAPH_SUCCESS; + auto data_type = tensor_desc.GetDataType(); + output_ptr->MutableTensorDesc().SetDataType(data_type); + switch (data_type) { + case DT_FLOAT: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_DOUBLE: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_UINT8: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_INT8: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_UINT16: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_INT16: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_INT32: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + case DT_INT64: + ret = ConstructData(tensor_desc, use_floor_value, output_ptr); + break; + default: + GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str()); + ret = GRAPH_PARAM_INVALID; + } + return ret; +} + +vector InferValueRangePass::ConstructInputTensors(const NodePtr &node, bool use_floor_value) { + vector input_tensors; + auto cur_op_desc = node->GetOpDesc(); + auto in_data_anchors = node->GetAllInDataAnchors(); + for (size_t i = 0; i < in_data_anchors.size(); ++i) { + auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + continue; + } + auto peer_node = peer_out_anchor->GetOwnerNode(); + if (peer_node == nullptr) { + continue; + } + + // construct input tensor by constant node + if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) { + vector const_weight = OpDescUtils::MutableWeights(peer_node); + if (const_weight.empty()) { + GELOGW("MutableWeights failed, weight is empty, node: %s(%s)", peer_node->GetName().c_str(), + peer_node->GetType().c_str()); + return vector(); + } + // const/constant op has only one weight + if (const_weight.at(0) == nullptr) { + GELOGW("MutableWeights failed, weight of constant is null, node name: %s(%s)", + peer_node->GetName().c_str(), peer_node->GetType().c_str()); + return vector(); + } + input_tensors.push_back(const_weight.at(0)); + GELOGD("Node %s construct input tensor %zu by constant node.", node->GetName().c_str(), input_tensors.size()); + continue; + } + + // construct input tensor by boundary of value range + const auto &input_tensor_desc = cur_op_desc->GetInputDesc(i); + GeTensorPtr tmp_tensor_ptr = MakeShared(input_tensor_desc); + if (tmp_tensor_ptr == nullptr) { + REPORT_INNER_ERROR("E19999", "Make shared failed"); + GELOGE(MEMALLOC_FAILED, "Make shared failed"); + return vector(); + } + + auto ret = ConstructDataByType(input_tensor_desc, use_floor_value, tmp_tensor_ptr); + if (ret != GRAPH_SUCCESS) { + GELOGW("Construct input tensor by boundary of value range failed for input %s.", + input_tensor_desc.GetName().c_str()); + return vector(); + } + input_tensors.push_back(tmp_tensor_ptr); + GELOGD("Node %s construct input tensor %zu by input desc value range.", node->GetName().c_str(), + input_tensors.size()); + } + + return input_tensors; +} + +graphStatus InferValueRangePass::ConstructInputAndInferValueRange(NodePtr &node) { + auto inputs = ConstructInputTensors(node, true); + if (inputs.empty()) { + return GRAPH_PARAM_INVALID; + } + vector lower_boundary_outputs; + auto ret = RunCpuKernelForValueRange(node, inputs, lower_boundary_outputs); + if (ret != SUCCESS) { + GELOGW("Node %s run cpu kernel failed while calculating value range.", node->GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + + inputs = ConstructInputTensors(node, false); + if (inputs.empty()) { + return GRAPH_PARAM_INVALID; + } + vector upper_boundary_outputs; + ret = RunCpuKernelForValueRange(node, inputs, upper_boundary_outputs); + if (ret != SUCCESS) { + GELOGW("Node %s run cpu kernel failed while calculating value range.", node->GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + + // construct value range from output tensor + OpDescPtr node_desc = node->GetOpDesc(); + std::vector> output_tensor_value_range; + size_t node_output_desc_size = node_desc->GetOutputsSize(); + for (size_t i = 0; i < node_output_desc_size; ++i) { + output_tensor_value_range.clear(); + auto output_tensor_desc = node_desc->MutableOutputDesc(i); + auto output_shape_size = output_tensor_desc->GetShape().GetShapeSize(); + auto lower_boundary_tensor = lower_boundary_outputs[i]; + auto lower_boundary_shape = lower_boundary_tensor->GetTensorDesc().GetShape(); + auto upper_boundary_tensor = upper_boundary_outputs[i]; + auto upper_boundary_shape = upper_boundary_tensor->GetTensorDesc().GetShape(); + if (lower_boundary_shape.GetShapeSize() != output_shape_size || + upper_boundary_shape.GetShapeSize() != output_shape_size) { + GELOGD( + "Cpu kernel result shapes %s, %s and output shape %s do not match, can not infer value range for output %s.", + formats::ShapeToString(lower_boundary_shape).c_str(), formats::ShapeToString(upper_boundary_shape).c_str(), + formats::ShapeToString(output_tensor_desc->GetShape()).c_str(), output_tensor_desc->GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + + auto data_type = output_tensor_desc->GetDataType(); + switch (data_type) { + GET_DATA_BY_DTYPE(DT_INT8, int8_t) + GET_DATA_BY_DTYPE(DT_INT16, int16_t) + GET_DATA_BY_DTYPE(DT_INT32, int32_t) + GET_DATA_BY_DTYPE(DT_INT64, int64_t) + GET_DATA_BY_DTYPE(DT_UINT8, uint8_t) + GET_DATA_BY_DTYPE(DT_UINT16, uint16_t) + GET_DATA_BY_DTYPE(DT_UINT32, uint32_t) + GET_DATA_BY_DTYPE(DT_UINT64, uint64_t) + GET_DATA_BY_DTYPE(DT_FLOAT, float) + GET_DATA_BY_DTYPE(DT_DOUBLE, double) + default: + GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str()); + return GRAPH_PARAM_INVALID; + } + output_tensor_desc->SetValueRange(output_tensor_value_range); + GELOGD("Node %s calculates output %zu value range %s by running cpu kernel.", node->GetName().c_str(), i, + formats::RangeToString(output_tensor_value_range).c_str()); + } + return GRAPH_SUCCESS; +} + +template +void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor, + std::vector> &value_range) { + auto x = reinterpret_cast(left_tensor->GetData().GetData()); + auto y = reinterpret_cast(right_tensor->GetData().GetData()); + if (x == nullptr || y == nullptr) { + GELOGI("Output tensor of cpu kernel does not have data, no way to set value range."); + return; + } + auto left_tensor_shape = left_tensor->GetTensorDesc().GetShape(); + for (auto j = 0; j < left_tensor_shape.GetShapeSize(); ++j) { + auto left = static_cast(*(x + j)); + auto right = static_cast(*(y + j)); + value_range.emplace_back(left, right); + } + + if (left_tensor_shape.IsScalar()) { + GELOGD("When inferring value range, output tensors of cpu kernel are scalar tensors."); + value_range.emplace_back(static_cast(*x), static_cast(*y)); + } +} +} // namespace ge diff --git a/ge/graph/passes/infer_value_range_pass.h b/ge/graph/passes/infer_value_range_pass.h new file mode 100644 index 00000000..503b5a9f --- /dev/null +++ b/ge/graph/passes/infer_value_range_pass.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ +#define GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ + +#include "graph/passes/infer_base_pass.h" + +namespace ge { +class InferValueRangePass : public InferBasePass { + public: + graphStatus Infer(NodePtr &node) override; + + private: + std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; + graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; + graphStatus UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) override; + graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) override; + bool NeedInfer(const NodePtr &node) const override; + + bool InputIsDynamic(const NodePtr &node) const; + bool InputIsConstOrHasValueRange(const NodePtr &node) const; + void CheckInputValueRange(const NodePtr &node, bool &has_unknown_value_range, bool &has_zero_in_value_range) const; + graphStatus GenerateWorstValueRange(NodePtr &node); + template + graphStatus ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr); + graphStatus ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr); + vector ConstructInputTensors(const NodePtr &node, bool use_floor_value); + template + void ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor, + std::vector> &value_range); + graphStatus ConstructInputAndInferValueRange(NodePtr &node); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_INFER_VALUE_RANGE_PASS_H_ diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index b74d1c97..0555929d 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -1,175 +1,369 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/infershape_pass.h" -#include "common/util/error_manager/error_manager.h" -#include "framework/common/debug/ge_log.h" -#include "analyzer/analyzer.h" -#include "framework/common/util.h" -#include "graph/shape_refiner.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" - -namespace ge { - -void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { - desc_str += "["; - std::vector> shape_range; - (void)desc->GetShapeRange(shape_range); - for (const auto &pair : shape_range) { - desc_str += "{"; - desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); - desc_str += "},"; - } - desc_str += "]"; - shape_range.clear(); - (void)desc->GetOriginShapeRange(shape_range); - for (const auto &pair : shape_range) { - desc_str += ",{"; - desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); - desc_str += "},"; - } -} - -std::string GetInTensorInfoWithString(const ge::NodePtr &node) { - ge::OpDescPtr op_desc = node->GetOpDesc(); - std::stringstream ss; - ss << "{"; - int32_t in_idx = 0; - for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { - if (input_desc == nullptr) { - in_idx++; - continue; - } - if (in_idx > 0) { - ss << " "; - } - ss << "input_" << in_idx << " " << "tensor: ["; - ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; - ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; - ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; - ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),"; - ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),"; - ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),"; - string range_str; - SerialShapeRange(input_desc, range_str); - ss << "(shape_range:" << range_str << ")]"; - in_idx++; - } - return ss.str(); -} - -Status InferShapePass::Run(NodePtr &node) { - // kOptimizeAfterSubGraph exist means after subgraph - auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); - if (ret != GRAPH_SUCCESS) { - // select INFERSHAPE failed info - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - auto root_graph = ge::GraphUtils::FindRootGraph(graph); - GE_CHECK_NOTNULL(root_graph); - analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), - analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; - (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); - (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), - root_graph->GetGraphID()); - - REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", - node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); - GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s", - node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); - return GE_GRAPH_INFERSHAPE_FAILED; - } - - GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node)); - bool need_repass = false; - auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass); - if (has_attr) { - if (!OptionExists(kOptimizeAfterSubGraph)) { - return SUCCESS; - } - if (need_repass) { - AddImmediateRePassNode(node); - GELOGD("Node %s need repass immediately.", node->GetName().c_str()); - } else { - // clear attr on while - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - } - } - return SUCCESS; -} - -Status InferShapePass::RePassLoopNode(const NodePtr &node) { - const auto RePassNode = [&](const std::set &re_pass_types) { - for (auto &n : node->GetOutDataNodes()) { - GE_CHECK_NOTNULL(n); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); - if (re_pass_types.count(node_type) > 0) { - AddImmediateRePassNode(n); - (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); - GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); - } - } - return SUCCESS; - }; - - const auto ExProcNode = [&](const std::set &proc_types, - const std::function &proc_func, - const std::string &info) { - for (auto &n : node->GetOutDataNodes()) { - GE_CHECK_NOTNULL(n); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); - if (proc_types.count(node_type) > 0) { - proc_func(this, n); - GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); - } - } - return SUCCESS; - }; - - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type), - "[Get][OriginalType] of node:%s failed.", node->GetName().c_str()); - if (kNextIterationOpTypes.count(node_type) > 0) { - return RePassNode(kMergeOpTypes); // Re-Pass Merge - } - - if (kMergeOpTypes.count(node_type) > 0) { - if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - return RePassNode(kSwitchOpTypes); // Re-Pass Switch - } - return SUCCESS; - } - - if (kSwitchOpTypes.count(node_type) > 0) { - if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit - } else { - return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit - } - } - - return SUCCESS; -} -} // namespace ge +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + *+ + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/infershape_pass.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "analyzer/analyzer.h" +#include "framework/common/util.h" +#include "graph/shape_refiner.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "common/omg_util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" + +#include "external/graph/operator_factory.h" + +namespace ge { +namespace { +constexpr int kSwitchExitAnchorIndex = 0; +constexpr int kSwitchPredAnchorIndex = 1; +void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { + desc_str += "["; + std::vector> shape_range; + (void)desc->GetShapeRange(shape_range); + for (const auto &pair : shape_range) { + desc_str += "{"; + desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); + desc_str += "},"; + } + desc_str += "]"; + shape_range.clear(); + (void)desc->GetOriginShapeRange(shape_range); + for (const auto &pair : shape_range) { + desc_str += ",{"; + desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); + desc_str += "},"; + } +} +void UpdateShapeAndDType(const GeTensorDescPtr &src, GeTensorDescPtr &dst) { + dst->SetOriginShape(src->GetOriginShape()); + dst->SetShape(src->GetShape()); + dst->SetDataType(src->GetDataType()); + dst->SetOriginDataType(src->GetOriginDataType()); + vector> src_shape_range; + src->GetShapeRange(src_shape_range); + dst->SetShapeRange(src_shape_range); + dst->SetOriginShapeRange(src_shape_range); + ge::TensorUtils::SetRealDimCnt(*dst, static_cast(src->GetShape().GetDims().size())); +} +} // namespace + +std::string InferShapePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const { + std::stringstream ss; + ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),"; + ss << "(format:" << TypeUtils::FormatToSerialString(tensor_desc->GetFormat()) << "),"; + ss << "(dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()) << "),"; + ss << "(origin_shape:" << tensor_desc->GetOriginShape().ToString() << "),"; + ss << "(origin_format:" << TypeUtils::FormatToSerialString(tensor_desc->GetOriginFormat()) << "),"; + ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetOriginDataType()) << "),"; + string range_str; + SerialShapeRange(tensor_desc, range_str); + ss << "(shape_range:" << range_str << ")"; + return ss.str(); +} +Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) { + if (node->GetType() != SWITCH) { + return SUCCESS; + } + auto pred_node = NodeUtils::GetInDataNodeByIndex(*node, kSwitchPredAnchorIndex); + GE_CHECK_NOTNULL(pred_node); + if (pred_node->GetType() != LOOPCOND) { + return SUCCESS; + } + + for (const auto &anchor_2_node : NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kSwitchExitAnchorIndex)) { + GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(), + anchor_2_node.second->GetType().c_str()); + auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()]; + if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) { + suspend_nodes.nodes.push(anchor_2_node.second); + AddNodeSuspend(anchor_2_node.second); + } + } + return SUCCESS; +} + +Status InferShapePass::Infer(NodePtr &node) { + auto ret = InferShapeAndType(node); + if (ret != GRAPH_SUCCESS) { + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + auto root_graph = ge::GraphUtils::FindRootGraph(graph); + GE_CHECK_NOTNULL(root_graph); + analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), + analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; + (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); + (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), + root_graph->GetGraphID()); + REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed", node->GetName().c_str(), + node->GetType().c_str()); + GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed", node->GetName().c_str(), + node->GetType().c_str()); + return GE_GRAPH_INFERSHAPE_FAILED; + } + return SUCCESS; +} + +graphStatus InferShapePass::InferShapeAndType(NodePtr &node) { + auto ret = SuspendV1LoopExitNodes(node); + if (ret != SUCCESS) { + GELOGE(ret, "Suspend V1 loop exit nodes failed."); + return ret; + } + bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + auto opdesc = node->GetOpDesc(); + if (node->Verify() != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + Operator op = OpDescUtils::CreateOperatorFromNode(node); + + if (!is_unknown_graph) { + auto inference_context = ShapeRefiner::CreateInferenceContext(node); + GE_CHECK_NOTNULL(inference_context); + std::vector marks; + inference_context->GetMarks(marks); + GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), marks.size()); + op.SetInferenceContext(inference_context); + } + + graphStatus status = CallInferShapeFunc(node, op); + if (status != GRAPH_NODE_NEED_REPASS && status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) { + // node like netoutput return param_invalid, but valid ? + return GE_GRAPH_INFERSHAPE_FAILED; + } + UpdateCurNodeOutputDesc(node); + if (!is_unknown_graph) { + auto ctx_after_infer = op.GetInferenceContext(); + if (ctx_after_infer != nullptr) { + std::vector marks; + ctx_after_infer->GetMarks(marks); + GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), marks.size()); + if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !marks.empty()) { + GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), + marks.size()); + ShapeRefiner::PushToContextMap(node, ctx_after_infer); + } + } + } + + return (status == GRAPH_NODE_NEED_REPASS) ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS; +} + +void InferShapePass::UpdateCurNodeOutputDesc(NodePtr &node) { + auto op_desc = node->GetOpDesc(); + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + GE_IF_BOOL_EXEC(output_tensor == nullptr, continue); + GE_IF_BOOL_EXEC(output_tensor->MutableShape().GetDims().empty(), + output_tensor->SetOriginShape(output_tensor->GetShape())); + + ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetOriginShape().GetDims() + .size())); + output_tensor->SetOriginDataType(output_tensor->GetDataType()); + // set output origin shape range + std::vector> range; + (void)output_tensor->GetShapeRange(range); + output_tensor->SetOriginShapeRange(range); + GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", + node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); + } +} + +bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { + // check shape range + vector> src_shape_range; + vector> dst_shape_range; + src->GetShapeRange(src_shape_range); + dst->GetShapeRange(dst_shape_range); + if (src_shape_range.size() != dst_shape_range.size()) { + GELOGI("Src shape range size is %zu, dst shape range size is %zu, not same.", src_shape_range.size(), + dst_shape_range.size()); + return false; + } + for (size_t i = 0; i < src_shape_range.size(); ++i) { + if (src_shape_range[i].first != dst_shape_range[i].first || + src_shape_range[i].second != dst_shape_range[i].second) { + GELOGI("Current dim %zu. Src shape range is [%lu-%lu], dst shape range is [%lu-%lu], not same.", + i, src_shape_range[i].first, src_shape_range[i].second, dst_shape_range[i].first, dst_shape_range[i].second); + return false; + } + } + + // check shape + auto src_shape = src->GetShape(); + auto dst_shape = dst->GetShape(); + if (src_shape.GetDims() != dst_shape.GetDims() || src->GetOriginShape().GetDims() != dst->GetOriginShape().GetDims() || + src->GetDataType() != dst->GetDataType() || src->GetOriginDataType() != dst->GetOriginDataType()) { + GELOGD( + "Src shape is %s, origin_shape is %s, data_type is %s, origin data_type is %s; " + "Dst shape is %s, origin_shape is %s, data_type is %s, original data_type is %s, not same.", + src_shape.ToString().c_str(), src->GetOriginShape().ToString().c_str(), + TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst_shape.ToString().c_str(), + dst->GetOriginShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); + return false; + } + return true; +} + +graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { + changed = false; + if (SameTensorDesc(src, dst)) { + GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); + return SUCCESS; + } + + changed = true; + UpdateShapeAndDType(src, dst); + GELOGD( + "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s." + "To dst Node: shape: [%s], datatype: %s, original datatype is %s.", + src->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst->GetShape().ToString().c_str(), + TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); + return SUCCESS; +} + +graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { + auto op_desc = node->GetOpDesc(); + const auto &op_type = op_desc->GetType(); + auto ret = op_desc->CallInferFunc(op); + if (ret == GRAPH_PARAM_INVALID) { + // Op ir no infer func, try to get infer func from operator factory + + auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType().c_str()); + if (node_op.IsEmpty()) { + GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); + return ret; + } + + GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); + auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); + node_op.BreakConnect(); + if (temp_op_desc == nullptr) { + REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr."); + GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null"); + return GRAPH_FAILED; + } + if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { + GELOGW("InferShapeAndType UpdateInputName failed"); + for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { + if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { + break; + } + return GRAPH_SUCCESS; + } + } + if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { + GELOGW("InferShapeAndType UpdateOutputName failed"); + } + op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); + ret = op_desc->CallInferFunc(op); + GELOGI("op CallInferFunc second. ret: %u", ret); + } + return ret; +} + +graphStatus InferShapePass::UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) { + GELOGD("Enter update parent node shape for class branch op process"); + // check sub_graph shape.If not same ,do unknown shape process + auto ref_out_tensor = src.at(0); + ge::GeShape &ref_out_tensor_shape = ref_out_tensor->MutableShape(); + for (auto &tensor : src) { + if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { + REPORT_INNER_ERROR("E19999", "Does not support diff dtype among all ref output, shape:%s", + ref_out_tensor_shape.ToString().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype output"); + return GRAPH_FAILED; + } + auto shape = tensor->MutableShape(); + if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { + GELOGD("Shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", shape.GetShapeSize(), + ref_out_tensor_shape.GetShapeSize()); + ref_out_tensor_shape = GeShape(UNKNOWN_RANK); + break; + } + for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { + if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { + continue; + } + GELOGD("j: %zu ,shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", j, shape.GetShapeSize(), + ref_out_tensor_shape.GetShapeSize()); + (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); + } + } + UpdateShapeAndDType(ref_out_tensor, dst); + return GRAPH_SUCCESS; +} +graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) { + // check sub_graph shape. Get max for update. + if (src.empty()) { + GELOGI("Src subgraph shape is empty."); + return SUCCESS; + } + + int64_t max_size = 0; + size_t max_shape_index = 0; + auto &ref_out_tensor = src.at(0); + for (size_t j = 0; j < src.size(); ++j) { + auto &tensor = src.at(j); + if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { + REPORT_INNER_ERROR("E19999", "node does not support diff dtype among all ref output"); + GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype among all ref output"); + return GRAPH_FAILED; + } + + auto shape = tensor->MutableShape(); + int64_t size = 1; + for (auto dim : shape.GetDims()) { + if (dim != 0 && INT64_MAX / dim < size) { + REPORT_INNER_ERROR("E19999", "The shape:%s size overflow", shape.ToString().c_str()); + GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow"); + return PARAM_INVALID; + } + size *= dim; + } + + if (size > max_size) { + max_size = size; + max_shape_index = j; + } + } + UpdateShapeAndDType(src.at(max_shape_index), dst); + return GRAPH_SUCCESS; +} +Status InferShapePass::OnSuspendNodesLeaked() { + auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); + if (iter == graphs_2_suspend_nodes_.end()) { + GELOGI("Current graph %s no suspend node.", GetCurrentGraphName().c_str()); + return SUCCESS; + } + if (!iter->second.nodes.empty()) { + AddNodeResume(iter->second.PopSuspendedNode()); + } + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/infershape_pass.h b/ge/graph/passes/infershape_pass.h index 9c5d432d..00d90775 100644 --- a/ge/graph/passes/infershape_pass.h +++ b/ge/graph/passes/infershape_pass.h @@ -1,38 +1,56 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ -#define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ - -#include "graph/passes/base_pass.h" - -namespace ge { -class InferShapePass : public BaseNodePass { - public: - /// - /// Entry of the InferShapePass optimizer - /// @param [in] graph: Input ComputeGraph - /// @return SUCCESS: Execution succeed - /// @return OTHERS: Execution failed - /// @author - /// - Status Run(ge::NodePtr &node) override; - - private: - Status RePassLoopNode(const NodePtr &node); -}; -} // namespace ge -#endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ +#define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ + +#include "graph/passes/infer_base_pass.h" +#include + +namespace ge { +class InferShapePass : public InferBasePass { + public: + std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; + graphStatus Infer(NodePtr &node) override; + + graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; + graphStatus UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) override; + graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) override; + + Status OnSuspendNodesLeaked() override; + + private: + graphStatus InferShapeAndType(NodePtr &node); + graphStatus CallInferShapeFunc(NodePtr &node, Operator &op); + bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst); + void UpdateCurNodeOutputDesc(NodePtr &node); + Status SuspendV1LoopExitNodes(const NodePtr &node); + struct SuspendNodes { + std::stack nodes; + std::unordered_set nodes_set; + + NodePtr PopSuspendedNode() { + auto top_node = nodes.top(); + nodes.pop(); + nodes_set.erase(top_node); + return top_node; + } + }; + std::map graphs_2_suspend_nodes_; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ diff --git a/ge/graph/passes/input_output_connection_identify_pass.cc b/ge/graph/passes/input_output_connection_identify_pass.cc index 5779fb41..d5551bdc 100644 --- a/ge/graph/passes/input_output_connection_identify_pass.cc +++ b/ge/graph/passes/input_output_connection_identify_pass.cc @@ -23,7 +23,7 @@ #include #include "common/ge/ge_util.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/passes/input_output_connection_identify_pass.h b/ge/graph/passes/input_output_connection_identify_pass.h index 97ed315d..c4a4653e 100755 --- a/ge/graph/passes/input_output_connection_identify_pass.h +++ b/ge/graph/passes/input_output_connection_identify_pass.h @@ -19,7 +19,7 @@ #include #include -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/iterator_op_pass.cc b/ge/graph/passes/iterator_op_pass.cc index 3e85887b..57416017 100644 --- a/ge/graph/passes/iterator_op_pass.cc +++ b/ge/graph/passes/iterator_op_pass.cc @@ -21,13 +21,13 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" -#include "graph/common/omg_util.h" -#include "graph/graph.h" +#include "common/omg_util.h" +#include "external/graph/graph.h" #include "graph/node.h" #include "graph/passes/pass_utils.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/passes/iterator_op_pass.h b/ge/graph/passes/iterator_op_pass.h index d9303358..611109dc 100644 --- a/ge/graph/passes/iterator_op_pass.h +++ b/ge/graph/passes/iterator_op_pass.h @@ -17,7 +17,7 @@ #ifndef GE_GRAPH_PASSES_ITERATOR_OP_PASS_H_ #define GE_GRAPH_PASSES_ITERATOR_OP_PASS_H_ -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index 522c20ad..9ff3bfd7 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -18,7 +18,7 @@ #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "framework/common/types.h" #include "init/gelib.h" diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.h b/ge/graph/passes/link_gen_mask_nodes_pass.h index 12d68f1b..c6c1e703 100644 --- a/ge/graph/passes/link_gen_mask_nodes_pass.h +++ b/ge/graph/passes/link_gen_mask_nodes_pass.h @@ -21,7 +21,7 @@ #include #include -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc index 08b358ee..3989e54f 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -14,29 +14,17 @@ * limitations under the License. */ -#include "mark_force_unknown_for_cond_pass.h" - -#include +#include "graph/passes/mark_force_unknown_for_cond_pass.h" #include "graph/utils/node_utils.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { namespace { inline bool IsMergeInLoop(const NodePtr &node) { const static std::set kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; - std::string node_type; - (void)GetOriginalType(node, node_type); - return kLoopMergeInputs.count(node_type) > 0; -} - -inline bool IsSwitchInLoop(const NodePtr &node) { - const static std::set kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; - - std::string node_type; - (void)GetOriginalType(node, node_type); - return kLoopSwitchInputs.count(node_type) > 0; + return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0; } } @@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { GELOGD("MarkForceUnknownForCondPass Enter"); std::map> switch_groups; for (const auto &node : graph->GetDirectNode()) { - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type), - "[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str()); - if (kMergeOpTypes.count(node_type) == 0) { + if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) { continue; } @@ -65,6 +50,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { } /// +/// @brief Deal with Switch node for LoopCond +/// @param [in] Switch node +/// @param [in] dest span +/// @param [out] Search queue +/// @return true: Switch In while loop / false: Not in while Loop. +/// +bool MarkForceUnknownForCondPass::DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, + std::queue> &search_queue) { + /// LoopCond --->\. + /// \. + /// Enter-----------+ \. + /// +--> Merge --> Switch --> Exit + /// NextIteration---+ + const auto is_loop_op = [](const NodePtr &n) { + return NodeUtils::GetNodeType(n) == LOOPCOND; + }; + const auto is_exit_op = [](const NodePtr &n) { + return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0; + }; + + const auto src_nodes = node->GetInAllNodes(); + const auto dst_nodes = node->GetOutAllNodes(); + if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) && + std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) { + return false; + } + + for (const auto &m : src_nodes) { + if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) { + for (const auto &n : m->GetInAllNodes()) { + if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) { + continue; + } + + search_queue.push({n, dst_span}); + GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(), + n->GetName().c_str(), dst_span); + } + } + } + + return true; +} + +/// /// @brief Mark force unknown shape for Switch node /// @param [in] merge node /// @param [out] switch group @@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { /// void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector &switch_group) { // Switch --> {Switch --> Merge} --> Merge + GELOGD("Search Switch node for Merge: %s", node->GetName().c_str()); std::unordered_set nodes_seen; std::queue> search_queue({{node, 0}}); while (!search_queue.empty()) { @@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: const auto dst_span = search_queue.front().second; search_queue.pop(); - // Switch --> Identity --> Constant - for (const auto &in_node : dst_node->GetInControlNodes()) { + for (const auto &in_node : dst_node->GetInAllNodes()) { if (nodes_seen.count(in_node) > 0) { GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); continue; } nodes_seen.insert(in_node); - if (in_node->GetType() == IDENTITY) { - GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(), - in_node->GetName().c_str(), dst_span); - search_queue.push({in_node, dst_span}); - } - } - - for (const auto &in_node : dst_node->GetInDataNodes()) { - if (nodes_seen.count(in_node) > 0) { - GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); - continue; - } - nodes_seen.insert(in_node); - - std::string node_type; - (void)GetOriginalType(in_node, node_type); + const std::string node_type = NodeUtils::GetNodeType(in_node); GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), in_node->GetName().c_str(), dst_span); if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. + if (DealAsLoopSwitch(in_node, dst_span, search_queue)) { + continue; + } + if (dst_span > 0) { search_queue.push({in_node, dst_span - 1}); } else { - const auto &all_in_nodes = in_node->GetInDataNodes(); - if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { - GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), - in_node->GetName().c_str()); - } else { - switch_group.emplace_back(in_node); - } + switch_group.emplace_back(in_node); } } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. search_queue.push({in_node, dst_span + 1}); @@ -132,39 +145,63 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: /// @return /// void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map> &switch_groups) { - std::function callback = [](const NodePtr &n) { - return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); - }; + // Step 0: no group assigned. such as: + // Merge1{id=0, group=} => {Switch1{id=1, group=}, Switch2{id=2, group=}} + // Merge2{id=3, group=} => {Switch1{id=1, group=}, Switch3{id=4, group=}} + // Merge3{id=5, group=} => {Switch4{id=6, group=}, Switch5{id=7, group=}} + // Merge4{id=8, group=} => {Switch1{id=1, group=}, Switch5{id=7, group=}} + std::map unique_groups; + const auto get_group_index = [&unique_groups](const NodePtr &merge, const std::vector &switch_group) { + int64_t group_index = merge->GetOpDesc()->GetId(); + std::set group_ids{group_index}; + for (const auto &node : switch_group) { + if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { + GELOGI("[%s] Get group from [%s], index[%ld]", merge->GetName().c_str(), node->GetName().c_str(), group_index); + group_ids.insert(group_index); + } + } - for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { - const auto &op_node1 = it1->first; - const auto &op_desc1 = op_node1->GetOpDesc(); - if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { - continue; + const auto it = unique_groups.find(group_index); + if (it != unique_groups.end()) { + group_index = it->second; } - if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) { - int64_t group_index = op_desc1->GetId(); - GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); - MarkForceUnknownShape(op_node1, true, group_index); - for (const auto &n : it1->second) { - MarkForceUnknownShape(n, true, group_index); - } + for (auto id : group_ids) { + unique_groups[id] = group_index; + } - for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { - const auto &op_node2 = it2->first; - const auto &op_desc2 = op_node2->GetOpDesc(); - if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { - continue; - } + return group_index; + }; - if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { - MarkForceUnknownShape(op_node2, true, group_index); - for (const auto &n : it2->second) { - MarkForceUnknownShape(n, true, group_index); - } - } - } + const auto set_group_index = [](const NodePtr &merge, const std::vector &switch_group, int64_t group_index) { + SetControlFlowGroup(merge, group_index); + for (const auto &node : switch_group) { + SetControlFlowGroup(node, group_index); + } + }; + + // Step 1: Set group index to merge, if switch already has group, use assigned group. + // Merge1{id=0, group=0} => {Switch1{id=1, group=0}, Switch2{id=2, group=0}} + // Merge2{id=3, group=0} => {Switch1{id=1, group=0}, Switch3{id=4, group=0}} + // Merge3{id=5, group=5} => {Switch4{id=6, group=5}, Switch5{id=7, group=5}} + // Merge4{id=8, group=0} => {Switch1{id=1, group=0}, Switch5{id=7, group=0}} + for (const auto group : switch_groups) { + int64_t group_index = get_group_index(group.first, group.second); + set_group_index(group.first, group.second, group_index); + } + + // Step 2: Adjust crossed merge group for unique group. + // Merge1{id=0, group=0} => {Switch1{id=1, group=0}, Switch2{id=2, group=0}} + // Merge2{id=3, group=0} => {Switch1{id=1, group=0}, Switch3{id=4, group=0}} + // Merge3{id=5, group=0} => {Switch4{id=6, group=0}, Switch5{id=7, group=0}} + // Merge4{id=8, group=0} => {Switch1{id=1, group=0}, Switch5{id=7, group=0}} + for (const auto group : switch_groups) { + int64_t group_index = -1; + (void)AttrUtils::GetInt(group.first->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); + + const auto it = unique_groups.find(group_index); + if (it != unique_groups.end() && it->first != it->second) { + set_group_index(group.first, group.second, it->second); } } } diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.h b/ge/graph/passes/mark_force_unknown_for_cond_pass.h index 528a8fdc..030b55ee 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.h +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.h @@ -19,6 +19,8 @@ #include "inc/graph_pass.h" +#include + namespace ge { class MarkForceUnknownForCondPass : public GraphPass { public: @@ -26,6 +28,15 @@ class MarkForceUnknownForCondPass : public GraphPass { private: /// + /// @brief Deal with Switch node for LoopCond + /// @param [in] Switch node + /// @param [in] dest span + /// @param [out] Search queue + /// @return true: Switch In while loop / false: Not in while Loop. + /// + bool DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue> &search_queue); + + /// /// @brief Mark force unknown shape for Switch node /// @param [in] merge node /// @param [out] switch group diff --git a/ge/graph/passes/mark_graph_unknown_status_pass.cc b/ge/graph/passes/mark_graph_unknown_status_pass.cc index 2d7b179b..9e460fc7 100644 --- a/ge/graph/passes/mark_graph_unknown_status_pass.cc +++ b/ge/graph/passes/mark_graph_unknown_status_pass.cc @@ -40,6 +40,12 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) { } } + const auto &node = graph->GetParentNode(); + if (!is_unknown_shape && node != nullptr && node->GetType() == PARTITIONEDCALL) { + GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), + "[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str()); + } + for (const auto &node : graph->GetDirectNode()) { GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str()); (void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape); diff --git a/ge/graph/passes/mark_graph_unknown_status_pass.h b/ge/graph/passes/mark_graph_unknown_status_pass.h index a1148c6e..2cc86dbd 100644 --- a/ge/graph/passes/mark_graph_unknown_status_pass.h +++ b/ge/graph/passes/mark_graph_unknown_status_pass.h @@ -16,7 +16,7 @@ #ifndef GE_GRAPH_PASSES_MARK_GRAPH_UNKNOWN_STATUS_PASS_H_ #define GE_GRAPH_PASSES_MARK_GRAPH_UNKNOWN_STATUS_PASS_H_ -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/mark_node_unknown_shape_pass.cc b/ge/graph/passes/mark_node_unknown_shape_pass.cc index c040e846..eadd3ca7 100644 --- a/ge/graph/passes/mark_node_unknown_shape_pass.cc +++ b/ge/graph/passes/mark_node_unknown_shape_pass.cc @@ -17,7 +17,7 @@ #include "graph/passes/mark_node_unknown_shape_pass.h" #include "graph/utils/node_utils.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" namespace ge { namespace { diff --git a/ge/graph/passes/mark_node_unknown_shape_pass.h b/ge/graph/passes/mark_node_unknown_shape_pass.h index b78b7826..acd12582 100644 --- a/ge/graph/passes/mark_node_unknown_shape_pass.h +++ b/ge/graph/passes/mark_node_unknown_shape_pass.h @@ -16,7 +16,7 @@ #ifndef GE_GRAPH_PASSES_MARK_NODE_UNKNOWN_SHAPE_PASS_H_ #define GE_GRAPH_PASSES_MARK_NODE_UNKNOWN_SHAPE_PASS_H_ -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/mark_same_addr_pass.h b/ge/graph/passes/mark_same_addr_pass.h index 518fe418..adf971a2 100644 --- a/ge/graph/passes/mark_same_addr_pass.h +++ b/ge/graph/passes/mark_same_addr_pass.h @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" #ifndef GE_GRAPH_PASSES_MARK_SAME_ADDR_PASS_H_ diff --git a/ge/graph/passes/merge_input_memcpy_pass.cc b/ge/graph/passes/merge_input_memcpy_pass.cc index 00c04131..97a17d99 100644 --- a/ge/graph/passes/merge_input_memcpy_pass.cc +++ b/ge/graph/passes/merge_input_memcpy_pass.cc @@ -17,8 +17,8 @@ #include "graph/passes/merge_input_memcpy_pass.h" #include "common/ge/ge_util.h" -#include "ge/ge_api_types.h" -#include "graph/common/omg_util.h" +#include "external/ge/ge_api_types.h" +#include "common/omg_util.h" namespace ge { Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { diff --git a/ge/graph/passes/merge_pass.cc b/ge/graph/passes/merge_pass.cc index fec9c6d0..2ddfcaab 100644 --- a/ge/graph/passes/merge_pass.cc +++ b/ge/graph/passes/merge_pass.cc @@ -22,7 +22,7 @@ #include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/passes/pass_utils.h" diff --git a/ge/graph/passes/merge_to_stream_merge_pass.cc b/ge/graph/passes/merge_to_stream_merge_pass.cc index 0b383911..e91410e1 100644 --- a/ge/graph/passes/merge_to_stream_merge_pass.cc +++ b/ge/graph/passes/merge_to_stream_merge_pass.cc @@ -16,8 +16,8 @@ #include "graph/passes/merge_to_stream_merge_pass.h" #include "common/ge/ge_util.h" -#include "ge/ge_api_types.h" -#include "graph/common/omg_util.h" +#include "external/ge/ge_api_types.h" +#include "common/omg_util.h" namespace ge { Status MergeToStreamMergePass::Run(ComputeGraphPtr graph) { @@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); return FAILED, "[Check][Param] Param of pre node is nullptr."); int64_t group_index = -1; - bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); - MarkForceUnknownShape(node, force_unknown, group_index); + (void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); @@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); return FAILED; } - MarkForceUnknownShape(active_node, force_unknown, group_index); + SetControlFlowGroup(active_node, group_index); } return SUCCESS; diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index d36b4186..b25239b1 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -18,14 +18,14 @@ #include "common/formats/utils/formats_trans_utils.h" #include "common/ge/ge_util.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/preprocess/multi_batch_options.h" #include "graph/utils/node_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "register/op_registry.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { namespace { diff --git a/ge/graph/passes/multi_batch_pass.cc b/ge/graph/passes/multi_batch_pass.cc index 25d629fa..9fba362c 100644 --- a/ge/graph/passes/multi_batch_pass.cc +++ b/ge/graph/passes/multi_batch_pass.cc @@ -19,7 +19,7 @@ #include #include #include "common/ge/ge_util.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/type_utils.h" #include "common/formats/utils/formats_trans_utils.h" diff --git a/ge/graph/passes/net_output_pass.cc b/ge/graph/passes/net_output_pass.cc index 30455fa0..9aea4863 100644 --- a/ge/graph/passes/net_output_pass.cc +++ b/ge/graph/passes/net_output_pass.cc @@ -27,7 +27,7 @@ #include "framework/common/ge_inner_error_codes.h" #include "framework/omg/omg_inner_types.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/passes/pass_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/graph/passes/net_output_pass.h b/ge/graph/passes/net_output_pass.h index ab190169..fecccc35 100644 --- a/ge/graph/passes/net_output_pass.h +++ b/ge/graph/passes/net_output_pass.h @@ -22,7 +22,7 @@ #include #include -#include "graph/types.h" +#include "external/graph/types.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc index 67735b8b..af3e4d2d 100644 --- a/ge/graph/passes/next_iteration_pass.cc +++ b/ge/graph/passes/next_iteration_pass.cc @@ -17,14 +17,16 @@ #include "graph/passes/next_iteration_pass.h" #include "common/ge/ge_util.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/node_utils.h" using std::string; namespace ge { namespace { -const int64_t kLoopType = 1; +constexpr int64_t kLoopType = 1; +constexpr uint8_t kMaxTransOp = 3; +constexpr uint8_t kTransOpIoSize = 1; } Status NextIterationPass::Run(ComputeGraphPtr graph) { @@ -284,13 +286,28 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { /// @return void /// void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { + std::string node_type; for (const auto &switch_node : loop_group.switch_nodes) { SetControlFlowGroup(switch_node, group_index); - for (const auto &node : switch_node->GetOutDataNodes()) { - std::string node_type; - (void)GetOriginalType(node, node_type); - if (kExitOpTypes.count(node_type) > 0) { - SetControlFlowGroup(node, group_index); + for (auto node : switch_node->GetOutDataNodes()) { + // Switch --> Exit + // Switch --> Cast --> Exit + // Switch --> TransData --> Cast --> Exit + for (uint8_t i = 0; i < kMaxTransOp; ++i) { + if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) { + break; + } + + if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { + SetControlFlowGroup(node, group_index); + break; + } + + const auto &all_nodes = node->GetOutAllNodes(); + if (all_nodes.size() != kTransOpIoSize) { + break; + } + node = all_nodes.at(0); } } } diff --git a/ge/graph/passes/no_use_reshape_remove_pass.cc b/ge/graph/passes/no_use_reshape_remove_pass.cc index b3074565..e0a0ceb8 100644 --- a/ge/graph/passes/no_use_reshape_remove_pass.cc +++ b/ge/graph/passes/no_use_reshape_remove_pass.cc @@ -19,7 +19,7 @@ #include #include -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "external/graph/types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" diff --git a/ge/graph/passes/parallel_group_pass.cc b/ge/graph/passes/parallel_group_pass.cc index 9c93f6cf..795002f1 100644 --- a/ge/graph/passes/parallel_group_pass.cc +++ b/ge/graph/passes/parallel_group_pass.cc @@ -15,7 +15,7 @@ */ #include "graph/passes/parallel_group_pass.h" - +#include #include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" #include "framework/common/ge_inner_error_codes.h" @@ -299,24 +299,19 @@ Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cu for (const auto &switch_node : cur_itr->second.first) { int64_t pre_id = pre_node->GetOpDesc()->GetId(); int64_t switch_id = switch_node->GetOpDesc()->GetId(); - // avoid ring - if (pre_id > switch_id) { - auto merge_node = cur_itr->second.second; - if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) { - GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - return FAILED; - } - } else { - if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) { - GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", - pre_node->GetName().c_str(), switch_node->GetName().c_str()); - return FAILED; - } + NodePtr first_node = pre_node; + NodePtr second_node = switch_node; + if (pre_id > switch_id && IsIndirectConnect(switch_node, pre_node)) { + // avoid ring, merge->pre_node + first_node = cur_itr->second.second; + second_node = pre_node; + } + if (AddCtrlEdge(first_node, second_node) != SUCCESS) { + GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + first_node->GetName().c_str(), second_node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", + first_node->GetName().c_str(), second_node->GetName().c_str()); + return FAILED; } } } else { @@ -345,4 +340,29 @@ bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) { return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) && stream_switch_type == kLoopType); } + +bool ParallelGroupPass::IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b) { + if (node_a == nullptr || node_b == nullptr) { + GELOGW("node_a or node_b is nullptr."); + return false; + } + int64_t end_id = node_b->GetOpDesc()->GetId(); + std::queue nodes; + nodes.push(node_a); + while (!nodes.empty()) { + NodePtr tmp_node = nodes.front(); + nodes.pop(); + if (tmp_node == nullptr || tmp_node->GetOpDesc() == nullptr || + tmp_node->GetOpDesc()->GetId() > end_id) { + continue; + } + if (tmp_node == node_b) { + return true; + } + for (const auto &out_node : tmp_node->GetOutAllNodes()) { + nodes.push(out_node); + } + } + return false; +} } // namespace ge diff --git a/ge/graph/passes/parallel_group_pass.h b/ge/graph/passes/parallel_group_pass.h index 9b895598..93b0b158 100644 --- a/ge/graph/passes/parallel_group_pass.h +++ b/ge/graph/passes/parallel_group_pass.h @@ -19,7 +19,7 @@ #include #include -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { @@ -48,6 +48,7 @@ class ParallelGroupPass : public GraphPass { bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc); bool IsWhileStreamSwitch(OpDescPtr switch_op_desc); + bool IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b); }; } // namespace ge #endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H diff --git a/ge/graph/passes/pass_manager.cc b/ge/graph/passes/pass_manager.cc index fa2f1e17..afd2e4a7 100644 --- a/ge/graph/passes/pass_manager.cc +++ b/ge/graph/passes/pass_manager.cc @@ -15,12 +15,12 @@ */ #include "inc/pass_manager.h" -#include "common/debug/log.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/debug/log.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "graph/utils/node_utils.h" -#include "graph/common/ge_call_wrapper.h" -#include "omg/omg_inner_types.h" +#include "common/ge_call_wrapper.h" +#include "framework/omg/omg_inner_types.h" namespace ge { const vector>& PassManager::GraphPasses() const { return names_to_graph_passes_; } diff --git a/ge/graph/passes/pass_utils.cc b/ge/graph/passes/pass_utils.cc index c0ef7685..0e056a0f 100644 --- a/ge/graph/passes/pass_utils.cc +++ b/ge/graph/passes/pass_utils.cc @@ -23,11 +23,11 @@ #include #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "graph/common/omg_util.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" #include "graph/manager/graph_var_manager.h" @@ -35,7 +35,7 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -#include "utils/node_utils.h" +#include "graph/utils/node_utils.h" #include "common/formats/utils/formats_trans_utils.h" namespace ge { diff --git a/ge/graph/passes/pass_utils.h b/ge/graph/passes/pass_utils.h index bd506d09..475c4e77 100755 --- a/ge/graph/passes/pass_utils.h +++ b/ge/graph/passes/pass_utils.h @@ -19,7 +19,7 @@ #include #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/compute_graph.h" namespace ge { diff --git a/ge/graph/passes/permute_pass.cc b/ge/graph/passes/permute_pass.cc index 8254db72..f3045b1a 100644 --- a/ge/graph/passes/permute_pass.cc +++ b/ge/graph/passes/permute_pass.cc @@ -17,14 +17,14 @@ #include "graph/passes/permute_pass.h" #include #include -#include "common/debug/log.h" -#include "common/types.h" +#include "framework/common/debug/log.h" +#include "framework/common/types.h" #include "graph/utils/attr_utils.h" #include "graph/utils/op_desc_utils.h" #include "inc/kernel.h" #include "inc/kernel_factory.h" #include "framework/omg/omg_inner_types.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" using domi::DOMI_TENSOR_ND; using domi::DOMI_TENSOR_NHWC; diff --git a/ge/graph/passes/placeholder_with_default_pass.cc b/ge/graph/passes/placeholder_with_default_pass.cc index 893ee798..bc51b217 100644 --- a/ge/graph/passes/placeholder_with_default_pass.cc +++ b/ge/graph/passes/placeholder_with_default_pass.cc @@ -18,7 +18,7 @@ #include #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { Status PlaceholderWithDefaultPass::Run(NodePtr &node) { diff --git a/ge/graph/passes/prevent_gradient_pass.cc b/ge/graph/passes/prevent_gradient_pass.cc index c531fd2f..8b8b17bd 100644 --- a/ge/graph/passes/prevent_gradient_pass.cc +++ b/ge/graph/passes/prevent_gradient_pass.cc @@ -19,7 +19,7 @@ #include #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { Status PreventGradientPass::Run(NodePtr &node) { diff --git a/ge/graph/passes/print_op_pass.h b/ge/graph/passes/print_op_pass.h index deaf559b..96501dc5 100755 --- a/ge/graph/passes/print_op_pass.h +++ b/ge/graph/passes/print_op_pass.h @@ -20,8 +20,8 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/omg_util.h" -#include "graph/graph.h" +#include "common/omg_util.h" +#include "external/graph/graph.h" #include "graph/passes/base_pass.h" #include "graph/utils/graph_utils.h" #include "graph/passes/pass_utils.h" diff --git a/ge/graph/passes/prune_pass.cc b/ge/graph/passes/prune_pass.cc index 1e2ec4ab..cc6c7618 100644 --- a/ge/graph/passes/prune_pass.cc +++ b/ge/graph/passes/prune_pass.cc @@ -19,8 +19,8 @@ #include #include #include -#include "common/debug/log.h" -#include "common/types.h" +#include "framework/common/debug/log.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/utils/node_utils.h" diff --git a/ge/graph/passes/ref_identity_delete_op_pass.cc b/ge/graph/passes/ref_identity_delete_op_pass.cc index 39794cff..46bc7467 100644 --- a/ge/graph/passes/ref_identity_delete_op_pass.cc +++ b/ge/graph/passes/ref_identity_delete_op_pass.cc @@ -14,10 +14,10 @@ * limitations under the License. */ -#include "ref_identity_delete_op_pass.h" +#include "graph/passes/ref_identity_delete_op_pass.h" #include #include -#include "graph/common/transop_util.h" +#include "common/transop_util.h" namespace ge { Status RefIdentityDeleteOpPass::Run(ComputeGraphPtr graph) { diff --git a/ge/graph/passes/remove_same_const_pass.cc b/ge/graph/passes/remove_same_const_pass.cc index a06eea43..947ff3f3 100644 --- a/ge/graph/passes/remove_same_const_pass.cc +++ b/ge/graph/passes/remove_same_const_pass.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "remove_same_const_pass.h" +#include "graph/passes/remove_same_const_pass.h" #include #include diff --git a/ge/graph/passes/remove_same_const_pass.h b/ge/graph/passes/remove_same_const_pass.h index 08905bd2..6934a472 100644 --- a/ge/graph/passes/remove_same_const_pass.h +++ b/ge/graph/passes/remove_same_const_pass.h @@ -16,7 +16,7 @@ #ifndef GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_ #define GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_ -#include "graph/types.h" +#include "external/graph/types.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/replace_transshape_pass.cc b/ge/graph/passes/replace_transshape_pass.cc index 28957a61..0e1701ab 100644 --- a/ge/graph/passes/replace_transshape_pass.cc +++ b/ge/graph/passes/replace_transshape_pass.cc @@ -19,9 +19,9 @@ #include #include "common/ge/ge_util.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/graph_utils.h" namespace ge { diff --git a/ge/graph/passes/replace_with_empty_const_pass.cc b/ge/graph/passes/replace_with_empty_const_pass.cc index 9459c852..6cb31627 100644 --- a/ge/graph/passes/replace_with_empty_const_pass.cc +++ b/ge/graph/passes/replace_with_empty_const_pass.cc @@ -21,7 +21,23 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +namespace { +const std::unordered_set kControlFlowOps = { + ge::SWITCH, + ge::REFSWITCH, + ge::MERGE, + ge::REFMERGE, + ge::ENTER, + ge::REFENTER, + ge::NEXTITERATION, + ge::REFNEXTITERATION, + ge::EXIT, + ge::REFEXIT, + ge::LOOPCOND +}; +} namespace ge { Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { GELOGD("ReplaceWithEmptyConstPass in."); @@ -39,6 +55,10 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str()); return SUCCESS; } + if (kControlFlowOps.count(NodeUtils::GetNodeType(node)) != 0) { + GELOGI("Node %s is control flow op. Ignore current pass.", node->GetName().c_str()); + return SUCCESS; + } // Node like no op, it has no output if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) { GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str()); @@ -51,7 +71,7 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { GELOGI("Node %s Got empty output_desc_ptr, ignore current pass.", node->GetName().c_str()); return SUCCESS; } - if (!IsEmptyTenor(output_desc_ptr->GetShape())) { + if (!IsKnownEmptyTenor(output_desc_ptr->GetShape())) { is_all_output_empty = false; break; } @@ -87,12 +107,16 @@ Status ReplaceWithEmptyConstPass::GetOutputsOfCurrNode(const NodePtr &node_to_re return SUCCESS; } -bool ReplaceWithEmptyConstPass::IsEmptyTenor(const GeShape &shape) const { +bool ReplaceWithEmptyConstPass::IsKnownEmptyTenor(const GeShape &shape) const { + bool is_known_empty_tensor = false; for (auto dim : shape.GetDims()) { - if (dim == 0) { - return true; + if (dim < 0) { + // current dim is unknown dim, skip replace + return false; + } else if (dim == 0) { + is_known_empty_tensor = true; } } - return false; + return is_known_empty_tensor; } } // namespace ge diff --git a/ge/graph/passes/replace_with_empty_const_pass.h b/ge/graph/passes/replace_with_empty_const_pass.h index fde75358..90103432 100644 --- a/ge/graph/passes/replace_with_empty_const_pass.h +++ b/ge/graph/passes/replace_with_empty_const_pass.h @@ -26,7 +26,7 @@ class ReplaceWithEmptyConstPass : public FoldingPass { private: Status GetOutputsOfCurrNode(const NodePtr &node_to_replace, vector &outputs); - bool IsEmptyTenor(const GeShape &shape) const; + bool IsKnownEmptyTenor(const GeShape &shape) const; }; } // namespace ge #endif // GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ diff --git a/ge/graph/passes/resource_pair_add_control_pass.cc b/ge/graph/passes/resource_pair_add_control_pass.cc index a104a95e..14f04fe0 100755 --- a/ge/graph/passes/resource_pair_add_control_pass.cc +++ b/ge/graph/passes/resource_pair_add_control_pass.cc @@ -21,9 +21,9 @@ #include #include #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "graph/utils/attr_utils.h" #include "graph/utils/tensor_adapter.h" diff --git a/ge/graph/passes/resource_pair_remove_control_pass.cc b/ge/graph/passes/resource_pair_remove_control_pass.cc index 73b96008..138efb43 100755 --- a/ge/graph/passes/resource_pair_remove_control_pass.cc +++ b/ge/graph/passes/resource_pair_remove_control_pass.cc @@ -21,9 +21,9 @@ #include #include #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "graph/utils/attr_utils.h" #include "graph/utils/tensor_adapter.h" diff --git a/ge/graph/passes/same_transdata_breadth_fusion_pass.cc b/ge/graph/passes/same_transdata_breadth_fusion_pass.cc index 60f5c7c9..afd78a4d 100644 --- a/ge/graph/passes/same_transdata_breadth_fusion_pass.cc +++ b/ge/graph/passes/same_transdata_breadth_fusion_pass.cc @@ -20,8 +20,8 @@ #include #include #include -#include "common/ge_inner_error_codes.h" -#include "common/types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" diff --git a/ge/graph/passes/save_pass.cc b/ge/graph/passes/save_pass.cc index 1181461b..6fec3a3b 100755 --- a/ge/graph/passes/save_pass.cc +++ b/ge/graph/passes/save_pass.cc @@ -20,7 +20,7 @@ #include #include #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/utils/graph_utils.h" namespace ge { diff --git a/ge/graph/passes/save_pass.h b/ge/graph/passes/save_pass.h index 512dfa62..8efcc46e 100755 --- a/ge/graph/passes/save_pass.h +++ b/ge/graph/passes/save_pass.h @@ -17,7 +17,7 @@ #ifndef GE_GRAPH_PASSES_SAVE_PASS_H_ #define GE_GRAPH_PASSES_SAVE_PASS_H_ -#include "graph/graph.h" +#include "external/graph/graph.h" #include "inc/graph_pass.h" namespace ge { diff --git a/ge/graph/passes/shape_operate_op_remove_pass.cc b/ge/graph/passes/shape_operate_op_remove_pass.cc index a703f1c9..f6ce0ec1 100755 --- a/ge/graph/passes/shape_operate_op_remove_pass.cc +++ b/ge/graph/passes/shape_operate_op_remove_pass.cc @@ -15,9 +15,9 @@ */ #include "graph/passes/shape_operate_op_remove_pass.h" -#include "common/debug/log.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/debug/log.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "graph/utils/attr_utils.h" using domi::SUCCESS; diff --git a/ge/graph/passes/snapshot_pass.cc b/ge/graph/passes/snapshot_pass.cc index 95733e67..a6cd79a3 100644 --- a/ge/graph/passes/snapshot_pass.cc +++ b/ge/graph/passes/snapshot_pass.cc @@ -18,7 +18,7 @@ #include #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { Status SnapshotPass::Run(NodePtr &node) { diff --git a/ge/graph/passes/stop_gradient_pass.h b/ge/graph/passes/stop_gradient_pass.h index 808174bc..5f022200 100755 --- a/ge/graph/passes/stop_gradient_pass.h +++ b/ge/graph/passes/stop_gradient_pass.h @@ -18,9 +18,9 @@ #define GE_GRAPH_PASSES_STOP_GRADIENT_PASS_H_ #include "framework/common/debug/ge_log.h" -#include "common/types.h" +#include "framework/common/types.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/passes/base_pass.h" namespace ge { diff --git a/ge/graph/passes/subexpression_migration_pass.cc b/ge/graph/passes/subexpression_migration_pass.cc index 6265851a..f39e02e5 100755 --- a/ge/graph/passes/subexpression_migration_pass.cc +++ b/ge/graph/passes/subexpression_migration_pass.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "subexpression_migration_pass.h" +#include "graph/passes/subexpression_migration_pass.h" #include "graph/utils/node_utils.h" #include "ge_local_engine/engine/host_cpu_engine.h" diff --git a/ge/graph/passes/subexpression_migration_pass.h b/ge/graph/passes/subexpression_migration_pass.h index d2733fcf..52326798 100755 --- a/ge/graph/passes/subexpression_migration_pass.h +++ b/ge/graph/passes/subexpression_migration_pass.h @@ -17,7 +17,7 @@ #ifndef GE_COMMON_SUBEXPRESSION_MIGRATION_H_ #define GE_COMMON_SUBEXPRESSION_MIGRATION_H_ -#include "graph/types.h" +#include "external/graph/types.h" #include "inc/graph_pass.h" #include diff --git a/ge/graph/passes/subgraph_const_migration_pass.cc b/ge/graph/passes/subgraph_const_migration_pass.cc index d15e60cf..eac0c84b 100644 --- a/ge/graph/passes/subgraph_const_migration_pass.cc +++ b/ge/graph/passes/subgraph_const_migration_pass.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "subgraph_const_migration_pass.h" +#include "graph/passes/subgraph_const_migration_pass.h" #include "graph/utils/node_utils.h" #include "ge_local_engine/engine/host_cpu_engine.h" diff --git a/ge/graph/passes/subgraph_const_migration_pass.h b/ge/graph/passes/subgraph_const_migration_pass.h index 2834fd66..e43a3049 100755 --- a/ge/graph/passes/subgraph_const_migration_pass.h +++ b/ge/graph/passes/subgraph_const_migration_pass.h @@ -17,7 +17,7 @@ #ifndef GE_COMMON_SUBGRAPH_CONST_MIGRATION_H_ #define GE_COMMON_SUBGRAPH_CONST_MIGRATION_H_ -#include "graph/types.h" +#include "external/graph/types.h" #include "inc/graph_pass.h" #include diff --git a/ge/graph/passes/switch_data_edges_bypass.cc b/ge/graph/passes/switch_data_edges_bypass.cc index 5f66a0ca..c7b46b7c 100644 --- a/ge/graph/passes/switch_data_edges_bypass.cc +++ b/ge/graph/passes/switch_data_edges_bypass.cc @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "switch_data_edges_bypass.h" +#include "graph/passes/switch_data_edges_bypass.h" #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/ge/ge_util.h" -#include "common/op/ge_op_utils.h" -#include "common/util.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/util.h" #include "graph/utils/node_utils.h" namespace ge { diff --git a/ge/graph/passes/switch_dead_branch_elimination.cc b/ge/graph/passes/switch_dead_branch_elimination.cc index 3c6c57d0..284111ba 100644 --- a/ge/graph/passes/switch_dead_branch_elimination.cc +++ b/ge/graph/passes/switch_dead_branch_elimination.cc @@ -19,7 +19,7 @@ #include #include #include "framework/common/debug/ge_log.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/passes/pass_utils.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/passes/switch_logic_remove_pass.cc b/ge/graph/passes/switch_logic_remove_pass.cc index 13b409c5..0d6bc2ce 100644 --- a/ge/graph/passes/switch_logic_remove_pass.cc +++ b/ge/graph/passes/switch_logic_remove_pass.cc @@ -21,7 +21,7 @@ #include "framework/common/debug/ge_log.h" #include "graph/utils/graph_utils.h" #include "graph/passes/pass_utils.h" -#include "common/util.h" +#include "framework/common/util.h" namespace ge { namespace { diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index e7743130..acbf27e3 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -17,8 +17,8 @@ #include "graph/passes/switch_to_stream_switch_pass.h" #include #include "common/ge/ge_util.h" -#include "ge/ge_api_types.h" -#include "graph/common/omg_util.h" +#include "external/ge/ge_api_types.h" +#include "common/omg_util.h" #include "graph/ge_context.h" #include "graph/utils/type_utils.h" @@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); int64_t group_index = -1; - bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); - MarkForceUnknownShape(stream_switch, force_unknown, group_index); + if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { + SetControlFlowGroup(stream_switch, group_index); + } return stream_switch; } @@ -491,8 +492,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { - std::list false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; - std::list true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; + const std::list &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; + const std::list &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; std::set same_cond_switch; same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); @@ -524,13 +525,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) std::function callback = [&group_index](const NodePtr &n) { return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); }; - bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); - MarkForceUnknownShape(active_node, is_unknown_shape, group_index); + (void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); + SetControlFlowGroup(active_node, group_index); const std::string &cond_group = cond_node->GetName(); for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); - std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); + const std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); GE_IF_BOOL_EXEC(switch_list.empty(), continue); // select first stream_switch @@ -559,7 +560,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) "[Add][Edge] between %s and %s failed.", cast_node->GetName().c_str(), stream_switch->GetName().c_str()); - MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); + SetControlFlowGroup(stream_switch, group_index); for (const NodePtr &node : switch_list) { GE_IF_BOOL_EXEC(node != stream_switch, { GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), diff --git a/ge/graph/passes/transop_breadth_fusion_pass.cc b/ge/graph/passes/transop_breadth_fusion_pass.cc index 58b40a5f..88db9501 100644 --- a/ge/graph/passes/transop_breadth_fusion_pass.cc +++ b/ge/graph/passes/transop_breadth_fusion_pass.cc @@ -19,8 +19,8 @@ #include #include -#include "common/types.h" -#include "graph/common/transop_util.h" +#include "framework/common/types.h" +#include "common/transop_util.h" #include "graph/utils/node_utils.h" namespace ge { diff --git a/ge/graph/passes/transop_depth_fusion_pass.cc b/ge/graph/passes/transop_depth_fusion_pass.cc index ea4add35..3ce54e50 100755 --- a/ge/graph/passes/transop_depth_fusion_pass.cc +++ b/ge/graph/passes/transop_depth_fusion_pass.cc @@ -17,13 +17,13 @@ #include "graph/passes/transop_depth_fusion_pass.h" #include -#include "common/ge_inner_error_codes.h" -#include "common/types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" #include "graph/compute_graph.h" #include "graph/ge_tensor.h" #include "graph/op_desc.h" #include "graph/utils/graph_utils.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "graph/utils/node_utils.h" namespace ge { diff --git a/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc b/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc index 76233f53..437926ef 100644 --- a/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc +++ b/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc @@ -16,10 +16,10 @@ #include "graph/passes/transop_nearby_allreduce_fusion_pass.h" #include "framework/common/debug/ge_log.h" -#include "common/debug/log.h" -#include "common/types.h" +#include "framework/common/debug/log.h" +#include "framework/common/types.h" #include "graph/utils/graph_utils.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" namespace ge { Status TransOpNearbyAllreduceFusionPass::Run(NodePtr &node) { diff --git a/ge/graph/passes/transop_symmetry_elimination_pass.cc b/ge/graph/passes/transop_symmetry_elimination_pass.cc index 665f4bd8..2bd00206 100644 --- a/ge/graph/passes/transop_symmetry_elimination_pass.cc +++ b/ge/graph/passes/transop_symmetry_elimination_pass.cc @@ -14,16 +14,16 @@ * limitations under the License. */ -#include "transop_symmetry_elimination_pass.h" +#include "graph/passes/transop_symmetry_elimination_pass.h" #include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" -#include "types.h" +#include "framework/common/types.h" namespace { const std::set white_list_op{ge::TRANSPOSED, ge::RESHAPE, ge::REFORMAT, ge::CAST, ge::TRANSDATA}; diff --git a/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/ge/graph/passes/transop_without_reshape_fusion_pass.cc index 7e80299b..58145fe7 100644 --- a/ge/graph/passes/transop_without_reshape_fusion_pass.cc +++ b/ge/graph/passes/transop_without_reshape_fusion_pass.cc @@ -20,9 +20,9 @@ #include #include #include "common/ge/ge_util.h" -#include "common/ge_inner_error_codes.h" -#include "common/types.h" -#include "graph/common/transop_util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" +#include "common/transop_util.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" diff --git a/ge/graph/passes/unused_args_clean_pass.cc b/ge/graph/passes/unused_args_clean_pass.cc index 33250311..bc338b86 100755 --- a/ge/graph/passes/unused_args_clean_pass.cc +++ b/ge/graph/passes/unused_args_clean_pass.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "unused_args_clean_pass.h" +#include "graph/passes/unused_args_clean_pass.h" #include "graph/utils/node_utils.h" diff --git a/ge/graph/passes/unused_args_clean_pass.h b/ge/graph/passes/unused_args_clean_pass.h index 90a146b2..400cc802 100644 --- a/ge/graph/passes/unused_args_clean_pass.h +++ b/ge/graph/passes/unused_args_clean_pass.h @@ -16,7 +16,7 @@ #ifndef GE_COMMON_CASE_ARGS_CLEAN_H_ #define GE_COMMON_CASE_ARGS_CLEAN_H_ -#include "graph/types.h" +#include "external/graph/types.h" #include "inc/graph_pass.h" #include diff --git a/ge/graph/passes/variable_op_pass.cc b/ge/graph/passes/variable_op_pass.cc index 862b7016..e803949e 100644 --- a/ge/graph/passes/variable_op_pass.cc +++ b/ge/graph/passes/variable_op_pass.cc @@ -21,7 +21,7 @@ #include "common/formats/formats.h" #include "common/formats/utils/formats_trans_utils.h" #include "graph/ge_context.h" -#include "graph/graph.h" +#include "external/graph/graph.h" #include "graph/manager/graph_var_manager.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" diff --git a/ge/graph/passes/variable_op_pass.h b/ge/graph/passes/variable_op_pass.h index 3b18882c..e314fd12 100755 --- a/ge/graph/passes/variable_op_pass.h +++ b/ge/graph/passes/variable_op_pass.h @@ -18,8 +18,8 @@ #define GE_GRAPH_PASSES_VARIABLE_OP_PASS_H_ #include #include -#include "graph/common/transop_util.h" -#include "graph/graph.h" +#include "common/transop_util.h" +#include "external/graph/graph.h" #include "graph/manager/graph_var_manager.h" #include "graph/manager/util/variable_accelerate_ctrl.h" #include "inc/graph_pass.h" diff --git a/ge/graph/passes/variable_prepare_op_pass.cc b/ge/graph/passes/variable_prepare_op_pass.cc index 3bb9a2fa..288ff185 100644 --- a/ge/graph/passes/variable_prepare_op_pass.cc +++ b/ge/graph/passes/variable_prepare_op_pass.cc @@ -21,7 +21,7 @@ #include "common/ge/ge_util.h" #include "external/graph/graph.h" #include "framework/common/debug/ge_log.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/node.h" #include "graph/utils/tensor_utils.h" diff --git a/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc b/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc index 1c8eb0ec..cac6bf75 100644 --- a/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc +++ b/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "variable_ref_useless_control_out_delete_pass.h" +#include "graph/passes/variable_ref_useless_control_out_delete_pass.h" namespace ge { Status VariableRefUselessControlOutDeletePass::Run(ge::ComputeGraphPtr graph) { diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 0c4adeea..446af9bf 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -24,13 +24,13 @@ #include "common/formats/format_transfers/format_transfer_transpose.h" #include "common/formats/utils/formats_trans_utils.h" #include "common/util/error_manager/error_manager.h" -#include "common/helper/model_helper.h" +#include "framework/common/helper/model_helper.h" #include "common/math/math_util.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "ir_build/option_utils.h" -#include "graph/common/ge_call_wrapper.h" -#include "graph/common/local_context.h" -#include "graph/common/transop_util.h" +#include "common/ge_call_wrapper.h" +#include "common/local_context.h" +#include "common/transop_util.h" #include "graph/ge_context.h" #include "graph/shape_refiner.h" #include "graph/manager/graph_var_manager.h" @@ -39,7 +39,7 @@ #include "graph/passes/addn_pass.h" #include "graph/passes/aicpu_constant_folding_pass.h" #include "graph/passes/assert_pass.h" -#include "ge/ge_api_types.h" +#include "external/ge/ge_api_types.h" #include "graph/passes/common_subexpression_elimination_pass.h" #include "graph/passes/cond_pass.h" #include "graph/passes/cond_remove_pass.h" @@ -54,6 +54,7 @@ #include "graph/passes/hccl_group_pass.h" #include "graph/passes/identity_pass.h" #include "graph/passes/infershape_pass.h" +#include "graph/passes/infer_value_range_pass.h" #include "graph/passes/merge_pass.h" #include "graph/passes/net_output_pass.h" #include "graph/passes/no_use_reshape_remove_pass.h" @@ -79,7 +80,7 @@ #include "graph/utils/type_utils.h" #include "inc/pass_manager.h" #include "init/gelib.h" -#include "multi_batch_copy_graph.h" +#include "graph/preprocess/multi_batch_copy_graph.h" #include "graph/passes/data_pass.h" #include "graph/passes/mark_agnostic_pass.h" @@ -414,16 +415,16 @@ Status UpdateVarFormats(const NodePtr &var, const GeTensorDesc &tensor_desc) { Status RecoverTransRoadForVar(const NodePtr &var, const VarTransRoad &road) { GE_CHECK_NOTNULL(var); - int index = 0; + static std::atomic_int index(0); NodePtr last_node = var; for (auto iter = road.rbegin(); iter != road.rend(); ++iter) { auto trans_name = var->GetName() + "_trans_" + std::to_string(index++); auto ret = RecoverOneTransNodeForVar(trans_name, *iter, last_node, last_node); if (ret != SUCCESS) { - REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %d, type %s", - var->GetName().c_str(), index, iter->node_type.c_str()); - GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s, index %d, type %s", var->GetName().c_str(), - index, iter->node_type.c_str()); + REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %s, type %s", + var->GetName().c_str(), std::to_string(index).c_str(), iter->node_type.c_str()); + GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s, index %s, type %s", var->GetName().c_str(), + std::to_string(index).c_str(), iter->node_type.c_str()); return INTERNAL_ERROR; } // set stream_label @@ -459,17 +460,17 @@ Status RecoverTransRoadForVar(const NodePtr &var, const VarTransRoad &road) { Status RecoverTransRoadForVarRef(const std::set &nodes, const VarTransRoad &road) { for (auto &var : nodes) { GE_CHECK_NOTNULL(var); - int index = 0; + static std::atomic_int index(0); NodePtr last_node = var; GELOGI("Recover trans nodes for variable ref %s", var->GetName().c_str()); for (auto iter = road.rbegin(); iter != road.rend(); ++iter) { auto trans_name = var->GetName() + "_trans_" + std::to_string(index++); auto ret = RecoverOneTransNodeForVarRef(trans_name, *iter, last_node, last_node); if (ret != SUCCESS) { - REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %d, type %s", - var->GetName().c_str(), index, iter->node_type.c_str()); - GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s failed, index %d, type %s", - var->GetName().c_str(), index, iter->node_type.c_str()); + REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %s, type %s", + var->GetName().c_str(), std::to_string(index).c_str(), iter->node_type.c_str()); + GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s failed, index %s, type %s", + var->GetName().c_str(), std::to_string(index).c_str(), iter->node_type.c_str()); return INTERNAL_ERROR; } // set stream_label @@ -1420,9 +1421,10 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { return SUCCESS; } -Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag) { +Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc) { auto format = desc.GetFormat(); auto origin_format = desc.GetOriginFormat(); + auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag); if (need_check_internal_format) { bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); @@ -1439,6 +1441,63 @@ Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTens return SUCCESS; } +Status GraphPrepare::UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc) { + auto data_type = desc.GetDataType(); + uint32_t length = 1; + bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); + if (!type_ret) { + std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" + + std::to_string(index) + " input tensor is not support"; + REPORT_INPUT_ERROR("E19025", std::vector({"reason"}), std::vector({reason})); + GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.", + TypeUtils::DataTypeToSerialString(data_type).c_str()); + return FAILED; + } + int64_t desc_shape = desc.GetShape().GetShapeSize(); + FMK_INT64_UINT32_MULCHECK(desc_shape, length); + int64_t shape_size = desc_shape * length; + GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast(length)); + int64_t size = 0; + GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, + REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index); + GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index); return FAILED); + bool size_check = (size != 0 && shape_size != size); + if (size_check) { + std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) + + "] != shape_size[" + std::to_string(size) + "], check invalid"; + REPORT_INPUT_ERROR("E19025", std::vector({"reason"}), std::vector({reason})); + GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size); + return FAILED; + } + ge::TensorUtils::SetSize(desc, shape_size); + + auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); + if (!tune_flag) { + graphStatus graph_ret = op->UpdateInputDesc(0, desc); + if (graph_ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0", + op->GetName().c_str(), op->GetType().c_str()); + GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0", + op->GetName().c_str(), op->GetType().c_str()); + return graph_ret; + } + // Size will be recalculated in the build stage + ge::TensorUtils::SetSize(desc, 0); + graph_ret = op->UpdateOutputDesc(0, desc); + if (graph_ret != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0", + op->GetName().c_str(), op->GetType().c_str()); + GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0", + op->GetName().c_str(), op->GetType().c_str()); + return graph_ret; + } + } else { + GELOGI("data %s skip update info in tune mode", op->GetName().c_str()); + } + + return SUCCESS; +} + Status GraphPrepare::UpdateInput(const std::vector &user_input, const std::map &graph_option) { // Get shape range of input in dynamic_execute mode @@ -1471,63 +1530,18 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input, } GeTensorDesc desc(user_input[index].GetTensorDesc()); // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. - auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); - ret = CheckInternalFormat(input_node, desc, tune_flag); + ret = CheckInternalFormat(input_node, desc); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str()); return ret; } - auto data_type = desc.GetDataType(); - uint32_t length = 1; - bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); - if (!type_ret) { - std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" + - std::to_string(index) + " input tensor is not support"; - REPORT_INPUT_ERROR("E19025", std::vector({"reason"}), std::vector({reason})); - GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.", - TypeUtils::DataTypeToSerialString(data_type).c_str()); - return FAILED; - } - int64_t desc_shape = desc.GetShape().GetShapeSize(); - FMK_INT64_UINT32_MULCHECK(desc_shape, length); - int64_t shape_size = desc_shape * length; - GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast(length)); - int64_t size = 0; - GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, - REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index); - GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index); - return FAILED); - bool size_check = (size != 0 && shape_size != size); - if (size_check) { - std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) + - "] != shape_size[" + std::to_string(size) + "], check invalid"; - REPORT_INPUT_ERROR("E19025", std::vector({"reason"}), std::vector({reason})); - GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size); - return FAILED; - } - ge::TensorUtils::SetSize(desc, shape_size); - if (!tune_flag) { - graphStatus graph_ret = op->UpdateInputDesc(0, desc); - if (graph_ret != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0", - op->GetName().c_str(), op->GetType().c_str()); - GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0", - op->GetName().c_str(), op->GetType().c_str()); - return graph_ret; - } - // Size will be recalculated in the build stage - ge::TensorUtils::SetSize(desc, 0); - graph_ret = op->UpdateOutputDesc(0, desc); - if (graph_ret != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0", - op->GetName().c_str(), op->GetType().c_str()); - GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0", - op->GetName().c_str(), op->GetType().c_str()); - return graph_ret; - } - } else { - GELOGI("data %s skip update info in tune mode", op->GetName().c_str()); + + ret = UpdateDataInputOutputDesc(index, op, desc); + if (ret != SUCCESS) { + GELOGE(FAILED, "[Update][DataInputOutputDesc] on %s failed", op->GetName().c_str()); + return ret; } + if (!dynamic_shape_range_vec.empty()) { ret = UpdateDynamicInputShapeRange(index, dynamic_shape_range_vec, op, desc); GE_CHK_STATUS_RET(ret, "[Update][DynamicInputShapeRange] on %s failed.", op->GetName().c_str()); @@ -1742,8 +1756,8 @@ Status GraphPrepare::CtrlFlowPreProcess() { PassManager graph_pass; // After InferShape Mark v1 control flow for unknown shape. - auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; - GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::MarkForceUnknownForCondPass", mark_force_unknown_pass)); + GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::MarkForceUnknownForCondPass", + new (std::nothrow) MarkForceUnknownForCondPass)); GE_CHK_STATUS_RET(graph_pass.Run(compute_graph_)); return SUCCESS; @@ -1985,6 +1999,22 @@ Status GraphPrepare::CheckUserInput(const std::vector &user_input) { Status GraphPrepare::InferShapeForPreprocess() { GELOGI("Start infershape for preprocess."); + // Prepare dummy_shape for v1 control_flow op before infershape + for (const auto &node : compute_graph_->GetAllNodes()) { + string type; + GetOriginalType(node, type); + if (type == MERGE || type == REFMERGE) { + for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + GELOGD("Prepare for infershape: update %s input_shape as dummy.", node->GetName().c_str()); + NodeUtils::UpdateInputShape(*node, i, GeShape(DUMMY_SHAPE)); + } + } else if (type == WHILE) { + for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + GELOGD("Prepare for infershape: update %s output_shape as dummy.", node->GetName().c_str()); + NodeUtils::UpdateOutputShape(*node, i, GeShape(DUMMY_SHAPE)); + } + } + } GEPass ge_passes(compute_graph_); NamesToPass names_to_passes; AssertPass assert_pass; @@ -2003,6 +2033,8 @@ Status GraphPrepare::InferShapeForPreprocess() { names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); ConstantFoldingPass constant_folding_pass; names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); + InferValueRangePass infer_value_pass; + names_to_passes.emplace_back("InferValuePass", &infer_value_pass); int32_t dev_count = 0; AicpuConstantFoldingPass aicpu_constant_folding_pass; diff --git a/ge/graph/preprocess/graph_preprocess.h b/ge/graph/preprocess/graph_preprocess.h index 584f4d16..3dfe1797 100755 --- a/ge/graph/preprocess/graph_preprocess.h +++ b/ge/graph/preprocess/graph_preprocess.h @@ -21,13 +21,13 @@ #include #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/debug/memory_dumper.h" #include "common/model_parser/model_parser.h" #include "common/properties_manager.h" -#include "common/string_util.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/string_util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "graph/compute_graph.h" #include "graph/manager/graph_manager_utils.h" #include "graph/manager/util/variable_accelerate_ctrl.h" @@ -35,7 +35,7 @@ #include "graph/node.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" -#include "omg/omg_inner_types.h" +#include "framework/omg/omg_inner_types.h" #include "runtime/context.h" namespace ge { @@ -63,7 +63,8 @@ class GraphPrepare { Status CheckRefOp(); Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); Status AdjustDataOpOutput(const NodePtr &node); - Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag); + Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc); + Status UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc); Status UpdateInput(const std::vector &user_input, const std::map &graph_option); Status CheckAndUpdateInput(const std::vector &user_input, const std::map &graph_option); Status CheckConstOp(); diff --git a/ge/graph/preprocess/insert_op/base_insert_op.h b/ge/graph/preprocess/insert_op/base_insert_op.h index b0d7a7a6..6b1eb177 100644 --- a/ge/graph/preprocess/insert_op/base_insert_op.h +++ b/ge/graph/preprocess/insert_op/base_insert_op.h @@ -21,8 +21,8 @@ #include #include #include -#include "common/fmk_error_codes.h" -#include "common/types.h" +#include "framework/common/fmk_error_codes.h" +#include "framework/common/types.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/compute_graph.h" #include "proto/insert_op.pb.h" diff --git a/ge/graph/preprocess/insert_op/ge_aipp_op.cc b/ge/graph/preprocess/insert_op/ge_aipp_op.cc index 5c191af7..7a89a1f4 100755 --- a/ge/graph/preprocess/insert_op/ge_aipp_op.cc +++ b/ge/graph/preprocess/insert_op/ge_aipp_op.cc @@ -20,10 +20,10 @@ #include #include #include -#include "base_insert_op.h" +#include "graph/preprocess/insert_op/base_insert_op.h" #include "common/dynamic_aipp.h" #include "common/ge/ge_util.h" -#include "common/util.h" +#include "framework/common/util.h" #include "common/util/error_manager/error_manager.h" #include "external/graph/operator_factory.h" #include "framework/common/debug/ge_log.h" @@ -39,7 +39,7 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "proto/insert_op.pb.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #define SAVE_AIPP_ATTR(KEY, SAVE_TYPE) \ do { \ @@ -114,7 +114,7 @@ Status GetDataDimN(const ge::NodePtr &data_node, ge::Format format, int64_t &bat std::vector({ data_node->GetName() + " format", TypeUtils::FormatToSerialString(format), - "only format " + TypeUtils::FormatToSerialString(FORMAT_NCHW) + " and "+ + "only format " + TypeUtils::FormatToSerialString(FORMAT_NCHW) + " and " + TypeUtils::FormatToSerialString(FORMAT_NHWC) + " supported which dynamic aipp is linked"})); GELOGE(PARAM_INVALID, "[Check][Param] Not support data format:%s, node:%s", diff --git a/ge/graph/preprocess/insert_op/ge_aipp_op.h b/ge/graph/preprocess/insert_op/ge_aipp_op.h index 5e509dda..87f80291 100755 --- a/ge/graph/preprocess/insert_op/ge_aipp_op.h +++ b/ge/graph/preprocess/insert_op/ge_aipp_op.h @@ -19,7 +19,7 @@ #include #include -#include "common/op/attr_value_util.h" +#include "framework/common/op/attr_value_util.h" #include "graph/preprocess/insert_op/base_insert_op.h" #include "proto/insert_op.pb.h" diff --git a/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc index d76b79b9..cc7f276e 100755 --- a/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc +++ b/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc @@ -20,8 +20,8 @@ #include "common/dynamic_aipp.h" #include "common/formats/utils/formats_trans_utils.h" #include "common/ge/ge_util.h" -#include "common/op/ge_op_utils.h" -#include "common/util.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/util.h" #include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" @@ -568,6 +568,7 @@ Status InsertNewOpUtil::GetDataRelatedNode(NodePtr &node, std::map aipp_params(new (std::nothrow) domi::AippOpParams()); + GE_CHECK_NOTNULL(aipp_params); ge::GeAttrValue::NAMED_ATTRS aipp_attr; GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST, "[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str()); diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index 1634c8ce..d88cf6cd 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -38,8 +38,8 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "inc/pass_manager.h" -#include "graph/common/local_context.h" -#include "graph/common/omg_util.h" +#include "common/local_context.h" +#include "common/omg_util.h" using std::set; using std::string; @@ -1206,7 +1206,7 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector &start_ auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); if (!IsAllDimsPositive(dims)) { REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s", - node->GetName().c_str(), formats::ShapeToString(dims).c_str()); + node->GetName().c_str(), formats::ShapeToString(dims).c_str()); GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s", node->GetName().c_str(), formats::ShapeToString(dims).c_str()); return INTERNAL_ERROR; diff --git a/ge/graph/preprocess/multi_batch_options.cc b/ge/graph/preprocess/multi_batch_options.cc index b3e5b616..9cda6194 100644 --- a/ge/graph/preprocess/multi_batch_options.cc +++ b/ge/graph/preprocess/multi_batch_options.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "multi_batch_options.h" +#include "graph/preprocess/multi_batch_options.h" #include "framework/common/debug/ge_log.h" #include "framework/omg/omg_inner_types.h" @@ -25,11 +25,11 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/node_utils.h" #include "graph/ge_context.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "framework/common/types.h" #include "graph/compute_graph.h" #include "graph/utils/graph_utils.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { namespace multibatch { diff --git a/ge/host_kernels/add_kernel.cc b/ge/host_kernels/add_kernel.cc index 1c206018..eb0ea86d 100644 --- a/ge/host_kernels/add_kernel.cc +++ b/ge/host_kernels/add_kernel.cc @@ -19,7 +19,7 @@ #include #include "common/math/math_util.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/broadcast_args_kernel.cc b/ge/host_kernels/broadcast_args_kernel.cc index d8880db9..660717ad 100644 --- a/ge/host_kernels/broadcast_args_kernel.cc +++ b/ge/host_kernels/broadcast_args_kernel.cc @@ -18,11 +18,11 @@ #include -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/broadcast_gradient_args_kernel.cc b/ge/host_kernels/broadcast_gradient_args_kernel.cc index 51ff4a4c..8b9e3fb5 100644 --- a/ge/host_kernels/broadcast_gradient_args_kernel.cc +++ b/ge/host_kernels/broadcast_gradient_args_kernel.cc @@ -17,12 +17,12 @@ #include -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/cast_kernel.cc b/ge/host_kernels/cast_kernel.cc index 056081a1..2d2f463c 100644 --- a/ge/host_kernels/cast_kernel.cc +++ b/ge/host_kernels/cast_kernel.cc @@ -19,16 +19,16 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/formats/formats.h" #include "common/formats/utils/formats_trans_utils.h" #include "common/fp16_t.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/concat_offset_kernel.cc b/ge/host_kernels/concat_offset_kernel.cc index b6940eb4..79552183 100644 --- a/ge/host_kernels/concat_offset_kernel.cc +++ b/ge/host_kernels/concat_offset_kernel.cc @@ -18,9 +18,9 @@ #include -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/concat_v2_kernel.cc b/ge/host_kernels/concat_v2_kernel.cc index 234d8c8a..c5a7d889 100644 --- a/ge/host_kernels/concat_v2_kernel.cc +++ b/ge/host_kernels/concat_v2_kernel.cc @@ -19,9 +19,9 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/fp16_t.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/host_kernels/dynamic_stitch_kernel.cc b/ge/host_kernels/dynamic_stitch_kernel.cc index 52f6cdcf..0313c856 100644 --- a/ge/host_kernels/dynamic_stitch_kernel.cc +++ b/ge/host_kernels/dynamic_stitch_kernel.cc @@ -20,10 +20,10 @@ #include #include "common/fp16_t.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/math/math_util.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/empty_kernel.cc b/ge/host_kernels/empty_kernel.cc index 61310abc..68ba7f9f 100644 --- a/ge/host_kernels/empty_kernel.cc +++ b/ge/host_kernels/empty_kernel.cc @@ -19,8 +19,8 @@ #include #include "common/fp16_t.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "host_kernels/kernel_utils.h" diff --git a/ge/host_kernels/expanddims_kernel.cc b/ge/host_kernels/expanddims_kernel.cc index f304fbdb..d6ea0287 100644 --- a/ge/host_kernels/expanddims_kernel.cc +++ b/ge/host_kernels/expanddims_kernel.cc @@ -18,9 +18,9 @@ #include -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/fill_kernel.cc b/ge/host_kernels/fill_kernel.cc index 0022791c..ac46101b 100644 --- a/ge/host_kernels/fill_kernel.cc +++ b/ge/host_kernels/fill_kernel.cc @@ -20,8 +20,8 @@ #include #include "common/fp16_t.h" -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" @@ -45,6 +45,7 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vectorGetName().c_str()); GE_CHECK_NOTNULL(input.at(kFillDimsInputIndex)); GE_CHECK_NOTNULL(input.at(kFillDataInputIndex)); @@ -57,6 +58,13 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vectorGetOutputDescPtr(0); + GE_CHECK_NOTNULL(output_desc); + if (output_desc->GetShape().IsUnknownShape()) { + GELOGD("Output is unknown shape, [%s] skip FillKernel.", op_desc_ptr->GetName().c_str()); + return NOT_CHANGED; + } + GeTensorPtr output_ptr; output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(0)); if (output_ptr == nullptr) { diff --git a/ge/host_kernels/floordiv_kernel.cc b/ge/host_kernels/floordiv_kernel.cc index df381212..566a45a3 100644 --- a/ge/host_kernels/floordiv_kernel.cc +++ b/ge/host_kernels/floordiv_kernel.cc @@ -21,8 +21,8 @@ #include #include -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/host_kernels/floormod_kernel.cc b/ge/host_kernels/floormod_kernel.cc index 31e4e19b..1d101667 100644 --- a/ge/host_kernels/floormod_kernel.cc +++ b/ge/host_kernels/floormod_kernel.cc @@ -19,11 +19,11 @@ #include #include -#include "common/types.h" -#include "common/util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/gather_v2_kernel.cc b/ge/host_kernels/gather_v2_kernel.cc index 5702954c..45445143 100644 --- a/ge/host_kernels/gather_v2_kernel.cc +++ b/ge/host_kernels/gather_v2_kernel.cc @@ -20,10 +20,10 @@ #include #include "common/fp16_t.h" -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/host_kernels/greater_kernel.cc b/ge/host_kernels/greater_kernel.cc index a245ec8d..0cc895c4 100644 --- a/ge/host_kernels/greater_kernel.cc +++ b/ge/host_kernels/greater_kernel.cc @@ -19,13 +19,13 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/fp16_t.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/identity_kernel.cc b/ge/host_kernels/identity_kernel.cc index ef1446a8..30f55027 100644 --- a/ge/host_kernels/identity_kernel.cc +++ b/ge/host_kernels/identity_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "identity_kernel.h" +#include "host_kernels/identity_kernel.h" #include "inc/kernel_factory.h" #include "framework/common/types.h" diff --git a/ge/host_kernels/kernel_utils.cc b/ge/host_kernels/kernel_utils.cc index 595f9517..6447fa43 100755 --- a/ge/host_kernels/kernel_utils.cc +++ b/ge/host_kernels/kernel_utils.cc @@ -18,8 +18,8 @@ #include -#include "common/ge_inner_error_codes.h" -#include "common/types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" diff --git a/ge/host_kernels/kernel_utils.h b/ge/host_kernels/kernel_utils.h index c9c90634..7a7545ea 100755 --- a/ge/host_kernels/kernel_utils.h +++ b/ge/host_kernels/kernel_utils.h @@ -20,8 +20,8 @@ #include #include -#include "common/ge_inner_error_codes.h" -#include "common/util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "graph/compute_graph.h" diff --git a/ge/host_kernels/maximum_kernel.cc b/ge/host_kernels/maximum_kernel.cc index 2ced113f..0e28fcdc 100644 --- a/ge/host_kernels/maximum_kernel.cc +++ b/ge/host_kernels/maximum_kernel.cc @@ -19,13 +19,13 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/fp16_t.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/mul_kernel.cc b/ge/host_kernels/mul_kernel.cc index b01a5c79..608f351d 100644 --- a/ge/host_kernels/mul_kernel.cc +++ b/ge/host_kernels/mul_kernel.cc @@ -19,13 +19,13 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/math/math_util.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/pack_kernel.cc b/ge/host_kernels/pack_kernel.cc index bf7a2a1f..103f4029 100644 --- a/ge/host_kernels/pack_kernel.cc +++ b/ge/host_kernels/pack_kernel.cc @@ -18,10 +18,10 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/formats/utils/formats_trans_utils.h" -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" #include "host_kernels/kernel_utils.h" diff --git a/ge/host_kernels/permute_kernel.cc b/ge/host_kernels/permute_kernel.cc index 327c94f8..9e9462b6 100755 --- a/ge/host_kernels/permute_kernel.cc +++ b/ge/host_kernels/permute_kernel.cc @@ -19,12 +19,12 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "framework/common/debug/ge_log.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" -#include "graph/common/bcast.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" +#include "framework/common/util.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" #include "common/formats/formats.h" diff --git a/ge/host_kernels/range_kernel.cc b/ge/host_kernels/range_kernel.cc index 97254fff..d8f8200a 100644 --- a/ge/host_kernels/range_kernel.cc +++ b/ge/host_kernels/range_kernel.cc @@ -19,10 +19,10 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/fp16_t.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/utils/type_utils.h" diff --git a/ge/host_kernels/rank_kernel.cc b/ge/host_kernels/rank_kernel.cc index b246b976..9bc404f3 100755 --- a/ge/host_kernels/rank_kernel.cc +++ b/ge/host_kernels/rank_kernel.cc @@ -19,12 +19,12 @@ #include #include -#include "graph/types.h" -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" +#include "external/graph/types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "inc/kernel_factory.h" -#include "omg/omg_inner_types.h" +#include "framework/omg/omg_inner_types.h" #include "framework/common/types.h" namespace { diff --git a/ge/host_kernels/reduce_prod_kernel.cc b/ge/host_kernels/reduce_prod_kernel.cc index 4837a921..efe48997 100644 --- a/ge/host_kernels/reduce_prod_kernel.cc +++ b/ge/host_kernels/reduce_prod_kernel.cc @@ -20,8 +20,8 @@ #include #include "common/math/math_util.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "host_kernels/kernel_utils.h" diff --git a/ge/host_kernels/reformat_kernel.cc b/ge/host_kernels/reformat_kernel.cc index c1942983..b841ed39 100644 --- a/ge/host_kernels/reformat_kernel.cc +++ b/ge/host_kernels/reformat_kernel.cc @@ -17,10 +17,10 @@ #include "host_kernels/reformat_kernel.h" #include "common/formats/utils/formats_trans_utils.h" #include "common/ge/ge_util.h" -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/host_kernels/reshape_kernel.cc b/ge/host_kernels/reshape_kernel.cc index 7c4f58f6..bead8468 100644 --- a/ge/host_kernels/reshape_kernel.cc +++ b/ge/host_kernels/reshape_kernel.cc @@ -16,9 +16,9 @@ #include "host_kernels/reshape_kernel.h" -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/rsqrt_kernel.cc b/ge/host_kernels/rsqrt_kernel.cc index 74c78787..f1f74a99 100755 --- a/ge/host_kernels/rsqrt_kernel.cc +++ b/ge/host_kernels/rsqrt_kernel.cc @@ -19,10 +19,10 @@ #include -#include "common/debug/ge_log.h" -#include "common/debug/log.h" -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/size_kernel.cc b/ge/host_kernels/size_kernel.cc index caa5febc..9f7bc0ff 100644 --- a/ge/host_kernels/size_kernel.cc +++ b/ge/host_kernels/size_kernel.cc @@ -19,8 +19,8 @@ #include #include -#include "common/debug/log.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/debug/log.h" +#include "framework/common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" @@ -28,7 +28,7 @@ #include "host_kernels/kernel_utils.h" #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" -#include "omg/omg_inner_types.h" +#include "framework/omg/omg_inner_types.h" namespace ge { namespace { diff --git a/ge/host_kernels/slice_d_kernel.cc b/ge/host_kernels/slice_d_kernel.cc index b8572290..60caac38 100644 --- a/ge/host_kernels/slice_d_kernel.cc +++ b/ge/host_kernels/slice_d_kernel.cc @@ -19,8 +19,8 @@ #include #include "common/fp16_t.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/host_kernels/slice_kernel.cc b/ge/host_kernels/slice_kernel.cc index 6e398e96..0b1e7325 100644 --- a/ge/host_kernels/slice_kernel.cc +++ b/ge/host_kernels/slice_kernel.cc @@ -18,10 +18,10 @@ #include -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "graph/utils/type_utils.h" #include "host_kernels/kernel_utils.h" diff --git a/ge/host_kernels/squeeze_kernel.cc b/ge/host_kernels/squeeze_kernel.cc index 4a2c6725..852a46a1 100644 --- a/ge/host_kernels/squeeze_kernel.cc +++ b/ge/host_kernels/squeeze_kernel.cc @@ -16,9 +16,9 @@ #include "host_kernels/squeeze_kernel.h" -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/ssd_prior_box_kernel.cc b/ge/host_kernels/ssd_prior_box_kernel.cc index 3661fa9d..496cc185 100644 --- a/ge/host_kernels/ssd_prior_box_kernel.cc +++ b/ge/host_kernels/ssd_prior_box_kernel.cc @@ -23,7 +23,7 @@ #include "common/math/math_util.h" #include "common/math_util.h" -#include "common/types.h" +#include "framework/common/types.h" #include "framework/common/util.h" #include "graph/debug/ge_attr_define.h" #include "graph/passes/pass_utils.h" diff --git a/ge/host_kernels/sub_kernel.cc b/ge/host_kernels/sub_kernel.cc index deb36cb3..0aebb946 100644 --- a/ge/host_kernels/sub_kernel.cc +++ b/ge/host_kernels/sub_kernel.cc @@ -20,10 +20,10 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/math/math_util.h" -#include "common/op/ge_op_utils.h" -#include "graph/common/bcast.h" +#include "framework/common/op/ge_op_utils.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/transdata_kernel.cc b/ge/host_kernels/transdata_kernel.cc index 2b16b075..7d44fdae 100644 --- a/ge/host_kernels/transdata_kernel.cc +++ b/ge/host_kernels/transdata_kernel.cc @@ -19,16 +19,16 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/formats/formats.h" #include "common/formats/utils/formats_trans_utils.h" #include "common/fp16_t.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/transpose_kernel.cc b/ge/host_kernels/transpose_kernel.cc index 03d112aa..9291ecf5 100755 --- a/ge/host_kernels/transpose_kernel.cc +++ b/ge/host_kernels/transpose_kernel.cc @@ -17,13 +17,13 @@ #include "host_kernels/transpose_kernel.h" #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/formats/format_transfers/format_transfer_transpose.h" #include "common/formats/formats.h" #include "common/formats/utils/formats_trans_utils.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "host_kernels/kernel_utils.h" diff --git a/ge/host_kernels/unpack_kernel.cc b/ge/host_kernels/unpack_kernel.cc index 1c28151f..a90e3616 100755 --- a/ge/host_kernels/unpack_kernel.cc +++ b/ge/host_kernels/unpack_kernel.cc @@ -15,10 +15,10 @@ */ #include "host_kernels/unpack_kernel.h" -#include "common/debug/ge_log.h" -#include "common/op/ge_op_utils.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/unsqueeze_kernel.cc b/ge/host_kernels/unsqueeze_kernel.cc index 4ceaba3f..d2b0d05f 100644 --- a/ge/host_kernels/unsqueeze_kernel.cc +++ b/ge/host_kernels/unsqueeze_kernel.cc @@ -16,9 +16,9 @@ #include "host_kernels/unsqueeze_kernel.h" #include -#include "common/ge_inner_error_codes.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/hybrid/common/npu_memory_allocator.cc b/ge/hybrid/common/npu_memory_allocator.cc index b66038d9..8a9aa0cc 100644 --- a/ge/hybrid/common/npu_memory_allocator.cc +++ b/ge/hybrid/common/npu_memory_allocator.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "npu_memory_allocator.h" +#include "hybrid/common/npu_memory_allocator.h" #include #include "framework/common/debug/log.h" #include "graph/manager/graph_mem_manager.h" diff --git a/ge/hybrid/common/npu_memory_allocator.h b/ge/hybrid/common/npu_memory_allocator.h index 55cb13ad..8df89108 100644 --- a/ge/hybrid/common/npu_memory_allocator.h +++ b/ge/hybrid/common/npu_memory_allocator.h @@ -23,7 +23,7 @@ #include #include #include "external/ge/ge_api_error_codes.h" -#include "memory/memory_api.h" +#include "framework/memory/memory_api.h" namespace ge { namespace hybrid { diff --git a/ge/hybrid/common/tensor_value.h b/ge/hybrid/common/tensor_value.h index 348e4e6d..c041263b 100644 --- a/ge/hybrid/common/tensor_value.h +++ b/ge/hybrid/common/tensor_value.h @@ -20,7 +20,7 @@ #include #include #include -#include "memory/memory_api.h" +#include "framework/memory/memory_api.h" #include "framework/common/util.h" namespace ge { @@ -95,7 +95,8 @@ class TensorValue { name_ = name; } - MemStorageType GetMemType() const { + Status GetMemType(MemStorageType &mem_type) const { + GE_CHECK_NOTNULL(buffer_); return buffer_->GetMemType(); } diff --git a/ge/hybrid/executor/hybrid_execution_context.cc b/ge/hybrid/executor/hybrid_execution_context.cc index 0f978bf8..2de9b1ce 100644 --- a/ge/hybrid/executor/hybrid_execution_context.cc +++ b/ge/hybrid/executor/hybrid_execution_context.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "hybrid_execution_context.h" +#include "hybrid/executor/hybrid_execution_context.h" #include namespace ge { diff --git a/ge/hybrid/executor/hybrid_model_async_executor.cc b/ge/hybrid/executor/hybrid_model_async_executor.cc index 930412e3..229cce84 100644 --- a/ge/hybrid/executor/hybrid_model_async_executor.cc +++ b/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -19,7 +19,7 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "graph/ge_context.h" -#include "graph/types.h" +#include "external/graph/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_caching_allocator.h" #include "graph/manager/graph_mem_allocator.h" @@ -78,8 +78,6 @@ Status HybridModelAsyncExecutor::Start(const std::shared_ptr &lis GetThreadLocalContext() = *executor_->GetContext()->ge_context; GetContext().SetSessionId(executor_->GetContext()->session_id); GetContext().SetContextId(executor_->GetContext()->context_id); - GE_CHECK_NOTNULL(executor_->GetContext()->ge_context); - GetThreadLocalContext() = *executor_->GetContext()->ge_context; return RunInternal(); }); @@ -297,13 +295,15 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData ¤t_data, Hy } } tensor_desc->SetShape(shape); - args.input_desc[input_index] = tensor_desc; - GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str()); + GELOGD("Update shape[%s] of input[%zu] to [%s]", + shape.ToString().c_str(), input_index, tensor_desc->MutableShape().ToString().c_str()); GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), "[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size," "index = %zu, shape = [%s], model_id = %u.", input_index, tensor_desc->GetShape().ToString().c_str(), model_id_); - GELOGD("Input tensor[%zu] size = %zu", input_index, tensor_size); + GELOGD("Input tensor[%zu] size = %ld", input_index, tensor_size); + TensorUtils::SetSize(*tensor_desc, tensor_size); + args.input_desc[input_index] = tensor_desc; } GE_CHECK_GE(tensor_size, 0); @@ -460,7 +460,8 @@ Status HybridModelAsyncExecutor::CopyOutputs(HybridModelExecutor::ExecuteArgs &a auto tensor = TensorAdapter::AsTensor(ge_tensor); outputs.emplace_back(std::move(tensor)); } else { - BuildDeviceTensor(output_tensor, ge_tensor_desc, output_size, outputs); + GE_CHK_STATUS_RET(BuildDeviceTensor(output_tensor, ge_tensor_desc, output_size, outputs), + "Build device tensor failed"); output_data->blobs.emplace_back(output_tensor.Release(), static_cast(output_size), false, static_cast(kPlacementDevice)); } @@ -480,13 +481,15 @@ Status HybridModelAsyncExecutor::CopyOutputs(HybridModelExecutor::ExecuteArgs &a return SUCCESS; } -void HybridModelAsyncExecutor::BuildDeviceTensor(TensorValue &output_tensor, GeTensorDesc &ge_tensor_desc, - int64_t output_size, std::vector &outputs) { +Status HybridModelAsyncExecutor::BuildDeviceTensor(TensorValue &output_tensor, GeTensorDesc &ge_tensor_desc, + int64_t output_size, std::vector &outputs) { GELOGD("Start to build device tensor"); - auto mem_type = output_tensor.GetMemType(); + MemStorageType mem_type = HBM; + GE_CHK_STATUS_RET(output_tensor.GetMemType(mem_type), "[Build][DeviceTensor] Get mem type failed"); GELOGD("Mem type is %d", static_cast(mem_type)); auto deleter = [=](uint8_t *device_data) { if (device_data != nullptr) { + GELOGD("Free device addr is %p", device_data); if (mem_type == RDMA_HBM) { MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Free(device_data, device_id_); } else if (mem_type == HOST_DDR) { @@ -501,6 +504,7 @@ void HybridModelAsyncExecutor::BuildDeviceTensor(TensorValue &output_tensor, GeT auto tensor = TensorAdapter::AsTensor(ge_tensor); tensor.SetData(reinterpret_cast(output_tensor.Release()), static_cast(output_size), deleter); outputs.emplace_back(std::move(tensor)); + return SUCCESS; } Status HybridModelAsyncExecutor::Execute(const std::vector &inputs, diff --git a/ge/hybrid/executor/hybrid_model_async_executor.h b/ge/hybrid/executor/hybrid_model_async_executor.h index 5ae1a222..f94f6aa5 100644 --- a/ge/hybrid/executor/hybrid_model_async_executor.h +++ b/ge/hybrid/executor/hybrid_model_async_executor.h @@ -76,8 +76,8 @@ class HybridModelAsyncExecutor { OutputData *output_data); Status CopyOutputs(HybridModelExecutor::ExecuteArgs &args, OutputData *output_data, std::vector &outputs); - void BuildDeviceTensor(TensorValue &output_tensor, GeTensorDesc &ge_tensor_desc, int64_t output_size, - std::vector &outputs); + Status BuildDeviceTensor(TensorValue &output_tensor, GeTensorDesc &ge_tensor_desc, int64_t output_size, + std::vector &outputs); Status OnComputeDone(uint32_t data_index, uint32_t result_code, std::vector &outputs); diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc index d8939175..dd8aace6 100755 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "hybrid_model_executor.h" +#include "hybrid/executor/hybrid_model_executor.h" #include "graph/ge_context.h" #include "graph/runtime_inference_context.h" #include "graph/utils/tensor_utils.h" @@ -33,14 +33,14 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, } HybridModelExecutor::~HybridModelExecutor() { - if (context_.rt_gen_context != nullptr) { - (void) rtCtxDestroy(context_.rt_gen_context); - } } -Status HybridModelExecutor::Init() { +Status HybridModelExecutor::Init(ThreadPool *thread_pool) { GELOGD("Start to init HybridGraphEngine."); GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); + root_graph_executor_.reset( + new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_, false, thread_pool)); + GE_CHECK_NOTNULL(root_graph_executor_); GELOGD("HybridGraphEngine initialized successfully."); return SUCCESS; } @@ -60,8 +60,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); } - SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); - auto ret = ExecuteGraphInternal(executor, args); + auto ret = ExecuteGraphInternal(args); Cleanup(); RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); GELOGD("Model executed successfully."); @@ -69,6 +68,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { context_.profiler->Dump(std::cout); context_.profiler->Reset(); } + root_graph_executor_->ReleaseContext(); context_.iteration += 1; if (ret == END_OF_SEQUENCE) { @@ -79,8 +79,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { return SUCCESS; } -Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, - HybridModelExecutor::ExecuteArgs &args) { +Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) { RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); @@ -94,7 +93,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); } - HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs), + HYBRID_CHK_STATUS_RET(root_graph_executor_->ExecuteAsync(args.inputs, args.input_desc, args.outputs), "Failed to execute partitioned call."); RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); @@ -103,7 +102,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, } if (!model_->IsSingleOp()) { - Status ret = executor.Synchronize(); + Status ret = root_graph_executor_->Synchronize(); if (ret != ge::SUCCESS) { auto model_manager = ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); @@ -123,7 +122,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, } args.outputs.clear(); - HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); + HYBRID_CHK_STATUS_RET(root_graph_executor_->GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); return SUCCESS; } @@ -138,7 +137,6 @@ Status HybridModelExecutor::Cleanup() { Status HybridModelExecutor::InitExecutionContext() { GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); - GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); context_.global_step = model_->GetGlobalStep(); diff --git a/ge/hybrid/executor/hybrid_model_executor.h b/ge/hybrid/executor/hybrid_model_executor.h index 566043d9..dbec7adf 100644 --- a/ge/hybrid/executor/hybrid_model_executor.h +++ b/ge/hybrid/executor/hybrid_model_executor.h @@ -39,7 +39,7 @@ class HybridModelExecutor { ~HybridModelExecutor(); - Status Init(); + Status Init(ThreadPool *thread_pool = nullptr); const GraphExecutionContext* GetContext() const { return &context_; @@ -48,7 +48,7 @@ class HybridModelExecutor { Status Execute(ExecuteArgs &args); private: - Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args); + Status ExecuteGraphInternal(ExecuteArgs &args); Status Cleanup(); Status InitExecutionContext(); static Status ResetExecutionContext(GraphExecutionContext &context); @@ -58,6 +58,7 @@ class HybridModelExecutor { uint32_t device_id_; rtStream_t stream_; GraphExecutionContext context_; + std::unique_ptr root_graph_executor_; }; } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/executor/hybrid_model_pipeline_executor.cc b/ge/hybrid/executor/hybrid_model_pipeline_executor.cc index c0bd5c7d..57ba20d4 100644 --- a/ge/hybrid/executor/hybrid_model_pipeline_executor.cc +++ b/ge/hybrid/executor/hybrid_model_pipeline_executor.cc @@ -1,4 +1,20 @@ -#include "hybrid_model_pipeline_executor.h" +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/executor/hybrid_model_pipeline_executor.h" #include "common/math/math_util.h" #include "common/dump/dump_manager.h" @@ -172,10 +188,10 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin config_.num_executors = kNumExecutors; config_.num_stages = model_->GetRootGraphItem()->NumGroups(); config_.device_id = device_id_; + config_.iteration_end = 0; } Status StageExecutor::InitExecutionContext() { - GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); context_.model = model_; diff --git a/ge/hybrid/executor/hybrid_model_pipeline_executor.h b/ge/hybrid/executor/hybrid_model_pipeline_executor.h index c59e1462..f694c4e4 100644 --- a/ge/hybrid/executor/hybrid_model_pipeline_executor.h +++ b/ge/hybrid/executor/hybrid_model_pipeline_executor.h @@ -1,3 +1,19 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #ifndef GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_ #define GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_ @@ -6,7 +22,7 @@ #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/executor/rt_callback_manager.h" #include "hybrid/executor/subgraph_executor.h" -#include "hybrid_model_executor.h" +#include "hybrid/executor/hybrid_model_executor.h" namespace ge { namespace hybrid { diff --git a/ge/hybrid/executor/hybrid_profiler.cc b/ge/hybrid/executor/hybrid_profiler.cc index 384dc770..f9231a39 100644 --- a/ge/hybrid/executor/hybrid_profiler.cc +++ b/ge/hybrid/executor/hybrid_profiler.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "hybrid_profiler.h" +#include "hybrid/executor/hybrid_profiler.h" #include #include #include diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 313a2934..ad38c792 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -19,8 +19,9 @@ #include "framework/common/debug/log.h" #include "graph/compute_graph.h" #include "graph/utils/tensor_utils.h" -#include "hybrid_execution_context.h" -#include "subgraph_context.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/subgraph_context.h" +#include "hybrid/node_executor/task_context.h" #define INC_ITERATION_COUNT(iteration) \ do { \ @@ -260,6 +261,16 @@ NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_contex this->op_desc_ = node_item.node->GetOpDesc(); } +Status NodeState::Init(int group, const shared_ptr &frame_state) { + GE_CHECK_NOTNULL(frame_state); + group_ = group; + frame_state_ = frame_state; + auto unique_task_context = TaskContext::Create(this, subgraph_context_); + GE_CHECK_NOTNULL(unique_task_context); + task_context_ = std::shared_ptr(unique_task_context.release()); + return SUCCESS; +} + Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { if (node_item_->IsMergeOp()) { GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size()); @@ -314,15 +325,75 @@ std::shared_ptr NodeState::GetTaskContext() { return task_context_; } +void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { + const auto is_persist_tensor = [](const std::map> &items, int idx) { + const auto is_exist = [&idx](const std::pair> &items) { + return items.second.count(idx) > 0; + }; + return std::any_of(items.begin(), items.end(), is_exist); + }; + + if (root_tensor_values_.count(input_idx) > 0) { + return; + } + + if (is_persist_tensor(node_item_->root_data_, input_idx)) { + GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); + root_tensor_values_[input_idx] = tensor; + } else if (is_persist_tensor(node_item_->enter_data_, input_idx)) { + GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); + root_tensor_values_[input_idx] = tensor; + } +} + +void NodeState::UpdatePersistTensor() { + const auto update_tensor = [&](const std::map> &items) { + for (const auto &item : items) { + for (const auto idx : item.second) { + UpdatePersistTensor(idx); + } + } + }; + + if (root_tensor_values_.empty()) { + return; + } + + update_tensor(node_item_->root_data_); + if (iteration_count_ > 0) { + update_tensor(node_item_->enter_data_); + } +} + +void NodeState::UpdatePersistTensor(int input_idx) { + const auto it = root_tensor_values_.find(input_idx); + if (it == root_tensor_values_.end()) { + GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); + return; + } + + auto tensor = task_context_->MutableInput(input_idx); + if (tensor == nullptr) { + GELOGW("[%s] Not found input tensor: %d", GetName().c_str(), input_idx); + return; + } + + *tensor = it->second; + GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx); +} + void NodeState::ResetContext(uint64_t iteration) { switch_index_ = -1; subgraph_context_->ResetContext(node_item_->node); - if (iteration == 0) { - data_scheduled_ = static_cast(node_item_->root_data_.size()); - ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size()); - } else { - data_scheduled_ = static_cast(node_item_->root_data_.size() + node_item_->enter_data_.size()); - ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size()); + auto unique_task_context = TaskContext::Create(this, subgraph_context_); + GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context); + task_context_ = std::shared_ptr(unique_task_context.release()); + + data_scheduled_ = static_cast(node_item_->root_data_.size()); + ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size()); + if (iteration > 0) { + data_scheduled_ += static_cast(node_item_->enter_data_.size()); + ctrl_scheduled_ += static_cast(node_item_->enter_ctrl_.size()); } iteration_count_ = iteration; diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index 9dd29846..f1cec215 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -24,7 +24,7 @@ #include "common/blocking_queue.h" #include "external/ge/ge_api_error_codes.h" #include "hybrid/model/node_item.h" -#include "node_done_manager.h" +#include "hybrid/executor/node_done_manager.h" namespace ge { namespace hybrid { @@ -100,6 +100,8 @@ struct NodeState { NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); ~NodeState() = default; + Status Init(int group, const shared_ptr &frame_state); + OpDesc *GetOpDesc() const { return op_desc_.get(); } @@ -129,6 +131,9 @@ struct NodeState { void RunStreamActive(); void RunNextIteration(); + void SavePersistTensor(int input_idx, const TensorValue &tensor); + void UpdatePersistTensor(); + Status NodeScheduled(const std::function &ready) const; void SetScheduleFuture(std::future &&future); @@ -150,18 +155,10 @@ struct NodeState { return merge_index_; } - void SetGroup(int group) { - group_ = group; - } - int GetGroup() const { return group_; } - void SetFrameState(const shared_ptr &frame_state) { - frame_state_ = frame_state; - } - const shared_ptr &GetKernelTask() const { return kernel_task_; } @@ -181,12 +178,17 @@ struct NodeState { void SetTaskContext(std::shared_ptr &task_context); std::shared_ptr GetTaskContext(); + void SetSkipInferShape(bool skip_infershape) { skip_infershape_ = skip_infershape; } + + bool MaySkipShapeInference() const { return skip_infershape_; } + private: bool IsScheduleReady() const; void SetDataSchedule(const NodeState &node_state, const std::function &ready); void SetCtrlSchedule(const NodeState &node_state, const std::function &ready); void ResetContext(uint64_t iteration); void ScheduleContext(const NodeState &node_state); + void UpdatePersistTensor(int input_idx); const NodeItem *node_item_ = nullptr; std::shared_ptr kernel_task_ = nullptr; @@ -199,6 +201,7 @@ struct NodeState { std::future schedule_future_; std::shared_ptr frame_state_; + std::map root_tensor_values_; uint64_t active_count_ = 0; uint64_t iteration_count_ = 0; uint32_t ctrl_scheduled_ = 0; @@ -206,6 +209,7 @@ struct NodeState { int merge_index_ = -1; // Use for Execute (Reset after Executed). int switch_index_ = -1; // Use for Schedule (Reset after Prepared). int group_ = -1; + bool skip_infershape_ = false; }; } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/executor/rt_callback_manager.h b/ge/hybrid/executor/rt_callback_manager.h index 9c062134..15b0dede 100644 --- a/ge/hybrid/executor/rt_callback_manager.h +++ b/ge/hybrid/executor/rt_callback_manager.h @@ -23,7 +23,7 @@ #include #include "common/blocking_queue.h" -#include "ge/ge_api_error_codes.h" +#include "external/ge/ge_api_error_codes.h" #include "runtime/rt.h" namespace ge { diff --git a/ge/hybrid/executor/subgraph_context.cc b/ge/hybrid/executor/subgraph_context.cc index b6763ffd..4b748e3f 100644 --- a/ge/hybrid/executor/subgraph_context.cc +++ b/ge/hybrid/executor/subgraph_context.cc @@ -14,12 +14,12 @@ * limitations under the License. */ -#include "subgraph_context.h" +#include "hybrid/executor/subgraph_context.h" #include "hybrid/executor/hybrid_model_executor.h" namespace ge { namespace hybrid { -SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context) +SubgraphContext::SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context) : graph_item_(graph_item), execution_context_(execution_context) { } @@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { return nullptr; } + return CreateNodeState(node_item); +} + +NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) { GELOGD("[%s] lock for write", node_item->NodeName().c_str()); if (mmRWLockWRLock(&rw_lock_) != EN_OK) { REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); return nullptr; } + auto &node_state = node_states_[node_item]; - if (node_state == nullptr) { - const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); - node_state.reset(new(std::nothrow)NodeState(*node_item, this)); - node_state->SetFrameState(GetOrCreateFrameState(*node_item)); - node_state->SetGroup(group_); - (void)guard; - } + do { + if (node_state == nullptr) { + const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); + node_state.reset(new(std::nothrow)NodeState(*node_item, this)); + if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str()); + REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str()); + break; + } + (void)guard; + } + } while (0); + GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); if (mmWRLockUnLock(&rw_lock_) != EN_OK) { REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); diff --git a/ge/hybrid/executor/subgraph_context.h b/ge/hybrid/executor/subgraph_context.h index a43cd210..023be981 100644 --- a/ge/hybrid/executor/subgraph_context.h +++ b/ge/hybrid/executor/subgraph_context.h @@ -30,7 +30,7 @@ namespace ge { namespace hybrid { class SubgraphContext { public: - explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context); + explicit SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context); ~SubgraphContext(); Status Init(); @@ -51,10 +51,11 @@ class SubgraphContext { void NodeDone(const NodePtr &node); private: + NodeStatePtr CreateNodeState(const NodeItem *node_item); FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock friend class TaskContext; const GraphItem *graph_item_; - const GraphExecutionContext *execution_context_; + GraphExecutionContext *execution_context_; mmRWLock_t rw_lock_; std::vector all_inputs_; std::vector all_outputs_; diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index 612e7565..7fcdec5d 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -28,20 +28,30 @@ constexpr int kDefaultQueueSize = 16; constexpr int kDataInputIndex = 0; } -SubgraphExecutor::SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape) +SubgraphExecutor::SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape, + ThreadPool *pre_run_pool) : graph_item_(graph_item), context_(context), force_infer_shape_(force_infer_shape), - pre_run_pool_(kDefaultThreadNum), + pre_run_pool_(pre_run_pool), + own_thread_pool_(false), ready_queue_(kDefaultQueueSize) { } SubgraphExecutor::~SubgraphExecutor() { + if (own_thread_pool_ && pre_run_pool_ != nullptr) { + delete pre_run_pool_; + } GELOGD("[%s] SubgraphExecutor destroyed.", graph_item_->GetName().c_str()); } Status SubgraphExecutor::Init(const std::vector &inputs, const std::vector &input_desc) { + if (pre_run_pool_ == nullptr) { + pre_run_pool_ = new (std::nothrow) ThreadPool(kDefaultThreadNum); + GE_CHECK_NOTNULL(pre_run_pool_); + own_thread_pool_ = true; + } subgraph_context_.reset(new(std::nothrow)SubgraphContext(graph_item_, context_)); GE_CHECK_NOTNULL(subgraph_context_); GE_CHK_STATUS_RET(subgraph_context_->Init(), @@ -103,6 +113,13 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vectorGetOrCreateNodeState(input_node); GE_CHECK_NOTNULL(node_state); node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); + auto op_desc = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex); + GE_CHECK_NOTNULL(output_desc); + output_desc->SetShape(tensor_desc->GetShape()); + output_desc->SetOriginShape(tensor_desc->GetOriginShape()); + node_state->SetSkipInferShape(true); } } @@ -175,16 +192,12 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vectorSetKernelTask(node_item->kernel_task); - known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); - GE_CHECK_NOTNULL(known_shape_task_context_); - node_state->SetTaskContext(known_shape_task_context_); - std::function callback; GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); - HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback), + HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, node_state->GetTaskContext(), *context_, callback), "[%s] Failed to execute node [%s] for known subgraph.", graph_item_->GetName().c_str(), - known_shape_task_context_->GetNodeName()); + node_state->GetName().c_str()); GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); return SUCCESS; @@ -251,7 +264,8 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { // only do shape inference and compilation for nodes with dynamic shapes. if (node_item.is_dynamic) { - auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status { + GE_CHECK_NOTNULL(pre_run_pool_); + auto prepare_future = pre_run_pool_->commit([this, p_node_state]() -> Status { GetContext().SetSessionId(context_->session_id); GetContext().SetContextId(context_->context_id); GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); @@ -271,16 +285,12 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { } else { node_state->SetKernelTask(node_item.kernel_task); } - auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); - GE_CHECK_NOTNULL(unique_task_context); const auto &task = node_state->GetKernelTask(); if (task == nullptr) { GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); return INTERNAL_ERROR; } - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); return AfterPrepared(p_node_state); } @@ -350,7 +360,8 @@ Status SubgraphExecutor::NodeScheduled(NodeState *node_state) { node_state->GetNodeItem()->data_send_.size(), node_state->GetNodeItem()->ctrl_send_.size(), node_state->GetSwitchIndex(), node_state->GetMergeIndex()); - auto future = pre_run_pool_.commit([this, node_state]() -> Status { + GE_CHECK_NOTNULL(pre_run_pool_); + auto future = pre_run_pool_->commit([this, node_state]() -> Status { RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] Start"); std::function callback = [&](const NodeItem *node_item) { const auto &node_name = node_item->node_name; @@ -480,19 +491,15 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta } else { node_state.SetKernelTask(node_item.kernel_task); } - auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get()); - GE_CHECK_NOTNULL(unique_task_context); const auto &task = node_state.GetKernelTask(); if (task == nullptr) { GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str()); REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str()); return INTERNAL_ERROR; } - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state.SetTaskContext(shared_task_context); GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start"); - GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws + GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*node_state.GetTaskContext())); // update op_desc before alloc ws RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end"); return SUCCESS; } diff --git a/ge/hybrid/executor/subgraph_executor.h b/ge/hybrid/executor/subgraph_executor.h index 758bf426..be11ff59 100644 --- a/ge/hybrid/executor/subgraph_executor.h +++ b/ge/hybrid/executor/subgraph_executor.h @@ -33,7 +33,8 @@ namespace hybrid { // Executor for executing a subgraph class SubgraphExecutor { public: - SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false); + SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false, + ThreadPool *pre_run_pool = nullptr); ~SubgraphExecutor(); Status InitForPartialExecution(const std::vector &inputs, @@ -41,6 +42,8 @@ class SubgraphExecutor { Status PartialExecuteAsync(int task_group); + void ReleaseContext() { subgraph_context_.reset(nullptr); } + /** * Execute subgraph async, output tensor address(not data) and output tensor descriptions are * valid after this method returned @@ -122,10 +125,10 @@ class SubgraphExecutor { GraphExecutionContext *context_; std::unique_ptr subgraph_context_; bool force_infer_shape_; - ThreadPool pre_run_pool_; + ThreadPool *pre_run_pool_; + bool own_thread_pool_; BlockingQueue ready_queue_; std::unique_ptr shape_inference_engine_; - std::shared_ptr known_shape_task_context_; std::mutex mu_; // Guard for prepare_queues_. std::map> prepare_queues_; diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc index 8eecbc80..4bd02193 100755 --- a/ge/hybrid/executor/worker/execution_engine.cc +++ b/ge/hybrid/executor/worker/execution_engine.cc @@ -373,6 +373,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, auto executor = node_item.node_executor; GE_CHECK_NOTNULL(executor); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); + node_state.UpdatePersistTensor(); GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", node_state.GetName().c_str()); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); @@ -427,7 +428,7 @@ Status ExecutionEngine::ValidateInputTensors(const NodeState &node_state, const continue; } - int64_t expected_size; + int64_t expected_size = 0; (void)TensorUtils::GetSize(*tensor_desc, expected_size); GELOGD("[%s] Input[%d] expects [%ld] bytes.", task_context.GetNodeName(), i, expected_size); auto size_diff = expected_size - static_cast(input_tensor->GetSize()); diff --git a/ge/hybrid/executor/worker/shape_inference_engine.cc b/ge/hybrid/executor/worker/shape_inference_engine.cc index a2efbb25..50dc389c 100755 --- a/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -68,8 +68,9 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { } // Do shape inference + // Skipping infer shape of input node. GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); - { + if (!node_state.MaySkipShapeInference()) { RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), "[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); diff --git a/ge/hybrid/executor/worker/task_compile_engine.cc b/ge/hybrid/executor/worker/task_compile_engine.cc index f7da9acd..491e0997 100755 --- a/ge/hybrid/executor/worker/task_compile_engine.cc +++ b/ge/hybrid/executor/worker/task_compile_engine.cc @@ -21,10 +21,17 @@ namespace ge { namespace hybrid { Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { - const auto &node_item = *node_state.GetNodeItem(); GE_CHECK_NOTNULL(context); + rtContext_t rt_gen_context = nullptr; + GE_CHK_RT_RET(rtCtxCreate(&rt_gen_context, RT_CTX_GEN_MODE, 0)); + std::function callback = [&]() { + (void) rtCtxDestroy(rt_gen_context); + GE_CHK_RT(rtCtxSetCurrent(context->rt_context)); + }; + GE_MAKE_GUARD(rt_gen_context, callback); + + const auto &node_item = *node_state.GetNodeItem(); RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); - GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); if (context->ge_context != nullptr) { GetThreadLocalContext() = *context->ge_context; diff --git a/ge/hybrid/hybrid_davinci_model.cc b/ge/hybrid/hybrid_davinci_model.cc index 7368784c..c4500b6d 100755 --- a/ge/hybrid/hybrid_davinci_model.cc +++ b/ge/hybrid/hybrid_davinci_model.cc @@ -15,7 +15,7 @@ */ #include -#include "hybrid_davinci_model.h" +#include "hybrid/hybrid_davinci_model.h" #include "hybrid/model/hybrid_model.h" #include "hybrid/executor/hybrid_model_async_executor.h" #include "hybrid/node_executor/node_executor.h" diff --git a/ge/hybrid/hybrid_davinci_model.h b/ge/hybrid/hybrid_davinci_model.h index 34503b01..abab74f6 100644 --- a/ge/hybrid/hybrid_davinci_model.h +++ b/ge/hybrid/hybrid_davinci_model.h @@ -20,7 +20,7 @@ #include #include "external/ge/ge_api_error_codes.h" #include "graph/load/model_manager/data_inputer.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" namespace ge { namespace hybrid { diff --git a/ge/hybrid/hybrid_davinci_model_stub.cc b/ge/hybrid/hybrid_davinci_model_stub.cc index 67cd29b8..b8a2f242 100644 --- a/ge/hybrid/hybrid_davinci_model_stub.cc +++ b/ge/hybrid/hybrid_davinci_model_stub.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "hybrid_davinci_model.h" +#include "hybrid/hybrid_davinci_model.h" namespace ge { namespace hybrid { diff --git a/ge/hybrid/model/graph_item.cc b/ge/hybrid/model/graph_item.cc index c38e0a0d..ca23108d 100644 --- a/ge/hybrid/model/graph_item.cc +++ b/ge/hybrid/model/graph_item.cc @@ -15,7 +15,7 @@ */ #include "framework/common/util.h" -#include "graph_item.h" +#include "hybrid/model/graph_item.h" namespace ge { namespace hybrid { diff --git a/ge/hybrid/model/hybrid_model.cc b/ge/hybrid/model/hybrid_model.cc index 5e496c3b..e6ddbb8d 100644 --- a/ge/hybrid/model/hybrid_model.cc +++ b/ge/hybrid/model/hybrid_model.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "hybrid_model.h" +#include "hybrid/model/hybrid_model.h" #include #include "graph/debug/ge_attr_define.h" #include "graph/load/model_manager/model_utils.h" @@ -25,7 +25,7 @@ #include "hybrid/common/npu_memory_allocator.h" #include "hybrid/model/hybrid_model_builder.h" #include "hybrid/node_executor/node_executor.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" namespace ge { namespace hybrid { diff --git a/ge/hybrid/model/hybrid_model.h b/ge/hybrid/model/hybrid_model.h index 9821242a..3cb936f6 100644 --- a/ge/hybrid/model/hybrid_model.h +++ b/ge/hybrid/model/hybrid_model.h @@ -27,7 +27,7 @@ #include "hybrid/common/tensor_value.h" #include "hybrid/model/node_item.h" #include "hybrid/model/graph_item.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" namespace ge { namespace hybrid { @@ -147,6 +147,7 @@ class HybridModel { GeRootModelPtr ge_root_model_; std::map input_nodes_; ComputeGraphPtr root_graph_; + ComputeGraphPtr orig_root_graph_; std::map device_variable_nodes_; //lint !e148 std::map host_variable_nodes_; //lint !e148 std::map> variable_tensors_; diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 5337a0cf..44115240 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -17,11 +17,11 @@ #include "hybrid/model/hybrid_model_builder.h" #include #include "common/math/math_util.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "graph/ge_context.h" #include "graph/build/memory/var_mem_assign_util.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/load/model_manager/model_utils.h" #include "graph/load/model_manager/model_manager.h" #include "graph/manager/graph_var_manager.h" @@ -147,6 +147,7 @@ Status HybridModelBuilder::Build() { GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); hybrid_model_.model_name_ = ge_root_model_->GetModelName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); + GE_CHK_STATUS_RET(CopyGraph(), "[Invoke][CopyGraph] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[Invoke][RecoverGraphUnknownFlag] failed, model_name_:[%s]", GetGraphName()); @@ -171,11 +172,12 @@ Status HybridModelBuilder::Build() { Status HybridModelBuilder::BuildForSingleOp() { GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); + hybrid_model_.root_graph_ = ge_root_model_->GetRootGraph(); hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); - const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()]; - GE_CHK_STATUS_RET(IndexTaskDefs(ge_root_model_->GetRootGraph(), ge_model), + const GeModelPtr ge_model = ret[hybrid_model_.root_graph_->GetName()]; + GE_CHK_STATUS_RET(IndexTaskDefs(hybrid_model_.root_graph_, ge_model), "[Invoke][IndexTaskDefs] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(LoadGraph(), "[Invoke][LoadGraph] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(InitWeights(), "[Invoke][InitWeights] failed, model_name_:[%s]", GetGraphName()); @@ -190,6 +192,27 @@ Status HybridModelBuilder::ValidateParams() { return SUCCESS; } +Status HybridModelBuilder::CopyGraph() { + GELOGD("Copy compute graph begin."); + auto root_graph = ge_root_model_->GetRootGraph(); + + std::string new_graph_name = ge_root_model_->GetRootGraph()->GetName(); + ComputeGraphPtr new_root_graph = MakeShared(new_graph_name); + GE_CHECK_NOTNULL(new_root_graph); + int32_t depth = 0; + std::map node_old_2_new; + std::map op_desc_old_2_new; + graphStatus ret = GraphUtils::CopyComputeGraph(root_graph, new_root_graph, node_old_2_new, op_desc_old_2_new, depth); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Copy compute graph failed."); + return GRAPH_FAILED; + } + hybrid_model_.root_graph_ = new_root_graph; + + GELOGD("Copy compute graph[%s] success.", new_graph_name.c_str()); + return SUCCESS; +} + Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { auto op_desc = node->GetOpDesc(); GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), @@ -265,10 +288,6 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n return SUCCESS; } - if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity. - node->GetOpDesc()->SetType(IDENTITY); - } - std::unique_ptr new_node; GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); @@ -569,7 +588,10 @@ Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { auto dst_node = peer_in_data_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(dst_node); - root_nodes.emplace(dst_node); + const auto in_nodes = dst_node->GetInDataNodes(); + if (std::all_of(in_nodes.begin(), in_nodes.end(), [](const NodePtr &n) { return n->GetType() == DATA; })) { + root_nodes.emplace(dst_node); + } GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(out_data_anchor, peer_in_data_anchor)); GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, peer_in_data_anchor)); } @@ -814,12 +836,13 @@ Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, } Status HybridModelBuilder::LoadGraph() { - auto root_graph = ge_root_model_->GetRootGraph(); + auto root_graph = hybrid_model_.root_graph_; if (!GetContext().GetHostExecFlag()) { std::shared_ptr merged_graph; GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), root_graph->GetAllNodesSize()); + hybrid_model_.orig_root_graph_ = root_graph; GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "[Invoke][UnfoldSubgraphs]Failed to unfold subgraphs, model_name_:%s.", GetGraphName()); root_graph = std::move(merged_graph); @@ -877,8 +900,10 @@ Status HybridModelBuilder::LoadGraph() { } for (auto &it : hybrid_model_.known_shape_sub_models_) { auto node_item = MutableNodeItem(it.first); + GE_CHECK_NOTNULL(node_item); AscendString graph_name; GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); + GE_CHECK_NOTNULL(graph_name.GetString()); auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); GE_CHECK_NOTNULL(subgraph); GE_CHK_STATUS_RET(IdentifyVariableOutputs(*node_item, subgraph), @@ -924,7 +949,7 @@ Status HybridModelBuilder::VarNodeToTensor(const NodePtr &var_node, std::unique_ } int64_t var_size = CalcVarSizeInBytes(*tensor_desc); - // var size is only for checking, will not allocate any memory by it + GE_CHECK_GE(var_size, 0); tensor.reset(new(std::nothrow)TensorValue(dev_mem, static_cast(var_size))); GE_CHECK_NOTNULL(tensor); GELOGI("Get var memory addr %p for node %s, size = %ld, mem_type=%u", dev_mem, var_name.c_str(), var_size, mem_type); @@ -946,6 +971,7 @@ Status HybridModelBuilder::HandleDtString(const GeTensor &tensor, void *var_addr auto &mutable_tensor = const_cast(tensor); uint64_t *buff = reinterpret_cast(mutable_tensor.MutableData().data()); + GE_CHECK_NOTNULL(buff); GE_CHK_BOOL_RET_STATUS(ge::CheckInt64Uint32MulOverflow(elem_num, kBytes * kStringHeadElems) == SUCCESS, FAILED, "[Invoke][CheckInt64Uint32MulOverflow] failed because Shape size is invalid."); auto offset = static_cast(elem_num * kBytes * kStringHeadElems); @@ -1023,6 +1049,7 @@ Status HybridModelBuilder::InitConstantOps() { } else { var_tensor.reset(new(std::nothrow)TensorValue(nullptr, 0)); } + GE_CHECK_NOTNULL(var_tensor); } else { GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); @@ -1125,7 +1152,9 @@ Status HybridModelBuilder::InitWeights() { sub_weight_buffer->GetSize()); auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); if (subgraph != ge_root_model_->GetRootGraph()) { - subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); + subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_model.first); + } else { + subgraph = hybrid_model_.root_graph_; } GE_CHECK_NOTNULL(subgraph); hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); @@ -1227,6 +1256,28 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr hybrid_model_.known_shape_sub_models_.emplace(parent_node, ge_model); } + GE_CHK_STATUS_RET_NOLOG(InitHcclExecutorOnDemand(ge_model)); + return SUCCESS; +} + +Status HybridModelBuilder::InitHcclExecutorOnDemand(const GeModelPtr &ge_model) { + if (NodeExecutorManager::GetInstance().IsExecutorInitialized(NodeExecutorManager::ExecutorType::HCCL)) { + return SUCCESS; + } + + // HCCL tasks in known-shaped subgraph which resides in a dynamic root graph + // still depends on the initialization of the HcclExecutor + auto tasks = ge_model->GetModelTaskDefPtr()->task(); + for (int i = 0; i < tasks.size(); ++i) { + const domi::TaskDef &task_def = tasks[i]; + auto task_type = static_cast(task_def.type()); + if (task_type == RT_MODEL_TASK_HCCL) { + const NodeExecutor *unused = nullptr; + GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance() + .GetOrCreateExecutor(NodeExecutorManager::ExecutorType::HCCL, &unused)); + return SUCCESS; + } + } return SUCCESS; } @@ -1282,7 +1333,7 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const } Status HybridModelBuilder::IndexTaskDefs() { - const auto root_graph = ge_root_model_->GetRootGraph(); + const auto &root_graph = hybrid_model_.root_graph_; const auto &root_graph_name = root_graph->GetName(); if (SetOutputNameAttr(*root_graph) != SUCCESS) { GELOGW("Set output name attr failed."); @@ -1316,7 +1367,7 @@ Status HybridModelBuilder::IndexTaskDefs() { Status HybridModelBuilder::IndexSpecialNodes() { GELOGD("Start to index special nodes"); - const auto &root_graph = ge_root_model_->GetRootGraph(); + const auto &root_graph = hybrid_model_.root_graph_; for (auto &node : root_graph->GetAllNodes()) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); @@ -1471,7 +1522,7 @@ Status HybridModelBuilder::InitRuntimeParams() { runtime_param_.session_id = ret ? static_cast(value) : 0; ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value); runtime_param_.logic_var_base = ret ? static_cast(value) : 0; - runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID(); + runtime_param_.graph_id = hybrid_model_.root_graph_->GetGraphID(); value = 0; for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { (void) ge::AttrUtils::GetInt(it.second, ATTR_MODEL_VAR_SIZE, value); @@ -1608,7 +1659,7 @@ Status HybridModelBuilder::TransAllVarData() { } Status HybridModelBuilder::CopyVarData() { - GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(ge_root_model_->GetRootGraph(), + GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(hybrid_model_.root_graph_, runtime_param_.session_id, hybrid_model_.device_id_), "[Invoke][CopyVarData] failed."); @@ -1691,7 +1742,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem } Status HybridModelBuilder::RecoverGraphUnknownFlag() { - const auto &root_graph = ge_root_model_->GetRootGraph(); + const auto &root_graph = hybrid_model_.root_graph_; for (auto &sub_graph : root_graph->GetAllSubgraphs()) { GE_CHECK_NOTNULL(sub_graph); for (const auto &node : sub_graph->GetDirectNode()) { diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 92974441..3592d3d2 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -25,7 +25,7 @@ #include "graph/node.h" #include "hybrid/model/hybrid_model.h" #include "hybrid/model/node_item.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" namespace ge { class VarManager; @@ -56,7 +56,9 @@ class HybridModelBuilder { Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph); Status ValidateParams(); Status LoadGraph(); + Status CopyGraph(); Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); + static Status InitHcclExecutorOnDemand(const GeModelPtr &ge_model); Status LoadTask(NodeItem &node_item); Status LoadTasks(); Status IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph); diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index b339e630..f66d4638 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -14,10 +14,8 @@ * limitations under the License. */ -#include "node_item.h" -#include -#include "common/debug/log.h" -#include "graph/common/omg_util.h" +#include "hybrid/model/node_item.h" + #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "hybrid/executor/worker/shape_inference_engine.h" @@ -26,6 +24,8 @@ namespace ge { namespace hybrid { namespace { +const uint8_t kMaxTransCount = 3; +const uint8_t kTransOpIoSize = 1; const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; const char *const kNodeTypeRetVal = "_RetVal"; const std::set kControlOpTypes{ @@ -41,6 +41,25 @@ const std::set kMergeOpTypes{ MERGE, REFMERGE, STREAMMERGE }; +bool IsEnterFeedNode(NodePtr node) { + // For: Enter -> node + // For: Enter -> Cast -> node + // For: Enter -> TransData -> Cast -> node + for (uint8_t i = 0; i < kMaxTransCount; ++i) { + if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { + GELOGD("Node[%s] is Enter feed node.", node->GetName().c_str()); + return true; + } + + const auto all_nodes = node->GetInDataNodes(); + if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) { + return false; + } + node = all_nodes.at(0); + } + return false; +} + Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { uint32_t parent_index = 0; if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { @@ -98,8 +117,7 @@ Status ParseFusedSubgraph(NodeItem &node_item) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type)); + const std::string node_type = NodeUtils::GetNodeType(node); if (node_type == DATA) { GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); } else if (node_type == kNodeTypeRetVal) { @@ -398,20 +416,21 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { data_send_.emplace(node_item); node_item->data_recv_[this] = anchor_index; if (is_root_node_) { - node_item->root_data_.emplace(this); + auto &data_anchors = node_item->root_data_[this]; + data_anchors.emplace(anchor_index); } // If Enter feed Not Merge, take as root Node. - if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { - node_item->enter_data_.emplace(this); - node_item->enter_inside_.emplace(anchor_index); + if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) { + auto &data_anchors = node_item->enter_data_[this]; + data_anchors.emplace(anchor_index); } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); } void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { if (switch_index < switch_groups_.size()) { - std::vector &switch_group = switch_groups_[switch_index]; - switch_group.emplace_back(node_item); + auto &switch_group = switch_groups_[switch_index]; + switch_group.emplace(node_item); } else { ctrl_send_.insert(node_item); } @@ -421,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { node_item->root_ctrl_.emplace(this); } // If Enter feed control signal, take as root Node. - if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { + if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { node_item->enter_ctrl_.emplace(this); } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); @@ -434,8 +453,8 @@ void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { } // this is StreamMerge node, node_item is StreamActive node. - std::vector &switch_group = switch_groups_[merge_index]; - switch_group.emplace_back(node_item); + auto &switch_group = switch_groups_[merge_index]; + switch_group.emplace(node_item); node_item->ctrl_send_.emplace(this); GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 8de15952..f6dcdcf6 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -148,15 +148,14 @@ struct NodeItem { int64_t frame_index_ = -1; int64_t parent_frame_ = -1; std::set root_ctrl_; // Recv ctrl from root node - std::set root_data_; // Recv data from root node + std::map> root_data_; // Recv data from root node std::set enter_ctrl_; // Recv ctrl from Enter node - std::set enter_data_; // Recv data from Enter node + std::map> enter_data_; // Recv data from Enter node std::set data_send_; // Send data notify to std::map data_recv_; // Recv data notify from std::set ctrl_send_; // Send ctrl notify to std::set ctrl_recv_; // Recv ctrl notify from - std::vector> switch_groups_; // Send ctrl notify to - std::set enter_inside_; // Enter feed loop inside Node, Not cross Merge. + std::vector> switch_groups_; // Send ctrl notify to std::shared_ptr kernel_task; std::unique_ptr fused_subgraph; diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc index 7ebb9e39..6ed6866c 100755 --- a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc +++ b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc @@ -14,10 +14,11 @@ * limitations under the License. */ -#include "aicore_node_executor.h" +#include "hybrid/node_executor/aicore/aicore_node_executor.h" #include "framework/common/taskdown_common.h" #include "hybrid/executor/hybrid_execution_context.h" #include "external/runtime/rt_error_codes.h" +#include "single_op/task/build_task_utils.h" namespace ge { namespace hybrid { @@ -196,6 +197,11 @@ Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] Start"); GE_CHK_STATUS_RET_NOLOG((*it)->LaunchKernel(context.GetStream())); GE_CHK_STATUS_RET_NOLOG(CheckOverflow(context)); + GE_CHECK_NOTNULL(context.GetExecutionContext()->model); + GELOGD("[DEBUG_TASK_INFO : Executor Task] %s/%s %s", + context.GetExecutionContext()->model->GetModelName().c_str(), + (*it)->GetName().empty() ? (*it)->GetLogName().c_str() : (*it)->GetName().c_str(), + BuildTaskUtils::GetTaskInfo(context).c_str()); // save profiling data uint32_t task_id = 0; uint32_t stream_id = 0; @@ -208,7 +214,7 @@ Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function context.SetTaskId(task_id); context.SetStreamId(stream_id); GELOGD("Aicore node[%s] task_id: %u, stream_id: %u.", context.GetNodeName(), task_id, stream_id); - (void)context.SaveProfilingTaskDescInfo(task_id, stream_id, kTaskTypeAicore, (*it)->GetBlockDim()); + (void)context.SaveProfilingTaskDescInfo(task_id, stream_id, kTaskTypeAicore, (*it)->GetBlockDim(), (*it)->GetOpType()); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); } diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/ge/hybrid/node_executor/aicore/aicore_op_task.cc index 8cd24bd1..fe9bba9a 100644 --- a/ge/hybrid/node_executor/aicore/aicore_op_task.cc +++ b/ge/hybrid/node_executor/aicore/aicore_op_task.cc @@ -21,11 +21,11 @@ #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/node_executor/aicore/aicore_task_builder.h" #include "graph/load/model_manager/tbe_handle_store.h" -#include "graph/types.h" +#include "external/graph/types.h" #include "single_op/task/build_task_utils.h" #include "single_op/task/tbe_task_builder.h" -using optiling::OpRunInfo; +using optiling::utils::OpRunInfo; namespace ge { namespace hybrid { @@ -33,6 +33,7 @@ namespace { constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; constexpr char const *kAttrOpParamSize = "op_para_size"; constexpr char const *kAttrAtomicOpParamSize = "atomic_op_para_size"; +const string kAtomicOpType = "DynamicAtomicAddrClean"; std::atomic log_id(0); } // namespace @@ -51,6 +52,7 @@ bool TbeHandleRegistry::AddHandle(std::unique_ptr &&holder) { } Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) { + op_type_ = op_desc.GetType(); log_name_ = op_desc.GetName() + "_tvmbin"; log_id_ = log_id++; auto op_desc_ptr = MakeShared(op_desc); @@ -81,7 +83,7 @@ Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) Status AiCoreOpTask::RegisterTbeHandle(const OpDesc &op_desc) { rtError_t rt_ret = rtQueryFunctionRegistered(stub_name_.c_str()); - if (rt_ret != RT_ERROR_NONE || is_single_op_) { + if (rt_ret != RT_ERROR_NONE) { auto op_desc_ptr = MakeShared(op_desc); GE_CHECK_NOTNULL(op_desc_ptr); auto tbe_kernel = op_desc_ptr->TryGetExtAttr(GetKeyForTbeKernel(), TBEKernelPtr()); @@ -194,7 +196,7 @@ Status AiCoreOpTask::RegisterKernelHandle(const OpDesc &op_desc) { Status AiCoreOpTask::InitWithKernelDef(const OpDesc &op_desc, const domi::TaskDef &task_def) { const domi::KernelDef &kernel_def = task_def.kernel(); const domi::KernelContext &context = kernel_def.context(); - stub_name_ = kernel_def.stub_func(); + stub_name_ = is_single_op_ ? to_string(log_id_) + kernel_def.stub_func() : kernel_def.stub_func(); GE_CHK_STATUS_RET(RegisterTbeHandle(op_desc)); GE_CHK_RT_RET(rtGetFunctionByName(stub_name_.c_str(), &stub_func_)); args_size_ = kernel_def.args_size(); @@ -359,9 +361,7 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { GE_CHECK_NOTNULL(op_desc); GELOGD("[%s] Start to update tiling info for task: [%s]", node->GetName().c_str(), stub_name_.c_str()); - OpRunInfo tiling_info; - tiling_info.block_dim = -1; // codex: Using uninitialized value - tiling_info.clear_atomic = true; + OpRunInfo tiling_info(-1, true, 0); auto execution_context = context.GetExecutionContext(); @@ -370,12 +370,11 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] End"); // update op args by tiling info - block_dim_ = static_cast(tiling_info.block_dim); - op_desc->SetWorkspaceBytes(tiling_info.workspaces); - clear_atomic_ = tiling_info.clear_atomic; + block_dim_ = tiling_info.GetBlockDim(); + clear_atomic_ = tiling_info.GetClearAtomic(); - tiling_data_ = tiling_info.tiling_data.str(); - tiling_key_ = tiling_info.tiling_key; + tiling_data_ = tiling_info.GetAllTilingData().str(); + tiling_key_ = tiling_info.GetTilingKey(); GELOGD("Successfully getting [tiling_key] : %u", tiling_key_); if (tiling_data_.empty()) { GELOGD("[%s] Tiling data is empty.", op_desc->GetName().c_str()); @@ -412,9 +411,14 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { Status AiCoreOpTask::CalcTilingInfo(const NodePtr &node, OpRunInfo &tiling_info) { GELOGD("[%s] Start to invoke OpParaCalculate.", node->GetName().c_str()); - GE_CHK_STATUS_RET(OpParaCalculate(*node, tiling_info), + GE_CHK_STATUS_RET(optiling::OpParaCalculateV2(*node, tiling_info), "[Invoke][OpParaCalculate]Failed calc tiling data of node %s.", node->GetName().c_str()); + // Only non atomic task need update workspace + auto op_desc = node->GetOpDesc(); + std::vector workspaces; + tiling_info.GetAllWorkspaces(workspaces); + op_desc->SetWorkspaceBytes(workspaces); GELOGD("[%s] Done invoking OpParaCalculate successfully.", node->GetName().c_str()); return SUCCESS; } @@ -538,6 +542,10 @@ const std::string &AiCoreOpTask::GetName() const { return stub_name_; } +const std::string &AiCoreOpTask::GetOpType() const { + return op_type_; +} + std::string AiCoreOpTask::GetKeyForOpParamSize() const { return kAttrOpParamSize; } @@ -631,9 +639,13 @@ std::string AtomicAddrCleanOpTask::GetKeyForKernelName(const OpDesc &op_desc) co return op_desc.GetName() + "_atomic_kernelname"; } +const std::string &AtomicAddrCleanOpTask::GetOpType() const { + return kAtomicOpType; +} + Status AtomicAddrCleanOpTask::CalcTilingInfo(const NodePtr &node, OpRunInfo &tiling_info) { GELOGD("[%s] Start to invoke OpAtomicCalculate.", node->GetName().c_str()); - GE_CHK_STATUS_RET(OpAtomicCalculate(*node, tiling_info), + GE_CHK_STATUS_RET(optiling::OpAtomicCalculateV2(*node, tiling_info), "[Invoke][OpAtomicCalculate]Failed calc tiling data of node %s.", node->GetName().c_str()); GELOGD("[%s] Done invoking OpAtomicCalculate successfully.", node->GetName().c_str()); diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.h b/ge/hybrid/node_executor/aicore/aicore_op_task.h index 8d7b7f1e..21a947f2 100755 --- a/ge/hybrid/node_executor/aicore/aicore_op_task.h +++ b/ge/hybrid/node_executor/aicore/aicore_op_task.h @@ -19,7 +19,7 @@ #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "runtime/stream.h" #include "hybrid/common/tensor_value.h" #include "hybrid/node_executor/task_context.h" @@ -72,12 +72,16 @@ class AiCoreOpTask { const std::string& GetName() const; + const std::string& GetLogName() const {return log_name_;} + bool GetClearAtomic() const {return clear_atomic_;} uint32_t GetBlockDim() const {return block_dim_;} void SetSingleOp(bool is_single_op) {is_single_op_ = is_single_op;}; + virtual const std::string& GetOpType() const; + protected: Status UpdateTilingInfo(TaskContext &context); virtual std::string GetKeyForOpParamSize() const; @@ -85,7 +89,7 @@ class AiCoreOpTask { virtual std::string GetKeyForTvmMagic() const; virtual std::string GetKeyForTvmMetaData() const; virtual std::string GetKeyForKernelName(const OpDesc &op_desc) const; - virtual Status CalcTilingInfo(const NodePtr &node, optiling::OpRunInfo &tiling_info); + virtual Status CalcTilingInfo(const NodePtr &node, optiling::utils::OpRunInfo &tiling_info); std::unique_ptr tiling_buffer_ = nullptr; std::string tiling_data_; @@ -117,12 +121,14 @@ class AiCoreOpTask { uint64_t log_id_ = 0; std::string log_name_; uint32_t offset_ = 0; + std::string op_type_; }; class AtomicAddrCleanOpTask : public AiCoreOpTask { public: Status Init(const OpDesc &op_desc, const domi::TaskDef &task_def) override; Status UpdateArgs(TaskContext &task_context) override; + const std::string& GetOpType() const override; protected: std::string GetKeyForOpParamSize() const override; @@ -130,7 +136,7 @@ class AtomicAddrCleanOpTask : public AiCoreOpTask { std::string GetKeyForTvmMagic() const override; std::string GetKeyForTvmMetaData() const override; std::string GetKeyForKernelName(const OpDesc &op_desc) const override; - Status CalcTilingInfo(const NodePtr &node, optiling::OpRunInfo &tiling_info) override; + Status CalcTilingInfo(const NodePtr &node, optiling::utils::OpRunInfo &tiling_info) override; private: Status InitAtomicAddrCleanIndices(const OpDesc &op_desc); diff --git a/ge/hybrid/node_executor/aicore/aicore_task_builder.cc b/ge/hybrid/node_executor/aicore/aicore_task_builder.cc index 114451b3..0ba71fe4 100755 --- a/ge/hybrid/node_executor/aicore/aicore_task_builder.cc +++ b/ge/hybrid/node_executor/aicore/aicore_task_builder.cc @@ -14,9 +14,9 @@ * limitations under the License. */ -#include "aicore_task_builder.h" -#include "common/debug/log.h" -#include "aicore_node_executor.h" +#include "hybrid/node_executor/aicore/aicore_task_builder.h" +#include "framework/common/debug/log.h" +#include "hybrid/node_executor/aicore/aicore_node_executor.h" namespace ge { namespace hybrid { diff --git a/ge/hybrid/node_executor/aicore/aicore_task_builder.h b/ge/hybrid/node_executor/aicore/aicore_task_builder.h index 6a472a21..e57538ba 100755 --- a/ge/hybrid/node_executor/aicore/aicore_task_builder.h +++ b/ge/hybrid/node_executor/aicore/aicore_task_builder.h @@ -19,7 +19,7 @@ #include #include -#include "aicore_op_task.h" +#include "hybrid/node_executor/aicore/aicore_op_task.h" #include "framework/common/debug/ge_log.h" #include "graph/utils/attr_utils.h" #include "graph/op_kernel_bin.h" diff --git a/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc b/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc index 0cdea5d5..43563bb6 100755 --- a/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc +++ b/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "aicore_task_compiler.h" +#include "hybrid/node_executor/aicore/aicore_task_compiler.h" #include "framework/common/debug/log.h" #include "graph/debug/ge_attr_define.h" #include "opskernel_manager/ops_kernel_builder_manager.h" diff --git a/ge/hybrid/node_executor/aicore/aicore_task_compiler.h b/ge/hybrid/node_executor/aicore/aicore_task_compiler.h index 4cb4dc58..2778aeb0 100755 --- a/ge/hybrid/node_executor/aicore/aicore_task_compiler.h +++ b/ge/hybrid/node_executor/aicore/aicore_task_compiler.h @@ -19,7 +19,7 @@ #include #include "opskernel_manager/ops_kernel_manager.h" -#include "aicore_node_executor.h" +#include "hybrid/node_executor/aicore/aicore_node_executor.h" namespace ge { namespace hybrid { diff --git a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc index c607a43e..6e8841b9 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc +++ b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc @@ -81,6 +81,9 @@ Status AicpuExtInfoHandler::Parse(const std::string &ext_info) { case aicpu::FWKAdapter::FWK_ADPT_EXT_TOPIC_TYPE: GE_CHK_STATUS_RET(ParseExtTopicType(aicpu_ext_info), "[Parse][ExtTopicType] failed."); break; + case aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT: + GE_CHK_STATUS_RET(ParseExtAsyncWait(aicpu_ext_info), "[Parse][ExtAsyncWait] failed."); + break; default: GELOGD("Node[%s] ignore infoType=%d, infoLen=%u.", node_name_.c_str(), aicpu_ext_info->infoType, aicpu_ext_info->infoLen); @@ -101,6 +104,22 @@ Status AicpuExtInfoHandler::Parse(const std::string &ext_info) { return SUCCESS; } +Status AicpuExtInfoHandler::ParseExtAsyncWait(AicpuExtInfo *aicpu_ext_info) { + if (aicpu_ext_info->infoLen != sizeof(AsyncWaitInfo)) { + REPORT_INNER_ERROR("E19999", + "Node[%s] parse ext async wait info failed as infoLen must be %zu but %u.", + node_name_.c_str(), sizeof(AsyncWaitInfo), aicpu_ext_info->infoLen); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, + "[Check][DataLen]Node[%s] parse ext async wait info failed as infoLen must be %zu but %u.", + node_name_.c_str(), sizeof(AsyncWaitInfo), aicpu_ext_info->infoLen); + return ACL_ERROR_GE_PARAM_INVALID; + } + + async_wait_ = reinterpret_cast(aicpu_ext_info->infoMsg); + GELOGI("Node[%s] parse async wait info success infoLen=%u.", node_name_.c_str(), aicpu_ext_info->infoLen); + return SUCCESS; +} + Status AicpuExtInfoHandler::ParseExtShapeType(AicpuExtInfo *aicpu_ext_info) { GE_IF_BOOL_EXEC(aicpu_ext_info->infoLen != sizeof(int32_t), REPORT_INNER_ERROR("E19999", "Node[%s] parse ext shape type failed as infoLen must be %zu but %u.", @@ -280,6 +299,17 @@ Status AicpuExtInfoHandler::UpdateSessionInfo(uint64_t session_id, uint64_t kern return SUCCESS; } +Status AicpuExtInfoHandler::UpdateEventId(uint32_t event_id) { + if (async_wait_ == nullptr) { + REPORT_INNER_ERROR("E19999", "async_wait_ is nullptr."); + GELOGE(FAILED, "[Check][async_wait_] async_wait_ is nullptr."); + return FAILED; + } + async_wait_->waitType = 1; + async_wait_->waitId = event_id; + return SUCCESS; +} + Status AicpuExtInfoHandler::UpdateSessionInfoSessionId(uint64_t session_id) { if (session_info_ == nullptr) { GELOGD("There is no session info in ext_info, no need update."); diff --git a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h index 46fb7c05..80e3bb92 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h +++ b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h @@ -27,6 +27,7 @@ namespace ge { namespace hybrid { using AicpuShapeAndType = aicpu::FWKAdapter::ShapeAndType; using AicpuExtInfo = aicpu::FWKAdapter::ExtInfo; +using AsyncWaitInfo = aicpu::FWKAdapter::AsyncWait; using AicpuSessionInfo = SessionInfo; class AicpuExtInfoHandler { @@ -59,6 +60,8 @@ class AicpuExtInfoHandler { Status UpdateExecuteMode(bool flag); + Status UpdateEventId(uint32_t event_id); + Status GetOutputShapeAndType(uint32_t output_index, GeShape &shape, DataType &data_type); bool IsNeedRefreshIOAddr(); @@ -73,6 +76,7 @@ class AicpuExtInfoHandler { Status ParseExtBitMap(AicpuExtInfo *aicpu_ext_info); Status ParseExtUpdateAddr(AicpuExtInfo *aicpu_ext_info); Status ParseExtTopicType(AicpuExtInfo *aicpu_ext_info); + Status ParseExtAsyncWait(AicpuExtInfo *aicpu_ext_info); static Status UpdateShapeAndType(const GeShape &shape, DataType data_type, @@ -90,6 +94,7 @@ class AicpuExtInfoHandler { const uint32_t output_num_; UnknowShapeOpType unknown_type_; AicpuSessionInfo *session_info_ = nullptr; + AsyncWaitInfo *async_wait_ = nullptr; uint64_t *bit_map_ = nullptr; uint32_t *update_addr_ = nullptr; int32_t topic_type_flag_ = -1; diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc index c2ebf654..f309ebd0 100755 --- a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc +++ b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc @@ -22,6 +22,7 @@ #include "graph/utils/node_utils.h" #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/model/hybrid_model.h" +#include "runtime/rt.h" namespace ge { namespace hybrid { @@ -33,6 +34,12 @@ const char *const kAicpuAllshape = "_AllShape"; REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_TF, AiCpuNodeExecutor); REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_CUSTOM, AiCpuNodeExecutor); +AicpuNodeTaskBase::~AicpuNodeTaskBase() { + if (rt_event_ != nullptr) { + (void)rtEventDestroy(rt_event_); + } +} + Status AicpuNodeTaskBase::AllocTensorBuffer(size_t size, std::unique_ptr &tensor_buffer) { auto allocator = NpuMemoryAllocator::GetAllocator(); GE_CHECK_NOTNULL(allocator); @@ -64,9 +71,12 @@ Status AicpuNodeTaskBase::InitExtInfo(const std::string &kernel_ext_info, int64_ GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateSessionInfoSessionId(session_id), "[Update][SessionInfoSessionId] failed, session_id:%ld.", session_id); - bool execute_mode = !aicpu_ext_handle_.IsNeedRefreshIOAddr() && !node_item_->is_dynamic; - GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateExecuteMode(execute_mode), - "[Update][ExecuteMode] failed, node:%s.", node_name_.c_str()); + if (is_blocking_aicpu_op_) { + if (UpdateEventIdForBlockingAicpuOp() != SUCCESS) { + GELOGE(FAILED, "[Call][UpdateEventIdForBlockingAicpuOp] Call UpdateEventIdForBlockingAicpuOp failed"); + return FAILED; + } + } // copy task args buf GE_CHK_STATUS_RET(AllocTensorBuffer(aicpu_ext_handle_.GetExtInfoLen(), ext_info_addr_dev_), @@ -211,7 +221,7 @@ Status AicpuNodeTaskBase::ExecuteAsync(TaskContext &context, std::functionnum_outputs == 0)) { GELOGD("Node[%s] type[%s] unknown_type is %d, output num is %d.", @@ -329,6 +429,9 @@ Status AicpuTfNodeTask::Init(const HybridModel &model) { // init ext info uint64_t ext_session_id = model.GetSessionId(); + const OpDescPtr op_desc = node_item_->GetOpDesc(); + AttrUtils::GetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, is_blocking_aicpu_op_); + GELOGD("Get op:%s attribute(is_blocking_op), value:%d", op_desc->GetName().c_str(), is_blocking_aicpu_op_); GE_CHK_STATUS_RET(InitExtInfo(kernel_ext_info, ext_session_id), "[Init][ExtInfo] failed for Node[%s].", node_name_.c_str()); GE_CHK_STATUS_RET(InitForDependComputeTask(), "[Init][DependComputeTask] failed for Node[%s].", node_name_.c_str()); @@ -481,7 +584,7 @@ Status AicpuTfNodeTask::CopyDataToHbm(TaskContext &context, GE_CHK_STATUS_RET_NOLOG(PrepareCopyInputs(context, out_shape_hbm)); RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[LaunchCopy] Start"); - GE_CHK_RT_RET(rtKernelLaunchEx(copy_task_args_buf_->GetData(), sizeof(STR_FWK_OP_KERNEL), + GE_CHK_RT_RET(rtKernelLaunchFwk(node_name_.c_str(), copy_task_args_buf_->GetData(), sizeof(STR_FWK_OP_KERNEL), RT_KERNEL_DEFAULT, context.GetStream())); RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[LaunchCopy] End"); @@ -642,9 +745,16 @@ Status AicpuTfNodeTask::LaunchTask(TaskContext &context) { GELOGD("Node[%s] launch task start, unknown_type=%d.", node_name_.c_str(), unknown_type_); uint32_t flag = RT_KERNEL_DEFAULT; RECORD_EXECUTION_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[AicpuTfNodertKernelLaunchEx] Start"); - GE_CHK_RT_RET(rtKernelLaunchEx(kernel_buf_->GetData(), kernel_buf_->GetSize(), flag, context.GetStream())); + GE_CHK_RT_RET(rtKernelLaunchFwk(node_name_.c_str(), kernel_buf_->GetData(), + kernel_buf_->GetSize(), flag, context.GetStream())); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[AicpuTfNodertKernelLaunchEx] End"); GELOGD("Node[%s] launch end.", node_name_.c_str()); + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp(context.GetStream()) != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } if (need_sync_) { GELOGD("[%s] Task needs sync", node_name_.c_str()); GE_CHK_STATUS_RET_NOLOG(context.Synchronize()); @@ -763,6 +873,8 @@ Status AicpuNodeTask::Init(const HybridModel &model) { return FAILED;); uint64_t ext_session_id = model.GetSessionId(); + AttrUtils::GetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, is_blocking_aicpu_op_); + GELOGD("Get op:%s attribute(is_blocking_op), value:%d", op_desc->GetName().c_str(), is_blocking_aicpu_op_); GE_CHK_STATUS_RET(InitExtInfo(kernel_ext_info, ext_session_id), "[Init][ExtInfo] failed for Node[%s].", node_name.c_str()); @@ -823,12 +935,18 @@ Status AicpuNodeTask::LaunchTask(TaskContext &context) { if (kernel_type == ccKernelType::CUST_AI_CPU) { flag |= static_cast(RT_KERNEL_CUSTOM_AICPU); } - auto rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name.c_str()), - reinterpret_cast(kernel_name.c_str()), - 1, // default core dim is 1 - args_.get(), args_size_, - nullptr, context.GetStream(), flag); + rtKernelLaunchNames_t launch_name = {so_name.c_str(), kernel_name.c_str(), node_name_.c_str()}; + auto rt_ret = rtAicpuKernelLaunchWithFlag(&launch_name, + 1, // default core dim is 1 + args_.get(), args_size_, + nullptr, context.GetStream(), flag); GE_CHK_RT_RET(rt_ret); + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp(context.GetStream()) != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } GELOGD("Node[%s] launch task end.", node_name_.c_str()); return SUCCESS; } diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h index 7577d486..3911e090 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h +++ b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h @@ -20,7 +20,7 @@ #include "external/graph/types.h" #include "cce/aicpu_engine_struct.h" #include "hybrid/node_executor/node_executor.h" -#include "aicpu_ext_info.h" +#include "hybrid/node_executor/aicpu/aicpu_ext_info.h" namespace ge { namespace hybrid { @@ -35,7 +35,7 @@ class AicpuNodeTaskBase : public NodeTask { node_item->num_outputs, node_item->shape_inference_type) {} - ~AicpuNodeTaskBase() override = default; + ~AicpuNodeTaskBase() override; using NodeTask::Init; @@ -61,6 +61,10 @@ class AicpuNodeTaskBase : public NodeTask { static Status AllocTensorBuffer(size_t size, std::unique_ptr &tensor_buffer); + Status DistributeWaitTaskForAicpuBlockingOp(rtStream_t stream); + Status CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support); + Status UpdateEventIdForBlockingAicpuOp(); + protected: const NodeItem *node_item_; // just reference. @@ -78,6 +82,10 @@ class AicpuNodeTaskBase : public NodeTask { // ext info addr, device mem std::unique_ptr ext_info_addr_dev_; + + // for blocking aicpu op + bool is_blocking_aicpu_op_ = false; + rtEvent_t rt_event_ = nullptr; }; class AicpuTfNodeTask : public AicpuNodeTaskBase { diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc index 8b839849..8b3c691f 100755 --- a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc +++ b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc @@ -20,7 +20,7 @@ #include "framework/common/fmk_error_codes.h" #include "common/dump/dump_manager.h" #include "common/ge/ge_util.h" -#include "graph/attr_value.h" +#include "external/graph/attr_value.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/load/model_manager/model_utils.h" @@ -136,8 +136,7 @@ Status KnownNodeTask::Init(TaskContext &context) { Status KnownNodeTask::InitDavinciModel(const HybridModel &model, TensorBuffer *weight_buffer) { GELOGD("[Init][DavinciModel] start"); davinci_model_->InitRuntimeParams(); - GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), - "[Init][VariableMem] failed"); + GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "[Init][VariableMem] failed"); int32_t device_id = 0; GE_CHK_RT_RET(rtGetDevice(&device_id)); davinci_model_->SetDeviceId(static_cast(device_id)); @@ -145,8 +144,6 @@ Status KnownNodeTask::InitDavinciModel(const HybridModel &model, TensorBuffer *w auto dump_properties = DumpManager::GetInstance().GetDumpProperties(model.GetSessionId()); if (dump_properties.IsDumpOpen() || dump_properties.IsOpDebugOpen()) { davinci_model_->SetDumpProperties(dump_properties); - void *global_step = model.GetGlobalStep(); - davinci_model_->SetKnownShapeGlobalStep(global_step); } void *weight = nullptr; @@ -182,6 +179,21 @@ Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) cons return SUCCESS; } +Status KnownNodeExecutor::SetDaviciModel(const HybridModel &model, const NodePtr &node, + std::shared_ptr &davinci_model) const { + // set known node flag as true + davinci_model->SetKnownNode(true); + davinci_model->SetId(model.GetModelId()); + davinci_model->SetDumpModelName(model.GetModelName()); + davinci_model->SetOmName(model.GetOmName()); + void *global_step = model.GetGlobalStep(); + GE_CHECK_NOTNULL(global_step); + davinci_model->SetGlobalStep(global_step, sizeof(int64_t)); + // set model id as root node's node id + davinci_model->SetSubModelId(node->GetOpDesc()->GetId()); + return SUCCESS; +} + Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { GELOGI("[%s] KnownNodeExecutor::LoadTask in.", node->GetName().c_str()); @@ -199,13 +211,7 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node std::shared_ptr davinci_model = MakeShared(0, nullptr); GE_CHECK_NOTNULL(davinci_model); - // set known node flag as true - davinci_model->SetKnownNode(true); - davinci_model->SetId(model.GetModelId()); - davinci_model->SetDumpModelName(model.GetModelName()); - davinci_model->SetOmName(model.GetOmName()); - // set model id as root node's node id - davinci_model->SetSubModelId(node->GetOpDesc()->GetId()); + GE_CHK_STATUS_RET_NOLOG(SetDaviciModel(model, node, davinci_model)); GELOGD("KnownNodeExecutor::LoadTask node id %ld.", node->GetOpDesc()->GetId()); GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), @@ -241,8 +247,7 @@ Status KnownNodeExecutor::ParseAttrForAllocatingOutputs(NodeItem &node_item, Com GE_CHECK_NOTNULL(net_output_desc); std::map connected_inputs; std::map data_indices; - GE_CHK_STATUS_RET(GetDataNodes(graph, data_indices), - "[%s] Failed to get data node indices", + GE_CHK_STATUS_RET(GetDataNodes(graph, data_indices), "[%s] Failed to get data node indices", node_item.NodeName().c_str()); for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h index 11cda846..37b5a3d8 100644 --- a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h +++ b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h @@ -59,6 +59,8 @@ class KnownNodeExecutor : public NodeExecutor { const NodePtr &node, GeModelPtr &ge_model, ComputeGraphPtr &graph); + Status SetDaviciModel(const HybridModel &model, const NodePtr &node, + std::shared_ptr &davinci_model) const; }; } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/node_executor/controlop/control_op_executor.cc b/ge/hybrid/node_executor/controlop/control_op_executor.cc index d55607ff..fa44d761 100644 --- a/ge/hybrid/node_executor/controlop/control_op_executor.cc +++ b/ge/hybrid/node_executor/controlop/control_op_executor.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "control_op_executor.h" +#include "hybrid/node_executor/controlop/control_op_executor.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" #include "hybrid/executor/hybrid_execution_context.h" diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc index 72092cd8..6be9849c 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc @@ -15,15 +15,17 @@ */ #include "hybrid/node_executor/hccl/hccl_node_executor.h" + #include "common/ge/plugin_manager.h" #include "common/math/math_util.h" -#include "graph/attr_value.h" +#include "external/graph/attr_value.h" +#include "external/graph/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/util/hcom_util.h" #include "graph/utils/type_utils.h" -#include "graph/types.h" -#include "hybrid/executor/hybrid_execution_context.h" #include "hccl/hcom.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "runtime/event.h" namespace ge { namespace { @@ -266,14 +268,16 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector do rtEvent_t evt = nullptr; if (context.GetExecutionContext()->hccl_stream != nullptr) { - GE_CHK_RT_RET(rtEventCreateWithFlag(&evt, 0x01)); + GE_CHK_RT_RET(rtEventCreateWithFlag(&evt, RT_EVENT_WITH_FLAG)); GE_CHK_RT_RET(rtStreamWaitEvent(context.GetExecutionContext()->hccl_stream, evt)); } TaskContext *p_ctx = &context; @@ -341,6 +345,7 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function do GE_CHK_RT_RET(rtEventDestroy(evt)); } GELOGI("rdma callback success."); + return SUCCESS; }; HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); @@ -355,8 +360,8 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function do } Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { - void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, - ¶ms.recvcounts, ¶ms.rdispls}; + void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, ¶ms.recvcounts, + ¶ms.rdispls}; for (size_t i = 0; i < kAllToAllVInputNums; ++i) { auto addr = context.MutableInput(i); GE_CHECK_NOTNULL(addr); @@ -381,13 +386,14 @@ Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { } params.sendtype = iter->second; params.recvtype = iter->second; + params.group = nullptr; return SUCCESS; } Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams ¶ms) { - void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, - ¶ms.recvcounts, ¶ms.rdispls}; + void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, ¶ms.recvcounts, + ¶ms.rdispls}; for (size_t i = 0; i < kGatherAllToAllVInputNums; ++i) { auto addr = context.MutableInput(i); GE_CHECK_NOTNULL(addr); @@ -415,9 +421,10 @@ Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams } params.recvtype = iter->second; - int64_t addr_len; - (void) ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); + int64_t addr_len = 0; + (void)ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); params.addrLength = static_cast(addr_len); + params.group = nullptr; return SUCCESS; } @@ -426,7 +433,7 @@ Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::functionGetNodeName()); p_ctx->SetStatus(FAILED); diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.h b/ge/hybrid/node_executor/hccl/hccl_node_executor.h index b020208d..757f7593 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.h +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.h @@ -62,7 +62,7 @@ class RdmaNodeTask : public NodeTask { int32_t local_index_ = 0; std::mutex hccl_mutex_; std::condition_variable cond_; - bool skip_flag_; + bool skip_flag_ = false; }; diff --git a/ge/hybrid/node_executor/node_executor.cc b/ge/hybrid/node_executor/node_executor.cc index 5f3d6e45..9e9354d9 100755 --- a/ge/hybrid/node_executor/node_executor.cc +++ b/ge/hybrid/node_executor/node_executor.cc @@ -58,8 +58,8 @@ Status NodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, } Status NodeExecutorManager::EnsureInitialized() { - GE_CHK_STATUS_RET(InitializeExecutors()); std::lock_guard lk(mu_); + ++ref_count_; if (initialized_) { return SUCCESS; } @@ -115,17 +115,14 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node return it->second; } -Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) const { +Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) { auto executor_type = ResolveExecutorType(node); + GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast(executor_type)); const auto it = executors_.find(executor_type); if (it == executors_.end()) { - REPORT_INNER_ERROR("E19999", "Failed to get executor by type: %d.", static_cast(executor_type)); - GELOGE(INTERNAL_ERROR, "[Check][ExecutorType]Failed to get executor by type: %d.", - static_cast(executor_type)); - return INTERNAL_ERROR; + return GetOrCreateExecutor(executor_type, executor); } - GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast(executor_type)); *executor = it->second.get(); return SUCCESS; } @@ -178,51 +175,55 @@ Status NodeExecutorManager::CalcOpRunningParam(Node &node) const { return OpsKernelBuilderManager::Instance().CalcOpRunningParam(node); } -Status NodeExecutorManager::InitializeExecutors() { +bool NodeExecutorManager::IsExecutorInitialized(NodeExecutorManager::ExecutorType executor_type) { + std::lock_guard lk(mu_); + return executors_.find(executor_type) != executors_.end(); +} + +Status NodeExecutorManager::GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **out_executor) { std::lock_guard lk(mu_); - if (executor_initialized_) { - ++ref_count_; - GELOGI("Executor is already initialized. add ref count to [%d]", ref_count_); + const auto executor_it = executors_.find(executor_type); + if (executor_it != executors_.end()) { + *out_executor = executor_it->second.get(); return SUCCESS; } - GELOGI("Start to Initialize NodeExecutors"); - for (auto &it : builders_) { - auto engine_type = it.first; - auto build_fn = it.second; - GE_CHECK_NOTNULL(build_fn); - auto executor = std::unique_ptr(build_fn()); - if (executor == nullptr) { - REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for engine type = %d", - static_cast(engine_type)); - GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast(engine_type)); - return INTERNAL_ERROR; - } + GELOGI("Start to Initialize NodeExecutor, type = %d", static_cast(executor_type)); + auto it = builders_.find(executor_type); + if (it == builders_.end()) { + REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d", + static_cast(executor_type)); + GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for executor type = %d", static_cast(executor_type)); + return INTERNAL_ERROR; + } - GELOGD("Executor of engine type = %d was created successfully", static_cast(engine_type)); - auto ret = executor->Initialize(); - if (ret != SUCCESS) { - REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast(engine_type)); - GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast(engine_type)); - for (auto &executor_it : executors_) { - executor_it.second->Finalize(); - } - executors_.clear(); - return ret; - } + auto build_fn = it->second; + GE_CHECK_NOTNULL(build_fn); + auto executor = std::unique_ptr(build_fn()); + if (executor == nullptr) { + REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d", + static_cast(executor_type)); + GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast(executor_type)); + return INTERNAL_ERROR; + } - executors_.emplace(engine_type, std::move(executor)); + GELOGD("Executor of engine type = %d was created successfully", static_cast(executor_type)); + auto ret = executor->Initialize(); + if (ret != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast(executor_type)); + GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast(executor_type)); + return ret; } - ++ref_count_; - executor_initialized_ = true; - GELOGI("Initializing NodeExecutors successfully."); + *out_executor = executor.get(); + executors_.emplace(executor_type, std::move(executor)); + GELOGI("Initializing NodeExecutor successfully, type = %d", static_cast(executor_type)); return SUCCESS; } void NodeExecutorManager::FinalizeExecutors() { std::lock_guard lk(mu_); - if (!executor_initialized_) { + if (ref_count_ <= 0) { GELOGD("No need for finalizing for not initialized."); return; } @@ -237,7 +238,6 @@ void NodeExecutorManager::FinalizeExecutors() { it.second->Finalize(); } executors_.clear(); - executor_initialized_ = false; GELOGD("Done invoking Finalize successfully."); } diff --git a/ge/hybrid/node_executor/node_executor.h b/ge/hybrid/node_executor/node_executor.h index fffd4e7d..0e4a8464 100644 --- a/ge/hybrid/node_executor/node_executor.h +++ b/ge/hybrid/node_executor/node_executor.h @@ -20,7 +20,7 @@ #include "external/ge/ge_api_error_codes.h" #include "common/opskernel/ops_kernel_builder.h" #include "graph/node.h" -#include "task_context.h" +#include "hybrid/node_executor/task_context.h" namespace ge { const uint32_t MEMORY_ALIGN_RATIO = 2; @@ -179,8 +179,6 @@ class NodeExecutorManager { */ Status EnsureInitialized(); - Status InitializeExecutors(); - void FinalizeExecutors(); /** @@ -196,7 +194,7 @@ class NodeExecutorManager { * @param executor executor * @return SUCCESS on success, error code otherwise */ - Status GetExecutor(Node &node, const NodeExecutor **executor) const; + Status GetExecutor(Node &node, const NodeExecutor **executor); /** * Resolve executor type by node @@ -205,13 +203,16 @@ class NodeExecutorManager { */ ExecutorType ResolveExecutorType(Node &node) const; + Status GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **executor); + + bool IsExecutorInitialized(ExecutorType executor_type); + private: std::map> executors_; std::map> builders_; std::map engine_mapping_; std::mutex mu_; bool initialized_ = false; - bool executor_initialized_ = false; int ref_count_ = 0; }; diff --git a/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc b/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc index 28a5dea1..ad1f7e61 100755 --- a/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc +++ b/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "partitioned_call_node_executor.h" +#include "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h" #include "graph/utils/node_utils.h" namespace ge { diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.cc b/ge/hybrid/node_executor/rts/rts_node_executor.cc index 3ad791b6..e3058ee3 100644 --- a/ge/hybrid/node_executor/rts/rts_node_executor.cc +++ b/ge/hybrid/node_executor/rts/rts_node_executor.cc @@ -17,13 +17,9 @@ #include "hybrid/node_executor/rts/rts_node_executor.h" #include "hybrid/node_executor/rts/rts_task_factory.h" -#include "common/debug/log.h" #include "common/ge/ge_util.h" -#include "common/types.h" -#include "graph/common/omg_util.h" #include "graph/utils/tensor_utils.h" #include "hybrid/model/hybrid_model.h" -#include "runtime/rt.h" namespace ge { namespace hybrid { @@ -33,6 +29,7 @@ REGISTER_RTS_TASK_CREATOR(IDENTITY, IdentityNodeTask); REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask); REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask); REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask); +REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, IdentityNodeTask); Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { auto input_desc = context.MutableInputDesc(index); @@ -133,8 +130,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { GE_CHECK_NOTNULL(node); GELOGD("[%s] Load for local task.", node->GetName().c_str()); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); + const std::string node_type = NodeUtils::GetNodeType(node); RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); if (rts_task == nullptr) { GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); diff --git a/ge/hybrid/node_executor/rts/rts_node_task.cc b/ge/hybrid/node_executor/rts/rts_node_task.cc index 104196ee..7b95f98a 100644 --- a/ge/hybrid/node_executor/rts/rts_node_task.cc +++ b/ge/hybrid/node_executor/rts/rts_node_task.cc @@ -22,7 +22,7 @@ #include "graph/utils/type_utils.h" #include "graph/utils/node_utils.h" #include "common/ge/ge_util.h" -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" namespace { constexpr uint8_t kSwitchPredIndex = 0; @@ -43,7 +43,6 @@ namespace hybrid { REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); -REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask); REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); @@ -168,34 +167,6 @@ Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::functio return SUCCESS; } -Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { - GELOGD("[%s] Start to execute.", task_context.GetNodeName()); - auto input_desc = task_context.MutableInputDesc(0); - GE_CHECK_NOTNULL(input_desc); - int64_t copy_size = 0; - GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size)); - // copy_size would not be negative since GetTensorSizeInBytes returned successfully. - if (copy_size > 0) { - const auto in_v = task_context.MutableInput(0); - const auto out_v = task_context.MutableOutput(0); - GE_CHECK_NOTNULL(in_v); - GE_CHECK_NOTNULL(out_v); - GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(), - in_v->GetSize(), out_v->GetSize(), copy_size); - GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size, - RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream())); - } else { - GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size); - } - - if (done_callback) { - GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); - } - - GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); - return SUCCESS; -} - Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { GELOGD("[%s] Start to execute.", task_context.GetNodeName()); const auto in_x = task_context.GetInput(0); // x diff --git a/ge/hybrid/node_executor/rts/rts_node_task.h b/ge/hybrid/node_executor/rts/rts_node_task.h index d7d63eb5..e18f9a8f 100644 --- a/ge/hybrid/node_executor/rts/rts_node_task.h +++ b/ge/hybrid/node_executor/rts/rts_node_task.h @@ -60,11 +60,6 @@ class StreamMergeNodeTask : public RtsNodeTask { Status ExecuteAsync(TaskContext &task_context, std::function done_callback) override; }; -class MemcpyAsyncNodeTask : public RtsNodeTask { - public: - Status ExecuteAsync(TaskContext &task_context, std::function done_callback) override; -}; - class PassThroughNodeTask : public RtsNodeTask { public: Status ExecuteAsync(TaskContext &task_context, std::function done_callback) override; diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index 14eb1222..8bda5084 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "task_context.h" +#include "hybrid/node_executor/task_context.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/log.h" #include "graph/utils/tensor_utils.h" -#include "graph/types.h" +#include "external/graph/types.h" #include "graph/debug/ge_attr_define.h" #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/executor/subgraph_executor.h" @@ -43,6 +43,7 @@ TaskContext::~TaskContext() { output_tensor->Destroy(); } } + ReleaseWorkspace(); } void TaskContext::ReleaseWorkspace() { @@ -50,11 +51,10 @@ void TaskContext::ReleaseWorkspace() { for (auto ws_addr : workspaces_) { execution_context_->allocator->Deallocate(ws_addr); } + workspaces_.clear(); } -std::unique_ptr TaskContext::Create(NodeState *node_state, - GraphExecutionContext *execution_context, - SubgraphContext *subgraph_context) { +std::unique_ptr TaskContext::Create(NodeState *node_state, SubgraphContext *subgraph_context) { const NodeItem &node_item = *node_state->GetNodeItem(); GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", node_item.NodeName().c_str(), @@ -75,7 +75,7 @@ std::unique_ptr TaskContext::Create(NodeState *node_state, } auto task_context = std::unique_ptr( - new(std::nothrow)TaskContext(execution_context, node_state, subgraph_context)); + new(std::nothrow)TaskContext(subgraph_context->execution_context_, node_state, subgraph_context)); if (task_context == nullptr) { REPORT_CALL_ERROR("E19999", "Create TaskContext failed for [%s].", node_item.NodeName().c_str()); GELOGE(MEMALLOC_FAILED, "[Create][TaskContext] failed for [%s].", node_item.NodeName().c_str()); @@ -85,7 +85,7 @@ std::unique_ptr TaskContext::Create(NodeState *node_state, task_context->node_item_ = &node_item; task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start; task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start; - task_context->iteration_ = execution_context->iteration; + task_context->iteration_ = subgraph_context->execution_context_->iteration; return task_context; } @@ -489,13 +489,9 @@ void TaskContext::ReleaseInputsAndOutputs() { } void TaskContext::ReleaseInput(int index) { - if (node_item_->enter_inside_.count(index) > 0) { - GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index); - return; - } - auto input_tensor = MutableInput(index); if (input_tensor != nullptr) { + node_state_->SavePersistTensor(index, *input_tensor); input_tensor->Destroy(); GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); } @@ -575,8 +571,8 @@ Status TaskContext::Synchronize() { return execution_context_->Synchronize(GetStream()); } -Status TaskContext::SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream_id, - const std::string &task_type, uint32_t block_dim) { +Status TaskContext::SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream_id, const std::string &task_type, + uint32_t block_dim, const std::string &op_type) { if (ProfilingManager::Instance().ProfilingModelLoadOn()) { const NodeItem &node_item = GetNodeItem(); auto op_desc = node_item.GetOpDesc(); @@ -590,7 +586,7 @@ Status TaskContext::SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream TaskDescInfo tmp_task_desc_info; tmp_task_desc_info.model_name = dynamic_model_name; tmp_task_desc_info.op_name = op_desc->GetName(); - tmp_task_desc_info.op_type = op_desc->GetType(); + tmp_task_desc_info.op_type = op_type; tmp_task_desc_info.block_dim = block_dim; tmp_task_desc_info.task_type = task_type; tmp_task_desc_info.task_id = task_id; diff --git a/ge/hybrid/node_executor/task_context.h b/ge/hybrid/node_executor/task_context.h index ba4c62e6..5304606b 100644 --- a/ge/hybrid/node_executor/task_context.h +++ b/ge/hybrid/node_executor/task_context.h @@ -36,9 +36,7 @@ class SubgraphContext; class TaskContext { public: - static std::unique_ptr Create(NodeState *node_state, - GraphExecutionContext *execution_context, - SubgraphContext *subgraph_context); + static std::unique_ptr Create(NodeState *node_state, SubgraphContext *subgraph_context); ~TaskContext(); @@ -120,8 +118,8 @@ class TaskContext { void *handle_ = nullptr; const std::vector& GetProfilingTaskDescInfo() const { return task_desc_info; } - Status SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream_id, - const std::string &task_type, uint32_t block_dim); + Status SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream_id, const std::string &task_type, + uint32_t block_dim, const std::string &op_type); void ClearProfilingTaskDescInfo() { task_desc_info.clear(); } private: diff --git a/ge/inc/graph_pass.h b/ge/inc/graph_pass.h index 642b94ea..a9cc7a32 100644 --- a/ge/inc/graph_pass.h +++ b/ge/inc/graph_pass.h @@ -20,9 +20,9 @@ #include #include -#include "common/op/attr_value_util.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" +#include "framework/common/op/attr_value_util.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/common/types.h" #include "framework/common/debug/ge_log.h" #include "graph/compute_graph.h" #include "graph/utils/attr_utils.h" diff --git a/ge/inc/kernel.h b/ge/inc/kernel.h index 84af5234..a83776a9 100644 --- a/ge/inc/kernel.h +++ b/ge/inc/kernel.h @@ -19,9 +19,9 @@ #include -#include "common/op/ge_op_utils.h" +#include "framework/common/op/ge_op_utils.h" #include "graph/compute_graph.h" -#include "graph/graph.h" +#include "external/graph/graph.h" #include "graph/op_desc.h" using std::vector; diff --git a/ge/inc/kernel_factory.h b/ge/inc/kernel_factory.h index 61455836..e532b894 100644 --- a/ge/inc/kernel_factory.h +++ b/ge/inc/kernel_factory.h @@ -24,7 +24,7 @@ #include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" -#include "graph/graph.h" +#include "external/graph/graph.h" using std::string; diff --git a/ge/inc/pass.h b/ge/inc/pass.h index 9f8519e1..56f77fef 100644 --- a/ge/inc/pass.h +++ b/ge/inc/pass.h @@ -19,7 +19,7 @@ #include -#include "common/fmk_error_codes.h" +#include "framework/common/fmk_error_codes.h" namespace ge { /// diff --git a/ge/init/gelib.cc b/ge/init/gelib.cc index 2374e75f..2491715b 100644 --- a/ge/init/gelib.cc +++ b/ge/init/gelib.cc @@ -16,7 +16,6 @@ #include "init/gelib.h" -#include #include #include #include @@ -33,12 +32,11 @@ #include "framework/common/util.h" #include "framework/omg/ge_init.h" #include "analyzer/analyzer.h" -#include "ge/ge_api_types.h" +#include "external/ge/ge_api_types.h" #include "ge_local_engine/engine/host_cpu_engine.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/ge_context.h" #include "graph/ge_global_options.h" -#include "graph/load/model_manager/model_manager.h" #include "graph/manager/graph_mem_manager.h" #include "graph/manager/host_mem_manager.h" #include "graph/manager/graph_var_manager.h" @@ -160,18 +158,6 @@ Status GELib::InnerInitialize(const map &options) { return initOpsBuilderStatus; } - ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther); - GELOGI("sessionManager initial."); - GE_TIMESTAMP_START(SessionManagerInitialize); - Status initSmStatus = sessionManager_.Initialize(options); - GE_TIMESTAMP_END(SessionManagerInitialize, "InnerInitialize::SessionManagerInitialize"); - if (initSmStatus != SUCCESS) { - GELOGE(initSmStatus, "[Init][SessionManager] GE session manager initial failed."); - REPORT_CALL_ERROR("E19999", "SessionManager initialize failed."); - RollbackInit(); - return initSmStatus; - } - GELOGI("Start to initialize HostCpuEngine"); GE_TIMESTAMP_START(HostCpuEngineInitialize); Status initHostCpuEngineStatus = HostCpuEngine::GetInstance().Initialize(); @@ -209,12 +195,6 @@ Status GELib::SystemInitialize(const map &options) { // In train and infer, profiling is always needed. InitProfiling(this->options_); - auto model_manager = ModelManager::GetInstance(); - GE_CHECK_NOTNULL(model_manager); - GE_IF_BOOL_EXEC(model_manager->EnableExceptionDump(options) != SUCCESS, - REPORT_CALL_ERROR("E19999", "ModelManager EnableExceptionDump failed."); - GELOGE(FAILED, "[Enable][ExceptionDump] failed."); - return FAILED); // 1.`is_train_mode_` means case: train // 2.`(!is_train_mode_) && (options_.device_id != kDefaultDeviceIdForInfer)` means case: online infer // these two case with logical device id @@ -454,12 +434,6 @@ Status GELib::Finalize() { GELOGW("engineManager finalize failed"); final_state = mid_state; } - GELOGI("sessionManager finalization."); - mid_state = sessionManager_.Finalize(); - if (mid_state != SUCCESS) { - GELOGW("sessionManager finalize failed"); - final_state = mid_state; - } GELOGI("opsBuilderManager finalization."); mid_state = OpsKernelBuilderManager::Instance().Finalize(); @@ -539,9 +513,6 @@ void GELib::RollbackInit() { if (opsManager_.init_flag_) { (void)opsManager_.Finalize(); } - if (sessionManager_.init_flag_) { - (void)sessionManager_.Finalize(); - } MemManager::Instance().Finalize(); HostMemManager::Instance().Finalize(); VarManagerPool::Instance().Destory(); diff --git a/ge/init/gelib.h b/ge/init/gelib.h index ed6fe5d4..226dd4c8 100644 --- a/ge/init/gelib.h +++ b/ge/init/gelib.h @@ -22,9 +22,14 @@ #include #include "engine_manager/dnnengine_manager.h" #include "opskernel_manager/ops_kernel_manager.h" -#include "session/session_manager.h" -#include "common/ge_inner_error_codes.h" -#include "common/ge_types.h" +#include "graph/tuning_utils.h" +#include "graph/operator_factory.h" +#include "graph/ge_local_context.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/anchor_utils.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/ge_types.h" using std::string; using std::map; @@ -53,20 +58,11 @@ class GE_FUNC_VISIBILITY GELib { // get OpsKernelManager object OpsKernelManager &OpsKernelManagerObj() { return opsManager_; } - // get SessionManager object - SessionManager &SessionManagerObj() { return sessionManager_; } - // get Initial flag bool InitFlag() const { return init_flag_; } // get TrainMode flag - bool isTrainMode() { return is_train_mode_; } - - // get incre build flag - bool IsIncreBuild() const { return is_incre_build_; } - - // get incre build cache path - const std::string &GetIncreBuildCachePath() const { return incre_build_cache_path_; } + bool IsTrainMode() { return is_train_mode_; } void InitProfiling(Options &options); void ShutDownProfiling(); @@ -90,7 +86,6 @@ class GE_FUNC_VISIBILITY GELib { DNNEngineManager engineManager_; OpsKernelManager opsManager_; - SessionManager sessionManager_; std::mutex status_mutex_; bool init_flag_ = false; Options options_; @@ -98,8 +93,6 @@ class GE_FUNC_VISIBILITY GELib { bool is_system_inited = false; bool is_shutdown = false; bool is_use_hcom = false; - bool is_incre_build_ = false; - std::string incre_build_cache_path_; }; } // namespace ge diff --git a/ge/ir_build/attr_options/attr_options.h b/ge/ir_build/attr_options/attr_options.h index 7c0f4f4f..9ea2b9a1 100644 --- a/ge/ir_build/attr_options/attr_options.h +++ b/ge/ir_build/attr_options/attr_options.h @@ -1,29 +1,30 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef ATTR_OPTIONS_H_ -#define ATTR_OPTIONS_H_ - -#include -#include "graph/compute_graph.h" -#include "graph/ge_error_codes.h" - -namespace ge { -bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); - -graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); -graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); -} // namespace +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ATTR_OPTIONS_H_ +#define ATTR_OPTIONS_H_ + +#include +#include "graph/compute_graph.h" +#include "graph/ge_error_codes.h" + +namespace ge { +bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); +bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type); +bool IsContainOpType(const std::string &cfg_line, std::string &op_type); +graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); +graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); +} // namespace #endif // ATTR_OPTIONS_H_ \ No newline at end of file diff --git a/ge/ir_build/attr_options/keep_dtype_option.cc b/ge/ir_build/attr_options/keep_dtype_option.cc index dfdd0df3..88f238c0 100644 --- a/ge/ir_build/attr_options/keep_dtype_option.cc +++ b/ge/ir_build/attr_options/keep_dtype_option.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "attr_options.h" +#include "ir_build/attr_options/attr_options.h" #include #include #include @@ -32,18 +32,24 @@ void KeepDtypeReportError(const std::vector &invalid_list, const st size_t list_size = invalid_list.size(); err_msg << "config file contains " << list_size; if (list_size == 1) { - err_msg << " operator not in the graph, op name:"; + err_msg << " operator not in the graph, "; } else { - err_msg << " operators not in the graph, op names:"; + err_msg << " operators not in the graph, "; } - + std::string cft_type; for (size_t i = 0; i < list_size; i++) { if (i == kMaxOpsNum) { err_msg << ".."; break; } - err_msg << invalid_list[i]; - if (i != list_size - 1) { + bool istype = IsContainOpType(invalid_list[i], cft_type); + if (!istype) { + err_msg << "op name:"; + } else { + err_msg << "op type:"; + } + err_msg << cft_type; + if (i != (list_size - 1)) { err_msg << " "; } } @@ -72,7 +78,7 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { return GRAPH_FAILED; } - std::string op_name; + std::string op_name, op_type; std::vector invalid_list; while (std::getline(ifs, op_name)) { if (op_name.empty()) { @@ -80,13 +86,20 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { } op_name = StringUtils::Trim(op_name); bool is_find = false; - for (auto &node_ptr : graph->GetDirectNode()) { + bool is_type = IsContainOpType(op_name, op_type); + for (auto &node_ptr : graph->GetAllNodes()) { auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - - if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) { - is_find = true; - (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); + if (is_type) { + if (IsOpTypeEqual(node_ptr, op_type)) { + is_find = true; + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); + } + } else { + if (op_desc->GetName() == op_name || IsOriginalOpFind(op_desc, op_name)) { + is_find = true; + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); + } } } if (!is_find) { diff --git a/ge/ir_build/attr_options/utils.cc b/ge/ir_build/attr_options/utils.cc index f0b559ec..23bb0b7b 100644 --- a/ge/ir_build/attr_options/utils.cc +++ b/ge/ir_build/attr_options/utils.cc @@ -13,12 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "attr_options.h" +#include "ir_build/attr_options/attr_options.h" #include #include "graph/debug/ge_attr_define.h" -#include "common/util/error_manager/error_manager.h" - +#include "framework/common/debug/ge_log.h" +#include "common/omg_util.h" namespace ge { + namespace { + const std::string CFG_PRE_OPTYPE = "OpType::"; +} bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { std::vector original_op_names; if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { @@ -33,4 +36,36 @@ bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { return false; } + +bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type) { + if (op_type != node->GetOpDesc()->GetType()) { + return false; + } + std::string origin_type; + auto ret = GetOriginalType(node, origin_type); + if (ret != SUCCESS) { + GELOGW("[Get][OriginalType] from op:%s failed.", node->GetName().c_str()); + return false; + } + if (op_type != origin_type) { + return false; + } + return true; +} + +bool IsContainOpType(const std::string &cfg_line, std::string &op_type) { + op_type = cfg_line; + size_t pos = op_type.find(CFG_PRE_OPTYPE); + if (pos != std::string::npos) { + if (pos == 0) { + op_type = cfg_line.substr(CFG_PRE_OPTYPE.length()); + return true; + } else { + GELOGW("[Check][Param] %s must be at zero pos of %s", CFG_PRE_OPTYPE.c_str(), cfg_line.c_str()); + } + return false; + } + GELOGW("[Check][Param] %s not contain optype", cfg_line.c_str()); + return false; +} } // namespace ge \ No newline at end of file diff --git a/ge/ir_build/attr_options/weight_compress_option.cc b/ge/ir_build/attr_options/weight_compress_option.cc index 3c057d04..b59c6adc 100644 --- a/ge/ir_build/attr_options/weight_compress_option.cc +++ b/ge/ir_build/attr_options/weight_compress_option.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "attr_options.h" +#include "ir_build/attr_options/attr_options.h" #include #include #include diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index 21db83aa..cafc534d 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -25,15 +25,15 @@ #include "framework/common/util.h" #include "framework/omg/omg_inner_types.h" #include "framework/omg/omg_inner_types.h" -#include "ge/ge_api_types.h" -#include "generator/ge_generator.h" +#include "external/ge/ge_api_types.h" +#include "framework/generator/ge_generator.h" #include "graph/compute_graph.h" #include "graph/ge_tensor.h" #include "graph/utils/type_utils.h" #include "graph/ge_global_options.h" #include "init/gelib.h" #include "ir_build/option_utils.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #include "graph/shape_refiner.h" #include "graph/opsproto_manager.h" #include "inc/pass_manager.h" @@ -263,6 +263,7 @@ class Impl { omg_context_.user_attr_index_valid = false; }; ~Impl() { (void)generator_.Finalize(); }; + graphStatus CheckBuildModeAndBuildStep(); graphStatus GetSupportedOptions(const std::map &in, std::map &out); graphStatus CheckOptions(const std::map &options); @@ -451,6 +452,37 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { return GRAPH_SUCCESS; } +graphStatus Impl::CheckBuildModeAndBuildStep() { + std::string build_mode; + auto it = options_.find(BUILD_MODE); + if (it != options_.end() && !(it->second.empty())) { + if (build_mode_options.find(it->second) == build_mode_options.end()) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({BUILD_MODE, it->second, "value is unsupported. Please check!"})); + GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode]:%s is unsupported. Please check!", it->second.c_str()); + return GRAPH_PARAM_INVALID; + } + build_mode = it->second; + } + it = options_.find(BUILD_STEP); + if (it != options_.end() && !(it->second.empty())) { + if (build_step_options.find(it->second) == build_step_options.end()) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({BUILD_STEP, it->second, "value is unsupported. Please check!"})); + GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildStep]:%s is unsupported. Please check!", it->second.c_str()); + return GRAPH_PARAM_INVALID; + } + } else { + if (build_mode == BUILD_MODE_TUNING) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({BUILD_MODE, it->second, "tuning must specify build step. Please check!"})); + GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode] tuning must specify build step. Please check!"); + return GRAPH_PARAM_INVALID; + } + } + return GRAPH_SUCCESS; +} + graphStatus Impl::GetSupportedOptions(const std::map &in, std::map &out) { for (auto &ele : in) { @@ -475,29 +507,12 @@ graphStatus Impl::CheckOptions(const std::map &options } // Check options build_mode and build_step. - std::string build_mode; - auto it = options_.find(BUILD_MODE); - if (it != options_.end() && !(it->second.empty())) { - if (build_mode_options.find(it->second) == build_mode_options.end()) { - GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode]:%s is unsupported. Please check!", it->second.c_str()); - return GRAPH_PARAM_INVALID; - } - build_mode = it->second; - } - it = options_.find(BUILD_STEP); - if (it != options_.end() && !(it->second.empty())) { - if (build_step_options.find(it->second) == build_step_options.end()) { - GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildStep]:%s is unsupported. Please check!", it->second.c_str()); - return GRAPH_PARAM_INVALID; - } - } else { - if (build_mode == BUILD_MODE_TUNING) { - GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode] tuning must specify build step. Please check!"); - return GRAPH_PARAM_INVALID; - } + ret = CheckBuildModeAndBuildStep(); + if (ret != GRAPH_SUCCESS) { + return ret; } // Check option EXEC_DISABLE_REUSED_MEMORY - it = options_.find(ge::ir_option::EXEC_DISABLE_REUSED_MEMORY); + auto it = options_.find(ge::ir_option::EXEC_DISABLE_REUSED_MEMORY); if (it != options_.end() && (CheckDisableReuseMemoryParamValid(it->second) != GRAPH_SUCCESS)) { return GRAPH_PARAM_INVALID; } @@ -505,6 +520,18 @@ graphStatus Impl::CheckOptions(const std::map &options if (ge::CheckModifyMixlistParamValid(options_) != GRAPH_SUCCESS) { return GRAPH_PARAM_INVALID; } + // Check option OP_PRECISION_MODE + it = options_.find(ge::ir_option::OP_PRECISION_MODE); + if (it != options_.end() && !it->second.empty() && !ge::CheckInputPathValid(it->second)) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ge::ir_option::OP_PRECISION_MODE, it->second, "path is not found"})); + GELOGE(GRAPH_PARAM_INVALID, "[Check][OP_PRECISION_MODE] %s not found", it->second.c_str()); + return GRAPH_PARAM_INVALID; + } + if (it != options_.end()) { + GELOGI("Option set successfully, option_key=%s, option_value=%s", + ge::ir_option::OP_PRECISION_MODE, it->second.c_str()); + } // Check Input Format if (options_.find(kInputFormat) != options_.end()) { return CheckInputFormat(options_[kInputFormat]); @@ -559,8 +586,8 @@ graphStatus Impl::Init(const Graph &graph, const std::map(string(ge::RUN_FLAG), to_string(0))); options_.insert(std::pair(string(ge::TRAIN_FLAG), to_string(0))); options_.insert(std::pair(string(ge::SAVE_ORIGINAL_MODEL), to_string(0))); + options_.insert(std::pair(string(ge::OPTION_GRAPH_RUN_MODE), to_string(0))); // print ge option map ge::PrintOptionMap(options_, "ge option"); @@ -839,6 +867,7 @@ graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const siz graphStatus aclgrphGenerateForOp(const AscendString &op_type, const vector &inputs, const vector &outputs, Graph &graph) { ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); + GE_CHECK_NOTNULL(op_type.GetString()); auto op_type_str = std::string(op_type.GetString()); auto op_name = op_type_str + "_" + std::to_string(ge::GetCurrentTimestamp()); auto op_desc = ge::MakeShared(op_name, op_type_str); diff --git a/ge/ir_build/option_utils.cc b/ge/ir_build/option_utils.cc index cecc2588..7287fe91 100755 --- a/ge/ir_build/option_utils.cc +++ b/ge/ir_build/option_utils.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "option_utils.h" +#include "ir_build/option_utils.h" #include "common/util/error_manager/error_manager.h" #include "external/ge/ge_api_types.h" #include "framework/common/string_util.h" @@ -50,6 +50,8 @@ const std::set kBufferOptimizeSupportOption = {"l1_optimize", "l2_o const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize"; const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance"; const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision"; +const char *const IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PRECISION_FOR_ALL = "high_precision_for_all"; +const char *const IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PERFORMANCE_FOR_ALL = "high_performance_for_all"; const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; const char *const kSplitError1 = "size not equal to 2 split by \":\""; @@ -57,7 +59,8 @@ const char *const kEmptyError = "can not be empty"; const char *const kFloatNumError = "exist float number"; const char *const kDigitError = "is not digit"; const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; -const char *const kSelectImplmodeError = "only support high_performance, high_precision"; +const char *const kSelectImplmodeError = "only support high_performance, high_precision, " + "high_precision_for_all, high_performance_for_all"; const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\""; const char *const kKeepDtypeError = "file not found"; @@ -204,7 +207,7 @@ bool CheckDynamicImagesizeInputShapeValid(map> shape_map if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(input_format.c_str())) { GELOGE(ge::PARAM_INVALID, "[Check][DynamicImagesizeInputShape] input_format [%s] invalid, can not support now.", input_format.c_str()); - REPORT_INPUT_ERROR("E10003", std::vector({"parameter","value","reason"}), + REPORT_INPUT_ERROR("E10003", std::vector({"parameter", "value", "reason"}), std::vector({"input_format", input_format, "this format is not support"})); return false; } @@ -782,7 +785,9 @@ Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std:: op_select_implmode = IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT; } else { if (op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT && - op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON) { + op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON && + op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PRECISION_FOR_ALL && + op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_HIGH_PERFORMANCE_FOR_ALL) { ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, {"--op_select_implmode", op_select_implmode.c_str(), kSelectImplmodeError}); diff --git a/ge/offline/CMakeLists.txt b/ge/offline/CMakeLists.txt index a520652f..935d8a30 100644 --- a/ge/offline/CMakeLists.txt +++ b/ge/offline/CMakeLists.txt @@ -22,7 +22,6 @@ target_compile_options(atc_atc.bin PRIVATE target_compile_definitions(atc_atc.bin PRIVATE PROTOBUF_INLINE_NOT_IN_HEADERS=0 - COMPILE_OMG_PACKAGE google=ascend_private LOG_CPP FUNC_VISIBILITY @@ -30,25 +29,17 @@ target_compile_definitions(atc_atc.bin PRIVATE target_include_directories(atc_atc.bin PRIVATE ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR} ${GE_CODE_DIR}/ge ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/common/inc/external - ${GE_CODE_DIR}/common/inc/external/graph ${GE_CODE_DIR}/inc ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/graph - ${METADEF_DIR}/inc/register ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/external/register ${PARSER_DIR} ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### ${GE_CODE_DIR}/../inc - ${GE_CODE_DIR}/../inc/common #### blue zone #### ${GE_CODE_DIR}/third_party/fwkacllib/inc ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain @@ -56,6 +47,7 @@ target_include_directories(atc_atc.bin PRIVATE target_link_options(atc_atc.bin PRIVATE -Wl,-Bsymbolic + -Wl,-rpath-link,${ASCEND_ATC_DIR}/stub ) target_link_libraries(atc_atc.bin PRIVATE @@ -70,7 +62,7 @@ target_link_libraries(atc_atc.bin PRIVATE parser_common gflags json - runtime_compile + runtime slog static_mmpa -lrt @@ -99,7 +91,6 @@ target_compile_options(fwk_atc.bin PRIVATE target_compile_definitions(fwk_atc.bin PRIVATE PROTOBUF_INLINE_NOT_IN_HEADERS=0 - COMPILE_OMG_PACKAGE google=ascend_private LOG_CPP FUNC_VISIBILITY @@ -107,25 +98,17 @@ target_compile_definitions(fwk_atc.bin PRIVATE target_include_directories(fwk_atc.bin PRIVATE ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR} ${GE_CODE_DIR}/ge ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/common/inc/external - ${GE_CODE_DIR}/common/inc/external/graph ${GE_CODE_DIR}/inc ${GE_CODE_DIR}/inc/framework ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/graph - ${METADEF_DIR}/inc/register ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/external/register ${PARSER_DIR} ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/proto/graphengine_protos #### yellow zone #### ${GE_CODE_DIR}/../inc - ${GE_CODE_DIR}/../inc/common #### blue zone #### ${GE_CODE_DIR}/third_party/fwkacllib/inc ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain diff --git a/ge/offline/main.cc b/ge/offline/main.cc index a1ae476b..a50ff931 100755 --- a/ge/offline/main.cc +++ b/ge/offline/main.cc @@ -23,26 +23,26 @@ #include #include #include -#include "common/gflags_util.h" -#include "common/util.h" +#include "framework/common/gflags_util.h" +#include "framework/common/util.h" #include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" -#include "ge/ge_api.h" -#include "generator/ge_generator.h" +#include "external/ge/ge_api.h" +#include "framework/generator/ge_generator.h" #include "graph/anchor.h" #include "graph/debug/ge_attr_define.h" -#include "graph/graph.h" +#include "external/graph/graph.h" #include "graph/op_desc.h" #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" #include "ir_build/option_utils.h" -#include "omg/omg.h" -#include "omg/parser/parser_factory.h" -#include "omg/parser/parser_inner_ctx.h" +#include "framework/omg/omg.h" +#include "framework/omg/parser/parser_factory.h" +#include "framework/omg/parser/parser_inner_ctx.h" #include "parser/common/register_tbe.h" #include "register/op_registry.h" -#include "single_op_parser.h" +#include "offline/single_op_parser.h" #include "external/ge/ge_ir_build.h" using domi::BuildMode; @@ -106,10 +106,14 @@ DEFINE_string(out_nodes, "", "Optional; output nodes designated by users." "Format: \"node_name1:0;node_name1:1;node_name2:0\""); +DEFINE_string(op_precision_mode, "", "Optional; operator precision mode configuration file path"); + DEFINE_string(precision_mode, "force_fp16", "Optional; precision mode." "Support force_fp16, force_fp32, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); +DEFINE_string(modify_mixlist, "", "Optional; operator mixed precision configuration file path"); + DEFINE_string(keep_dtype, "", "Optional; config file to specify the precision used by the operator during compilation."); @@ -139,7 +143,8 @@ DEFINE_string(output_type, "", DEFINE_string(op_select_implmode, "", "Optional; op select implmode! " - "Support high_precision, high_performance."); + "Support high_precision, high_performance, " + "high_precision_for_all, high_performance_for_all."); DEFINE_string(optypelist_for_implmode, "", "Optional; Nodes need use implmode selected in op_select_implmode " @@ -192,8 +197,11 @@ DEFINE_string(log, "null", "Optional; generate atc log. Support debug, info, war DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0."); -DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug;" - "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler"); +DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug; " + "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler; " + "3: disable debug, and keep generating kernel file (.o and .json); 4: disable debug, " + "keep generation kernel file (.o and .json) and generate the operator CCE file (.cce) " + "and the UB fusion computing description file (.json)"); DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass," "multiple names can be set and separated by ','."); DEFINE_string(debug_dir, "", "Optional; the path to save the intermediate files of operator compilation"); @@ -210,8 +218,6 @@ DEFINE_string(display_model_info, "0", "Optional; display model info"); DEFINE_string(device_id, "0", "Optional; user device id"); -DEFINE_string(modify_mixlist, "", "Optional; operator mixed precision configuration file path"); - class GFlagUtils { public: /** @@ -298,14 +304,16 @@ class GFlagUtils { "\"l1_optimize\", \"off_optimize\"\n" " --mdl_bank_path Set the path of the custom repository generated after model tuning.\n" "\n[Operator Tuning]\n" + " --op_precision_mode Set the path of operator precision mode configuration file (.ini)\n" " --precision_mode precision mode, support force_fp16(default), force_fp32, allow_mix_precision, " "allow_fp32_to_fp16, must_keep_origin_dtype.\n" + " --modify_mixlist Set the path of operator mixed precision configuration file.\n" " --keep_dtype Retains the precision of certain operators in inference " "scenarios by using a configuration file.\n" " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" " --op_bank_path Set the path of the custom repository generated after operator tuning with Auto Tune.\n" - " --op_select_implmode Set op select implmode. Support high_precision, high_performance. " - "default: high_performance\n" + " --op_select_implmode Set op select implmode. Support high_precision, high_performance, " + "high_precision_for_all, high_performance_for_all. default: high_performance\n" " --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n" " Separate multiple nodes with commas (,). Use double quotation marks (\") " "to enclose each argument. E.g.: \"node_name1,node_name2\"\n" @@ -315,7 +323,8 @@ class GFlagUtils { " 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file " "(.json), and enable the CCE compiler -O0-g.\n" " 3: Disable debug, and keep generating kernel file (.o and .json)\n" - " --modify_mixlist Set the path of operator mixed precision configuration file.\n" + " 4: Disable debug, keep generation kernel file (.o and .json) and generate the " + "operator CCE file (.cce) and the UB fusion computing description file (.json)" "\n[Debug]\n" " --save_original_model Control whether to output original model. E.g.: true: output original model\n" " --log Generate log with level. Support debug, info, warning, error, null\n" @@ -365,6 +374,14 @@ class GFlagUtils { FLAGS_op_select_implmode) != ge::SUCCESS, ret = ge::FAILED, "[Check][ImplMode]check optypelist_for_implmode and op_select_implmode failed!"); + if (!FLAGS_op_precision_mode.empty() && !ge::CheckInputPathValid(FLAGS_op_precision_mode, "--op_precision_mode")) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"op_precision_mode", FLAGS_op_precision_mode.c_str(), + "path is not found"}); + GELOGE(ge::FAILED, "[Check][op_precision_mode] %s not found", FLAGS_op_precision_mode.c_str()); + ret = ge::FAILED; + } + if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) { ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, {"modify_mixlist", FLAGS_modify_mixlist.c_str(), @@ -847,6 +864,7 @@ domi::Status GenerateInfershapeJson() { ge::Graph graph; std::map atc_params; atc_params.insert(std::pair("input_format", FLAGS_input_format)); + atc_params.insert(std::pair("check_report", FLAGS_check_report)); ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework, "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false); if (ret != ge::SUCCESS) { @@ -953,8 +971,7 @@ domi::Status GenerateModel(std::map &options, std::string output ge::Model load_model = ge::Model("loadmodel", "version2"); auto ret1 = load_model.LoadFromFile(FLAGS_model); if (ret1 != ge::GRAPH_SUCCESS) { - REPORT_INPUT_ERROR("E10041", std::vector({"file"}), std::vector({FLAGS_model})); - REPORT_CALL_ERROR("E19999", "load from model file:%s failed", FLAGS_model.c_str()); + REPORT_INPUT_ERROR("E10041", std::vector({"parameter"}), std::vector({FLAGS_model})); DOMI_LOGE("Load model from %s failed, please check model file or " "input parameter[--framework] is correct", FLAGS_model.c_str()); (void)ge_generator.Finalize(); @@ -1050,6 +1067,7 @@ static void SetEnvForSingleOp(std::map &options) { options.emplace(ge::RUN_FLAG, flag_off); options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off); options.emplace(ge::SINGLE_OP_FLAG, flag_on); + options.emplace(ge::OP_PRECISION_MODE, FLAGS_op_precision_mode); options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode); options.emplace(ge::SOC_VERSION, FLAGS_soc_version); options.emplace(ge::CORE_TYPE, FLAGS_core_type); @@ -1077,6 +1095,14 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) { ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS, return ge::FAILED, "[Check][ImplmodeParam] fail for input optypelist_for_implmode and op_select_implmode."); + if (!FLAGS_op_precision_mode.empty() && !ge::CheckInputPathValid(FLAGS_op_precision_mode, "--op_precision_mode")) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"op_precision_mode", FLAGS_op_precision_mode.c_str(), + "path is not found"}); + GELOGE(ge::FAILED, "[Check][op_precision_mode] %s not found", FLAGS_op_precision_mode.c_str()); + return ge::FAILED; + } + if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) { ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, {"modify_mixlist", FLAGS_modify_mixlist.c_str(), @@ -1124,9 +1150,9 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) { if (ret != SUCCESS) { DOMI_LOGE("Compile op failed. ge ret = %u, op index = %d", ret, index); ret = domi::FAILED; - break; + } else { + GELOGI("Compile op success. op index = %d, output = %s", index, output_path.c_str()); } - GELOGI("Compile op success. op index = %d, output = %s", index, output_path.c_str()); index += 1; } @@ -1160,6 +1186,7 @@ domi::Status GenerateOmModel() { options.insert(std::pair(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf)); options.insert(std::pair(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); options.insert(std::pair(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); + options.insert(std::pair(string(ge::OP_PRECISION_MODE), FLAGS_op_precision_mode)); options.insert(std::pair(string(ge::PRECISION_MODE), FLAGS_precision_mode)); options.insert(std::pair(string(ge::TUNE_DEVICE_IDS), FLAGS_device_id)); diff --git a/ge/offline/proto/ge_ir.proto b/ge/offline/proto/ge_ir.proto deleted file mode 100644 index c0ef3071..00000000 --- a/ge/offline/proto/ge_ir.proto +++ /dev/null @@ -1,193 +0,0 @@ -syntax = "proto3"; - -package ge.proto; - -enum DataType -{ - DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. - DT_FLOAT = 1; // float type - DT_FLOAT16 = 2; // fp16 type - DT_INT8 = 3; // int8 type - DT_UINT8 = 4; // uint8 type - DT_INT16 = 5; // int16 type - DT_UINT16 = 6; // uint16 type - DT_INT32 = 7; // - DT_INT64 = 8; // int64 type - DT_UINT32 = 9; // unsigned int32 - DT_UINT64 = 10; // unsigned int64 - DT_BOOL = 11; // bool type - DT_DOUBLE = 12; // double type - DT_STRING = 13; // string type - DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ - DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ - DT_COMPLEX64 = 16; // complex64 type - DT_COMPLEX128 = 17; // complex128 type - DT_QINT8 = 18; // qint8 type - DT_QINT16 = 19; // qint16 type - DT_QINT32 = 20; // qint32 type - DT_QUINT8 = 21; // quint8 type - DT_QUINT16 = 22; // quint16 type - DT_RESOURCE = 23; // resource type - DT_STRING_REF = 24; // string_ref type - DT_DUAL = 25; /**< dual output type */ - DT_VARIANT = 26; // variant type - DT_BF16 = 27; // bf16 type - DT_INT4 = 28; // int4 type -} - -message AttrDef -{ - message ListValue - { - enum ListValueType{ - VT_LIST_NONE = 0; - VT_LIST_STRING = 1; - VT_LIST_INT = 2; - VT_LIST_FLOAT = 3; - VT_LIST_BOOL = 4; - VT_LIST_BYTES = 5; - VT_LIST_TENSOR_DESC = 6; - VT_LIST_TENSOR = 7; - VT_LIST_GRAPH = 8; - VT_LIST_NAMED_ATTRS = 9; - VT_LIST_DATA_TYPE = 10; - } - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3; // "list(int)" - repeated float f = 4; // "list(float)" - repeated bool b = 5; // "list(bool)" - repeated bytes bt = 7; - repeated TensorDescriptor td = 8; - repeated TensorDef t = 9; - repeated GraphDef g = 10; - repeated NamedAttrs na = 11; - repeated int64 dt = 12; // list ge::DataType - - ListValueType val_type = 20; - } - - message ListListInt{ - message ListInt{ - repeated int64 list_i = 1; // list int - } - repeated ListInt list_list_i = 1; // list list int - } - - oneof value - { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; // Used to support attr nesting - TensorDescriptor td = 11; // GeTensorDesc type - TensorDef t = 12; // GeTensor type - GraphDef g = 13; // Graph type - ListListInt list_list_int = 14; // List List Int type - int64 dt = 15; // ge::DataType - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs -{ - string name = 1; - map attr = 2; -} - -// Shape / dimension description, using row-major order -message ShapeDef -{ - repeated int64 dim = 1; // Size of each dimension -} - -// Multidimensional data description -message TensorDescriptor -{ - string name = 1; // Optional parameter, tensor name - - DataType dtype = 2; // tensor datatype - ShapeDef shape = 3; // Shape / dimension - string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" - - bool has_out_attr = 9; - int64 size = 10; - int64 weight_size = 11; - bool reuse_input = 12; - bool output_tensor = 13; - string device_type = 14; - bool input_tensor =15; - int64 real_dim_cnt = 16; - int64 reuse_input_index = 17; - int64 data_offset = 18; - int64 cmps_size = 19; - string cmps_tab = 20; - int64 cmps_tab_offset = 21; - - map attr = 5; // Set of extra parameter fields -} - -// GeTensor definition -message TensorDef -{ - TensorDescriptor desc = 1; // Tensor description - bytes data = 2; // Tensor data -} - - -// Operator description -message OpDef -{ - string name = 1; // name - string type = 2; // type - - repeated string input = 5; // input original op name + outgoing index. op_name:index - - map attr = 10; // Set of operator parameter fields - - bool has_out_attr = 20; - int64 id = 21; - int64 stream_id =22; - repeated string input_name = 23; - repeated string src_name = 24; - repeated int64 src_index = 25; - repeated string dst_name = 26; - repeated int64 dst_index = 27; - repeated int64 input_i = 28; - repeated int64 output_i = 29; - repeated int64 workspace = 30; - repeated int64 workspace_bytes = 31; - repeated bool is_input_const = 32; - repeated TensorDescriptor input_desc = 33; - repeated TensorDescriptor output_desc = 34; - repeated string subgraph_name = 35; -} - -// Graph definition -message GraphDef -{ - string name = 1; // name - - repeated string input = 4; // Graph input - repeated string output = 5; // Graph output - - repeated OpDef op = 6; // List of operators - - map attr = 11; // Extended field -} - -// model definition -message ModelDef -{ - string name = 1; // name - uint32 version = 2; // IR Proto verion - string custom_version = 3; // User model version number, passed in by user - - repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef - - map attr = 11; // Extended field -} - diff --git a/ge/offline/proto/insert_op.proto b/ge/offline/proto/insert_op.proto deleted file mode 100644 index 7d708865..00000000 --- a/ge/offline/proto/insert_op.proto +++ /dev/null @@ -1,140 +0,0 @@ -syntax = "proto3"; - -package domi; - -message InsertNewOps { - repeated AippOpParams aipp_op = 1; - repeated MultiShapeOpParams multi_shape_op = 2; -} - -message AippOpParams { - enum InputFormat { - UNDEFINED = 0; - YUV420SP_U8 = 1; - XRGB8888_U8 = 2; - RGB888_U8 = 3; - YUV400_U8 = 4; - NC1HWC0DI_FP16 = 5; - NC1HWC0DI_S8 = 6; - ARGB8888_U8 = 7; - YUYV_U8 = 8; - YUV422SP_U8 = 9; - AYUV444_U8 = 10; - RAW10 = 11; - RAW12 = 12; - RAW16 = 13; - RAW24 = 14; - RGB16 = 15; - RGB20 = 16; - RGB24 = 17; - RGB8_IR = 18; - RGB16_IR = 19; - RGB24_IR = 20; - } - - enum AippMode { - undefined = 0; - static = 1; - dynamic = 2; - } - - // AIPPģʽ־̬AIPPͶ̬AIPP - AippMode aipp_mode = 1; - - // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 - // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 - uint32 related_input_rank = 2; - - // related_input_name is optional and the top name of data node which inserts aipp - string related_input_name = 6; - - // input_edge_idxΪѡΪͣ÷ΧΪ>=0 - // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP - // ֵ <= Dataߵĸ - repeated uint32 input_edge_idx = 3; - - // [Begin] ̬AIPPþ̬AIPPʱЧ - uint32 max_src_image_size = 4; - - // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ - bool support_rotation = 5; - - // [End] ̬AIPP - - - // [Begin] ̬AIPPö̬AIPPʱЧ - InputFormat input_format = 51; - bool csc_switch = 52; - float cpadding_value = 53; - bool rbuv_swap_switch = 54; - bool ax_swap_switch = 55; - bool single_line_mode = 56; - - int32 src_image_size_w = 57; - int32 src_image_size_h = 58; - - bool crop = 59; - int32 load_start_pos_w = 60; - int32 load_start_pos_h = 61; - int32 crop_size_w = 62; - int32 crop_size_h = 63; - - bool resize = 64; - int32 resize_output_w = 65; - int32 resize_output_h = 66; - - bool padding = 67; - int32 left_padding_size = 68; - int32 right_padding_size = 69; - int32 top_padding_size = 70; - int32 bottom_padding_size = 71; - float padding_value = 72; - - int32 mean_chn_0 = 10; - int32 mean_chn_1 = 11; - int32 mean_chn_2 = 12; - int32 mean_chn_3 = 19; - float min_chn_0 = 13; - float min_chn_1 = 14; - float min_chn_2 = 15; - float min_chn_3 = 20; - repeated float var_reci_chn_0 = 16; - repeated float var_reci_chn_1 = 17; - repeated float var_reci_chn_2 = 18; - repeated float var_reci_chn_3 = 21; - - repeated int32 matrix_r0c0 = 30; - repeated int32 matrix_r0c1 = 31; - repeated int32 matrix_r0c2 = 32; - repeated int32 matrix_r1c0 = 33; - repeated int32 matrix_r1c1 = 34; - repeated int32 matrix_r1c2 = 35; - repeated int32 matrix_r2c0 = 36; - repeated int32 matrix_r2c1 = 37; - repeated int32 matrix_r2c2 = 38; - repeated int32 output_bias_0 = 39; - repeated int32 output_bias_1 = 40; - repeated int32 output_bias_2 = 41; - repeated int32 input_bias_0 = 42; - repeated int32 input_bias_1 = 43; - repeated int32 input_bias_2 = 44; - - // [End] ̬AIPP - - // The n number that is used for raw/rgbir data into f16 transformation. - // The transformation equation is x/(2^n). If set to 0, no transform is performed. - uint32 raw_rgbir_to_f16_n = 45; -} - -message MultiShapeOpParams { - enum MultiShapeMode { - batch = 0; //̬batch - resolution = 1; //ֱ̬ʣչ - } - - MultiShapeMode mode = 1; //ģʽ - uint32 related_input_rank = 2; //Ӳ뵽ĸ - - - repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ -} diff --git a/ge/offline/proto/om.proto b/ge/offline/proto/om.proto deleted file mode 100644 index e15e5f80..00000000 --- a/ge/offline/proto/om.proto +++ /dev/null @@ -1,396 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package domi; - -enum TargetType -{ - MINI = 0; - TINY = 1; - LITE = 2; -} - -// offline model -message ModelDef { - string name = 1; - uint32 version = 2; - - uint64 memory_size = 10; - uint32 stream_num = 11; - uint32 event_num = 12; - uint64 weight_size = 13; - uint32 label_num = 15; - repeated OpDef op = 20; - TargetType target_type = 23; - - map attr = 30; -}; - -// operator define -message OpDef { - string name = 1; - string type = 2; - - uint32 id = 3; - uint32 stream_id = 4; - - repeated string input_name = 5; - - repeated string src_name = 8; - repeated int32 src_index = 9; - repeated int64 input = 10; - repeated int64 output = 11; - repeated TensorDescriptor input_desc = 12; - repeated TensorDescriptor output_desc = 13; - repeated WeightDef weights = 14; - repeated string dst_name = 15; - repeated int32 dst_index = 16; - - repeated int64 workspace = 20; - repeated uint32 workspace_bytes = 21; - - repeated string weight_name = 22; - repeated bool is_input_const = 23; - - map attr = 30; - - QuantizeFactorParams quantize_factor = 31; - - oneof op_params { - // start at 100 here - SendOpParams sender_param = 100; - RecvOpParams receiver_param = 200; - ConvolutionOpParams convolution_param = 300; - PoolingOpParams pooling_param = 400; - EltwiseOpParams eltwise_param = 500; - BatchNormOpParams batchnorm_param = 600; - ScaleOpParams scale_param = 700; - FullConnectionOpParams full_connection_param = 800; - SoftmaxOpParams softmax_param = 900; - ActivationOpParams activation_param = 1000; - ReshapeOpParams reshape_param = 1100; - } -}; - -message SendOpParams { - uint32 event_id = 1; -}; - -message RecvOpParams { - uint32 event_id = 1; -}; - -enum QuantizeScaleType -{ - VECTOR_SCALE = 0; - SCALAR_SCALE = 1; -} - -enum QuantizeScaleMode -{ - NORMAL_MODE = 0; - SQRT_MODE = 1; -} - -enum QuantizeAlgorithm -{ - NON_OFFSET_ALGO = 0; - HALF_OFFSET_ALGO = 1; - ALL_OFFSET_ALGO = 2; -} -message QuantizeFactor -{ - QuantizeScaleMode scale_mode = 1; - bytes scale_value = 2; - int64 scale_offset = 3; - bytes offset_data_value = 4; - int64 offset_data_offset = 5; - bytes offset_weight_value = 6; - int64 offset_weight_offset = 7; - bytes offset_pad_value = 8; - int64 offset_pad_offset = 9; -}; - -message QuantizeCalcFactor -{ - bytes offsetw = 1; - int64 offsetw_offset = 2; - bytes offsetd = 3; - int64 offsetd_offset = 4; - bytes scalereq = 5; - int64 scaledreq_offset = 6; - bytes offsetdnext = 7; - int64 offsetdnext_offset = 8; -} - -message QuantizeFactorParams -{ - QuantizeAlgorithm quantize_algo = 1; - QuantizeScaleType scale_type = 2; - QuantizeFactor quantize_param = 3; - QuantizeFactor dequantize_param = 4; - QuantizeFactor requantize_param = 5; - QuantizeCalcFactor quantizecalc_param = 6; -}; - -message ConvolutionOpParams { - int32 mode = 1; - int32 algo = 2; - int32 pad_mode = 3; - uint32 group = 4; - uint32 num_output = 5; - - repeated uint32 pad = 10; - repeated uint32 stride = 11; - repeated uint32 dilation = 12; - repeated uint32 kernel = 13; - - float alpha = 20; - float beta = 21; - - WeightDef filter = 40; - WeightDef bias = 41; - - bool relu_flag = 62; - repeated uint32 adj = 70; - repeated uint32 target_shape = 71; - repeated uint32 before_pad = 72; -}; - -message PoolingOpParams { - int32 mode = 1; - int32 nan_opt = 2; - int32 pad_mode = 3; - bool global_pooling = 4; - - repeated uint32 window = 10; - repeated uint32 pad = 11; - repeated uint32 stride = 12; - bool ceil_mode = 13; - int32 data_mode = 14; - - float alpha = 20; - float beta = 21; - repeated uint32 before_pad = 22; -}; - -message EltwiseOpParams { - int32 mode = 1; - repeated float coeff = 2; - float alpha = 3; - float beta = 4; - repeated WeightDef weight = 5; - bool relu_flag = 6; -}; - -message ActivationOpParams { - int32 mode = 1; - float coef = 2; - float alpha = 3; - float beta = 4; -}; - -message BatchNormOpParams { - int32 mode = 1; - - float alpha = 2; - float beta = 3; - double epsilon = 4;//optinal,[default = 1e-5] - bool use_global_stats = 5; //optinal,by default true,testing mode - float moving_average_fraction = 6; //optinal,[default = .999]; - - WeightDef estimated_mean = 7; - WeightDef estimated_variance = 8; - - WeightDef scale = 9; - WeightDef bias = 10; -}; - -message ScaleOpParams { - WeightDef scale = 1; - WeightDef bias = 2; -}; - -message ReshapeOpParams { - float alpha = 1; - float beta = 2; - ShapeDef shape = 3; - int32 axis = 4; - int32 num_axes = 5; - int32 format = 6; -}; - -message SoftmaxOpParams { - int32 algo = 1; - int32 mode = 2; - float alpha = 3; - float beta = 4; -}; - -message FullConnectionOpParams { - WeightDef filter = 1; - WeightDef bias = 2; - uint32 num_output = 3; - bool relu_flag = 12; -}; - -message FlattenOpParams { - float alpha = 1; - float beta = 2; - int32 start_axis = 3; - int32 end_axis = 4; -} - -message AddLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message MulLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message AddOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message MulOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message SubOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message BiasAddOpParams { - float alpha = 1; - float beta = 2; - - WeightDef bias = 10; -}; - -message MatMulOpParams { - float alpha = 1; - float beta = 2; - bool transposeX = 3; - bool transposeW = 4; - - WeightDef filter = 10; - WeightDef bias = 12; -}; - -message RsqrtOpParams { - float alpha = 1; - float beta = 2; -}; - - -message WeightDef { - int32 format = 1; - int32 data_type = 2; - ShapeDef shape = 3; - bytes data = 4; - int64 data_offset = 5; - uint32 cmps_size = 6; - bytes cmps_tab = 7; - int64 cmps_tab_offset = 10; - CompressInfo cmps_info = 8; - AllOffsetQuantizeInfo alloffset_quantize_info = 11; -} - -message ShapeDef { - repeated int64 dim = 1; -} - -enum DeviceType { - NPU = 0; // In default, we will use NPU. - CPU = 1; // CPU -} - -message AllOffsetQuantizeInfo { - float scale = 1; - int32 offset = 2; -} - -message TensorDescriptor { - int32 format = 1; - int32 data_type = 2; - repeated int64 dim = 3; - uint32 size = 4; - bool reuse_input = 5; - bool output_tensor = 7; - DeviceType device_type = 8; - bool input_tensor = 9; - uint32 real_dim_cnt = 10; - uint32 reuse_input_index = 11; - AllOffsetQuantizeInfo alloffset_quantize_info = 12; -} - -message CompressInfo { - int32 blockRow = 1; // block row - int32 blockCol = 2; // block col - int32 fractalK = 3; // fractal K - int32 fractalN = 4; // fractal N - int32 lastFractalK = 5; // K of last fractal - int32 lastFractalN = 6; // N of last fractal - int32 cubeSize = 7; // cube's length - int32 loadDir = 8; // data load directtiono 0:col load 1:row load -} - -message AttrDef { - message ListValue { - repeated string s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated uint32 u = 6 [packed = true]; // "list(uint)" - repeated bytes bt = 7; - } - - oneof value { - string s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - uint32 u = 6; // "uint32" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs { - string name = 1; - map attr = 2; -} - diff --git a/ge/offline/proto/task.proto b/ge/offline/proto/task.proto deleted file mode 100644 index 0da5631e..00000000 --- a/ge/offline/proto/task.proto +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; - KernelDefWithHandle kernel_with_handle = 40; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; -} - -message KernelDefWithHandle { - KernelContext context = 1; - - uint64 handle = 10; - string dev_func = 11; - uint32 block_dim = 12; - uint32 args_size = 13; - bytes args = 14; - bytes sm_desc = 15; - string original_kernel_key = 16; - string node_info = 17; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - repeated uint32 origin_op_index = 8; -} - - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index dac2e15c..aeb73116 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "single_op_parser.h" +#include "offline/single_op_parser.h" #include #include @@ -24,7 +24,7 @@ #include "framework/common/debug/ge_log.h" #include "common/util/error_manager/error_manager.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/common/util.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" @@ -89,7 +89,8 @@ map kDataTypeDict = { {"float", DT_FLOAT}, {"float32", DT_FLOAT}, {"double", DT_DOUBLE}, - {"complex64", DT_COMPLEX64} + {"complex64", DT_COMPLEX64}, + {"complex128", DT_COMPLEX128} }; map kFormatDict = { diff --git a/ge/offline/single_op_parser.h b/ge/offline/single_op_parser.h index 11f5512e..25699552 100644 --- a/ge/offline/single_op_parser.h +++ b/ge/offline/single_op_parser.h @@ -21,8 +21,8 @@ #include -#include "ge/ge_api_error_codes.h" -#include "graph/types.h" +#include "external/ge/ge_api_error_codes.h" +#include "external/graph/types.h" #include "graph/ge_attr_value.h" #include "graph/op_desc.h" diff --git a/ge/opskernel_manager/ops_kernel_builder_manager.cc b/ge/opskernel_manager/ops_kernel_builder_manager.cc index 04262e1b..9f981302 100644 --- a/ge/opskernel_manager/ops_kernel_builder_manager.cc +++ b/ge/opskernel_manager/ops_kernel_builder_manager.cc @@ -15,11 +15,12 @@ */ #include "init/gelib.h" -#include "ops_kernel_builder_manager.h" +#include "opskernel_manager/ops_kernel_builder_manager.h" #include "register/ops_kernel_builder_registry.h" namespace ge { namespace { +#ifdef ONLY_COMPILE_OPEN_SRC const std::vector kBasicBuilderLibs = { "libge_local_opskernel_builder.so", "libhost_cpu_opskernel_builder.so", @@ -27,6 +28,15 @@ const std::vector kBasicBuilderLibs = { "libaicpu_ascend_builder.so", "libaicpu_tf_builder.so" }; +#else +const std::vector kBasicBuilderLibs = { + "libge_local_opskernel_builder.so", + "libhost_cpu_opskernel_builder.so", + "librts_engine.so", + "libaicpu_ascend_engine.so", + "libaicpu_tf_engine.so" +}; +#endif const std::vector kHcclBuilderLibs = { "libhcom_opskernel_builder.so", diff --git a/ge/opskernel_manager/ops_kernel_manager.cc b/ge/opskernel_manager/ops_kernel_manager.cc index ac5e9153..60958872 100644 --- a/ge/opskernel_manager/ops_kernel_manager.cc +++ b/ge/opskernel_manager/ops_kernel_manager.cc @@ -16,17 +16,7 @@ #include "opskernel_manager/ops_kernel_manager.h" -#include -#include -#include -#include - -#include -#include -#include -#include "../init/gelib.h" -#include "framework/common/debug/ge_log.h" -#include "ge/ge_api.h" +#include "init/gelib.h" #include "proto/optimizer_priority.pb.h" namespace { @@ -279,7 +269,7 @@ void OpsKernelManager::InitOpsKernelInfo() { if (it.second.empty()) { continue; } - auto comp_func = [this, &instance_ptr](const OpInfo &op_a, const OpInfo &op_b) -> bool { + auto comp_func = [&instance_ptr](const OpInfo &op_a, const OpInfo &op_b) -> bool { const string &a = op_a.engine; const string &b = op_b.engine; // check if a or b is registered diff --git a/ge/opskernel_manager/ops_kernel_manager.h b/ge/opskernel_manager/ops_kernel_manager.h index 19d703e3..5a72dc50 100644 --- a/ge/opskernel_manager/ops_kernel_manager.h +++ b/ge/opskernel_manager/ops_kernel_manager.h @@ -23,15 +23,15 @@ #include #include -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/ge/plugin_manager.h" #include "common/ge/op_tiling_manager.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "common/opskernel/ops_kernel_info_store.h" #include "common/optimizer/graph_optimizer.h" #include "graph/optimize/graph_optimize.h" #include "framework/common/ge_inner_error_codes.h" -#include "ge/ge_api_types.h" +#include "external/ge/ge_api_types.h" #include "runtime/base.h" using std::string; diff --git a/ge/plugin/engine/CMakeLists.txt b/ge/plugin/engine/CMakeLists.txt index b4ea9c52..b8628ad1 100644 --- a/ge/plugin/engine/CMakeLists.txt +++ b/ge/plugin/engine/CMakeLists.txt @@ -24,9 +24,8 @@ target_compile_definitions(engine PRIVATE target_include_directories(engine PRIVATE ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc/ + ${GE_CODE_DIR}/inc ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common ${GE_CODE_DIR}/inc/external ${METADEF_DIR}/inc ${METADEF_DIR}/inc/external diff --git a/ge/plugin/engine/dnnengines.h b/ge/plugin/engine/dnnengines.h index 0633c104..829c83f1 100644 --- a/ge/plugin/engine/dnnengines.h +++ b/ge/plugin/engine/dnnengines.h @@ -21,7 +21,7 @@ #include #include -#include "engine/dnnengine.h" +#include "framework/engine/dnnengine.h" #include "plugin/engine/engine_manage.h" namespace ge { diff --git a/ge/plugin/engine/engine_manage.h b/ge/plugin/engine/engine_manage.h index 7eb88805..a047e5de 100644 --- a/ge/plugin/engine/engine_manage.h +++ b/ge/plugin/engine/engine_manage.h @@ -36,7 +36,7 @@ #include #include -#include "engine/dnnengine.h" +#include "framework/engine/dnnengine.h" namespace ge { using DNNEnginePtr = std::shared_ptr; diff --git a/ge/proto/caffe/caffe.proto b/ge/proto/caffe/caffe.proto deleted file mode 100644 index 20615fed..00000000 --- a/ge/proto/caffe/caffe.proto +++ /dev/null @@ -1,1829 +0,0 @@ -/** - * This file is part of Open Source Software caffe, version 1.0 https://github.com/BVLC/caffe - * - * This file is included by GraphEngine so as to support model format conversion from caffe model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto2"; - -package domi.caffe; - -// Specifies the shape (dimensions) of a Blob. -message BlobShape { - repeated int64 dim = 1 [packed = true]; -} - -message BlobProto { - optional BlobShape shape = 7; - repeated float data = 5 [packed = true]; - repeated float diff = 6 [packed = true]; - repeated double double_data = 8 [packed = true]; - repeated double double_diff = 9 [packed = true]; - optional bytes int8_data = 10; - repeated int32 int32_data = 11 [packed = true]; - repeated uint64 uint64_data = 12 [packed = true]; - // 4D dimensions -- deprecated. Use "shape" instead. - optional int32 num = 1 [default = 0]; - optional int32 channels = 2 [default = 0]; - optional int32 height = 3 [default = 0]; - optional int32 width = 4 [default = 0]; -} - -// The BlobProtoVector is simply a way to pass multiple blobproto instances -// around. -message BlobProtoVector { - repeated BlobProto blobs = 1; -} - -message Datum { - optional int32 channels = 1; - optional int32 height = 2; - optional int32 width = 3; - // the actual image data, in bytes - optional bytes data = 4; - optional int32 label = 5; - // Optionally, the datum could also hold float data. - repeated float float_data = 6; - // If true data contains an encoded image that need to be decoded - optional bool encoded = 7 [default = false]; -} - -message FillerParameter { - // The filler type. - optional string type = 1 [default = 'constant']; - optional float value = 2 [default = 0]; // the value in constant filler - optional float min = 3 [default = 0]; // the min value in uniform filler - optional float max = 4 [default = 1]; // the max value in uniform filler - optional float mean = 5 [default = 0]; // the mean value in Gaussian filler - optional float std = 6 [default = 1]; // the std value in Gaussian filler - // The expected number of non-zero output weights for a given input in - // Gaussian filler -- the default -1 means don't perform sparsification. - optional int32 sparse = 7 [default = -1]; - // Normalize the filler variance by fan_in, fan_out, or their average. - // Applies to 'xavier' and 'msra' fillers. - enum VarianceNorm { - FAN_IN = 0; - FAN_OUT = 1; - AVERAGE = 2; - } - optional VarianceNorm variance_norm = 8 [default = FAN_IN]; -} - -message NetParameter { - optional string name = 1; // consider giving the network a name - // DEPRECATED. See InputParameter. The input blobs to the network. - repeated string input = 3; - // DEPRECATED. See InputParameter. The shape of the input blobs. - repeated BlobShape input_shape = 8; - - // 4D input dimensions -- deprecated. Use "input_shape" instead. - // If specified, for each input blob there should be four - // values specifying the num, channels, height and width of the input blob. - // Thus, there should be a total of (4 * #input) numbers. - repeated int32 input_dim = 4; - - // Whether the network will force every layer to carry out backward operation. - // If set False, then whether to carry out backward is determined - // automatically according to the net structure and learning rates. - optional bool force_backward = 5 [default = false]; - // The current "state" of the network, including the phase, level, and stage. - // Some layers may be included/excluded depending on this state and the states - // specified in the layers' include and exclude fields. - optional NetState state = 6; - - // Print debugging information about results while running Net::Forward, - // Net::Backward, and Net::Update. - optional bool debug_info = 7 [default = false]; - - // The layers that make up the net. Each of their configurations, including - // connectivity and behavior, is specified as a LayerParameter. - repeated LayerParameter layer = 100; // ID 100 so layers are printed last. - - // DEPRECATED: use 'layer' instead. - repeated V1LayerParameter layers = 2; -} - -// NOTE -// Update the next available ID when you add a new SolverParameter field. -// -// SolverParameter next available ID: 42 (last added: layer_wise_reduce) -message SolverParameter { - ////////////////////////////////////////////////////////////////////////////// - // Specifying the train and test networks - // - // Exactly one train net must be specified using one of the following fields: - // train_net_param, train_net, net_param, net - // One or more test nets may be specified using any of the following fields: - // test_net_param, test_net, net_param, net - // If more than one test net field is specified (e.g., both net and - // test_net are specified), they will be evaluated in the field order given - // above: (1) test_net_param, (2) test_net, (3) net_param/net. - // A test_iter must be specified for each test_net. - // A test_level and/or a test_stage may also be specified for each test_net. - ////////////////////////////////////////////////////////////////////////////// - - // Proto filename for the train net, possibly combined with one or more - // test nets. - optional string net = 24; - // Inline train net param, possibly combined with one or more test nets. - optional NetParameter net_param = 25; - - optional string train_net = 1; // Proto filename for the train net. - repeated string test_net = 2; // Proto filenames for the test nets. - optional NetParameter train_net_param = 21; // Inline train net params. - repeated NetParameter test_net_param = 22; // Inline test net params. - - // The states for the train/test nets. Must be unspecified or - // specified once per net. - // - // By default, all states will have solver = true; - // train_state will have phase = TRAIN, - // and all test_state's will have phase = TEST. - // Other defaults are set according to the NetState defaults. - optional NetState train_state = 26; - repeated NetState test_state = 27; - - // The number of iterations for each test net. - repeated int32 test_iter = 3; - - // The number of iterations between two testing phases. - optional int32 test_interval = 4 [default = 0]; - optional bool test_compute_loss = 19 [default = false]; - // If true, run an initial test pass before the first iteration, - // ensuring memory availability and printing the starting value of the loss. - optional bool test_initialization = 32 [default = true]; - optional float base_lr = 5; // The base learning rate - // the number of iterations between displaying info. If display = 0, no info - // will be displayed. - optional int32 display = 6; - // Display the loss averaged over the last average_loss iterations - optional int32 average_loss = 33 [default = 1]; - optional int32 max_iter = 7; // the maximum number of iterations - // accumulate gradients over `iter_size` x `batch_size` instances - optional int32 iter_size = 36 [default = 1]; - - // The learning rate decay policy. The currently implemented learning rate - // policies are as follows: - // - fixed: always return base_lr. - // - step: return base_lr * gamma ^ (floor(iter / step)) - // - exp: return base_lr * gamma ^ iter - // - inv: return base_lr * (1 + gamma * iter) ^ (- power) - // - multistep: similar to step but it allows non uniform steps defined by - // stepvalue - // - poly: the effective learning rate follows a polynomial decay, to be - // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) - // - sigmoid: the effective learning rate follows a sigmod decay - // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) - // - // where base_lr, max_iter, gamma, step, stepvalue and power are defined - // in the solver parameter protocol buffer, and iter is the current iteration. - optional string lr_policy = 8; - optional float gamma = 9; // The parameter to compute the learning rate. - optional float power = 10; // The parameter to compute the learning rate. - optional float momentum = 11; // The momentum value. - optional float weight_decay = 12; // The weight decay. - // regularization types supported: L1 and L2 - // controlled by weight_decay - optional string regularization_type = 29 [default = "L2"]; - // the stepsize for learning rate policy "step" - optional int32 stepsize = 13; - // the stepsize for learning rate policy "multistep" - repeated int32 stepvalue = 34; - - // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, - // whenever their actual L2 norm is larger. - optional float clip_gradients = 35 [default = -1]; - - optional int32 snapshot = 14 [default = 0]; // The snapshot interval - optional string snapshot_prefix = 15; // The prefix for the snapshot. - // whether to snapshot diff in the results or not. Snapshotting diff will help - // debugging but the final protocol buffer size will be much larger. - optional bool snapshot_diff = 16 [default = false]; - enum SnapshotFormat { - HDF5 = 0; - BINARYPROTO = 1; - } - optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; - // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. - enum SolverMode { - CPU = 0; - GPU = 1; - } - optional SolverMode solver_mode = 17 [default = GPU]; - // the device_id will that be used in GPU mode. Use device_id = 0 in default. - optional int32 device_id = 18 [default = 0]; - // If non-negative, the seed with which the Solver will initialize the Caffe - // random number generator -- useful for reproducible results. Otherwise, - // (and by default) initialize using a seed derived from the system clock. - optional int64 random_seed = 20 [default = -1]; - - // type of the solver - optional string type = 40 [default = "SGD"]; - - // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam - optional float delta = 31 [default = 1e-8]; - // parameters for the Adam solver - optional float momentum2 = 39 [default = 0.999]; - - // RMSProp decay value - // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) - optional float rms_decay = 38 [default = 0.99]; - - // If true, print information about the state of the net that may help with - // debugging learning problems. - optional bool debug_info = 23 [default = false]; - - // If false, don't save a snapshot after training finishes. - optional bool snapshot_after_train = 28 [default = true]; - - // DEPRECATED: old solver enum types, use string instead - enum SolverType { - SGD = 0; - NESTEROV = 1; - ADAGRAD = 2; - RMSPROP = 3; - ADADELTA = 4; - ADAM = 5; - } - // DEPRECATED: use type instead of solver_type - optional SolverType solver_type = 30 [default = SGD]; - - // Overlap compute and communication for data parallel training - optional bool layer_wise_reduce = 41 [default = true]; -} - -// A message that stores the solver snapshots -message SolverState { - optional int32 iter = 1; // The current iteration - optional string learned_net = 2; // The file that stores the learned net. - repeated BlobProto history = 3; // The history for sgd solvers - optional int32 current_step = 4 [default = 0]; // The current step for learning rate -} - -enum Phase { - TRAIN = 0; - TEST = 1; -} - -message NetState { - optional Phase phase = 1 [default = TEST]; - optional int32 level = 2 [default = 0]; - repeated string stage = 3; -} - -message NetStateRule { - // Set phase to require the NetState have a particular phase (TRAIN or TEST) - // to meet this rule. - optional Phase phase = 1; - - // Set the minimum and/or maximum levels in which the layer should be used. - // Leave undefined to meet the rule regardless of level. - optional int32 min_level = 2; - optional int32 max_level = 3; - - // Customizable sets of stages to include or exclude. - // The net must have ALL of the specified stages and NONE of the specified - // "not_stage"s to meet the rule. - // (Use multiple NetStateRules to specify conjunctions of stages.) - repeated string stage = 4; - repeated string not_stage = 5; -} - -// Specifies training parameters (multipliers on global learning constants, -// and the name and other settings used for weight sharing). -message ParamSpec { - // The names of the parameter blobs -- useful for sharing parameters among - // layers, but never required otherwise. To share a parameter between two - // layers, give it a (non-empty) name. - optional string name = 1; - - // Whether to require shared weights to have the same shape, or just the same - // count -- defaults to STRICT if unspecified. - optional DimCheckMode share_mode = 2; - enum DimCheckMode { - // STRICT (default) requires that num, channels, height, width each match. - STRICT = 0; - // PERMISSIVE requires only the count (num*channels*height*width) to match. - PERMISSIVE = 1; - } - - // The multiplier on the global learning rate for this parameter. - optional float lr_mult = 3 [default = 1.0]; - - // The multiplier on the global weight decay for this parameter. - optional float decay_mult = 4 [default = 1.0]; -} - -// NOTE -// Update the next available ID when you add a new LayerParameter field. -// -// LayerParameter next available layer-specific ID: 151 (last added: smooth_l1_loss_param) -message LayerParameter { - optional string name = 1; // the layer name - optional string type = 2; // the layer type - repeated string bottom = 3; // the name of each bottom blob - repeated string top = 4; // the name of each top blob - - // The train / test phase for computation. - optional Phase phase = 10; - - // The amount of weight to assign each top blob in the objective. - // Each layer assigns a default value, usually of either 0 or 1, - // to each top blob. - repeated float loss_weight = 5; - - // Specifies training parameters (multipliers on global learning constants, - // and the name and other settings used for weight sharing). - repeated ParamSpec param = 6; - - // The blobs containing the numeric parameters of the layer. - repeated BlobProto blobs = 7; - - // Specifies whether to backpropagate to each bottom. If unspecified, - // Caffe will automatically infer whether each input needs backpropagation - // to compute parameter gradients. If set to true for some inputs, - // backpropagation to those inputs is forced; if set false for some inputs, - // backpropagation to those inputs is skipped. - // - // The size must be either 0 or equal to the number of bottoms. - repeated bool propagate_down = 11; - - // Rules controlling whether and when a layer is included in the network, - // based on the current NetState. You may specify a non-zero number of rules - // to include OR exclude, but not both. If no include or exclude rules are - // specified, the layer is always included. If the current NetState meets - // ANY (i.e., one or more) of the specified rules, the layer is - // included/excluded. - repeated NetStateRule include = 8; - repeated NetStateRule exclude = 9; - - // Parameters for data pre-processing. - optional TransformationParameter transform_param = 100; - - // Parameters shared by loss layers. - optional LossParameter loss_param = 101; - - // Layer type-specific parameters. - // - // Note: certain layers may have more than one computational engine - // for their implementation. These layers include an Engine type and - // engine parameter for selecting the implementation. - // The default for the engine is set by the ENGINE switch at compile-time. - optional AccuracyParameter accuracy_param = 102; - optional ArgMaxParameter argmax_param = 103; - optional BatchNormParameter batch_norm_param = 139; - optional BiasParameter bias_param = 141; - optional ConcatParameter concat_param = 104; - optional ContrastiveLossParameter contrastive_loss_param = 105; - optional ConvolutionParameter convolution_param = 106; - optional CropParameter crop_param = 144; - optional DataParameter data_param = 107; - optional DetectionOutputParameter detection_output_param = 150; - optional DropoutParameter dropout_param = 108; - optional DummyDataParameter dummy_data_param = 109; - optional EltwiseParameter eltwise_param = 110; - optional ELUParameter elu_param = 140; - optional EmbedParameter embed_param = 137; - optional ExpParameter exp_param = 111; - optional FlattenParameter flatten_param = 135; - optional HDF5DataParameter hdf5_data_param = 112; - optional HDF5OutputParameter hdf5_output_param = 113; - optional HingeLossParameter hinge_loss_param = 114; - optional ImageDataParameter image_data_param = 115; - optional InfogainLossParameter infogain_loss_param = 116; - optional InnerProductParameter inner_product_param = 117; - optional InputParameter input_param = 143; - optional LogParameter log_param = 134; - optional LRNParameter lrn_param = 118; - optional MemoryDataParameter memory_data_param = 119; - optional MVNParameter mvn_param = 120; - optional ParameterParameter parameter_param = 145; - optional PoolingParameter pooling_param = 121; - optional PowerParameter power_param = 122; - optional PReLUParameter prelu_param = 131; - optional PythonParameter python_param = 130; - optional RecurrentParameter recurrent_param = 146; - optional ReductionParameter reduction_param = 136; - optional ReLUParameter relu_param = 123; - optional ReshapeParameter reshape_param = 133; - optional ScaleParameter scale_param = 142; - optional SigmoidParameter sigmoid_param = 124; - optional SmoothL1LossParameter smooth_l1_loss_param = 148; - optional SoftmaxParameter softmax_param = 125; - optional SPPParameter spp_param = 132; - optional SliceParameter slice_param = 126; - optional TanHParameter tanh_param = 127; - optional ThresholdParameter threshold_param = 128; - optional TileParameter tile_param = 138; - optional WindowDataParameter window_data_param = 129; - optional PermuteParameter permute_param = 202; - optional PriorBoxParameter prior_box_param = 203; - optional NormalizeParameter norm_param = 206; - optional PSROIPoolingParameter psroi_pooling_param = 207; - optional FreespaceExtractParameter freespace_extract_param = 151; - optional PostprocessParameter postprocess_param = 152; - optional SpatialTransformParameter spatial_transform_param = 153; - optional ROIAlignParameter roi_align_param = 154; - optional ReorgParameter reorg_param = 155; - optional RegionParameter region_param = 156; - optional ReverseParameter reverse_param = 157; - optional InterpParameter interp_param = 158; - optional ShuffleChannelParameter shuffle_channel_param = 159; - optional UpsampleParameter upsample_param = 160; - optional ROIPoolingParameter roi_pooling_param = 161; - optional YoloParameter yolo_param = 199; - optional YoloV3DetectionOutputParameter yolov3_detection_output_param = 200; - optional ProposalParameter proposal_param = 201; - optional FSRDetectionOutputParameter fsrdetectionoutput_param = 222; - optional SSDDetectionOutputParameter ssddetectionoutput_param = 232; - optional YoloV2DetectionOutputParameter yolov2_detection_output_param = 204; - optional QuantParameter quant_param = 208; - optional CondTakeParameter condtake_param = 233; - optional MatrixInverseParameter matrix_inverse_param = 210; - optional WarpPerspectiveParameter warp_perspective_param = 234; - optional BatchMatMulParameter batch_matmul_param = 235; - optional SpatialTransformerParameter st_param = 5000; - optional YoloV3DetectionOutputV2Parameter yolov3_detection_output_v2_param = 5001; -} - -// Message that stores parameters used to apply transformation -// to the data layer's data -message TransformationParameter { - // For data pre-processing, we can do simple scaling and subtracting the - // data mean, if provided. Note that the mean subtraction is always carried - // out before scaling. - optional float scale = 1 [default = 1]; - // Specify if we want to randomly mirror data. - optional bool mirror = 2 [default = false]; - // Specify if we would like to randomly crop an image. - optional uint32 crop_size = 3 [default = 0]; - // mean_file and mean_value cannot be specified at the same time - optional string mean_file = 4; - // if specified can be repeated once (would substract it from all the channels) - // or can be repeated the same number of times as channels - // (would subtract them from the corresponding channel) - repeated float mean_value = 5; - // Force the decoded image to have 3 color channels. - optional bool force_color = 6 [default = false]; - // Force the decoded image to have 1 color channels. - optional bool force_gray = 7 [default = false]; -} - -// Message that stores parameters shared by loss layers -message LossParameter { - // If specified, ignore instances with the given label. - optional int32 ignore_label = 1; - // How to normalize the loss for loss layers that aggregate across batches, - // spatial dimensions, or other dimensions. Currently only implemented in - // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers. - enum NormalizationMode { - // Divide by the number of examples in the batch times spatial dimensions. - // Outputs that receive the ignore label will NOT be ignored in computing - // the normalization factor. - FULL = 0; - // Divide by the total number of output locations that do not take the - // ignore_label. If ignore_label is not set, this behaves like FULL. - VALID = 1; - // Divide by the batch size. - BATCH_SIZE = 2; - // Do not normalize the loss. - NONE = 3; - } - // For historical reasons, the default normalization for - // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID. - optional NormalizationMode normalization = 3 [default = VALID]; - // Deprecated. Ignored if normalization is specified. If normalization - // is not specified, then setting this to false will be equivalent to - // normalization = BATCH_SIZE to be consistent with previous behavior. - optional bool normalize = 2; -} - -// Messages that store parameters used by individual layer types follow, in -// alphabetical order. - -message AccuracyParameter { - // When computing accuracy, count as correct by comparing the true label to - // the top k scoring classes. By default, only compare to the top scoring - // class (i.e. argmax). - optional uint32 top_k = 1 [default = 1]; - - // The "label" axis of the prediction blob, whose argmax corresponds to the - // predicted label -- may be negative to index from the end (e.g., -1 for the - // last axis). For example, if axis == 1 and the predictions are - // (N x C x H x W), the label blob is expected to contain N*H*W ground truth - // labels with integer values in {0, 1, ..., C-1}. - optional int32 axis = 2 [default = 1]; - - // If specified, ignore instances with the given label. - optional int32 ignore_label = 3; -} - -message ArgMaxParameter { - // If true produce pairs (argmax, maxval) - optional bool out_max_val = 1 [default = false]; - optional uint32 top_k = 2 [default = 1]; - // The axis along which to maximise -- may be negative to index from the - // end (e.g., -1 for the last axis). - // By default ArgMaxLayer maximizes over the flattened trailing dimensions - // for each index of the first / num dimension. - optional int32 axis = 3; -} - -message ConcatParameter { - // The axis along which to concatenate -- may be negative to index from the - // end (e.g., -1 for the last axis). Other axes must have the - // same dimension for all the bottom blobs. - // By default, ConcatLayer concatenates blobs along the "channels" axis (1). - optional int32 axis = 2 [default = 1]; - - // DEPRECATED: alias for "axis" -- does not support negative indexing. - optional uint32 concat_dim = 1 [default = 1]; -} - -message BatchNormParameter { - // If false, normalization is performed over the current mini-batch - // and global statistics are accumulated (but not yet used) by a moving - // average. - // If true, those accumulated mean and variance values are used for the - // normalization. - // By default, it is set to false when the network is in the training - // phase and true when the network is in the testing phase. - optional bool use_global_stats = 1; - // What fraction of the moving average remains each iteration? - // Smaller values make the moving average decay faster, giving more - // weight to the recent values. - // Each iteration updates the moving average @f$S_{t-1}@f$ with the - // current mean @f$ Y_t @f$ by - // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$ - // is the moving_average_fraction parameter. - optional float moving_average_fraction = 2 [default = .999]; - // Small value to add to the variance estimate so that we don't divide by - // zero. - optional float eps = 3 [default = 1e-5]; -} - -message BiasParameter { - // The first axis of bottom[0] (the first input Blob) along which to apply - // bottom[1] (the second input Blob). May be negative to index from the end - // (e.g., -1 for the last axis). - // - // For example, if bottom[0] is 4D with shape 100x3x40x60, the output - // top[0] will have the same shape, and bottom[1] may have any of the - // following shapes (for the given value of axis): - // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 - // (axis == 1 == -3) 3; 3x40; 3x40x60 - // (axis == 2 == -2) 40; 40x60 - // (axis == 3 == -1) 60 - // Furthermore, bottom[1] may have the empty shape (regardless of the value of - // "axis") -- a scalar bias. - optional int32 axis = 1 [default = 1]; - - // (num_axes is ignored unless just one bottom is given and the bias is - // a learned parameter of the layer. Otherwise, num_axes is determined by the - // number of axes by the second bottom.) - // The number of axes of the input (bottom[0]) covered by the bias - // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. - // Set num_axes := 0, to add a zero-axis Blob: a scalar. - optional int32 num_axes = 2 [default = 1]; - - // (filler is ignored unless just one bottom is given and the bias is - // a learned parameter of the layer.) - // The initialization for the learned bias parameter. - // Default is the zero (0) initialization, resulting in the BiasLayer - // initially performing the identity operation. - optional FillerParameter filler = 3; - optional bool bias_from_blob = 4 [default = true]; -} - -message ContrastiveLossParameter { - // margin for dissimilar pair - optional float margin = 1 [default = 1.0]; - // The first implementation of this cost did not exactly match the cost of - // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2. - // legacy_version = false (the default) uses (margin - d)^2 as proposed in the - // Hadsell paper. New models should probably use this version. - // legacy_version = true uses (margin - d^2). This is kept to support / - // reproduce existing models and results - optional bool legacy_version = 2 [default = false]; -} - -message ConvolutionParameter { - optional uint32 num_output = 1; // The number of outputs for the layer - optional bool bias_term = 2 [default = true]; // whether to have bias terms - - // Pad, kernel size, and stride are all given as a single value for equal - // dimensions in all spatial dimensions, or once per spatial dimension. - repeated uint32 pad = 3; // The padding size; defaults to 0 - repeated uint32 kernel_size = 4; // The kernel size - repeated uint32 stride = 6; // The stride; defaults to 1 - // Factor used to dilate the kernel, (implicitly) zero-filling the resulting - // holes. (Kernel dilation is sometimes referred to by its use in the - // algorithme à trous from Holschneider et al. 1987.) - repeated uint32 dilation = 18; // The dilation; defaults to 1 - - // For 2D convolution only, the *_h and *_w versions may also be used to - // specify both spatial dimensions. - optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only) - optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only) - optional uint32 kernel_h = 11; // The kernel height (2D only) - optional uint32 kernel_w = 12; // The kernel width (2D only) - optional uint32 stride_h = 13; // The stride height (2D only) - optional uint32 stride_w = 14; // The stride width (2D only) - - optional uint32 group = 5 [default = 1]; // The group size for group conv - - optional FillerParameter weight_filler = 7; // The filler for the weight - optional FillerParameter bias_filler = 8; // The filler for the bias - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 15 [default = DEFAULT]; - - // The axis to interpret as "channels" when performing convolution. - // Preceding dimensions are treated as independent inputs; - // succeeding dimensions are treated as "spatial". - // With (N, C, H, W) inputs, and axis == 1 (the default), we perform - // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for - // groups g>1) filters across the spatial axes (H, W) of the input. - // With (N, C, D, H, W) inputs, and axis == 1, we perform - // N independent 3D convolutions, sliding (C/g)-channels - // filters across the spatial axes (D, H, W) of the input. - optional int32 axis = 16 [default = 1]; - - // Whether to force use of the general ND convolution, even if a specific - // implementation for blobs of the appropriate number of spatial dimensions - // is available. (Currently, there is only a 2D-specific convolution - // implementation; for input blobs with num_axes != 2, this option is - // ignored and the ND implementation will be used.) - optional bool force_nd_im2col = 17 [default = false]; -} - -message CropParameter { - // To crop, elements of the first bottom are selected to fit the dimensions - // of the second, reference bottom. The crop is configured by - // - the crop `axis` to pick the dimensions for cropping - // - the crop `offset` to set the shift for all/each dimension - // to align the cropped bottom with the reference bottom. - // All dimensions up to but excluding `axis` are preserved, while - // the dimensions including and trailing `axis` are cropped. - // If only one `offset` is set, then all dimensions are offset by this amount. - // Otherwise, the number of offsets must equal the number of cropped axes to - // shift the crop in each dimension accordingly. - // Note: standard dimensions are N,C,H,W so the default is a spatial crop, - // and `axis` may be negative to index from the end (e.g., -1 for the last - // axis). - optional int32 axis = 1 [default = 2]; - repeated uint32 offset = 2; -} - -message DataParameter { - enum DB { - LEVELDB = 0; - LMDB = 1; - } - // Specify the data source. - optional string source = 1; - // Specify the batch size. - optional uint32 batch_size = 4; - // The rand_skip variable is for the data layer to skip a few data points - // to avoid all asynchronous sgd clients to start at the same point. The skip - // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the database. - // DEPRECATED. Each solver accesses a different subset of the database. - optional uint32 rand_skip = 7 [default = 0]; - optional DB backend = 8 [default = LEVELDB]; - // DEPRECATED. See TransformationParameter. For data pre-processing, we can do - // simple scaling and subtracting the data mean, if provided. Note that the - // mean subtraction is always carried out before scaling. - optional float scale = 2 [default = 1]; - optional string mean_file = 3; - // DEPRECATED. See TransformationParameter. Specify if we would like to randomly - // crop an image. - optional uint32 crop_size = 5 [default = 0]; - // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror - // data. - optional bool mirror = 6 [default = false]; - // Force the encoded image to have 3 color channels - optional bool force_encoded_color = 9 [default = false]; - // Prefetch queue (Increase if data feeding bandwidth varies, within the - // limit of device memory for GPU training) - optional uint32 prefetch = 10 [default = 4]; -} - -message DropoutParameter { - optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio - optional bool scale_train = 2 [default = true]; // scale train or test phase -} - -// DummyDataLayer fills any number of arbitrarily shaped blobs with random -// (or constant) data generated by "Fillers" (see "message FillerParameter"). -message DummyDataParameter { - // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N - // shape fields, and 0, 1 or N data_fillers. - // - // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. - // If 1 data_filler is specified, it is applied to all top blobs. If N are - // specified, the ith is applied to the ith top blob. - repeated FillerParameter data_filler = 1; - repeated BlobShape shape = 6; - - // 4D dimensions -- deprecated. Use "shape" instead. - repeated uint32 num = 2; - repeated uint32 channels = 3; - repeated uint32 height = 4; - repeated uint32 width = 5; -} - -message EltwiseParameter { - enum EltwiseOp { - PROD = 0; - SUM = 1; - MAX = 2; - } - optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation - repeated float coeff = 2; // blob-wise coefficient for SUM operation - - // Whether to use an asymptotically slower (for >2 inputs) but stabler method - // of computing the gradient for the PROD operation. (No effect for SUM op.) - optional bool stable_prod_grad = 3 [default = true]; -} - -// Message that stores parameters used by ELULayer -message ELUParameter { - // Described in: - // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate - // Deep Network Learning by Exponential Linear Units (ELUs). arXiv - optional float alpha = 1 [default = 1]; -} - -// Message that stores parameters used by EmbedLayer -message EmbedParameter { - optional uint32 num_output = 1; // The number of outputs for the layer - // The input is given as integers to be interpreted as one-hot - // vector indices with dimension num_input. Hence num_input should be - // 1 greater than the maximum possible input value. - optional uint32 input_dim = 2; - - optional bool bias_term = 3 [default = true]; // Whether to use a bias term - optional FillerParameter weight_filler = 4; // The filler for the weight - optional FillerParameter bias_filler = 5; // The filler for the bias - -} - -// Message that stores parameters used by ExpLayer -message ExpParameter { - // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. - // Or if base is set to the default (-1), base is set to e, - // so y = exp(shift + scale * x). - optional float base = 1 [default = -1.0]; - optional float scale = 2 [default = 1.0]; - optional float shift = 3 [default = 0.0]; -} - -/// Message that stores parameters used by FlattenLayer -message FlattenParameter { - // The first axis to flatten: all preceding axes are retained in the output. - // May be negative to index from the end (e.g., -1 for the last axis). - optional int32 axis = 1 [default = 1]; - - // The last axis to flatten: all following axes are retained in the output. - // May be negative to index from the end (e.g., the default -1 for the last - // axis). - optional int32 end_axis = 2 [default = -1]; -} - -// Message that stores parameters used by HDF5DataLayer -message HDF5DataParameter { - // Specify the data source. - optional string source = 1; - // Specify the batch size. - optional uint32 batch_size = 2; - - // Specify whether to shuffle the data. - // If shuffle == true, the ordering of the HDF5 files is shuffled, - // and the ordering of data within any given HDF5 file is shuffled, - // but data between different files are not interleaved; all of a file's - // data are output (in a random order) before moving onto another file. - optional bool shuffle = 3 [default = false]; -} - -message HDF5OutputParameter { - optional string file_name = 1; -} - -message HingeLossParameter { - enum Norm { - L1 = 1; - L2 = 2; - } - // Specify the Norm to use L1 or L2 - optional Norm norm = 1 [default = L1]; -} - -message ImageDataParameter { - // Specify the data source. - optional string source = 1; - // Specify the batch size. - optional uint32 batch_size = 4 [default = 1]; - // The rand_skip variable is for the data layer to skip a few data points - // to avoid all asynchronous sgd clients to start at the same point. The skip - // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the database. - optional uint32 rand_skip = 7 [default = 0]; - // Whether or not ImageLayer should shuffle the list of files at every epoch. - optional bool shuffle = 8 [default = false]; - // It will also resize images if new_height or new_width are not zero. - optional uint32 new_height = 9 [default = 0]; - optional uint32 new_width = 10 [default = 0]; - // Specify if the images are color or gray - optional bool is_color = 11 [default = true]; - // DEPRECATED. See TransformationParameter. For data pre-processing, we can do - // simple scaling and subtracting the data mean, if provided. Note that the - // mean subtraction is always carried out before scaling. - optional float scale = 2 [default = 1]; - optional string mean_file = 3; - // DEPRECATED. See TransformationParameter. Specify if we would like to randomly - // crop an image. - optional uint32 crop_size = 5 [default = 0]; - // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror - // data. - optional bool mirror = 6 [default = false]; - optional string root_folder = 12 [default = ""]; -} - -message InfogainLossParameter { - // Specify the infogain matrix source. - optional string source = 1; - optional int32 axis = 2 [default = 1]; // axis of prob -} - -message InnerProductParameter { - optional uint32 num_output = 1; // The number of outputs for the layer - optional bool bias_term = 2 [default = true]; // whether to have bias terms - optional FillerParameter weight_filler = 3; // The filler for the weight - optional FillerParameter bias_filler = 4; // The filler for the bias - - // The first axis to be lumped into a single inner product computation; - // all preceding axes are retained in the output. - // May be negative to index from the end (e.g., -1 for the last axis). - optional int32 axis = 5 [default = 1]; - // Specify whether to transpose the weight matrix or not. - // If transpose == true, any operations will be performed on the transpose - // of the weight matrix. The weight matrix itself is not going to be transposed - // but rather the transfer flag of operations will be toggled accordingly. - optional bool transpose = 6 [default = false]; -} - -message InputParameter { - // This layer produces N >= 1 top blob(s) to be assigned manually. - // Define N shapes to set a shape for each top. - // Define 1 shape to set the same shape for every top. - // Define no shape to defer to reshaping manually. - repeated BlobShape shape = 1; -} - -// Message that stores parameters used by LogLayer -message LogParameter { - // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. - // Or if base is set to the default (-1), base is set to e, - // so y = ln(shift + scale * x) = log_e(shift + scale * x) - optional float base = 1 [default = -1.0]; - optional float scale = 2 [default = 1.0]; - optional float shift = 3 [default = 0.0]; -} - -// Message that stores parameters used by LRNLayer -message LRNParameter { - optional uint32 local_size = 1 [default = 5]; - optional float alpha = 2 [default = 1.]; - optional float beta = 3 [default = 0.75]; - enum NormRegion { - ACROSS_CHANNELS = 0; - WITHIN_CHANNEL = 1; - } - optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; - optional float k = 5 [default = 1.]; - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 6 [default = DEFAULT]; -} - -message MemoryDataParameter { - optional uint32 batch_size = 1; - optional uint32 channels = 2; - optional uint32 height = 3; - optional uint32 width = 4; -} - -message MVNParameter { - // This parameter can be set to false to normalize mean only - optional bool normalize_variance = 1 [default = true]; - - // This parameter can be set to true to perform DNN-like MVN - optional bool across_channels = 2 [default = false]; - - // Epsilon for not dividing by zero while normalizing variance - optional float eps = 3 [default = 1e-9]; -} - -message ParameterParameter { - optional BlobShape shape = 1; -} - -message PoolingParameter { - enum PoolMethod { - MAX = 0; - AVE = 1; - STOCHASTIC = 2; - } - optional PoolMethod pool = 1 [default = MAX]; // The pooling method - // Pad, kernel size, and stride are all given as a single value for equal - // dimensions in height and width or as Y, X pairs. - optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) - optional uint32 pad_h = 9 [default = 0]; // The padding height - optional uint32 pad_w = 10 [default = 0]; // The padding width - optional uint32 kernel_size = 2; // The kernel size (square) - optional uint32 kernel_h = 5; // The kernel height - optional uint32 kernel_w = 6; // The kernel width - optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) - optional uint32 stride_h = 7; // The stride height - optional uint32 stride_w = 8; // The stride width - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 11 [default = DEFAULT]; - // If global_pooling then it will pool over the size of the bottom by doing - // kernel_h = bottom->height and kernel_w = bottom->width - optional bool global_pooling = 12 [default = false]; - optional bool ceil_mode = 13 [default = true]; - // How to calculate the output size - using ceil (default) or floor rounding. - enum RoundMode { - CEIL = 0; - FLOOR = 1; - } - optional RoundMode round_mode = 14 [default = CEIL]; -} - -message PowerParameter { - // PowerLayer computes outputs y = (shift + scale * x) ^ power. - optional float power = 1 [default = 1.0]; - optional float scale = 2 [default = 1.0]; - optional float shift = 3 [default = 0.0]; -} - -message PythonParameter { - optional string module = 1; - optional string layer = 2; - // This value is set to the attribute `param_str` of the `PythonLayer` object - // in Python before calling the `setup()` method. This could be a number, - // string, dictionary in Python dict format, JSON, etc. You may parse this - // string in `setup` method and use it in `forward` and `backward`. - optional string param_str = 3 [default = '']; - // Whether this PythonLayer is shared among worker solvers during data parallelism. - // If true, each worker solver sequentially run forward from this layer. - // This value should be set true if you are using it as a data layer. - optional bool share_in_parallel = 4 [default = false]; -} - -// Message that stores parameters used by RecurrentLayer -message RecurrentParameter { - // The dimension of the output (and usually hidden state) representation -- - // must be explicitly set to non-zero. - optional uint32 num_output = 1 [default = 0]; - - optional FillerParameter weight_filler = 2; // The filler for the weight - optional FillerParameter bias_filler = 3; // The filler for the bias - - // Whether to enable displaying debug_info in the unrolled recurrent net. - optional bool debug_info = 4 [default = false]; - - // Whether to add as additional inputs (bottoms) the initial hidden state - // blobs, and add as additional outputs (tops) the final timestep hidden state - // blobs. The number of additional bottom/top blobs required depends on the - // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs. - optional bool expose_hidden = 5 [default = false]; -} - -// Message that stores parameters used by ReductionLayer -message ReductionParameter { - enum ReductionOp { - SUM = 1; - ASUM = 2; - SUMSQ = 3; - MEAN = 4; - } - - optional ReductionOp operation = 1 [default = SUM]; // reduction operation - - // The first axis to reduce to a scalar -- may be negative to index from the - // end (e.g., -1 for the last axis). - // (Currently, only reduction along ALL "tail" axes is supported; reduction - // of axis M through N, where N < num_axes - 1, is unsupported.) - // Suppose we have an n-axis bottom Blob with shape: - // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)). - // If axis == m, the output Blob will have shape - // (d0, d1, d2, ..., d(m-1)), - // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1)) - // times, each including (dm * d(m+1) * ... * d(n-1)) individual data. - // If axis == 0 (the default), the output Blob always has the empty shape - // (count 1), performing reduction across the entire input -- - // often useful for creating new loss functions. - optional int32 axis = 2 [default = 0]; - - optional float coeff = 3 [default = 1.0]; // coefficient for output -} - -// Message that stores parameters used by ReLULayer -message ReLUParameter { - // Allow non-zero slope for negative inputs to speed up optimization - // Described in: - // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities - // improve neural network acoustic models. In ICML Workshop on Deep Learning - // for Audio, Speech, and Language Processing. - optional float negative_slope = 1 [default = 0]; - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 2 [default = DEFAULT]; -} - -message ReshapeParameter { - // Specify the output dimensions. If some of the dimensions are set to 0, - // the corresponding dimension from the bottom layer is used (unchanged). - // Exactly one dimension may be set to -1, in which case its value is - // inferred from the count of the bottom blob and the remaining dimensions. - // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8: - // - // layer { - // type: "Reshape" bottom: "input" top: "output" - // reshape_param { ... } - // } - // - // If "input" is 2D with shape 2 x 8, then the following reshape_param - // specifications are all equivalent, producing a 3D blob "output" with shape - // 2 x 2 x 4: - // - // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } - // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } - // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } - // reshape_param { shape { dim: 0 dim:-1 dim: 4 } } - // - optional BlobShape shape = 1; - - // axis and num_axes control the portion of the bottom blob's shape that are - // replaced by (included in) the reshape. By default (axis == 0 and - // num_axes == -1), the entire bottom blob shape is included in the reshape, - // and hence the shape field must specify the entire output shape. - // - // axis may be non-zero to retain some portion of the beginning of the input - // shape (and may be negative to index from the end; e.g., -1 to begin the - // reshape after the last axis, including nothing in the reshape, - // -2 to include only the last axis, etc.). - // - // For example, suppose "input" is a 2D blob with shape 2 x 8. - // Then the following ReshapeLayer specifications are all equivalent, - // producing a blob "output" with shape 2 x 2 x 4: - // - // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } - // reshape_param { shape { dim: 2 dim: 4 } axis: 1 } - // reshape_param { shape { dim: 2 dim: 4 } axis: -3 } - // - // num_axes specifies the extent of the reshape. - // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on - // input axes in the range [axis, axis+num_axes]. - // num_axes may also be -1, the default, to include all remaining axes - // (starting from axis). - // - // For example, suppose "input" is a 2D blob with shape 2 x 8. - // Then the following ReshapeLayer specifications are equivalent, - // producing a blob "output" with shape 1 x 2 x 8. - // - // reshape_param { shape { dim: 1 dim: 2 dim: 8 } } - // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 } - // reshape_param { shape { dim: 1 } num_axes: 0 } - // - // On the other hand, these would produce output blob shape 2 x 1 x 8: - // - // reshape_param { shape { dim: 2 dim: 1 dim: 8 } } - // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 } - // - optional int32 axis = 2 [default = 0]; - optional int32 num_axes = 3 [default = -1]; -} - - -message ScaleParameter { - // The first axis of bottom[0] (the first input Blob) along which to apply - // bottom[1] (the second input Blob). May be negative to index from the end - // (e.g., -1 for the last axis). - // - // For example, if bottom[0] is 4D with shape 100x3x40x60, the output - // top[0] will have the same shape, and bottom[1] may have any of the - // following shapes (for the given value of axis): - // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 - // (axis == 1 == -3) 3; 3x40; 3x40x60 - // (axis == 2 == -2) 40; 40x60 - // (axis == 3 == -1) 60 - // Furthermore, bottom[1] may have the empty shape (regardless of the value of - // "axis") -- a scalar multiplier. - optional int32 axis = 1 [default = 1]; - - // (num_axes is ignored unless just one bottom is given and the scale is - // a learned parameter of the layer. Otherwise, num_axes is determined by the - // number of axes by the second bottom.) - // The number of axes of the input (bottom[0]) covered by the scale - // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. - // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar. - optional int32 num_axes = 2 [default = 1]; - - // (filler is ignored unless just one bottom is given and the scale is - // a learned parameter of the layer.) - // The initialization for the learned scale parameter. - // Default is the unit (1) initialization, resulting in the ScaleLayer - // initially performing the identity operation. - optional FillerParameter filler = 3; - - // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but - // may be more efficient). Initialized with bias_filler (defaults to 0). - optional bool bias_term = 4 [default = false]; - optional FillerParameter bias_filler = 5; - optional bool scale_from_blob = 6 [default = true]; -} - -message SigmoidParameter { - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 1 [default = DEFAULT]; -} - -message SliceParameter { - // The axis along which to slice -- may be negative to index from the end - // (e.g., -1 for the last axis). - // By default, SliceLayer concatenates blobs along the "channels" axis (1). - optional int32 axis = 3 [default = 1]; - repeated uint32 slice_point = 2; - - // DEPRECATED: alias for "axis" -- does not support negative indexing. - optional uint32 slice_dim = 1 [default = 1]; -} - -message SmoothL1LossParameter { - // SmoothL1Loss(x) = - // 0.5 * (sigma * x) ** 2 -- if x < 1.0 / sigma / sigma - // |x| - 0.5 / sigma / sigma -- otherwise - optional float sigma = 1 [default = 1]; -} - -// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer -message SoftmaxParameter { - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 1 [default = DEFAULT]; - - // The axis along which to perform the softmax -- may be negative to index - // from the end (e.g., -1 for the last axis). - // Any other axes will be evaluated as independent softmaxes. - optional int32 axis = 2 [default = 1]; -} - -message TanHParameter { - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 1 [default = DEFAULT]; -} - -// Message that stores parameters used by TileLayer -message TileParameter { - // The index of the axis to tile. - optional int32 axis = 1 [default = 1]; - - // The number of copies (tiles) of the blob to output. - optional int32 tiles = 2; -} - -// Message that stores parameters used by ThresholdLayer -message ThresholdParameter { - optional float threshold = 1 [default = 0]; // Strictly positive values -} - -message WindowDataParameter { - // Specify the data source. - optional string source = 1; - // For data pre-processing, we can do simple scaling and subtracting the - // data mean, if provided. Note that the mean subtraction is always carried - // out before scaling. - optional float scale = 2 [default = 1]; - optional string mean_file = 3; - // Specify the batch size. - optional uint32 batch_size = 4; - // Specify if we would like to randomly crop an image. - optional uint32 crop_size = 5 [default = 0]; - // Specify if we want to randomly mirror data. - optional bool mirror = 6 [default = false]; - // Foreground (object) overlap threshold - optional float fg_threshold = 7 [default = 0.5]; - // Background (non-object) overlap threshold - optional float bg_threshold = 8 [default = 0.5]; - // Fraction of batch that should be foreground objects - optional float fg_fraction = 9 [default = 0.25]; - // Amount of contextual padding to add around a window - // (used only by the window_data_layer) - optional uint32 context_pad = 10 [default = 0]; - // Mode for cropping out a detection window - // warp: cropped window is warped to a fixed size and aspect ratio - // square: the tightest square around the window is cropped - optional string crop_mode = 11 [default = "warp"]; - // cache_images: will load all images in memory for faster access - optional bool cache_images = 12 [default = false]; - // append root_folder to locate images - optional string root_folder = 13 [default = ""]; -} - -message SPPParameter { - enum PoolMethod { - MAX = 0; - AVE = 1; - STOCHASTIC = 2; - } - optional uint32 pyramid_height = 1; - optional PoolMethod pool = 2 [default = MAX]; // The pooling method - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 6 [default = DEFAULT]; -} - -// DEPRECATED: use LayerParameter. -message V1LayerParameter { - repeated string bottom = 2; - repeated string top = 3; - optional string name = 4; - repeated NetStateRule include = 32; - repeated NetStateRule exclude = 33; - enum LayerType { - NONE = 0; - ABSVAL = 35; - ACCURACY = 1; - ARGMAX = 30; - BNLL = 2; - CONCAT = 3; - CONTRASTIVE_LOSS = 37; - CONVOLUTION = 4; - DATA = 5; - DECONVOLUTION = 39; - DROPOUT = 6; - DUMMY_DATA = 32; - EUCLIDEAN_LOSS = 7; - ELTWISE = 25; - EXP = 38; - FLATTEN = 8; - HDF5_DATA = 9; - HDF5_OUTPUT = 10; - HINGE_LOSS = 28; - IM2COL = 11; - IMAGE_DATA = 12; - INFOGAIN_LOSS = 13; - INNER_PRODUCT = 14; - LRN = 15; - MEMORY_DATA = 29; - MULTINOMIAL_LOGISTIC_LOSS = 16; - MVN = 34; - POOLING = 17; - POWER = 26; - RELU = 18; - SIGMOID = 19; - SIGMOID_CROSS_ENTROPY_LOSS = 27; - SILENCE = 36; - SOFTMAX = 20; - SOFTMAX_LOSS = 21; - SPLIT = 22; - SLICE = 33; - TANH = 23; - WINDOW_DATA = 24; - THRESHOLD = 31; - QUANT = 208; - DEQUANT = 209; - } - optional LayerType type = 5; - repeated BlobProto blobs = 6; - repeated string param = 1001; - repeated DimCheckMode blob_share_mode = 1002; - enum DimCheckMode { - STRICT = 0; - PERMISSIVE = 1; - } - repeated float blobs_lr = 7; - repeated float weight_decay = 8; - repeated float loss_weight = 35; - optional AccuracyParameter accuracy_param = 27; - optional ArgMaxParameter argmax_param = 23; - optional ConcatParameter concat_param = 9; - optional ContrastiveLossParameter contrastive_loss_param = 40; - optional ConvolutionParameter convolution_param = 10; - optional DataParameter data_param = 11; - optional DropoutParameter dropout_param = 12; - optional DummyDataParameter dummy_data_param = 26; - optional EltwiseParameter eltwise_param = 24; - optional ExpParameter exp_param = 41; - optional HDF5DataParameter hdf5_data_param = 13; - optional HDF5OutputParameter hdf5_output_param = 14; - optional HingeLossParameter hinge_loss_param = 29; - optional ImageDataParameter image_data_param = 15; - optional InfogainLossParameter infogain_loss_param = 16; - optional InnerProductParameter inner_product_param = 17; - optional LRNParameter lrn_param = 18; - optional MemoryDataParameter memory_data_param = 22; - optional MVNParameter mvn_param = 34; - optional PoolingParameter pooling_param = 19; - optional PowerParameter power_param = 21; - optional ReLUParameter relu_param = 30; - optional SigmoidParameter sigmoid_param = 38; - optional SoftmaxParameter softmax_param = 39; - optional SliceParameter slice_param = 31; - optional TanHParameter tanh_param = 37; - optional ThresholdParameter threshold_param = 25; - optional WindowDataParameter window_data_param = 20; - optional TransformationParameter transform_param = 36; - optional LossParameter loss_param = 42; - optional V0LayerParameter layer = 1; -} - -// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters -// in Caffe. We keep this message type around for legacy support. -message V0LayerParameter { - optional string name = 1; // the layer name - optional string type = 2; // the string to specify the layer type - - // Parameters to specify layers with inner products. - optional uint32 num_output = 3; // The number of outputs for the layer - optional bool biasterm = 4 [default = true]; // whether to have bias terms - optional FillerParameter weight_filler = 5; // The filler for the weight - optional FillerParameter bias_filler = 6; // The filler for the bias - - optional uint32 pad = 7 [default = 0]; // The padding size - optional uint32 kernelsize = 8; // The kernel size - optional uint32 group = 9 [default = 1]; // The group size for group conv - optional uint32 stride = 10 [default = 1]; // The stride - enum PoolMethod { - MAX = 0; - AVE = 1; - STOCHASTIC = 2; - } - optional PoolMethod pool = 11 [default = MAX]; // The pooling method - optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio - - optional uint32 local_size = 13 [default = 5]; // for local response norm - optional float alpha = 14 [default = 1.]; // for local response norm - optional float beta = 15 [default = 0.75]; // for local response norm - optional float k = 22 [default = 1.]; - - // For data layers, specify the data source - optional string source = 16; - // For data pre-processing, we can do simple scaling and subtracting the - // data mean, if provided. Note that the mean subtraction is always carried - // out before scaling. - optional float scale = 17 [default = 1]; - optional string meanfile = 18; - // For data layers, specify the batch size. - optional uint32 batchsize = 19; - // For data layers, specify if we would like to randomly crop an image. - optional uint32 cropsize = 20 [default = 0]; - // For data layers, specify if we want to randomly mirror data. - optional bool mirror = 21 [default = false]; - - // The blobs containing the numeric parameters of the layer - repeated BlobProto blobs = 50; - // The ratio that is multiplied on the global learning rate. If you want to - // set the learning ratio for one blob, you need to set it for all blobs. - repeated float blobs_lr = 51; - // The weight decay that is multiplied on the global weight decay. - repeated float weight_decay = 52; - - // The rand_skip variable is for the data layer to skip a few data points - // to avoid all asynchronous sgd clients to start at the same point. The skip - // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the database. - optional uint32 rand_skip = 53 [default = 0]; - - // Fields related to detection (det_*) - // foreground (object) overlap threshold - optional float det_fg_threshold = 54 [default = 0.5]; - // background (non-object) overlap threshold - optional float det_bg_threshold = 55 [default = 0.5]; - // Fraction of batch that should be foreground objects - optional float det_fg_fraction = 56 [default = 0.25]; - - // optional bool OBSOLETE_can_clobber = 57 [default = true]; - - // Amount of contextual padding to add around a window - // (used only by the window_data_layer) - optional uint32 det_context_pad = 58 [default = 0]; - - // Mode for cropping out a detection window - // warp: cropped window is warped to a fixed size and aspect ratio - // square: the tightest square around the window is cropped - optional string det_crop_mode = 59 [default = "warp"]; - - // For ReshapeLayer, one needs to specify the new dimensions. - optional int32 new_num = 60 [default = 0]; - optional int32 new_channels = 61 [default = 0]; - optional int32 new_height = 62 [default = 0]; - optional int32 new_width = 63 [default = 0]; - - // Whether or not ImageLayer should shuffle the list of files at every epoch. - // It will also resize images if new_height or new_width are not zero. - optional bool shuffle_images = 64 [default = false]; - - // For ConcatLayer, one needs to specify the dimension for concatenation, and - // the other dimensions must be the same for all the bottom blobs. - // By default it will concatenate blobs along the channels dimension. - optional uint32 concat_dim = 65 [default = 1]; - - optional HDF5OutputParameter hdf5_output_param = 1001; -} - -message PReLUParameter { - // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: - // Surpassing Human-Level Performance on ImageNet Classification, 2015. - - // Initial value of a_i. Default is a_i=0.25 for all i. - optional FillerParameter filler = 1; - // Whether or not slope parameters are shared across channels. - optional bool channel_shared = 2 [default = false]; -} - -// Message that stores parameters used by DetectionOutputLayer -//message DetectionOutputParameter { -// optional int32 num_classes = 1 [default = 21]; -// optional float nms_threshold = 2 [default = 0.3]; -// optional int32 top_k = 3; -// optional float confidence_threshold = 4 [default = 0.8]; -//} - -// Message that store parameters used by PriorBoxLayer -message PriorBoxParameter { - // Encode/decode type. - enum CodeType { - CORNER = 1; - CENTER_SIZE = 2; - CORNER_SIZE = 3; - } - // Minimum box size (in pixels). Required! - repeated float min_size = 1; - // Maximum box size (in pixels). Required! - repeated float max_size = 2; - // Various of aspect ratios. Duplicate ratios will be ignored. - // If none is provided, we use default ratio 1. - repeated float aspect_ratio = 3; - // If true, will flip each aspect ratio. - // For example, if there is aspect ratio "r", - // we will generate aspect ratio "1.0/r" as well. - optional bool flip = 4 [default = true]; - // If true, will clip the prior so that it is within [0, 1] - optional bool clip = 5 [default = false]; - // Variance for adjusting the prior bboxes. - repeated float variance = 6; - // By default, we calculate img_height, img_width, step_x, step_y based on - // bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely - // provided. - // Explicitly provide the img_size. - optional uint32 img_size = 7; - // Either img_size or img_h/img_w should be specified; not both. - optional uint32 img_h = 8; - optional uint32 img_w = 9; - - // Explicitly provide the step size. - optional float step = 10; - // Either step or step_h/step_w should be specified; not both. - optional float step_h = 11; - optional float step_w = 12; - - // Offset to the top left corner of each cell. - optional float offset = 13 [default = 0.5]; -} - -// Message that stores parameters used by PermutetLayer -message PermuteParameter { - // The new orders of the axes of data. Notice it should be with - // in the same range as the input data, and it starts from 0. - // Do not provide repeated order. - repeated uint32 order = 1; -} - -message NormalizeParameter { - optional bool across_spatial = 1 [default = true]; - // Initial value of scale. Default is 1.0 for all - optional FillerParameter scale_filler = 2; - // Whether or not scale parameters are shared across channels. - optional bool channel_shared = 3 [default = true]; - // Epsilon for not dividing by zero while normalizing variance - optional float eps = 4 [default = 1e-10]; -} - -// needed by ssd -message SaveOutputParameter { - // Output directory. If not empty, we will save the results. - optional string output_directory = 1; - // Output name prefix. - optional string output_name_prefix = 2; - // Output format. - // VOC - PASCAL VOC output format. - // COCO - MS COCO output format. - optional string output_format = 3; - // If you want to output results, must also provide the following two files. - // Otherwise, we will ignore saving results. - // label map file. - optional string label_map_file = 4; - // A file which contains a list of names and sizes with same order - // of the input DB. The file is in the following format: - // name height width - // ... - optional string name_size_file = 5; - // Number of test images. It can be less than the lines specified in - // name_size_file. For example, when we only want to evaluate on part - // of the test images. - optional uint32 num_test_image = 6; - // The resize parameter used in saving the data. - // optional ResizeParameter resize_param = 7; -} - -message NonMaximumSuppressionParameter { - // Threshold to be used in nms. - optional float nms_threshold = 1 [default = 0.3]; - // Maximum number of results to be kept. - optional int32 top_k = 2; - // Parameter for adaptive nms. - optional float eta = 3 [default = 1.0]; -} - -message GeneralNmsParameter { - optional int32 post_top_k = 1 ; - optional float nms_threshold = 2 [default = 0]; - optional float iou_threshold_decay = 3 [default = 1.0]; - optional float coor_scale_factor = 4 [default = 1.0]; -} - -// Message that store parameters used by DetectionOutputLayer, ssd/fasterRcnn -message DetectionOutputParameter { - optional int32 num_classes = 1; - optional bool share_location = 2 [default = true]; - optional int32 background_label_id = 3 [default = 0]; - optional NonMaximumSuppressionParameter nms_param = 4; - optional SaveOutputParameter save_output_param = 5; - optional PriorBoxParameter.CodeType code_type = 6 [default = CENTER_SIZE]; - optional bool variance_encoded_in_target = 8 [default = true]; - optional int32 keep_top_k = 7; - optional float confidence_threshold = 9; - optional float nms_threshold = 13; - optional int32 top_k = 14; - optional int32 boxes = 15 [default = 1]; - optional bool relative = 17 [default = true]; - optional float objectness_threshold = 18 [default = 0.5]; - optional float class_threshold = 19 [default = 0.5]; - repeated float biases = 20; - optional GeneralNmsParameter general_nms_param = 21; - optional float objectness_score = 22; -} -message PSROIPoolingParameter { - required float spatial_scale = 1; - required int32 output_dim = 2; // output channel number - required int32 group_size = 3; // number of groups to encode position-sensitive score maps -} -// Message that stores parameters used by FreespaceExtractLayer -message FreespaceExtractParameter { - optional float org_height = 1; -} - -// Message that stores parameters used by DetectpostprocessLayer -message PostprocessParameter { - optional float nms_thresh = 1 [default = 0.3]; - optional float conf_thresh = 2 [default = 0.5]; - optional uint32 post_nms_topn = 3 [default = 100]; - optional uint32 cls_num = 4 [default = 12]; - repeated float bbox_reg_weights = 5; -} - -// Message that stores parameters used by SpatialTransformLayer -message SpatialTransformParameter { - optional uint32 output_h = 1 [default = 0]; - optional uint32 output_w = 2 [default = 0]; - optional float border_value = 3 [default = 0]; - repeated float affine_transform = 4; - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 15 [default = DEFAULT]; -} -message ROIAlignParameter { - // Pad, kernel size, and stride are all given as a single value for equal - // dimensions in height and width or as Y, X pairs. - optional uint32 pooled_h = 1 [default = 0]; // The pooled output height - optional uint32 pooled_w = 2 [default = 0]; // The pooled output width - // Multiplicative spatial scale factor to translate ROI coords from their - // input scale to the scale used when pooling - optional float spatial_scale = 3 [default = 1]; - optional int32 sampling_ratio = 4 [default = -1]; - optional int32 roi_end_mode = 5 [default = 0]; -} - -message RegionParameter { - optional uint32 classes = 1 [default = 20]; // Category of classification - optional uint32 coords = 2 [default = 4]; // Coordinates of box - optional uint32 boxes = 3 [default = 1]; // Number of boxes predicted per grid - optional uint32 softmax = 4 [default = 0]; - optional string softmax_tree = 5 [default = ""]; - optional uint32 background = 6 [default = 0]; -} -message ReorgParameter{ - optional uint32 stride = 2 [default = 2]; - optional bool reverse = 1 [default = false]; -} -message ReverseParameter{ - repeated int32 axis = 1; -} -message InterpParameter{ - optional int32 height = 1 [default = 0];//Height of output - optional int32 width = 2 [default = 0];//Width of output - optional int32 zoom_factor = 3 [default = 1];//zoom factor - optional int32 shrink_factor = 4 [default = 1];//shrink factor - optional int32 pad_beg = 5 [default = 0];//padding at begin of input - optional int32 pad_end = 6 [default = 0];//padding at end of input -} -message ShuffleChannelParameter{ - optional uint32 group = 1[default = 1]; // The number of group -} -message UpsampleParameter{ - optional float scale = 1[default = 1]; - optional int32 stride = 2[default = 2]; - optional int32 stride_h = 3[default = 2]; - optional int32 stride_w = 4[default=2]; -} -message ROIPoolingParameter { - required int32 pooled_h = 1; - required int32 pooled_w = 2; - optional float spatial_scale = 3 [default=0.0625]; - optional float spatial_scale_h = 4; - optional float spatial_scale_w = 5; -} - -message YoloParameter { - optional int32 boxes = 1 [default = 3]; - optional int32 coords = 2 [default = 4]; - optional int32 classes = 3 [default = 80]; - optional string yolo_version = 4 [default = "V3"]; - optional bool softmax = 5 [default = false]; - optional bool background = 6 [default = false]; - optional bool softmaxtree = 7 [default = false]; -} - -message YoloV3DetectionOutputParameter { - optional int32 boxes = 1 [default = 3]; - optional int32 classes = 2 [default = 80]; - optional bool relative = 3 [default = true]; - optional float obj_threshold = 4 [default = 0.5]; - optional float score_threshold = 5 [default = 0.5]; - optional float iou_threshold = 6 [default = 0.45]; - optional int32 pre_nms_topn = 7 [default = 512]; - optional int32 post_nms_topn = 8 [default = 1024]; - repeated float biases_high = 9; - repeated float biases_mid = 10; - repeated float biases_low = 11; - optional int32 coords = 12 [default = 4]; - repeated float biases = 13; - optional bool resize_origin_img_to_net = 14 [default = false]; -} - -message YoloV3DetectionOutputV2Parameter { - optional int32 boxes = 1 [default = 3]; - optional int32 classes = 2 [default = 80]; - optional bool relative = 3 [default = true]; - optional float obj_threshold = 4 [default = 0.5]; - optional float score_threshold = 5 [default = 0.5]; - optional float iou_threshold = 6 [default = 0.45]; - optional int32 pre_nms_topn = 7 [default = 512]; - optional int32 post_nms_topn = 8 [default = 1024]; - repeated float biases_high = 9; - repeated float biases_mid = 10; - repeated float biases_low = 11; - optional int32 coords = 12 [default = 4]; - repeated float biases = 13; - optional bool resize_origin_img_to_net = 14 [default = false]; - optional int32 out_box_dim = 15 [default = 3]; -} - -message ProposalParameter { - optional float feat_stride = 1 [default = 16]; - optional float base_size = 2 [default = 16]; - optional float min_size = 3 [default = 16]; - repeated float ratio = 4; - repeated float scale = 5; - optional int32 pre_nms_topn = 6 [default = 3000]; - optional int32 post_nms_topn = 7 [default = 304]; - optional float iou_threshold = 8 [default = 0.7]; - optional bool output_actual_rois_num = 9 [default = false]; -} - -message FSRDetectionOutputParameter { - required int32 num_classes = 1; - required float score_threshold = 2; - required float iou_threshold = 3; - optional int32 batch_rois = 4 [default = 1]; -} - -message SSDDetectionOutputParameter { - required int32 num_classes= 1 [default = 2]; - optional bool share_location = 2 [default = true]; - optional int32 background_label_id = 3 [default = 0]; - optional float iou_threshold = 4 [default = 0.3]; - optional int32 top_k = 5 [default = 200]; - optional float eta = 6 [default = 1.0]; - optional bool variance_encoded_in_target = 7 [default = false]; - optional int32 code_type = 8 [default = 1]; - optional int32 keep_top_k = 9 [default = -1]; - optional float confidence_threshold = 10 [default = 0.0]; -} -message YoloV2DetectionOutputParameter { - optional int32 boxes = 1 [default = 5]; - optional int32 classes = 2 [default = 80]; - optional bool relative = 3 [default = true]; - optional float obj_threshold = 4 [default = 0.5]; - optional float score_threshold = 5 [default = 0.5]; - optional float iou_threshold = 6 [default = 0.45]; - optional int32 pre_nms_topn = 7 [default = 512]; - optional int32 post_nms_topn = 8 [default = 1024]; - repeated float biases = 9; - optional int32 coords = 10 [default = 4]; - optional bool resize_origin_img_to_net = 11 [default = false]; -} - -message QuantParameter { - optional float scale = 2; - optional bytes offset = 3; -} - -message BatchMatMulParameter{ - optional bool adj_x1 = 1 [default = false]; - optional bool adj_x2 = 2 [default = false]; -} - -message CondTakeParameter { - required string mode = 1; - required float val = 2; - optional float eps = 3 [default = 1e-06]; -} - -message MatrixInverseParameter { - optional bool adjoint = 1 [default = false]; -} - -message WarpPerspectiveParameter { - required int32 out_height = 1; - required int32 out_width = 2; - optional float constant = 3; - optional string border_type = 4 [default = 'BORDER_CONSTANT']; -} - -message SpatialTransformerParameter { - // How to use the parameter passed by localisation network - optional string transform_type = 1 [default = "affine"]; - // What is the sampling technique - optional string sampler_type = 2 [default = "bilinear"]; - - // If not set,stay same with the input dimension H and W - optional int32 output_H = 3; - optional int32 output_W = 4; - // If false, only compute dTheta, DO NOT compute dU - optional bool to_compute_dU = 5 [default = true]; - - // The default value for some parameters - optional double theta_1_1 = 6; - optional double theta_1_2 = 7; - optional double theta_1_3 = 8; - optional double theta_2_1 = 9; - optional double theta_2_2 = 10; - optional double theta_2_3 = 11; -} diff --git a/ge/proto/dump_task.proto b/ge/proto/dump_task.proto deleted file mode 100644 index a2411ddb..00000000 --- a/ge/proto/dump_task.proto +++ /dev/null @@ -1,113 +0,0 @@ -syntax = "proto3"; -package toolkit.dump; - -enum OutputDataType { - DT_UNDEFINED = 0; - DT_FLOAT = 1; - DT_FLOAT16 = 2; - DT_INT8 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_UINT16 = 6; - DT_INT32 = 7; - DT_INT64 = 8; - DT_UINT32 = 9; - DT_UINT64 = 10; - DT_BOOL = 11; - DT_DOUBLE = 12; - DT_STRING = 13; - DT_DUAL_SUB_INT8 = 14; - DT_DUAL_SUB_UINT8 = 15; - DT_COMPLEX64 = 16; - DT_COMPLEX128 = 17; - DT_QINT8 = 18; - DT_QINT16 = 19; - DT_QINT32 = 20; - DT_QUINT8 = 21; - DT_QUINT16 = 22; - DT_RESOURCE = 23; - DT_STRING_REF = 24; - DT_DUAL = 25; - DT_VARIANT = 26; -} - -enum OutputFormat { - FORMAT_NCHW = 0; - FORMAT_NHWC = 1; - FORMAT_ND = 2; - FORMAT_NC1HWC0 = 3; - FORMAT_FRACTAL_Z = 4; - FORMAT_NC1C0HWPAD = 5; - FORMAT_NHWC1C0 = 6; - FORMAT_FSR_NCHW = 7; - FORMAT_FRACTAL_DECONV = 8; - FORMAT_C1HWNC0 = 9; - FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; - FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; - FORMAT_NC1HWC0_C04 = 12; - FORMAT_FRACTAL_Z_C04 = 13; - FORMAT_CHWN = 14; - FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; - FORMAT_HWCN = 16; - FORMAT_NC1KHKWHWC0 = 17; - FORMAT_BN_WEIGHT = 18; - FORMAT_FILTER_HWCK = 19; - FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; - FORMAT_HASHTABLE_LOOKUP_KEYS = 21; - FORMAT_HASHTABLE_LOOKUP_VALUE = 22; - FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; - FORMAT_HASHTABLE_LOOKUP_HITS=24; - FORMAT_C1HWNCoC0 = 25; - FORMAT_MD = 26; - FORMAT_NDHWC = 27; - FORMAT_FRACTAL_ZZ = 28; - FORMAT_FRACTAL_NZ = 29; - FORMAT_RESERVED = 30; -} - -message OriginalOp { - string name = 1; - uint32 output_index = 2; - OutputDataType data_type = 3; - OutputFormat format = 4; -} - -message Shape { - repeated uint64 dim = 1; -} - -message OpOutput { - OutputDataType data_type = 1; - OutputFormat format = 2; - Shape shape = 3; - OriginalOp original_op = 4; // the original op corresponding to the output - bytes data = 5; - uint64 size = 6; -} - -message OpInput { - OutputDataType data_type = 1; - OutputFormat format = 2; - Shape shape = 3; - bytes data = 4; - uint64 size = 5; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - bytes data = 2; - uint64 size = 3; -} - -message DumpData{ - string version = 1; - uint64 dump_time = 2; - repeated OpOutput output = 3; - repeated OpInput input = 4; - repeated OpBuffer buffer = 5; - string op_name = 6; -} diff --git a/ge/proto/fusion_model.proto b/ge/proto/fusion_model.proto deleted file mode 100755 index c92c5581..00000000 --- a/ge/proto/fusion_model.proto +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -import "om.proto"; - -package domi; - -message FusionModelDef { - string version = 1; - repeated OpDef fusion_op = 2; -} \ No newline at end of file diff --git a/ge/proto/fwk_adapter.proto b/ge/proto/fwk_adapter.proto deleted file mode 100644 index 9335c926..00000000 --- a/ge/proto/fwk_adapter.proto +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package aicpu.FWKAdapter; -option cc_enable_arenas = true; - - -// Defines an struct for input and output. -message TensorDataInfo { - - // value DataType - uint32 dtype = 1; - - // shape dim - repeated int64 dim = 2; - - // data point addr - int64 data_addr = 3; -} - -message KernelRunParam { - // input - repeated TensorDataInfo input = 1; - // output - repeated TensorDataInfo output = 2; -} - diff --git a/ge/proto/ge_api.proto b/ge/proto/ge_api.proto deleted file mode 100755 index 331c5aea..00000000 --- a/ge/proto/ge_api.proto +++ /dev/null @@ -1,88 +0,0 @@ -syntax = "proto3"; -package ge.api_pb; - -import "ge_ir.proto"; - -// GE initialize -message GEInitialize { - map options = 1; -}; - -// initialize response -message GEInitializeResponse { - uint32 status = 1; - uint32 clientId = 2; -}; - -// GE finalize -message GEFinalize { - bool final = 1; - uint32 clientId = 2; -}; - -message GEFinalizeResponse { - uint32 status = 1; -}; - -// GE Session -message CreateSession{ - map options = 1; -}; - -message CreateSessionResponse { - uint32 status = 1; - uint64 sessionId = 2; -}; - -//GE AddGraph -//model serialize :: serializegraph -message SessionAddGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - ge.proto.GraphDef graph = 3; -}; - -message SessionAddGraphResponse { - uint32 status = 1; -}; - -//GE SessionRemoveGraph -message SessionRemoveGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; -}; - -message SessionRemoveGraphResponse { - uint32 status = 1; -}; - -message SessionRunGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - repeated ge.proto.TensorDef tensor = 3; -}; - -message SessionBuildGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - repeated ge.proto.TensorDef tensor = 3; - string savePath = 4; -}; - -message SessionRunGraphResponse { - uint32 status = 1; - repeated ge.proto.TensorDef tensor = 2; -}; - -message SessionBuildGraphResponse { - uint32 status = 1; -}; - -message DestroySession{ - bool final = 1; - uint64 sessionId = 2; -}; - -message DestroySessionResponse { - uint32 status = 1; -}; diff --git a/ge/proto/ge_ir.proto b/ge/proto/ge_ir.proto deleted file mode 100644 index c0ef3071..00000000 --- a/ge/proto/ge_ir.proto +++ /dev/null @@ -1,193 +0,0 @@ -syntax = "proto3"; - -package ge.proto; - -enum DataType -{ - DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. - DT_FLOAT = 1; // float type - DT_FLOAT16 = 2; // fp16 type - DT_INT8 = 3; // int8 type - DT_UINT8 = 4; // uint8 type - DT_INT16 = 5; // int16 type - DT_UINT16 = 6; // uint16 type - DT_INT32 = 7; // - DT_INT64 = 8; // int64 type - DT_UINT32 = 9; // unsigned int32 - DT_UINT64 = 10; // unsigned int64 - DT_BOOL = 11; // bool type - DT_DOUBLE = 12; // double type - DT_STRING = 13; // string type - DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ - DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ - DT_COMPLEX64 = 16; // complex64 type - DT_COMPLEX128 = 17; // complex128 type - DT_QINT8 = 18; // qint8 type - DT_QINT16 = 19; // qint16 type - DT_QINT32 = 20; // qint32 type - DT_QUINT8 = 21; // quint8 type - DT_QUINT16 = 22; // quint16 type - DT_RESOURCE = 23; // resource type - DT_STRING_REF = 24; // string_ref type - DT_DUAL = 25; /**< dual output type */ - DT_VARIANT = 26; // variant type - DT_BF16 = 27; // bf16 type - DT_INT4 = 28; // int4 type -} - -message AttrDef -{ - message ListValue - { - enum ListValueType{ - VT_LIST_NONE = 0; - VT_LIST_STRING = 1; - VT_LIST_INT = 2; - VT_LIST_FLOAT = 3; - VT_LIST_BOOL = 4; - VT_LIST_BYTES = 5; - VT_LIST_TENSOR_DESC = 6; - VT_LIST_TENSOR = 7; - VT_LIST_GRAPH = 8; - VT_LIST_NAMED_ATTRS = 9; - VT_LIST_DATA_TYPE = 10; - } - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3; // "list(int)" - repeated float f = 4; // "list(float)" - repeated bool b = 5; // "list(bool)" - repeated bytes bt = 7; - repeated TensorDescriptor td = 8; - repeated TensorDef t = 9; - repeated GraphDef g = 10; - repeated NamedAttrs na = 11; - repeated int64 dt = 12; // list ge::DataType - - ListValueType val_type = 20; - } - - message ListListInt{ - message ListInt{ - repeated int64 list_i = 1; // list int - } - repeated ListInt list_list_i = 1; // list list int - } - - oneof value - { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; // Used to support attr nesting - TensorDescriptor td = 11; // GeTensorDesc type - TensorDef t = 12; // GeTensor type - GraphDef g = 13; // Graph type - ListListInt list_list_int = 14; // List List Int type - int64 dt = 15; // ge::DataType - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs -{ - string name = 1; - map attr = 2; -} - -// Shape / dimension description, using row-major order -message ShapeDef -{ - repeated int64 dim = 1; // Size of each dimension -} - -// Multidimensional data description -message TensorDescriptor -{ - string name = 1; // Optional parameter, tensor name - - DataType dtype = 2; // tensor datatype - ShapeDef shape = 3; // Shape / dimension - string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" - - bool has_out_attr = 9; - int64 size = 10; - int64 weight_size = 11; - bool reuse_input = 12; - bool output_tensor = 13; - string device_type = 14; - bool input_tensor =15; - int64 real_dim_cnt = 16; - int64 reuse_input_index = 17; - int64 data_offset = 18; - int64 cmps_size = 19; - string cmps_tab = 20; - int64 cmps_tab_offset = 21; - - map attr = 5; // Set of extra parameter fields -} - -// GeTensor definition -message TensorDef -{ - TensorDescriptor desc = 1; // Tensor description - bytes data = 2; // Tensor data -} - - -// Operator description -message OpDef -{ - string name = 1; // name - string type = 2; // type - - repeated string input = 5; // input original op name + outgoing index. op_name:index - - map attr = 10; // Set of operator parameter fields - - bool has_out_attr = 20; - int64 id = 21; - int64 stream_id =22; - repeated string input_name = 23; - repeated string src_name = 24; - repeated int64 src_index = 25; - repeated string dst_name = 26; - repeated int64 dst_index = 27; - repeated int64 input_i = 28; - repeated int64 output_i = 29; - repeated int64 workspace = 30; - repeated int64 workspace_bytes = 31; - repeated bool is_input_const = 32; - repeated TensorDescriptor input_desc = 33; - repeated TensorDescriptor output_desc = 34; - repeated string subgraph_name = 35; -} - -// Graph definition -message GraphDef -{ - string name = 1; // name - - repeated string input = 4; // Graph input - repeated string output = 5; // Graph output - - repeated OpDef op = 6; // List of operators - - map attr = 11; // Extended field -} - -// model definition -message ModelDef -{ - string name = 1; // name - uint32 version = 2; // IR Proto verion - string custom_version = 3; // User model version number, passed in by user - - repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef - - map attr = 11; // Extended field -} - diff --git a/ge/proto/insert_op.proto b/ge/proto/insert_op.proto deleted file mode 100644 index 7d708865..00000000 --- a/ge/proto/insert_op.proto +++ /dev/null @@ -1,140 +0,0 @@ -syntax = "proto3"; - -package domi; - -message InsertNewOps { - repeated AippOpParams aipp_op = 1; - repeated MultiShapeOpParams multi_shape_op = 2; -} - -message AippOpParams { - enum InputFormat { - UNDEFINED = 0; - YUV420SP_U8 = 1; - XRGB8888_U8 = 2; - RGB888_U8 = 3; - YUV400_U8 = 4; - NC1HWC0DI_FP16 = 5; - NC1HWC0DI_S8 = 6; - ARGB8888_U8 = 7; - YUYV_U8 = 8; - YUV422SP_U8 = 9; - AYUV444_U8 = 10; - RAW10 = 11; - RAW12 = 12; - RAW16 = 13; - RAW24 = 14; - RGB16 = 15; - RGB20 = 16; - RGB24 = 17; - RGB8_IR = 18; - RGB16_IR = 19; - RGB24_IR = 20; - } - - enum AippMode { - undefined = 0; - static = 1; - dynamic = 2; - } - - // AIPPģʽ־̬AIPPͶ̬AIPP - AippMode aipp_mode = 1; - - // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 - // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 - uint32 related_input_rank = 2; - - // related_input_name is optional and the top name of data node which inserts aipp - string related_input_name = 6; - - // input_edge_idxΪѡΪͣ÷ΧΪ>=0 - // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP - // ֵ <= Dataߵĸ - repeated uint32 input_edge_idx = 3; - - // [Begin] ̬AIPPþ̬AIPPʱЧ - uint32 max_src_image_size = 4; - - // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ - bool support_rotation = 5; - - // [End] ̬AIPP - - - // [Begin] ̬AIPPö̬AIPPʱЧ - InputFormat input_format = 51; - bool csc_switch = 52; - float cpadding_value = 53; - bool rbuv_swap_switch = 54; - bool ax_swap_switch = 55; - bool single_line_mode = 56; - - int32 src_image_size_w = 57; - int32 src_image_size_h = 58; - - bool crop = 59; - int32 load_start_pos_w = 60; - int32 load_start_pos_h = 61; - int32 crop_size_w = 62; - int32 crop_size_h = 63; - - bool resize = 64; - int32 resize_output_w = 65; - int32 resize_output_h = 66; - - bool padding = 67; - int32 left_padding_size = 68; - int32 right_padding_size = 69; - int32 top_padding_size = 70; - int32 bottom_padding_size = 71; - float padding_value = 72; - - int32 mean_chn_0 = 10; - int32 mean_chn_1 = 11; - int32 mean_chn_2 = 12; - int32 mean_chn_3 = 19; - float min_chn_0 = 13; - float min_chn_1 = 14; - float min_chn_2 = 15; - float min_chn_3 = 20; - repeated float var_reci_chn_0 = 16; - repeated float var_reci_chn_1 = 17; - repeated float var_reci_chn_2 = 18; - repeated float var_reci_chn_3 = 21; - - repeated int32 matrix_r0c0 = 30; - repeated int32 matrix_r0c1 = 31; - repeated int32 matrix_r0c2 = 32; - repeated int32 matrix_r1c0 = 33; - repeated int32 matrix_r1c1 = 34; - repeated int32 matrix_r1c2 = 35; - repeated int32 matrix_r2c0 = 36; - repeated int32 matrix_r2c1 = 37; - repeated int32 matrix_r2c2 = 38; - repeated int32 output_bias_0 = 39; - repeated int32 output_bias_1 = 40; - repeated int32 output_bias_2 = 41; - repeated int32 input_bias_0 = 42; - repeated int32 input_bias_1 = 43; - repeated int32 input_bias_2 = 44; - - // [End] ̬AIPP - - // The n number that is used for raw/rgbir data into f16 transformation. - // The transformation equation is x/(2^n). If set to 0, no transform is performed. - uint32 raw_rgbir_to_f16_n = 45; -} - -message MultiShapeOpParams { - enum MultiShapeMode { - batch = 0; //̬batch - resolution = 1; //ֱ̬ʣչ - } - - MultiShapeMode mode = 1; //ģʽ - uint32 related_input_rank = 2; //Ӳ뵽ĸ - - - repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ -} diff --git a/ge/proto/om.proto b/ge/proto/om.proto deleted file mode 100644 index e15e5f80..00000000 --- a/ge/proto/om.proto +++ /dev/null @@ -1,396 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package domi; - -enum TargetType -{ - MINI = 0; - TINY = 1; - LITE = 2; -} - -// offline model -message ModelDef { - string name = 1; - uint32 version = 2; - - uint64 memory_size = 10; - uint32 stream_num = 11; - uint32 event_num = 12; - uint64 weight_size = 13; - uint32 label_num = 15; - repeated OpDef op = 20; - TargetType target_type = 23; - - map attr = 30; -}; - -// operator define -message OpDef { - string name = 1; - string type = 2; - - uint32 id = 3; - uint32 stream_id = 4; - - repeated string input_name = 5; - - repeated string src_name = 8; - repeated int32 src_index = 9; - repeated int64 input = 10; - repeated int64 output = 11; - repeated TensorDescriptor input_desc = 12; - repeated TensorDescriptor output_desc = 13; - repeated WeightDef weights = 14; - repeated string dst_name = 15; - repeated int32 dst_index = 16; - - repeated int64 workspace = 20; - repeated uint32 workspace_bytes = 21; - - repeated string weight_name = 22; - repeated bool is_input_const = 23; - - map attr = 30; - - QuantizeFactorParams quantize_factor = 31; - - oneof op_params { - // start at 100 here - SendOpParams sender_param = 100; - RecvOpParams receiver_param = 200; - ConvolutionOpParams convolution_param = 300; - PoolingOpParams pooling_param = 400; - EltwiseOpParams eltwise_param = 500; - BatchNormOpParams batchnorm_param = 600; - ScaleOpParams scale_param = 700; - FullConnectionOpParams full_connection_param = 800; - SoftmaxOpParams softmax_param = 900; - ActivationOpParams activation_param = 1000; - ReshapeOpParams reshape_param = 1100; - } -}; - -message SendOpParams { - uint32 event_id = 1; -}; - -message RecvOpParams { - uint32 event_id = 1; -}; - -enum QuantizeScaleType -{ - VECTOR_SCALE = 0; - SCALAR_SCALE = 1; -} - -enum QuantizeScaleMode -{ - NORMAL_MODE = 0; - SQRT_MODE = 1; -} - -enum QuantizeAlgorithm -{ - NON_OFFSET_ALGO = 0; - HALF_OFFSET_ALGO = 1; - ALL_OFFSET_ALGO = 2; -} -message QuantizeFactor -{ - QuantizeScaleMode scale_mode = 1; - bytes scale_value = 2; - int64 scale_offset = 3; - bytes offset_data_value = 4; - int64 offset_data_offset = 5; - bytes offset_weight_value = 6; - int64 offset_weight_offset = 7; - bytes offset_pad_value = 8; - int64 offset_pad_offset = 9; -}; - -message QuantizeCalcFactor -{ - bytes offsetw = 1; - int64 offsetw_offset = 2; - bytes offsetd = 3; - int64 offsetd_offset = 4; - bytes scalereq = 5; - int64 scaledreq_offset = 6; - bytes offsetdnext = 7; - int64 offsetdnext_offset = 8; -} - -message QuantizeFactorParams -{ - QuantizeAlgorithm quantize_algo = 1; - QuantizeScaleType scale_type = 2; - QuantizeFactor quantize_param = 3; - QuantizeFactor dequantize_param = 4; - QuantizeFactor requantize_param = 5; - QuantizeCalcFactor quantizecalc_param = 6; -}; - -message ConvolutionOpParams { - int32 mode = 1; - int32 algo = 2; - int32 pad_mode = 3; - uint32 group = 4; - uint32 num_output = 5; - - repeated uint32 pad = 10; - repeated uint32 stride = 11; - repeated uint32 dilation = 12; - repeated uint32 kernel = 13; - - float alpha = 20; - float beta = 21; - - WeightDef filter = 40; - WeightDef bias = 41; - - bool relu_flag = 62; - repeated uint32 adj = 70; - repeated uint32 target_shape = 71; - repeated uint32 before_pad = 72; -}; - -message PoolingOpParams { - int32 mode = 1; - int32 nan_opt = 2; - int32 pad_mode = 3; - bool global_pooling = 4; - - repeated uint32 window = 10; - repeated uint32 pad = 11; - repeated uint32 stride = 12; - bool ceil_mode = 13; - int32 data_mode = 14; - - float alpha = 20; - float beta = 21; - repeated uint32 before_pad = 22; -}; - -message EltwiseOpParams { - int32 mode = 1; - repeated float coeff = 2; - float alpha = 3; - float beta = 4; - repeated WeightDef weight = 5; - bool relu_flag = 6; -}; - -message ActivationOpParams { - int32 mode = 1; - float coef = 2; - float alpha = 3; - float beta = 4; -}; - -message BatchNormOpParams { - int32 mode = 1; - - float alpha = 2; - float beta = 3; - double epsilon = 4;//optinal,[default = 1e-5] - bool use_global_stats = 5; //optinal,by default true,testing mode - float moving_average_fraction = 6; //optinal,[default = .999]; - - WeightDef estimated_mean = 7; - WeightDef estimated_variance = 8; - - WeightDef scale = 9; - WeightDef bias = 10; -}; - -message ScaleOpParams { - WeightDef scale = 1; - WeightDef bias = 2; -}; - -message ReshapeOpParams { - float alpha = 1; - float beta = 2; - ShapeDef shape = 3; - int32 axis = 4; - int32 num_axes = 5; - int32 format = 6; -}; - -message SoftmaxOpParams { - int32 algo = 1; - int32 mode = 2; - float alpha = 3; - float beta = 4; -}; - -message FullConnectionOpParams { - WeightDef filter = 1; - WeightDef bias = 2; - uint32 num_output = 3; - bool relu_flag = 12; -}; - -message FlattenOpParams { - float alpha = 1; - float beta = 2; - int32 start_axis = 3; - int32 end_axis = 4; -} - -message AddLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message MulLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message AddOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message MulOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message SubOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message BiasAddOpParams { - float alpha = 1; - float beta = 2; - - WeightDef bias = 10; -}; - -message MatMulOpParams { - float alpha = 1; - float beta = 2; - bool transposeX = 3; - bool transposeW = 4; - - WeightDef filter = 10; - WeightDef bias = 12; -}; - -message RsqrtOpParams { - float alpha = 1; - float beta = 2; -}; - - -message WeightDef { - int32 format = 1; - int32 data_type = 2; - ShapeDef shape = 3; - bytes data = 4; - int64 data_offset = 5; - uint32 cmps_size = 6; - bytes cmps_tab = 7; - int64 cmps_tab_offset = 10; - CompressInfo cmps_info = 8; - AllOffsetQuantizeInfo alloffset_quantize_info = 11; -} - -message ShapeDef { - repeated int64 dim = 1; -} - -enum DeviceType { - NPU = 0; // In default, we will use NPU. - CPU = 1; // CPU -} - -message AllOffsetQuantizeInfo { - float scale = 1; - int32 offset = 2; -} - -message TensorDescriptor { - int32 format = 1; - int32 data_type = 2; - repeated int64 dim = 3; - uint32 size = 4; - bool reuse_input = 5; - bool output_tensor = 7; - DeviceType device_type = 8; - bool input_tensor = 9; - uint32 real_dim_cnt = 10; - uint32 reuse_input_index = 11; - AllOffsetQuantizeInfo alloffset_quantize_info = 12; -} - -message CompressInfo { - int32 blockRow = 1; // block row - int32 blockCol = 2; // block col - int32 fractalK = 3; // fractal K - int32 fractalN = 4; // fractal N - int32 lastFractalK = 5; // K of last fractal - int32 lastFractalN = 6; // N of last fractal - int32 cubeSize = 7; // cube's length - int32 loadDir = 8; // data load directtiono 0:col load 1:row load -} - -message AttrDef { - message ListValue { - repeated string s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated uint32 u = 6 [packed = true]; // "list(uint)" - repeated bytes bt = 7; - } - - oneof value { - string s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - uint32 u = 6; // "uint32" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs { - string name = 1; - map attr = 2; -} - diff --git a/ge/proto/op_mapping.proto b/ge/proto/op_mapping.proto deleted file mode 100644 index d626eb49..00000000 --- a/ge/proto/op_mapping.proto +++ /dev/null @@ -1,75 +0,0 @@ -syntax = "proto3"; -package toolkit.aicpu.dump; - -message Shape { - repeated uint64 dim = 1; -} - -message Output { - int32 data_type = 1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - string original_name = 5; - int32 original_output_index = 6; - int32 original_output_data_type = 7; - int32 original_output_format = 8; - uint64 size = 9; - Shape origin_shape = 10; -} - -message Input { - int32 data_type =1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - uint64 size = 5; - Shape origin_shape = 6; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - uint64 address = 2; - uint64 size = 3; -} - -message Op { - string op_name = 1; - string op_type = 2; -} - -message Task { - uint32 task_id = 1; - uint32 stream_id = 2; - Op op = 3; - repeated Output output = 4; - bool end_graph = 5; - repeated Input input = 6; - repeated OpBuffer buffer = 7; -} - -message OpMappingInfo { - string dump_path = 1; - oneof model_name_param { - string model_name = 2; - } - oneof model_id_param { - uint32 model_id = 3; - } - oneof step_id { - uint64 step_id_addr = 4; - } - oneof iterations_per_loop { - uint64 iterations_per_loop_addr = 5; - } - oneof loop_cond { - uint64 loop_cond_addr = 6; - } - uint32 flag = 7; // 0x01 load, 0x00 unload - repeated Task task = 8; - string dump_step = 9; -} \ No newline at end of file diff --git a/ge/proto/optimizer_priority.proto b/ge/proto/optimizer_priority.proto deleted file mode 100644 index 769619cf..00000000 --- a/ge/proto/optimizer_priority.proto +++ /dev/null @@ -1,7 +0,0 @@ -syntax = "proto3"; -package ge.optimizers; - -// Default: GE>FE>AICPU -message Priority{ - repeated string optimizer = 1; -} \ No newline at end of file diff --git a/ge/proto/task.proto b/ge/proto/task.proto deleted file mode 100644 index 0da5631e..00000000 --- a/ge/proto/task.proto +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Apache License for more details at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; - KernelDefWithHandle kernel_with_handle = 40; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; -} - -message KernelDefWithHandle { - KernelContext context = 1; - - uint64 handle = 10; - string dev_func = 11; - uint32 block_dim = 12; - uint32 args_size = 13; - bytes args = 14; - bytes sm_desc = 15; - string original_kernel_key = 16; - string node_info = 17; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - repeated uint32 origin_op_index = 8; -} - - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/ge/proto/tensorflow/attr_value.proto b/ge/proto/tensorflow/attr_value.proto deleted file mode 100644 index 438d7163..00000000 --- a/ge/proto/tensorflow/attr_value.proto +++ /dev/null @@ -1,70 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "AttrValueProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensor.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing the value for an attr used to configure an Op. -// Comment indicates the corresponding attr type. Only the field matching the -// attr type may be filled. -message AttrValue { - // LINT.IfChange - message ListValue { - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated DataType type = 6 [packed = true]; // "list(type)" - repeated TensorShapeProto shape = 7; // "list(shape)" - repeated TensorProto tensor = 8; // "list(tensor)" - repeated NameAttrList func = 9; // "list(attr)" - } - // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) - - oneof value { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - DataType type = 6; // "type" - TensorShapeProto shape = 7; // "shape" - TensorProto tensor = 8; // "tensor" - ListValue list = 1; // any "list(...)" - - // "func" represents a function. func.name is a function's name or - // a primitive op's name. func.attr.first is the name of an attr - // defined for that function. func.attr.second is the value for - // that attr in the instantiation. - NameAttrList func = 10; - - // This is a placeholder only used in nodes defined inside a - // function. It indicates the attr value will be supplied when - // the function is instantiated. For example, let us suppose a - // node "N" in function "FN". "N" has an attr "A" with value - // placeholder = "foo". When FN is instantiated with attr "foo" - // set to "bar", the instantiated node N's attr A will have been - // given the value "bar". - string placeholder = 9; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NameAttrList { - string name = 1; - map attr = 2; -} diff --git a/ge/proto/tensorflow/function.proto b/ge/proto/tensorflow/function.proto deleted file mode 100644 index 44681e32..00000000 --- a/ge/proto/tensorflow/function.proto +++ /dev/null @@ -1,108 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "FunctionProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; -import "node_def.proto"; -import "op_def.proto"; - -// A library is a set of named functions. -message FunctionDefLibrary { - repeated FunctionDef function = 1; - repeated GradientDef gradient = 2; -} - -// A function can be instantiated when the runtime can bind every attr -// with a value. When a GraphDef has a call to a function, it must -// have binding for every attr defined in the signature. -// * device spec, etc. -message FunctionDef { - // The definition of the function's name, arguments, return values, - // attrs etc. - OpDef signature = 1; - - // Attributes specific to this function definition. - map attr = 5; - - // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. - reserved 2; - - // In both of the following fields, there is the need to specify an - // output that is used as either the input to another node (in - // `node_def`) or as a return value of the function (in `ret`). - // Unlike the NodeDefs in GraphDef, we need to be able to specify a - // list in some cases (instead of just single outputs). Also, we - // need to be able to deal with lists of unknown length (so the - // output index may not be known at function definition time). So - // we use the following format instead: - // * "fun_in" where "fun_in" is the name of a function input arg in - // the `signature` field above. This represents that input, whether - // it is a single tensor or a list. - // * "fun_in:0" gives the first element of a function input arg (a - // non-list input is considered a list of length 1 for these - // purposes). - // * "node:out" where "node" is the name of a node in `node_def` and - // "out" is the name one of its op's output arguments (the name - // comes from the OpDef of the node's op). This represents that - // node's output, whether it is a single tensor or a list. - // Note: We enforce that an op's output arguments are never - // renamed in the backwards-compatibility test. - // * "node:out:0" gives the first element of a node output arg (a - // non-list output is considered a list of length 1 for these - // purposes). - // - // NOT CURRENTLY SUPPORTED (but may be in the future): - // * "node:out:-1" gives last element in a node output list - // * "node:out:1:" gives a list with all but the first element in a - // node output list - // * "node:out::-1" gives a list with all but the last element in a - // node output list - - // The body of the function. Unlike the NodeDefs in a GraphDef, attrs - // may have values of type `placeholder` and the `input` field uses - // the "output" format above. - - // By convention, "op" in node_def is resolved by consulting with a - // user-defined library first. If not resolved, "func" is assumed to - // be a builtin op. - repeated NodeDef node_def = 3; - - // A mapping from the output arg names from `signature` to the - // outputs from `node_def` that should be returned by the function. - map ret = 4; -} - -// GradientDef defines the gradient function of a function defined in -// a function library. -// -// A gradient function g (specified by gradient_func) for a function f -// (specified by function_name) must follow the following: -// -// The function 'f' must be a numerical function which takes N inputs -// and produces M outputs. Its gradient function 'g', which is a -// function taking N + M inputs and produces N outputs. -// -// I.e. if we have -// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), -// then, g is -// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, -// dL/dy1, dL/dy2, ..., dL/dy_M), -// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the -// loss function). dL/dx_i is the partial derivative of L with respect -// to x_i. -message GradientDef { - string function_name = 1; // The function name. - string gradient_func = 2; // The gradient function's name. -} diff --git a/ge/proto/tensorflow/graph.proto b/ge/proto/tensorflow/graph.proto deleted file mode 100644 index 73bfc6ee..00000000 --- a/ge/proto/tensorflow/graph.proto +++ /dev/null @@ -1,64 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "GraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "node_def.proto"; -import "function.proto"; -import "versions.proto"; - -// Represents the graph of operations -message GraphDef { - repeated NodeDef node = 1; - - // Compatibility versions of the graph. See core/public/version.h for version - // history. The GraphDef version is distinct from the TensorFlow version, and - // each release of TensorFlow will support a range of GraphDef versions. - VersionDef versions = 4; - - // Deprecated single version field; use versions above instead. Since all - // GraphDef changes before "versions" was introduced were forward - // compatible, this field is entirely ignored. - int32 version = 3 [deprecated = true]; - - // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. - // - // "library" provides user-defined functions. - // - // Naming: - // * library.function.name are in a flat namespace. - // NOTE: We may need to change it to be hierarchical to support - // different orgs. E.g., - // { "/google/nn", { ... }}, - // { "/google/vision", { ... }} - // { "/org_foo/module_bar", { ... }} - // map named_lib; - // * If node[i].op is the name of one function in "library", - // node[i] is deemed as a function call. Otherwise, node[i].op - // must be a primitive operation supported by the runtime. - // - // - // Function call semantics: - // - // * The callee may start execution as soon as some of its inputs - // are ready. The caller may want to use Tuple() mechanism to - // ensure all inputs are ready in the same time. - // - // * The consumer of return values may start executing as soon as - // the return values the consumer depends on are ready. The - // consumer may want to use Tuple() mechanism to ensure the - // consumer does not start until all return values of the callee - // function are ready. - FunctionDefLibrary library = 2; -}; diff --git a/ge/proto/tensorflow/graph_library.proto b/ge/proto/tensorflow/graph_library.proto deleted file mode 100644 index 7bca0838..00000000 --- a/ge/proto/tensorflow/graph_library.proto +++ /dev/null @@ -1,22 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; - -import "graph.proto"; - -message GeGraphDef { - string name = 1; - GraphDef graph = 2; -} - -message GraphDefLibrary { - repeated GeGraphDef graph_def = 1; -}; \ No newline at end of file diff --git a/ge/proto/tensorflow/node_def.proto b/ge/proto/tensorflow/node_def.proto deleted file mode 100644 index 50cf5cac..00000000 --- a/ge/proto/tensorflow/node_def.proto +++ /dev/null @@ -1,71 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "NodeProto"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; - -message NodeDef { - // The name given to this operator. Used for naming inputs, - // logging, visualization, etc. Unique within a single GraphDef. - // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". - string name = 1; - - // The operation name. There may be custom parameters in attrs. - // Op names starting with an underscore are reserved for internal use. - string op = 2; - - // Each input is "node:src_output" with "node" being a string name and - // "src_output" indicating which output tensor to use from "node". If - // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs - // may optionally be followed by control inputs that have the format - // "^node". - repeated string input = 3; - - // A (possibly partial) specification for the device on which this - // node should be placed. - // The expected syntax for this string is as follows: - // - // DEVICE_SPEC ::= PARTIAL_SPEC - // - // PARTIAL_SPEC ::= ("/" CONSTRAINT) * - // CONSTRAINT ::= ("job:" JOB_NAME) - // | ("replica:" [1-9][0-9]*) - // | ("task:" [1-9][0-9]*) - // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) - // - // Valid values for this string include: - // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) - // * "/job:worker/device:GPU:3" (partial specification) - // * "" (no specification) - // - // If the constraints do not resolve to a single device (or if this - // field is empty or not present), the runtime will attempt to - // choose a device automatically. - string device = 4; - - // Operation-specific graph-construction-time configuration. - // Note that this should include all attrs defined in the - // corresponding OpDef, including those with a value matching - // the default -- this allows the default to change and makes - // NodeDefs easier to interpret on their own. However, if - // an attr with a default is not specified in this list, the - // default will be used. - // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and - // one of the names from the corresponding OpDef's attr field). - // The values must have a type matching the corresponding OpDef - // attr's type field. - // Add some examples here showing best practices. - map attr = 5; -}; diff --git a/ge/proto/tensorflow/op_def.proto b/ge/proto/tensorflow/op_def.proto deleted file mode 100644 index 7f0e8ce2..00000000 --- a/ge/proto/tensorflow/op_def.proto +++ /dev/null @@ -1,172 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "OpDefProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; -import "types.proto"; - -// Defines an operation. A NodeDef in a GraphDef specifies an Op by -// using the "op" field which should match the name of a OpDef. -// LINT.IfChange -message OpDef { - // Op names starting with an underscore are reserved for internal use. - // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". - string name = 1; - - // For describing inputs and outputs. - message ArgDef { - // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". - string name = 1; - - // Human readable description. - string description = 2; - - // Describes the type of one or more tensors that are accepted/produced - // by this input/output arg. The only legal combinations are: - // * For a single tensor: either the "type" field is set or the - // "type_attr" field is set to the name of an attr with type "type". - // * For a sequence of tensors with the same type: the "number_attr" - // field will be set to the name of an attr with type "int", and - // either the "type" or "type_attr" field will be set as for - // single tensors. - // * For a sequence of tensors, the "type_list_attr" field will be set - // to the name of an attr with type "list(type)". - DataType type = 3; - string type_attr = 4; // if specified, attr must have type "type" - string number_attr = 5; // if specified, attr must have type "int" - // If specified, attr must have type "list(type)", and none of - // type, type_attr, and number_attr may be specified. - string type_list_attr = 6; - - // For inputs: if true, the inputs are required to be refs. - // By default, inputs can be either refs or non-refs. - // For outputs: if true, outputs are refs, otherwise they are not. - bool is_ref = 16; - }; - - // Description of the input(s). - repeated ArgDef input_arg = 2; - - // Description of the output(s). - repeated ArgDef output_arg = 3; - - // Description of the graph-construction-time configuration of this - // Op. That is to say, this describes the attr fields that will - // be specified in the NodeDef. - message AttrDef { - // A descriptive name for the argument. May be used, e.g. by the - // Python client, as a keyword argument name, and so should match - // the regexp "[a-z][a-z0-9_]+". - string name = 1; - - // One of the type names from attr_value.proto ("string", "list(string)", - // "int", etc.). - string type = 2; - - // A reasonable default for this attribute if the user does not supply - // a value. If not specified, the user must supply a value. - AttrValue default_value = 3; - - // Human-readable description. - string description = 4; - - - // --- Constraints --- - // These constraints are only in effect if specified. Default is no - // constraints. - - // For type == "int", this is a minimum value. For "list(___)" - // types, this is the minimum length. - bool has_minimum = 5; - int64 minimum = 6; - - // The set of allowed values. Has type that is the "list" version - // of the "type" field above (uses the "list" field of AttrValue). - // If type == "type" or "list(type)" above, then the "type" field - // of "allowed_values.list" has the set of allowed DataTypes. - // If type == "string" or "list(string)", then the "s" field of - // "allowed_values.list" has the set of allowed strings. - AttrValue allowed_values = 7; - } - repeated AttrDef attr = 4; - - // Optional deprecation based on GraphDef versions. - OpDeprecation deprecation = 8; - - // One-line human-readable description of what the Op does. - string summary = 5; - - // Additional, longer human-readable description of what the Op does. - string description = 6; - - // ------------------------------------------------------------------------- - // Which optimizations this operation can participate in. - - // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) - bool is_commutative = 18; - - // If is_aggregate is true, then this operation accepts N >= 2 - // inputs and produces 1 output all of the same type. Should be - // associative and commutative, and produce output with the same - // shape as the input. The optimizer may replace an aggregate op - // taking input from multiple devices with a tree of aggregate ops - // that aggregate locally within each device (and possibly within - // groups of nearby devices) before communicating. - bool is_aggregate = 16; // for things like add - - // Other optimizations go here, like - // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. - - // ------------------------------------------------------------------------- - // Optimization constraints. - - // Ops are marked as stateful if their behavior depends on some state beyond - // their input tensors (e.g. variable reading op) or if they have - // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops - // must always produce the same output for the same input and have - // no side-effects. - // - // By default Ops may be moved between devices. Stateful ops should - // either not be moved, or should only be moved if that state can also - // be moved (e.g. via some sort of save / restore). - // Stateful ops are guaranteed to never be optimized away by Common - // Subexpression Elimination (CSE). - bool is_stateful = 17; // for things like variables, queue - - // ------------------------------------------------------------------------- - // Non-standard options. - - // By default, all inputs to an Op must be initialized Tensors. Ops - // that may initialize tensors for the first time should set this - // field to true, to allow the Op to take an uninitialized Tensor as - // input. - bool allows_uninitialized_input = 19; // for Assign, etc. -}; -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) - -// Information about version-dependent deprecation of an op -message OpDeprecation { - // First GraphDef version at which the op is disallowed. - int32 version = 1; - - // Explanation of why it was deprecated and what to use instead. - string explanation = 2; -}; - -// A collection of OpDefs -message OpList { - repeated OpDef op = 1; -}; diff --git a/ge/proto/tensorflow/resource_handle.proto b/ge/proto/tensorflow/resource_handle.proto deleted file mode 100644 index 91c46c9a..00000000 --- a/ge/proto/tensorflow/resource_handle.proto +++ /dev/null @@ -1,37 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "ResourceHandle"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Protocol buffer representing a handle to a tensorflow resource. Handles are -// not valid across executions, but can be serialized back and forth from within -// a single run. -message ResourceHandleProto { - // Unique name for the device containing the resource. - string device = 1; - - // Container in which this resource is placed. - string container = 2; - - // Unique name of this resource. - string name = 3; - - // Hash code for the type of the resource. Is only valid in the same device - // and in the same execution. - uint64 hash_code = 4; - - // For debug-only, the name of the type pointed to by this handle, if - // available. - string maybe_type_name = 5; -}; diff --git a/ge/proto/tensorflow/tensor.proto b/ge/proto/tensorflow/tensor.proto deleted file mode 100644 index 48eeb6c4..00000000 --- a/ge/proto/tensorflow/tensor.proto +++ /dev/null @@ -1,102 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TensorProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "resource_handle.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing a tensor. -message TensorProto { - DataType dtype = 1; - - // Shape of the tensor. - TensorShapeProto tensor_shape = 2; - - // Only one of the representations below is set, one of "tensor_contents" and - // the "xxx_val" attributes. We are not using oneof because as oneofs cannot - // contain repeated fields it would require another extra set of messages. - - // Version number. - // - // In version 0, if the "repeated xxx" representations contain only one - // element, that element is repeated to fill the shape. This makes it easy - // to represent a constant Tensor with a single value. - int32 version_number = 3; - - // Serialized raw tensor content from either Tensor::AsProtoTensorContent or - // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation - // can be used for all tensor types. The purpose of this representation is to - // reduce serialization overhead during RPC call by avoiding serialization of - // many repeated small items. - bytes tensor_content = 4; - - // Type specific representations that make it easy to create tensor protos in - // all languages. Only the representation corresponding to "dtype" can - // be set. The values hold the flattened representation of the tensor in - // row major order. - - // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll - // have some pointless zero padding for each value here. - repeated int32 half_val = 13 [packed = true]; - - // DT_FLOAT. - repeated float float_val = 5 [packed = true]; - - // DT_DOUBLE. - repeated double double_val = 6 [packed = true]; - - // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. - repeated int32 int_val = 7 [packed = true]; - - // DT_STRING - repeated bytes string_val = 8; - - // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real - // and imaginary parts of i-th single precision complex. - repeated float scomplex_val = 9 [packed = true]; - - // DT_INT64 - repeated int64 int64_val = 10 [packed = true]; - - // DT_BOOL - repeated bool bool_val = 11 [packed = true]; - - // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real - // and imaginary parts of i-th double precision complex. - repeated double dcomplex_val = 12 [packed = true]; - - // DT_RESOURCE - repeated ResourceHandleProto resource_handle_val = 14; - - // DT_VARIANT - repeated VariantTensorDataProto variant_val = 15; - - // DT_UINT32 - repeated uint32 uint32_val = 16 [packed = true]; - - // DT_UINT64 - repeated uint64 uint64_val = 17 [packed = true]; -}; - -// Protocol buffer representing the serialization format of DT_VARIANT tensors. -message VariantTensorDataProto { - // Name of the type of objects being serialized. - string type_name = 1; - // Portions of the object that are not Tensors. - bytes metadata = 2; - // Tensors contained within objects being serialized. - repeated TensorProto tensors = 3; -} diff --git a/ge/proto/tensorflow/tensor_shape.proto b/ge/proto/tensorflow/tensor_shape.proto deleted file mode 100644 index 3a6d8c5a..00000000 --- a/ge/proto/tensorflow/tensor_shape.proto +++ /dev/null @@ -1,53 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -// Protocol buffer representing the shape of tensors. - -syntax = "proto3"; -option cc_enable_arenas = true; -option java_outer_classname = "TensorShapeProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -package domi.tensorflow; - -// Dimensions of a tensor. -message TensorShapeProto { - // One dimension of the tensor. - message Dim { - // Size of the tensor in that dimension. - // This value must be >= -1, but values of -1 are reserved for "unknown" - // shapes (values of -1 mean "unknown" dimension). Certain wrappers - // that work with TensorShapeProto may fail at runtime when deserializing - // a TensorShapeProto containing a dim value of -1. - int64 size = 1; - - // Optional name of the tensor dimension. - string name = 2; - }; - - // Dimensions of the tensor, such as {"input", 30}, {"output", 40} - // for a 30 x 40 2D tensor. If an entry has size -1, this - // corresponds to a dimension of unknown size. The names are - // optional. - // - // The order of entries in "dim" matters: It indicates the layout of the - // values in the tensor in-memory representation. - // - // The first entry in "dim" is the outermost dimension used to layout the - // values, the last entry is the innermost dimension. This matches the - // in-memory layout of RowMajor Eigen tensors. - // - // If "dim.size()" > 0, "unknown_rank" must be false. - repeated Dim dim = 2; - - // If true, the number of dimensions in the shape is unknown. - // - // If true, "dim.size()" must be 0. - bool unknown_rank = 3; -}; diff --git a/ge/proto/tensorflow/types.proto b/ge/proto/tensorflow/types.proto deleted file mode 100644 index f40e49cb..00000000 --- a/ge/proto/tensorflow/types.proto +++ /dev/null @@ -1,82 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TypesProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// LINT.IfChange -enum DataType { - // Not a legal value for DataType. Used to indicate a DataType field - // has not been set. - DT_INVALID = 0; - - // Data types that all computation devices are expected to be - // capable to support. - DT_FLOAT = 1; - DT_DOUBLE = 2; - DT_INT32 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_INT8 = 6; - DT_STRING = 7; - DT_COMPLEX64 = 8; // Single-precision complex - DT_INT64 = 9; - DT_BOOL = 10; - DT_QINT8 = 11; // Quantized int8 - DT_QUINT8 = 12; // Quantized uint8 - DT_QINT32 = 13; // Quantized int32 - DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. - DT_QINT16 = 15; // Quantized int16 - DT_QUINT16 = 16; // Quantized uint16 - DT_UINT16 = 17; - DT_COMPLEX128 = 18; // Double-precision complex - DT_HALF = 19; - DT_RESOURCE = 20; - DT_VARIANT = 21; // Arbitrary C++ data types - DT_UINT32 = 22; - DT_UINT64 = 23; - - // Do not use! These are only for parameters. Every enum above - // should have a corresponding value below (verified by types_test). - DT_FLOAT_REF = 101; - DT_DOUBLE_REF = 102; - DT_INT32_REF = 103; - DT_UINT8_REF = 104; - DT_INT16_REF = 105; - DT_INT8_REF = 106; - DT_STRING_REF = 107; - DT_COMPLEX64_REF = 108; - DT_INT64_REF = 109; - DT_BOOL_REF = 110; - DT_QINT8_REF = 111; - DT_QUINT8_REF = 112; - DT_QINT32_REF = 113; - DT_BFLOAT16_REF = 114; - DT_QINT16_REF = 115; - DT_QUINT16_REF = 116; - DT_UINT16_REF = 117; - DT_COMPLEX128_REF = 118; - DT_HALF_REF = 119; - DT_RESOURCE_REF = 120; - DT_VARIANT_REF = 121; - DT_UINT32_REF = 122; - DT_UINT64_REF = 123; -} -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/c/c_api.h, -// https://www.tensorflow.org/code/tensorflow/go/tensor.go, -// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, -// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, -// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/ge/proto/tensorflow/versions.proto b/ge/proto/tensorflow/versions.proto deleted file mode 100644 index 4e81548f..00000000 --- a/ge/proto/tensorflow/versions.proto +++ /dev/null @@ -1,39 +0,0 @@ -/** - * This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow - * - * This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model. - * This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications"). - * All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd. - */ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "VersionsProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Version information for a piece of serialized data -// -// There are different types of versions for each type of data -// (GraphDef, etc.), but they all have the same common shape -// described here. -// -// Each consumer has "consumer" and "min_producer" versions (specified -// elsewhere). A consumer is allowed to consume this data if -// -// producer >= min_producer -// consumer >= min_consumer -// consumer not in bad_consumers -// -message VersionDef { - // The version of the code that produced this data. - int32 producer = 1; - - // Any consumer below this version is not allowed to consume this data. - int32 min_consumer = 2; - - // Specific consumer versions which are disallowed (e.g. due to bugs). - repeated int32 bad_consumers = 3; -}; diff --git a/ge/session/inner_session.cc b/ge/session/inner_session.cc index 8248eecf..b9c44ef1 100755 --- a/ge/session/inner_session.cc +++ b/ge/session/inner_session.cc @@ -24,18 +24,18 @@ #include "adx_datadump_server.h" #include "common/dump/dump_properties.h" #include "common/dump/dump_manager.h" -#include "common/util.h" +#include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "graph/ge_context.h" #include "graph/ge_global_options.h" #include "graph/ge_local_context.h" -#include "graph/common/local_context.h" -#include "graph/load/model_manager/model_manager.h" +#include "common/local_context.h" #include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_mem_manager.h" #include "graph/utils/tensor_adapter.h" #include "runtime/mem.h" #include "ir_build/option_utils.h" +#include "common/profiling/profiling_manager.h" namespace ge { namespace { @@ -82,6 +82,18 @@ Status InnerSession::Initialize() { return ret; } + //Check option OP_PRECISION_MODE + auto iter = all_options.find(ge::OP_PRECISION_MODE); + if (iter != all_options.end() && !iter->second.empty() && !ge::CheckInputPathValid(iter->second)) { + REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), + std::vector({ge::OP_PRECISION_MODE, iter->second, "path is not found"})); + GELOGE(PARAM_INVALID, "[Check][OP_PRECISION_MODE] %s not found", iter->second.c_str()); + return FAILED; + } + if (iter != all_options.end()) { + GELOGI("Option set successfully, option_key=%s, option_value=%s", + ge::OP_PRECISION_MODE.c_str(), iter->second.c_str()); + } // Check option modify_mixlist if (ge::CheckModifyMixlistParamValid(all_options) != ge::SUCCESS) { return FAILED; @@ -109,10 +121,10 @@ Status InnerSession::Initialize() { GE_CHK_RT_RET(rtSetDevice(GetContext().DeviceId())); DumpProperties dump_properties; - dump_properties.InitByOptions(); + GE_CHK_STATUS_RET(dump_properties.InitByOptions(), "Init dump properties failed."); GE_CHK_STATUS_RET(AddDumpProperties(dump_properties), "[Add][DumpProperties] failed."); - ret = graph_manager_.Initialize(options_); + ret = InnerInitialize(); if (ret != SUCCESS) { GELOGE(ret, "[Init][GraphManager] failed, InnerSession:%lu.", session_id_); REPORT_CALL_ERROR("E19999", "GraphManager initialize failed, InnerSession:%lu.", session_id_); @@ -124,7 +136,7 @@ Status InnerSession::Initialize() { if (ret != SUCCESS) { GELOGE(ret, "[Set][MemoryMallocSize] failed."); REPORT_CALL_ERROR("E19999", "VarManager SetMemoryMallocSize failed, InnerSession:%lu.", session_id_); - (void)graph_manager_.Finalize(); + (void)InnerFinalize(); GE_CHK_STATUS(RemoveDumpProperties(), "[Remove][DumpProperties] failed."); GE_CHK_RT(rtDeviceReset(static_cast(GetContext().DeviceId()))); return ret; @@ -150,14 +162,13 @@ Status InnerSession::Finalize() { return SUCCESS; } UpdateThreadContext(std::map{}); - Status ret = graph_manager_.Finalize(); + Status ret = InnerFinalize(); if (ret != SUCCESS) { // Subsequent code execution is required, so no return is required GELOGE(ret, "[Finalize][GraphManager] failed, InnerSession:%lu.", session_id_); REPORT_CALL_ERROR("E19999", "GraphManager Finalize failed, InnerSession:%lu.", session_id_); } - ModelManager::GetInstance()->DestroyAicpuSession(session_id_); init_flag_ = false; // release var memory GELOGI("VarManager free var memory."); @@ -176,6 +187,44 @@ Status InnerSession::Finalize() { return ret; } +Status InnerSession::InnerInitialize() { + Status ret = model_executor_.Initialize(options_, session_id_); + if (ret != SUCCESS) { + GELOGE(ret, "[Init][GraphExecutor] failed, InnerSession:%lu.", session_id_); + REPORT_CALL_ERROR("E19999", "GraphExecutor initialize failed, InnerSession:%lu.", session_id_); + GE_CHK_STATUS(RemoveDumpProperties(), "[Remove][DumpProperties] failed."); + return ret; + } + + ret = graph_manager_.Initialize(options_, &model_executor_); + if (ret != SUCCESS) { + GELOGE(ret, "[Init][GraphManager] failed, InnerSession:%lu.", session_id_); + REPORT_CALL_ERROR("E19999", "GraphManager initialize failed, InnerSession:%lu.", session_id_); + GE_CHK_STATUS(RemoveDumpProperties(), "[Remove][DumpProperties] failed."); + return ret; + } + + return SUCCESS; +} + +Status InnerSession::InnerFinalize() { + Status ret = graph_manager_.Finalize(); + if (ret != SUCCESS) { + // Subsequent code execution is required, so no return is required + GELOGE(ret, "[Finalize][GraphManager] failed, InnerSession:%lu.", session_id_); + REPORT_CALL_ERROR("E19999", "GraphManager Finalize failed, InnerSession:%lu.", session_id_); + } + + ret = model_executor_.Finalize(); + if (ret != SUCCESS) { + // Subsequent code execution is required, so no return is required + GELOGE(ret, "[Finalize][GraphExecutor] failed, InnerSession:%lu.", session_id_); + REPORT_CALL_ERROR("E19999", "GraphExecutor Finalize failed, InnerSession:%lu.", session_id_); + } + + return SUCCESS; +} + Status InnerSession::GetVariable(const std::string &name, Tensor &val) { UpdateThreadContext(std::map{}); return graph_manager_.GetVariable(name, val); @@ -183,6 +232,9 @@ Status InnerSession::GetVariable(const std::string &name, Tensor &val) { Status InnerSession::AddGraph(uint32_t graph_id, const Graph &graph) { std::map options; + auto device_id = GetContext().DeviceId(); + GELOGD("Device id is %u", device_id); + ProfilingManager::Instance().SetGraphIdToDeviceMap(graph_id, device_id); return AddGraph(graph_id, graph, options); } diff --git a/ge/session/inner_session.h b/ge/session/inner_session.h index a2ec35df..afc273ac 100644 --- a/ge/session/inner_session.h +++ b/ge/session/inner_session.h @@ -21,8 +21,9 @@ #include #include #include "framework/common/ge_types.h" -#include "ge/ge_api_types.h" +#include "external/ge/ge_api_types.h" #include "graph/manager/graph_manager.h" +#include "graph/execute/model_executor.h" namespace ge { class InnerSession { @@ -82,10 +83,14 @@ class InnerSession { void SetRtSocVersion(); private: + Status InnerInitialize(); + Status InnerFinalize(); + bool init_flag_; uint64_t session_id_; std::map options_; GraphManager graph_manager_; + ModelExecutor model_executor_; std::mutex resource_mutex_; // AddGraph, RemoveGraph and Finalize use void UpdateThreadContext(const std::map &options); void UpdateThreadContext(uint32_t graph_id); diff --git a/ge/session/omg.cc b/ge/session/omg.cc index 878b0b39..f7f3def7 100755 --- a/ge/session/omg.cc +++ b/ge/session/omg.cc @@ -14,21 +14,21 @@ * limitations under the License. */ -#include "omg/omg.h" +#include "framework/omg/omg.h" #include #include #include #include "common/auth/file_saver.h" -#include "common/debug/log.h" +#include "framework/common/debug/log.h" #include "common/debug/memory_dumper.h" #include "common/ge/ge_util.h" -#include "common/helper/model_helper.h" +#include "framework/common/helper/model_helper.h" #include "common/model_parser/model_parser.h" #include "common/model_saver.h" #include "common/properties_manager.h" -#include "common/string_util.h" -#include "common/types.h" -#include "common/util.h" +#include "framework/common/string_util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" #include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/omg/parser/parser_inner_ctx.h" @@ -39,10 +39,10 @@ #include "graph/optimize/common/params.h" #include "graph/utils/type_utils.h" #include "ir_build/option_utils.h" -#include "omg/omg_inner_types.h" -#include "omg/parser/model_parser.h" -#include "omg/parser/parser_factory.h" -#include "omg/parser/weights_parser.h" +#include "framework/omg/omg_inner_types.h" +#include "framework/omg/parser/model_parser.h" +#include "framework/omg/parser/parser_factory.h" +#include "framework/omg/parser/weights_parser.h" #include "parser/common/pre_checker.h" #include "parser/common/convert/pb2json.h" #include "proto/ge_ir.pb.h" diff --git a/ge/session/session_manager.cc b/ge/session/session_manager.cc index fdf37d06..486dfd58 100755 --- a/ge/session/session_manager.cc +++ b/ge/session/session_manager.cc @@ -20,7 +20,6 @@ #include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/ge_context.h" -#include "graph/load/model_manager/model_manager.h" #include "graph/manager/util/rt_context_util.h" using std::map; @@ -105,10 +104,6 @@ Status SessionManager::DestroySession(SessionId session_id) { return GE_SESSION_NOT_EXIST; } - if (ModelManager::GetInstance() != nullptr) { - ModelManager::GetInstance()->DestroyAicpuSession(session_id); - } - // Unified destruct rt_context RtContextUtil::GetInstance().DestroyRtContexts(session_id); diff --git a/ge/session/session_manager.h b/ge/session/session_manager.h index 17152b0a..4a0b9d66 100644 --- a/ge/session/session_manager.h +++ b/ge/session/session_manager.h @@ -22,8 +22,8 @@ #include #include #include -#include "common/ge_inner_error_codes.h" -#include "ge/ge_api_types.h" +#include "framework/common/ge_inner_error_codes.h" +#include "external/ge/ge_api_types.h" #include "session/inner_session.h" #include "runtime/base.h" @@ -31,9 +31,26 @@ namespace ge { using SessionPtr = std::shared_ptr; class SessionManager { - friend class GELib; - public: + SessionManager() = default; + + ~SessionManager() = default; + + /// + /// @ingroup ge_session + /// @brief initialize session manager + /// @param [in] options session manager config options + /// @return Status result of function + /// + Status Initialize(const std::map &options); + + /// + /// @ingroup ge_session + /// @brief finalize session manager + /// @return Status result of function + /// + Status Finalize(); + /// /// @ingroup ge_session /// @brief create session @@ -181,25 +198,6 @@ class SessionManager { bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id); private: - SessionManager() = default; - - ~SessionManager() = default; - - /// - /// @ingroup ge_session - /// @brief initialize session manager - /// @param [in] options session manager config options - /// @return Status result of function - /// - Status Initialize(const std::map &options); - - /// - /// @ingroup ge_session - /// @brief finalize session manager - /// @return Status result of function - /// - Status Finalize(); - bool HasSession(SessionId session_id); Status GetNextSessionId(SessionId &next_session_id); diff --git a/ge/single_op/single_op.cc b/ge/single_op/single_op.cc index d09e8398..a82c30ba 100755 --- a/ge/single_op/single_op.cc +++ b/ge/single_op/single_op.cc @@ -16,8 +16,8 @@ #include "single_op/single_op.h" -#include "common/fmk_types.h" -#include "common/ge_types.h" +#include "framework/common/fmk_types.h" +#include "framework/common/ge_types.h" #include "common/math/math_util.h" #include "common/profiling/profiling_manager.h" #include "framework/common/debug/ge_log.h" @@ -58,7 +58,7 @@ Status ProfilingTaskInfo(OpTask *op_task, const string &shape_type) { tmp_task_desc_info.op_name.c_str(), tmp_task_desc_info.model_name.c_str()); tmp_task_desc_info.shape_type = shape_type; - tmp_task_desc_info.cur_iter_num = 0; + tmp_task_desc_info.cur_iter_num = ProfilingManager::Instance().GetStepInfoIndex(); tmp_task_desc_info.task_type = op_task->GetTaskType(); std::vector task_desc_info; @@ -297,6 +297,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOp::ExecuteAsync(c for (auto &task : tasks_) { ret = task->LaunchKernel(stream_); + GELOGD("[DEBUG_TASK_INFO : Static Task] %s %s", + task->GetTaskName().c_str(), + BuildTaskUtils::GetTaskInfo(task->GetOpdesc(), inputs, outputs).c_str()); if (ret != SUCCESS) { return ret; } @@ -447,6 +450,8 @@ Status DynamicSingleOp::ExecuteAsync(const vector &input_desc, } else { GE_CHK_STATUS_RET_NOLOG(op_task_->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_)); } + GELOGD("[DEBUG_TASK_INFO : Dynamic Task] %s", + BuildTaskUtils::GetTaskInfo(op_task_->GetOpdesc(), input_buffers, output_buffers).c_str()); GE_CHK_STATUS_RET_NOLOG(op_task_->OpenDump(stream_)); GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(op_task_.get(), kShapeTypeDynamic)); return SUCCESS; diff --git a/ge/single_op/single_op.h b/ge/single_op/single_op.h index deb4532e..7e05dd5f 100755 --- a/ge/single_op/single_op.h +++ b/ge/single_op/single_op.h @@ -23,10 +23,10 @@ #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "framework/executor/ge_executor.h" #include "runtime/stream.h" -#include "task/op_task.h" +#include "single_op/task/op_task.h" #include "cce/aicpu_engine_struct.h" #include "hybrid/executor/hybrid_model_executor.h" diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 67642f2e..ca07d2ae 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -28,10 +28,10 @@ #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" #include "runtime/rt.h" -#include "task/aicpu_task_builder.h" -#include "task/aicpu_kernel_task_builder.h" -#include "task/rts_kernel_task_builder.h" -#include "task/tbe_task_builder.h" +#include "single_op/task/aicpu_task_builder.h" +#include "single_op/task/aicpu_kernel_task_builder.h" +#include "single_op/task/rts_kernel_task_builder.h" +#include "single_op/task/tbe_task_builder.h" #include "hybrid/executor/hybrid_model_executor.h" #include "hybrid/node_executor/node_executor.h" @@ -44,48 +44,58 @@ using std::vector; namespace ge { namespace { const size_t kDataOutputNum = 1; +const uint32_t kInputIndexOfData = 0; const uint32_t kOutputIndexOfData = 0; +const size_t kNumTaskWithAtomicAddrCleanTask = 2; +const size_t kNumTaskWithMemCpyTask = 2; constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; +const char *const kEngineNameAiCore = "AIcoreEngine"; +const char *const kEngineNameAiCpu = "aicpu_ascend_kernel"; +const char *const kEngineNameAiCpuTf = "aicpu_tf_kernel"; + +Status CheckHostMem(const std::vector &dependencies, const NodePtr &node, bool &is_host_mem) { + auto op_desc = node->GetOpDesc(); + for (const auto &input_name : dependencies) { + int input_index = op_desc->GetInputIndexByName(input_name); + if (input_index < 0) { + GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.", + node->GetName().c_str(), input_name.c_str()); + REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.", + node->GetName().c_str(), input_name.c_str()); + return INTERNAL_ERROR; + } -Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { + const auto &src_node = NodeUtils::GetInDataNodeByIndex(*node, input_index); + GE_CHECK_NOTNULL(src_node); + auto src_op_desc = src_node->GetOpDesc(); + GE_CHECK_NOTNULL(src_op_desc); + if (src_op_desc->GetType() == DATA) { + auto tensor = src_op_desc->MutableInputDesc(kInputIndexOfData); + if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) { + GELOGD("Get hostmem from node %s, inputname: %s.", src_node->GetName().c_str(), input_name.c_str()); + continue; + } + } + is_host_mem = false; + return SUCCESS; + } + is_host_mem = true; + return SUCCESS; +} + +Status CheckInferDepend(GeModelPtr &ge_model, bool &is_infer_depend, bool &is_host_mem) { auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); GE_CHECK_NOTNULL(comp_graph); for (const auto &node : comp_graph->GetAllNodes()) { + GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); const auto &depends = op_desc->GetOpInferDepends(); bool support_dynamic_shape = false; (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); if (!depends.empty() && support_dynamic_shape) { - flag = true; - return SUCCESS; - } - } - return SUCCESS; -} - -Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { - bool infer_depend_flag = false; - GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); - auto tasks = ge_model->GetModelTaskDefPtr()->task(); - int32_t kernel_task_num = 0; - for (int i = 0; i < tasks.size(); ++i) { - auto task_type = static_cast(tasks[i].type()); - if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { - const auto &context = task_type == RT_MODEL_TASK_KERNEL ? tasks[i].kernel().context() : - tasks[i].kernel_with_handle().context(); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type == ccKernelType::TE) { - if (infer_depend_flag) { - flag = true; - return SUCCESS; - } - kernel_task_num++; - if (kernel_task_num > 1) { - flag = true; - return SUCCESS; - } - } + is_infer_depend = true; + return CheckHostMem(depends, node, is_host_mem); } } return SUCCESS; @@ -342,11 +352,10 @@ Status SingleOpModel::BuildTaskList(StreamResource *stream_resource, SingleOp &s } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { GELOGD("Building AICPU_TF task"); AiCpuTask *aicpu_task = nullptr; - bool depend_compute_flag = false; uint64_t singleop_kernel_id = aicpu_kernel_id++; GELOGI("Build singleOp TfTask, kernel_id = %lu", singleop_kernel_id); GE_CHK_STATUS_RET_NOLOG( - BuildKernelExTask(task_def.kernel_ex(), &aicpu_task, false, depend_compute_flag, singleop_kernel_id)); + BuildKernelExTask(task_def.kernel_ex(), &aicpu_task, singleop_kernel_id)); aicpu_task->SetModelArgs(model_name_, model_id_); ParseArgTable(aicpu_task, single_op); single_op.tasks_.emplace_back(aicpu_task); @@ -391,7 +400,7 @@ void SingleOpModel::ParseArgTable(OpTask *task, SingleOp &op) { } } } - + Status SingleOpModel::BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask **task) { GE_CHECK_NOTNULL(task); auto task_type = static_cast(task_def.type()); @@ -404,7 +413,7 @@ Status SingleOpModel::BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask * return ACL_ERROR_GE_INTERNAL_ERROR; } - auto *tbe_task = new (std::nothrow) TbeOpTask(); + std::unique_ptr tbe_task(new (std::nothrow) TbeOpTask()); if (tbe_task == nullptr) { GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Create][TbeOpTask]failed."); REPORT_INNER_ERROR("E19999", "BuildKernelTask fail for new TbeOpTask."); @@ -414,17 +423,45 @@ Status SingleOpModel::BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask * auto builder = TbeTaskBuilder(model_name_, iter->second, task_def); auto ret = builder.BuildTask(*tbe_task, model_params_); if (ret != SUCCESS) { - delete tbe_task; - tbe_task = nullptr; + GELOGE(ret, "[Build][TbeOpTask]failed."); + REPORT_INNER_ERROR("E19999", "[Build][TbeOpTask]failed."); return ret; } - *task = tbe_task; + *task = tbe_task.release(); return SUCCESS; } -Status SingleOpModel::BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, - bool dynamic_flag, bool& depend_compute_flag, uint64_t kernel_id) { +Status SingleOpModel::BuildAtomicTask(const domi::TaskDef &task_def, AtomicAddrCleanOpTask **task) { + GE_CHECK_NOTNULL(task); + const auto &context = task_def.kernel().context(); + auto iter = op_list_.find(context.op_index()); + if (iter == op_list_.end()) { + GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Param:TaskDef]op desc not found. op index = %u", context.op_index()); + REPORT_INNER_ERROR("E19999", "BuildKernelTask fail for op desc not found. op index = %u", context.op_index()); + return ACL_ERROR_GE_INTERNAL_ERROR; + } + + std::unique_ptr atomic_task(new (std::nothrow) AtomicAddrCleanOpTask()); + if (atomic_task == nullptr) { + GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Create][AtomicAddrCleanOpTask]failed."); + REPORT_INNER_ERROR("E19999", "BuildKernelTask fail for new AtomicAddrCleanOpTask."); + return ACL_ERROR_GE_MEMORY_ALLOCATION; + } + + auto builder = AtomicAddrCleanTaskBuilder(model_name_, iter->second, task_def); + auto ret = builder.BuildTask(*atomic_task, model_params_); + if (ret != SUCCESS) { + GELOGE(ret, "[Build][AtomicAddrCleanOpTask]failed."); + REPORT_INNER_ERROR("E19999", "[Build][AtomicAddrCleanOpTask]failed."); + return ret; + } + + *task = atomic_task.release(); + return SUCCESS; +} + +Status SingleOpModel::BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, uint64_t kernel_id) { auto iter = op_list_.find(kernel_def.op_index()); if (iter == op_list_.end()) { GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, @@ -442,12 +479,11 @@ Status SingleOpModel::BuildKernelExTask(const domi::KernelExDef &kernel_def, AiC return ACL_ERROR_GE_MEMORY_ALLOCATION; } auto builder = AiCpuTaskBuilder(iter->second->GetOpDesc(), kernel_def); - auto ret = builder.BuildTask(*aicpu_task, model_params_, dynamic_flag, kernel_id); + auto ret = builder.BuildTask(*aicpu_task, model_params_, kernel_id); if (ret != SUCCESS) { GELOGE(ret, "[Build][Task] failed, kernel_id:%lu.", kernel_id); return ret; } - depend_compute_flag = (aicpu_task->GetUnknownType() == DEPEND_COMPUTE); *task = aicpu_task.release(); return SUCCESS; @@ -517,7 +553,8 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { auto ge_model = model_helper_.GetGeModel(); GE_CHECK_NOTNULL(ge_model); bool infer_depend_flag = false; - GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); + bool is_host_mem = false; + GE_CHK_STATUS_RET(CheckInferDepend(ge_model, infer_depend_flag, is_host_mem), "[Check][InferDepend] failed."); if (infer_depend_flag) { // construct single_op, do single op with HybridModelExecutor GELOGD("Init hybrid model params of single op, and will do execute with hybrid model executor."); @@ -526,15 +563,36 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { return BuildTaskList(&resource, single_op); } -Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, const TaskDef &task_def, - DynamicSingleOp &single_op) { - auto task_type = static_cast(task_def.type()); - const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() : - task_def.kernel_with_handle().context(); +Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) { + auto ge_model = model_helper_.GetGeModel(); + GE_CHECK_NOTNULL(ge_model); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type == ccKernelType::TE) { + auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); + GE_CHECK_NOTNULL(compute_graph); + single_op.compute_graph_ = compute_graph; + + if (node_tasks_.size() != 1) { + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Size]Node size must be 1, but get %zu.", node_tasks_.size()); + REPORT_INNER_ERROR("E19999", "[Check][Size]Node size must be 1, but get %zu.", node_tasks_.size()); + return ACL_ERROR_GE_PARAM_INVALID; + } + + auto iter = node_tasks_.begin(); + auto node = iter->first; + const auto &task_defs = iter->second; + if (task_defs.size() <= 0 || task_defs.size() > kNumTaskWithAtomicAddrCleanTask) { + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Size]Node size must be 1, but get %zu.", node_tasks_.size()); + REPORT_INNER_ERROR("E19999", "[Check][Size]task_defs size must be 1 or 2, but get %zu.", task_defs.size()); + return ACL_ERROR_GE_PARAM_INVALID; + } + + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const auto &lib_name = op_desc->GetOpKernelLibName(); + if (lib_name == kEngineNameAiCore) { GELOGD("Building TBE task."); + const auto &task_def = task_defs.back(); TbeOpTask *tbe_task = nullptr; GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def, &tbe_task)); tbe_task->SetModelArgs(model_name_, model_id_); @@ -542,8 +600,16 @@ Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, cons GELOGD("tiling buffer is not nullptr."); tbe_task->stream_resource_ = stream_resource; } + if (task_defs.size() == kNumTaskWithAtomicAddrCleanTask) { + const auto &atomic_task_def = task_defs.front(); + AtomicAddrCleanOpTask *atomic_task = nullptr; + GE_CHK_STATUS_RET_NOLOG(BuildAtomicTask(atomic_task_def, &atomic_task)); + GE_CHK_STATUS_RET_NOLOG(atomic_task->InitAtomicAddrCleanIndices()); + tbe_task->SetAtomicAddrCleanTask(atomic_task); + } single_op.op_task_.reset(tbe_task); - } else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) { + } else if (lib_name == kEngineNameAiCpu) { + const auto &task_def = task_defs[0]; GELOGD("Building AICPU_CC task"); OpTask *task = nullptr; uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; @@ -551,69 +617,69 @@ Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, cons GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task, dynamic_singleop_kernel_id)); task->SetModelArgs(model_name_, model_id_); single_op.op_task_.reset(task); - } else { - GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, - "[Check][Param:TaskDef]Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", - context.kernel_type()); - REPORT_INNER_ERROR("E19999", - "BuildModelTaskKernel fail for got:%u not supported, Only TBE, AI_CPU, CUST_AI_CPU kernel are supported.", - context.kernel_type()); - return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID; + } else if (lib_name == kEngineNameAiCpuTf) { + const auto &task_def = task_defs[0]; + GELOGD("Building AICPU_TF task"); + AiCpuTask *aicpu_task = nullptr; + uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; + GELOGI("Build dynamic singleOp TfTask, kernel_id = %lu", dynamic_singleop_kernel_id); + GE_CHK_STATUS_RET_NOLOG(BuildKernelExTask(task_def.kernel_ex(), &aicpu_task, dynamic_singleop_kernel_id)); + if (aicpu_task->GetUnknownType() == DEPEND_COMPUTE) { + if (task_defs.size() < kNumTaskWithMemCpyTask) { + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Task]The copy task of the fourth operator was not found."); + REPORT_INNER_ERROR("E19999", "The copy task of the fourth operator was not found."); + return ACL_ERROR_GE_PARAM_INVALID; + } + const TaskDef ©_task_def = task_defs[1]; + GE_CHK_STATUS_RET_NOLOG(aicpu_task->SetMemCopyTask(copy_task_def.kernel_ex())); + } + aicpu_task->SetModelArgs(model_name_, model_id_); + single_op.op_task_.reset(aicpu_task); } + return SUCCESS; } -Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) { +Status SingleOpModel::NeedHybridModel(GeModelPtr &ge_model, bool &need_hybrid_model) { + bool is_infer_depend = false; + bool is_host_mem = false; + GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed."); + bool need_d2h_cpy = is_infer_depend && !is_host_mem; + need_hybrid_model = need_d2h_cpy || node_tasks_.size() > 1; + return SUCCESS; +} + +Status SingleOpModel::ParseTasks() { auto ge_model = model_helper_.GetGeModel(); GE_CHECK_NOTNULL(ge_model); - auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); - GE_CHECK_NOTNULL(compute_graph); - single_op.compute_graph_ = compute_graph; auto tasks = ge_model->GetModelTaskDefPtr()->task(); for (int i = 0; i < tasks.size(); ++i) { - const TaskDef &task_def = tasks[i]; + TaskDef &task_def = tasks[i]; GELOGI("[%s] Task[%d], type = [%u], DebugString = [%s]", model_name_.c_str(), i, task_def.type(), task_def.DebugString().c_str()); auto task_type = static_cast(task_def.type()); - if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { - if (single_op.op_task_ != nullptr) { - GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "[Check][TaskType]Do not support dynamic op with multiple tasks."); - REPORT_INNER_ERROR("E19999", - "BuildTaskListForDynamicOp fail for Do not support dynamic op with multiple tasks."); - return ACL_ERROR_GE_OP_TASK_TYPE_INVALID; - } - GE_CHK_STATUS_RET_NOLOG(BuildModelTaskKernel(stream_resource, task_def, single_op)); + uint32_t op_index = 0; + if (task_type == RT_MODEL_TASK_KERNEL) { + op_index = task_def.kernel().context().op_index(); } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { - if (single_op.op_task_ != nullptr) { - GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "[Check][TaskType]Do not support dynamic op with multiple tasks."); - REPORT_INNER_ERROR("E19999", - "BuildTaskListForDynamicOp fail for Do not support dynamic op with multiple tasks."); - return ACL_ERROR_GE_OP_TASK_TYPE_INVALID; - } - GELOGD("Building AICPU_TF task"); - AiCpuTask *aicpu_task = nullptr; - bool depend_compute_flag = false; - uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; - GELOGI("Build dynamic singleOp TfTask, kernel_id = %lu", dynamic_singleop_kernel_id); - GE_CHK_STATUS_RET_NOLOG(BuildKernelExTask(task_def.kernel_ex(), &aicpu_task, true, - depend_compute_flag, dynamic_singleop_kernel_id)); - if (depend_compute_flag) { - if (i >= tasks.size() - 1) { - GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Task]The copy task of the fourth operator was not found."); - REPORT_INNER_ERROR("E19999", "The copy task of the fourth operator was not found."); - return ACL_ERROR_GE_PARAM_INVALID; - } - ++i; - const TaskDef ©_task_def = tasks[i]; - GE_CHK_STATUS_RET_NOLOG(aicpu_task->SetMemCopyTask(copy_task_def.kernel_ex())); - } - aicpu_task->SetModelArgs(model_name_, model_id_); - single_op.op_task_.reset(aicpu_task); + op_index = task_def.kernel_ex().op_index(); + } else if (task_type == RT_MODEL_TASK_ALL_KERNEL) { + op_index = task_def.kernel_with_handle().context().op_index(); } else { - // skip GELOGD("Skip task type: %d", static_cast(task_type)); + continue; + } + GELOGD("op_index = %u, task_type = %d", op_index, task_type); + + auto iter = op_list_.find(op_index); + if (iter == op_list_.end()) { + GELOGE(INTERNAL_ERROR, "[Find][Node]Failed to get node by op_index = %u", op_index); + REPORT_INNER_ERROR("E19999", "Failed to get node by op_index = %u.", op_index); + return INTERNAL_ERROR; } + auto &node = iter->second; + node_tasks_[node].emplace_back(task_def); } return SUCCESS; } @@ -624,6 +690,7 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource)); model_params_.memory_size = UINT64_MAX; model_params_.graph_is_dynamic = true; + GE_CHK_STATUS_RET(ParseTasks(), "[Parse][Tasks] failed."); auto ge_model = model_helper_.GetGeModel(); GE_CHECK_NOTNULL(ge_model); @@ -646,7 +713,9 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & device_id, resource.GetStream())); GE_CHECK_NOTNULL(single_op.hybrid_model_executor_); - GE_CHK_STATUS_RET(single_op.hybrid_model_executor_->Init(), "[Init][HybridModelExecutor]Failed."); + ThreadPool *thread_pool = nullptr; + GE_CHK_STATUS_RET_NOLOG(resource.GetThreadPool(&thread_pool)); + GE_CHK_STATUS_RET(single_op.hybrid_model_executor_->Init(thread_pool), "[Init][HybridModelExecutor]Failed."); return SUCCESS; } return BuildTaskListForDynamicOp(&resource, single_op); diff --git a/ge/single_op/single_op_model.h b/ge/single_op/single_op_model.h index 529a442d..b1cd161c 100755 --- a/ge/single_op/single_op_model.h +++ b/ge/single_op/single_op_model.h @@ -23,7 +23,7 @@ #include #include -#include "common/helper/model_helper.h" +#include "framework/common/helper/model_helper.h" #include "single_op/single_op.h" #include "single_op/stream_resource.h" #include "single_op/task/op_task.h" @@ -69,17 +69,18 @@ class SingleOpModel { Status BuildTaskList(StreamResource *stream_resource, SingleOp &single_op); Status BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &dynamic_single_op); Status BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask **task); - Status BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, - bool dynamic_flag, bool& depend_compute_flag, uint64_t kernel_id); + Status BuildAtomicTask(const domi::TaskDef &task_def, AtomicAddrCleanOpTask **task); + Status BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, uint64_t kernel_id); Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task, uint64_t kernel_id); - Status BuildModelTaskKernel(StreamResource *stream_resource, const domi::TaskDef &task_def, - DynamicSingleOp &single_op); static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam ¶m); void ParseArgTable(OpTask *task, SingleOp &op); Status InitHybridModelExecutor(const StreamResource &resource, const GeModelPtr &ge_model, SingleOp &single_op); Status SetHostMemTensor(DynamicSingleOp &single_op); + Status NeedHybridModel(GeModelPtr &ge_model, bool &flag); + Status ParseTasks(); + std::map> node_tasks_; std::string model_name_; uint32_t model_id_ = 0; const void *ori_model_data_; diff --git a/ge/single_op/stream_resource.cc b/ge/single_op/stream_resource.cc index 9fe8f26a..10a8f72b 100755 --- a/ge/single_op/stream_resource.cc +++ b/ge/single_op/stream_resource.cc @@ -25,6 +25,7 @@ namespace ge { namespace { // limit available device mem size 1M const uint32_t kFuzzDeviceBufferSize = 1 * 1024 * 1024; +constexpr int kDefaultThreadNum = 4; } StreamResource::StreamResource(uintptr_t resource_id) : resource_id_(resource_id) { @@ -219,6 +220,16 @@ Status StreamResource::BuildOperator(const ModelData &model_data, SingleOp **sin return SUCCESS; } +Status StreamResource::GetThreadPool(ThreadPool **thread_pool) { + GE_CHECK_NOTNULL(thread_pool); + if (thread_pool_ == nullptr) { + thread_pool_.reset(new (std::nothrow) ThreadPool(kDefaultThreadNum)); + GE_CHECK_NOTNULL(thread_pool_); + } + *thread_pool = thread_pool_.get(); + return SUCCESS; +} + const uint8_t *StreamResource::GetMemoryBase() const { if (memory_list_.empty()) { return nullptr; diff --git a/ge/single_op/stream_resource.h b/ge/single_op/stream_resource.h index aecb38c8..f1e1bebb 100755 --- a/ge/single_op/stream_resource.h +++ b/ge/single_op/stream_resource.h @@ -23,7 +23,7 @@ #include #include -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "runtime/stream.h" #include "single_op/single_op.h" @@ -54,6 +54,8 @@ class StreamResource { return device_buffer_; } + Status GetThreadPool(ThreadPool **thread_pool); + private: uint8_t *DoMallocMemory(const std::string &purpose, size_t size, @@ -66,6 +68,7 @@ class StreamResource { std::vector weight_list_; std::unordered_map> op_map_; std::unordered_map> dynamic_op_map_; + std::unique_ptr thread_pool_; rtStream_t stream_ = nullptr; std::mutex mu_; std::mutex stream_mu_; diff --git a/ge/single_op/task/aicpu_kernel_task_builder.cc b/ge/single_op/task/aicpu_kernel_task_builder.cc index 18f13691..2f0856bf 100755 --- a/ge/single_op/task/aicpu_kernel_task_builder.cc +++ b/ge/single_op/task/aicpu_kernel_task_builder.cc @@ -17,7 +17,7 @@ #include "single_op/task/aicpu_kernel_task_builder.h" #include "framework/common/taskdown_common.h" #include "graph/load/model_manager/model_manager.h" -#include "build_task_utils.h" +#include "single_op/task/build_task_utils.h" namespace ge { AiCpuCCTaskBuilder::AiCpuCCTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def) diff --git a/ge/single_op/task/aicpu_task_builder.cc b/ge/single_op/task/aicpu_task_builder.cc index 805b1306..1b945280 100755 --- a/ge/single_op/task/aicpu_task_builder.cc +++ b/ge/single_op/task/aicpu_task_builder.cc @@ -63,7 +63,7 @@ namespace ge { return SUCCESS; } - Status AiCpuTaskBuilder::InitWorkspaceAndIO(AiCpuTask &task, const SingleOpModelParam ¶m, bool dynamic_flag) { + Status AiCpuTaskBuilder::InitWorkspaceAndIO(AiCpuTask &task, const SingleOpModelParam ¶m) { if (kernel_def_.args_size() > sizeof(STR_FWK_OP_KERNEL)) { GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Size]sizeof STR_FWK_OP_KERNEL is: %lu, but args_size is: %d", sizeof(STR_FWK_OP_KERNEL), kernel_def_.args_size()); @@ -83,9 +83,8 @@ namespace ge { return SUCCESS; } - Status AiCpuTaskBuilder::BuildTask(ge::AiCpuTask &task, const SingleOpModelParam ¶m, - bool dynamic_flag, uint64_t kernel_id) { - GE_CHK_STATUS_RET_NOLOG(InitWorkspaceAndIO(task, param, dynamic_flag)); + Status AiCpuTaskBuilder::BuildTask(ge::AiCpuTask &task, const SingleOpModelParam ¶m, uint64_t kernel_id) { + GE_CHK_STATUS_RET_NOLOG(InitWorkspaceAndIO(task, param)); STR_FWK_OP_KERNEL fwk_op_kernel = {0}; auto ret = SetFmkOpKernel(task.io_addr_, task.workspace_addr_, fwk_op_kernel); @@ -124,7 +123,6 @@ namespace ge { task.arg_size_ = sizeof(STR_FWK_OP_KERNEL); task.op_type_ = op_desc_->GetName(); task.task_info_ = kernel_def_.task_info(); - task.dynamic_flag_ = dynamic_flag; task.kernel_id_ = kernel_id; auto debug_info = BuildTaskUtils::GetTaskInfo(op_desc_); diff --git a/ge/single_op/task/aicpu_task_builder.h b/ge/single_op/task/aicpu_task_builder.h index fe9c9bc2..eca91254 100755 --- a/ge/single_op/task/aicpu_task_builder.h +++ b/ge/single_op/task/aicpu_task_builder.h @@ -29,12 +29,12 @@ namespace ge { AiCpuTaskBuilder(const OpDescPtr &op_desc, const domi::KernelExDef &kernel_def); ~AiCpuTaskBuilder() = default; - Status BuildTask(AiCpuTask &task, const SingleOpModelParam ¶m, bool dynamic_flag, uint64_t kernel_id); + Status BuildTask(AiCpuTask &task, const SingleOpModelParam ¶m, uint64_t kernel_id); private: static Status SetKernelArgs(void **args, STR_FWK_OP_KERNEL &kernel); Status SetFmkOpKernel(void *io_addr, void *ws_addr, STR_FWK_OP_KERNEL &kernel); - Status InitWorkspaceAndIO(AiCpuTask &task, const SingleOpModelParam ¶m, bool dynamic_flag); + Status InitWorkspaceAndIO(AiCpuTask &task, const SingleOpModelParam ¶m); const OpDescPtr op_desc_; const domi::KernelExDef &kernel_def_; diff --git a/ge/single_op/task/build_task_utils.cc b/ge/single_op/task/build_task_utils.cc index 9e4d55e1..b3a7ae09 100644 --- a/ge/single_op/task/build_task_utils.cc +++ b/ge/single_op/task/build_task_utils.cc @@ -70,7 +70,9 @@ std::vector BuildTaskUtils::GetKernelArgs(const OpDescPtr &op_desc, return JoinAddresses(addresses); } -std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { +std::string BuildTaskUtils::InnerGetTaskInfo(const OpDescPtr &op_desc, + const std::vector &input_addrs, + const std::vector &output_addrs) { std::stringstream ss; if (op_desc != nullptr) { auto op_type = op_desc->GetType(); @@ -87,7 +89,10 @@ std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { } ss << TypeUtils::DataTypeToSerialString(input->GetDataType()) << " "; ss << TypeUtils::FormatToSerialString(input->GetFormat()); - ss << VectorToString(input->GetShape().GetDims()); + ss << VectorToString(input->GetShape().GetDims()) << " "; + if (idx < input_addrs.size()) { + ss << input_addrs[idx]; + } if (idx < op_desc->GetInputsSize() - 1) { ss << ","; } @@ -101,7 +106,10 @@ std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { const GeShape &out_shape = output->GetShape(); const auto &dims = out_shape.GetDims(); ss << TypeUtils::FormatToSerialString(out_format); - ss << VectorToString(dims); + ss << VectorToString(dims) << " "; + if (idx < output_addrs.size()) { + ss << output_addrs[idx]; + } if (idx < op_desc->GetOutputsSize() - 1) { ss << ","; } @@ -110,4 +118,44 @@ std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { } return ss.str(); } + +std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { + vector input_addrs; + vector output_addrs; + return InnerGetTaskInfo(op_desc, input_addrs, output_addrs); +} + +std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc, + const std::vector &inputs, + const std::vector &outputs) { + vector input_addrs; + vector output_addrs; + GE_CHECK_NOTNULL_EXEC(op_desc, return ""); + if (op_desc->GetAllInputsSize() == inputs.size()) { + std::for_each(inputs.begin(), inputs.end(), [&](const DataBuffer &db) { input_addrs.push_back(db.data); }); + } + if (op_desc->GetOutputsSize() == outputs.size()) { + std::for_each(outputs.begin(), outputs.end(), [&](const DataBuffer &db) { output_addrs.push_back(db.data); }); + } + return InnerGetTaskInfo(op_desc, input_addrs, output_addrs); +} + +std::string BuildTaskUtils::GetTaskInfo(const hybrid::TaskContext &task_context) { + auto &node_item = task_context.GetNodeItem(); + auto op_desc = node_item.GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return ""); + vector input_addrs; + vector output_addrs; + if (op_desc->GetAllInputsSize() == static_cast(task_context.NumInputs())) { + for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { + input_addrs.push_back(task_context.GetInput(i)->GetData()); + } + } + if (op_desc->GetOutputsSize() == static_cast(task_context.NumOutputs())) { + for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { + output_addrs.push_back(task_context.GetOutput(i)->GetData()); + } + } + return InnerGetTaskInfo(op_desc, input_addrs, output_addrs); +} } // namespace ge diff --git a/ge/single_op/task/build_task_utils.h b/ge/single_op/task/build_task_utils.h index 7a2369e4..68894f5b 100644 --- a/ge/single_op/task/build_task_utils.h +++ b/ge/single_op/task/build_task_utils.h @@ -23,6 +23,7 @@ #include "graph/op_desc.h" #include "single_op/single_op.h" #include "single_op/single_op_model.h" +#include "hybrid/node_executor/task_context.h" namespace ge { class BuildTaskUtils { @@ -35,7 +36,14 @@ class BuildTaskUtils { bool keep_workspace = true); static std::vector JoinAddresses(const std::vector> &addresses); static std::vector GetKernelArgs(const OpDescPtr &op_desc, const SingleOpModelParam ¶m); + static std::string InnerGetTaskInfo(const OpDescPtr &op_desc, + const std::vector &input_addrs, + const std::vector &output_addrs); static std::string GetTaskInfo(const OpDescPtr &op_desc); + static std::string GetTaskInfo(const OpDescPtr &op_desc, + const std::vector &inputs, + const std::vector &outputs); + static std::string GetTaskInfo(const hybrid::TaskContext& task_context); template static std::string VectorToString(const std::vector &values) { std::stringstream ss; diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index e48677f8..83cb0529 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -27,9 +27,8 @@ #include "common/formats/formats.h" #include "common/math/math_util.h" #include "framework/common/debug/log.h" -#include "register/op_tiling.h" #include "runtime/rt.h" -#include "build_task_utils.h" +#include "single_op/task/build_task_utils.h" namespace ge { namespace { @@ -89,6 +88,7 @@ Status OpTask::OpenDump(rtStream_t stream) { void TbeOpTask::SetStubFunc(const std::string &name, const void *stub_func) { this->stub_name_ = name; this->stub_func_ = stub_func; + this->task_name_ = name; } void TbeOpTask::SetKernelArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim, @@ -221,21 +221,27 @@ Status TbeOpTask::LaunchKernel(rtStream_t stream) { return SUCCESS; } -Status TbeOpTask::UpdateRunInfo() { - // invoke OpParaCalculate - GELOGD("Start to invoke OpParaCalculate."); - optiling::OpRunInfo run_info; - run_info.block_dim = 0; - auto ret = optiling::OpParaCalculate(*node_, run_info); +Status TbeOpTask::CalcTilingInfo(optiling::utils::OpRunInfo &run_info) { + auto ret = optiling::OpParaCalculateV2(*node_, run_info); if (ret != GRAPH_SUCCESS) { GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Invoke][OpParaCalculate] failed, ret = %u.", ret); REPORT_INNER_ERROR("E19999", "invoke OpParaCalculate failed, ret = %u.", ret); return ACL_ERROR_GE_INTERNAL_ERROR; } - block_dim_ = run_info.block_dim; - tiling_data_ = run_info.tiling_data.str(); - tiling_key_ = run_info.tiling_key; - run_info_workspaces_ = run_info.workspaces; + return SUCCESS; +} + +Status TbeOpTask::UpdateRunInfo() { + // invoke OpParaCalculate + GELOGD("Start to invoke OpParaCalculate."); + optiling::utils::OpRunInfo run_info(0, true, 0); + GE_CHK_STATUS_RET(CalcTilingInfo(run_info), "[Calc][TilingInfo]failed."); + + block_dim_ = run_info.GetBlockDim(); + tiling_data_ = run_info.GetAllTilingData().str(); + tiling_key_ = run_info.GetTilingKey(); + clear_atomic_ = run_info.GetClearAtomic(); + run_info.GetAllWorkspaces(run_info_workspaces_); GELOGD("Done invoking OpParaCalculate successfully. block_dim = %u, tiling size = %zu, tiling_key = %u", block_dim_, tiling_data_.size(), tiling_key_); return SUCCESS; @@ -262,7 +268,6 @@ Status TbeOpTask::UpdateTensorDesc(const GeTensorDesc &src_tensor, GeTensorDesc dst_tensor.SetShape(GeShape(std::move(storage_shape))); dst_tensor.SetOriginShape(src_tensor.GetShape()); } - return SUCCESS; } @@ -346,49 +351,108 @@ Status TbeOpTask::AllocateWorkspaces(const vector &workspace_sizes) { return SUCCESS; } -Status TbeOpTask::LaunchKernel(const vector &input_desc, - const vector &input_buffers, - vector &output_desc, - vector &output_buffers, - rtStream_t stream) { - GELOGD("[%s] Start to launch kernel", node_->GetName().c_str()); - GE_CHK_STATUS_RET_NOLOG(UpdateNodeByShape(input_desc, output_desc)); - GE_CHK_STATUS_RET_NOLOG(UpdateRunInfo()); - GE_CHK_STATUS_RET(AllocateWorkspaces(run_info_workspaces_), "[Allocate][Workspaces] failed."); - std::vector args; - for (auto &buffer : input_buffers) { - args.emplace_back(buffer.data); +Status TbeOpTask::CheckAndExecuteAtomic(const vector &input_desc, + const vector &input_buffers, + vector &output_desc, + vector &output_buffers, + rtStream_t stream) { + if (clear_atomic_ && atomic_task_ != nullptr) { + return atomic_task_->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream); } - for (auto &buffer : output_buffers) { - args.emplace_back(buffer.data); + return SUCCESS; +} + +Status TbeOpTask::UpdateTilingArgs(rtStream_t stream) { + size_t args_size = input_num_ + output_num_ + workspaces_.size(); + if (tiling_buffer_ != nullptr) { + args_size++; + } + size_t temp_size = args_size * sizeof(void *); + if (arg_size_ < temp_size) { + GELOGD("Need to reset size of args_ from %zu to %zu.", arg_size_, temp_size); + std::unique_ptr args(new (std::nothrow) uint8_t[temp_size]()); + GE_CHECK_NOTNULL(args); + if (memcpy_s(args.get(), temp_size, args_.get(), arg_size_) != EOK) { + GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Update][KernelArgs] failed for [%s].", node_->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "update kernel args failed for %s.", node_->GetName().c_str()); + return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; + } + + args_ = std::move(args); + arg_size_ = temp_size; } - for (auto &buffer : workspaces_) { - args.emplace_back(buffer); + + uintptr_t *arg_base = reinterpret_cast(args_.get()); + size_t arg_index = input_num_ + output_num_; + for (size_t i = 0; i < workspaces_.size(); ++i) { + arg_base[arg_index++] = reinterpret_cast(workspaces_[i]); } if (tiling_buffer_ != nullptr) { GELOGD("[%s] Start to copy tiling info. size = %zu", node_->GetName().c_str(), tiling_data_.size()); GE_CHK_RT_RET(rtMemcpyAsync(tiling_buffer_, max_tiling_size_, tiling_data_.data(), tiling_data_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, stream)); + arg_base[arg_index] = reinterpret_cast(tiling_buffer_); + } + + return SUCCESS; +} + +Status TbeOpTask::SetArgIndex() { + const vector v_is_input_const = op_desc_->GetIsInputConst(); + size_t input_index = 0; + for (size_t i = 0; i < op_desc_->GetAllInputsSize(); ++i) { + const GeTensorDescPtr tensor_desc = op_desc_->MutableInputDesc(static_cast(i)); + if (tensor_desc == nullptr) { + GELOGD("SingleOp: %s, Index: %zu, has no input", op_desc_->GetName().c_str(), i); + continue; + } + if (i < v_is_input_const.size() && v_is_input_const[i]) { + GELOGD("SingleOp: %s, Index: %zu, input is const", op_desc_->GetName().c_str(), i); + input_index++; + continue; + } + arg_index_.emplace_back(input_index); + input_index++; + } + return SUCCESS; +} - args.emplace_back(tiling_buffer_); +Status TbeOpTask::UpdateIoAddr(const vector &inputs, const vector &outputs) { + if (arg_index_.size() != inputs.size()) { + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Size] Args size is %zu, but get input size is %zu.", + arg_index_.size(), inputs.size()); + REPORT_INNER_ERROR("E19999", "[Check][Size] Args size is %zu, but get input size is %zu.", + arg_index_.size(), inputs.size()); + return ACL_ERROR_GE_PARAM_INVALID; } - GELOGD("Dst size is %zu, src size is %zu.", arg_size_, args.size() * sizeof(void *)); - // node with workspace: build can not get size of workspace, need to update arg_size_ when execute - if (arg_size_ < (args.size() * sizeof(void *))) { - size_t temp_size = args.size() * sizeof(void *); - GELOGD("Need to reset size of args_ from %zu to %zu.", arg_size_, temp_size); - args_.reset(new(std::nothrow) uint8_t[temp_size]()); - GE_CHECK_NOTNULL(args_); - arg_size_ = temp_size; + uintptr_t *arg_base = reinterpret_cast(args_.get()); + for (size_t i = 0; i < arg_index_.size(); ++i) { + arg_base[arg_index_[i]] = reinterpret_cast(inputs[i].data); } - if (memcpy_s(args_.get(), arg_size_, args.data(), args.size() * sizeof(void *)) != EOK) { - GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Update][KernelArgs] failed for [%s].", node_->GetName().c_str()); - REPORT_INNER_ERROR("E19999", "update kernel args failed for %s.", node_->GetName().c_str()); - return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; + + for (size_t i = 0; i < op_desc_->GetOutputsSize(); ++i) { + arg_base[input_num_ + i] = reinterpret_cast(outputs[i].data); } + return SUCCESS; +} + +Status TbeOpTask::LaunchKernel(const vector &input_desc, + const vector &input_buffers, + vector &output_desc, + vector &output_buffers, + rtStream_t stream) { + GELOGD("[%s] Start to launch kernel", node_->GetName().c_str()); + GE_CHK_STATUS_RET(UpdateIoAddr(input_buffers, output_buffers), "[Update][IoAddr] failed."); + GE_CHK_STATUS_RET_NOLOG(UpdateNodeByShape(input_desc, output_desc)); + GE_CHK_STATUS_RET_NOLOG(UpdateRunInfo()); + GE_CHK_STATUS_RET(AllocateWorkspaces(run_info_workspaces_), "[Allocate][Workspaces] failed."); + GE_CHK_STATUS_RET(CheckAndExecuteAtomic(input_desc, input_buffers, output_desc, output_buffers, stream), + "[Execute][AtomicTask] failed."); + GE_CHK_STATUS_RET(UpdateTilingArgs(stream), "[Update][TilingArgs] failed."); + GELOGD("[%s] Start to invoke rtKernelLaunch", node_->GetName().c_str()); GE_CHK_STATUS_RET(DoLaunchKernel(stream), "Failed to do launch kernel."); @@ -417,10 +481,124 @@ void TbeOpTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) { } } +Status AtomicAddrCleanOpTask::UpdateNodeByShape(const vector &input_desc, + const vector &output_desc) { + return SUCCESS; +} + +Status AtomicAddrCleanOpTask::UpdateIoAddr(const vector &inputs, const vector &outputs) { + uintptr_t *arg_base = reinterpret_cast(args_.get()); + for (auto atomic_output_index : atomic_output_indices_) { + if (atomic_output_index >= static_cast(outputs.size())) { + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Update][Args] failed, atomic index must smaller then data size."); + REPORT_INNER_ERROR("E19999", "[Update][Args] failed, atomic index must smaller then data size."); + return ACL_ERROR_GE_PARAM_INVALID; + } + auto &output_buffer = outputs[atomic_output_index]; + *arg_base++ = reinterpret_cast(output_buffer.data); + + auto tensor_desc = op_desc_->MutableOutputDesc(atomic_output_index); + int64_t size = 0; + graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, size); + if (graph_status != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get tensor size in bytes failed!"); + GELOGE(graph_status, "[Get][TensorMemorySize] In Bytes failed!"); + return FAILED; + } + TensorUtils::SetSize(*tensor_desc, size); + } + return SUCCESS; +} + +Status AtomicAddrCleanOpTask::UpdateTilingArgs(rtStream_t stream) { + if (tiling_buffer_ != nullptr) { + GELOGD("[%s] Start to copy tiling info. size = %zu", node_->GetName().c_str(), tiling_data_.size()); + GE_CHK_RT_RET(rtMemcpyAsync(tiling_buffer_, max_tiling_size_, tiling_data_.data(), tiling_data_.size(), + RT_MEMCPY_HOST_TO_DEVICE_EX, stream)); + uintptr_t *arg_base = reinterpret_cast(args_.get()); + size_t idx = atomic_output_indices_.size(); + arg_base[idx] = reinterpret_cast(tiling_buffer_); + } + return SUCCESS; +} + +Status AtomicAddrCleanOpTask::CalcTilingInfo(optiling::utils::OpRunInfo &run_info) { + auto ret = optiling::OpAtomicCalculateV2(*node_, run_info); + if (ret != GRAPH_SUCCESS) { + GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Invoke][OpAtomicCalculate] failed, ret = %u.", ret); + REPORT_INNER_ERROR("E19999", "invoke OpAtomicCalculate failed, ret = %u.", ret); + return ACL_ERROR_GE_INTERNAL_ERROR; + } + return SUCCESS; +} + +Status AtomicAddrCleanOpTask::InitAtomicAddrCleanIndices() { + GELOGD("[%s] Start to setup AtomicAddrClean task.", op_desc_->GetName().c_str()); + std::vector atomic_output_indices; + (void) ge::AttrUtils::GetListInt(op_desc_, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); + if (atomic_output_indices.empty()) { + GELOGE(INTERNAL_ERROR, "[Check][Size][%s] atomic_output_indices must not be empty.", op_desc_->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "[%s] atomic_output_indices must not be empty.", op_desc_->GetName().c_str()); + return INTERNAL_ERROR; + } + + size_t max_arg_size = tiling_buffer_ == nullptr ? arg_size_ : arg_size_ - 1; + if (atomic_output_indices.size() > max_arg_size) { + GELOGE(INTERNAL_ERROR, "[Check][Size][%s] atomic_output_indices invalid. atomic_output_indices size is %zu," + "arg size is %zu.", op_desc_->GetName().c_str(), atomic_output_indices.size(), arg_size_); + REPORT_INNER_ERROR("E19999", "[%s] atomic_output_indices invalid. atomic_output_indices size is %zu," + "arg size is %zu.", op_desc_->GetName().c_str(), atomic_output_indices.size(), arg_size_); + return INTERNAL_ERROR; + } + + for (auto output_index : atomic_output_indices) { + GELOGD("[%s] Adding output index [%ld]", op_desc_->GetName().c_str(), output_index); + GE_CHECK_GE(output_index, 0); + GE_CHECK_LE(output_index, INT32_MAX); + atomic_output_indices_.emplace_back(static_cast(output_index)); + } + return SUCCESS; +} + AiCpuBaseTask::~AiCpuBaseTask() { if (ext_info_addr_dev_ != nullptr) { (void)rtFree(ext_info_addr_dev_); } + if (rt_event_ != nullptr) { + (void)rtEventDestroy(rt_event_); + } +} + +Status AiCpuBaseTask::UpdateEventIdForBlockingAicpuOp() { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("Device not support blocking aicpu op process"); + return SUCCESS; + } + uint32_t event_id = 0; + auto rt_ret = rtEventCreateWithFlag(&rt_event_, RT_EVENT_WITH_FLAG); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtEventCreateWithFlag failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtEventCreateWithFlag] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + rt_ret = rtGetEventID(rt_event_, &event_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetEventID failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetEventID] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + if (aicpu_ext_handle_->UpdateEventId(event_id) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Update event id=%u failed.", event_id); + GELOGE(FAILED, "[Update][EventId] Update event id failed", event_id); + return FAILED; + } + GELOGI("Update event_id=%u success", event_id); + return SUCCESS; } Status AiCpuBaseTask::SetExtInfoAndType(const std::string &kernel_ext_info, uint64_t kernel_id) { @@ -434,6 +612,9 @@ Status AiCpuBaseTask::SetExtInfoAndType(const std::string &kernel_ext_info, uint GELOGD("Get unknown_type is %d.", unknown_shape_type_val); unknown_type_ = static_cast(unknown_shape_type_val); + AttrUtils::GetBool(op_desc_, ATTR_NAME_IS_BLOCKING_OP, is_blocking_aicpu_op_); + GELOGD("Get op:%s attribute(is_blocking_op), value:%d", op_desc_->GetName().c_str(), is_blocking_aicpu_op_); + aicpu_ext_handle_.reset(new(std::nothrow) ::ge::hybrid::AicpuExtInfoHandler(op_desc_->GetName(), num_inputs_, num_outputs_, @@ -451,7 +632,13 @@ Status AiCpuBaseTask::SetExtInfoAndType(const std::string &kernel_ext_info, uint GE_CHK_STATUS_RET(aicpu_ext_handle_->UpdateSessionInfo(ULLONG_MAX, kernel_id, false), "[Update][SessionInfo] failed."); - GE_CHK_STATUS_RET(aicpu_ext_handle_->UpdateExecuteMode(true), "[Update][ExecuteMode] failed."); + + if (is_blocking_aicpu_op_) { + if (UpdateEventIdForBlockingAicpuOp() != SUCCESS) { + GELOGE(FAILED, "[Call][UpdateEventIdForBlockingAicpuOp] Call UpdateEventIdForBlockingAicpuOp failed"); + return FAILED; + } + } GE_CHK_RT_RET(rtMalloc(&ext_info_addr_dev_, aicpu_ext_handle_->GetExtInfoLen(), RT_MEMORY_HBM)); GE_CHK_RT_RET(rtMemcpy(ext_info_addr_dev_, aicpu_ext_handle_->GetExtInfoLen(), @@ -604,28 +791,91 @@ Status AiCpuBaseTask::UpdateIoAddr(const vector &inputs, const vecto GE_CHK_BOOL_RET_STATUS(non_const_index < inputs.size(), ACL_ERROR_GE_PARAM_INVALID, "[Check][Size] Input size is %zu, but get non_const_index is %zu", inputs.size(), non_const_index); auto addr = inputs[non_const_index].data; - GE_CHECK_NOTNULL(addr); - GELOGD("AICpuTask input[%zu] addr = %p", input_index, addr); + uint64_t length = inputs[non_const_index].length; + if (length != 0 && addr == nullptr) { + GELOGE(PARAM_INVALID, "[Check][Addr]AiCpuTask input[%zu] addr is nullptr, length = %lu", input_index, length); + return PARAM_INVALID; + } + GELOGD("AICpuTask input[%zu] addr = %p, length = %lu.", input_index, addr, length); *arg_base++ = reinterpret_cast(addr); non_const_index++; } for (size_t i = 0; i < outputs.size(); ++i) { auto addr = outputs[i].data; - GE_CHECK_NOTNULL(addr); - GELOGD("AICpuTask output[%zu] addr = %p", i, addr); + uint64_t length = outputs[i].length; + if (length != 0 && addr == nullptr) { + GELOGE(PARAM_INVALID, "[Check][Addr]AiCpuTask output[%zu] addr is nullptr, length = %lu", i, length); + return PARAM_INVALID; + } + GELOGD("AICpuTask output[%zu] addr = %p, length = %lu.", i, addr, length); *arg_base++ = reinterpret_cast(addr); } return SUCCESS; } +Status AiCpuBaseTask::CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support) { + int32_t device_id = 0; + auto rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDevice failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDevice] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + int32_t value = 0; + rt_ret = rtGetDeviceCapability(device_id, FEATURE_TYPE_BLOCKING_OPERATOR, RT_MODULE_TYPE_AICPU, &value); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDeviceCapability failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDeviceCapability] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + if (value != RT_AICPU_BLOCKING_OP_NOT_SUPPORT && value != RT_AICPU_BLOCKING_OP_SUPPORT) { + REPORT_INNER_ERROR("E19999", "Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + GELOGE(FAILED, "[Check][Value] Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + return FAILED; + } + is_support = (value == RT_AICPU_BLOCKING_OP_SUPPORT ? true : false); + return SUCCESS; +} + +Status AiCpuBaseTask::DistributeWaitTaskForAicpuBlockingOp(rtStream_t stream) { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("Device not support blocking aicpu op process."); + return SUCCESS; + } + GELOGI("Distribute queue task begin"); + if (rt_event_ == nullptr) { + REPORT_INNER_ERROR("E19999", "rt_event_ is nullptr"); + GELOGE(FAILED, "[Check][rt_event_] rt_event_ is nullptr"); + return FAILED; + } + auto rt_ret = rtStreamWaitEvent(stream, rt_event_); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtStreamWaitEvent failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + rt_ret = rtEventReset(rt_event_, stream); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtEventReset failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + return SUCCESS; +} + AiCpuTask::~AiCpuTask() { FreeHbm(args_); FreeHbm(io_addr_); - if (dynamic_flag_) { - FreeHbm(workspace_addr_); - } + FreeHbm(workspace_addr_); FreeHbm(copy_workspace_buf_); FreeHbm(copy_ioaddr_dev_); FreeHbm(copy_input_release_flag_dev_); @@ -665,6 +915,14 @@ Status AiCpuTask::LaunchKernel(rtStream_t stream) { GELOGI("[TASK_INFO] %lu/%s", kernel_id_, op_type_.c_str()); GELOGD("Done launch kernel successfully. task = %s", this->op_type_.c_str()); + + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp(stream) != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } + return SUCCESS; } @@ -941,6 +1199,13 @@ Status AiCpuCCTask::LaunchKernel(rtStream_t stream) { } GELOGI("[TASK_INFO] %lu/%s", kernel_id_, op_type_.c_str()); GELOGD("Invoke rtCpuKernelLaunch succeeded"); + + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp(stream) != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } return SUCCESS; } diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index ed6cf40f..adf51dba 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -23,7 +23,7 @@ #include "common/dump/dump_op.h" #include "common/dump/dump_properties.h" -#include "common/ge_inner_error_codes.h" +#include "framework/common/ge_inner_error_codes.h" #include "graph/op_kernel_bin.h" #include "runtime/stream.h" #include "graph/node.h" @@ -33,6 +33,10 @@ #include "register/op_tiling.h" namespace ge { +namespace { +const int kAddressNum = 2; +} // namespace + class StreamResource; struct SingleOpModelParam; class OpTask { @@ -44,6 +48,7 @@ class OpTask { virtual Status UpdateArgTable(const SingleOpModelParam ¶m); void SetModelArgs(std::string model_name, uint32_t model_id); Status GetProfilingArgs(TaskDescInfo &task_desc_info, uint32_t &model_id); + const std::string &GetTaskName() const {return task_name_;} void SetOpDesc(const OpDescPtr &op_desc) { op_desc_ = op_desc; } @@ -66,6 +71,7 @@ class OpTask { std::string model_name_; uint32_t model_id_ = 0; uint32_t block_dim_ = 1; + std::string task_name_; }; class TbeOpTask : public OpTask { @@ -83,8 +89,10 @@ class TbeOpTask : public OpTask { void SetKernelArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim, const OpDescPtr &op_desc); void SetKernelWithHandleArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim, const OpDescPtr &op_desc, const domi::KernelDefWithHandle& kernel_def_with_handle); + void SetAtomicAddrCleanTask(OpTask *task) { atomic_task_.reset(task); } Status UpdateRunInfo() override; + Status SetArgIndex(); const void *GetArgs() const; size_t GetArgSize() const; @@ -93,33 +101,63 @@ class TbeOpTask : public OpTask { const std::string &GetTaskType() const override; void SetHandle(void *handle); + protected: + NodePtr node_; + std::unique_ptr args_; + size_t arg_size_ = 0; + void *tiling_buffer_ = nullptr; + uint32_t max_tiling_size_ = 0; + std::string tiling_data_; + size_t input_num_; // include const input + size_t output_num_; + private: friend class SingleOpModel; friend class TbeTaskBuilder; static Status UpdateTensorDesc(const GeTensorDesc &src_tensor, GeTensorDesc &dst_tensor); - Status UpdateNodeByShape(const vector &input_desc, - const vector &output_desc); Status AllocateWorkspaces(const std::vector &workspace_sizes); Status DoLaunchKernel(rtStream_t stream); + Status CheckAndExecuteAtomic(const vector &input_desc, + const vector &input_buffers, + vector &output_desc, + vector &output_buffers, + rtStream_t stream); + virtual Status UpdateNodeByShape(const vector &input_desc, + const vector &output_desc); + virtual Status UpdateTilingArgs(rtStream_t stream); + virtual Status UpdateIoAddr(const vector &inputs, const vector &outputs); + virtual Status CalcTilingInfo(optiling::utils::OpRunInfo &run_info); const void *stub_func_ = nullptr; - std::unique_ptr args_; - size_t arg_size_ = 0; void *sm_desc_ = nullptr; std::string stub_name_; - StreamResource *stream_resource_ = nullptr; - void *tiling_buffer_ = nullptr; - uint32_t max_tiling_size_ = 0; - std::string tiling_data_; + std::vector run_info_workspaces_; std::vector workspaces_; - NodePtr node_; uint32_t tiling_key_ = 0; + bool clear_atomic_ = false; void* handle_ = nullptr; std::string original_kernel_key_; std::string node_info_; + std::vector arg_index_; // data index in args + + std::unique_ptr atomic_task_; +}; + +class AtomicAddrCleanOpTask : public TbeOpTask { + public: + Status InitAtomicAddrCleanIndices(); + + private: + Status UpdateNodeByShape(const vector &input_desc, + const vector &output_desc) override; + Status UpdateIoAddr(const vector &inputs, const vector &outputs) override; + Status UpdateTilingArgs(rtStream_t stream) override; + Status CalcTilingInfo(optiling::utils::OpRunInfo &run_info) override; + std::vector atomic_output_indices_; + }; class AiCpuBaseTask : public OpTask { @@ -140,6 +178,10 @@ class AiCpuBaseTask : public OpTask { rtStream_t stream); Status UpdateOutputShape(vector &output_desc); Status UpdateShapeToOutputDesc(const GeShape &shape_new, GeTensorDesc &output_desc); + // for blocking aicpu op + Status DistributeWaitTaskForAicpuBlockingOp(rtStream_t stream); + Status UpdateEventIdForBlockingAicpuOp(); + Status CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support); protected: size_t num_inputs_ = 0; @@ -148,6 +190,9 @@ class AiCpuBaseTask : public OpTask { std::unique_ptr aicpu_ext_handle_; void *ext_info_addr_dev_ = nullptr; vector input_is_const_; + // for blocking aicpu op + bool is_blocking_aicpu_op_ = false; + rtEvent_t rt_event_ = nullptr; }; class AiCpuTask : public AiCpuBaseTask { @@ -192,7 +237,6 @@ class AiCpuTask : public AiCpuBaseTask { // host addr std::vector io_addr_host_; - bool dynamic_flag_ = false; // for copy task void *copy_task_args_buf_ = nullptr; void *copy_workspace_buf_ = nullptr; @@ -257,7 +301,7 @@ class MemcpyAsyncTask : public OpTask { friend class SingleOpModel; friend class RtsKernelTaskBuilder; - uintptr_t addresses_[2]; + uintptr_t addresses_[kAddressNum] = {0}; size_t dst_max_; size_t count_; rtMemcpyKind_t kind_; diff --git a/ge/single_op/task/rts_kernel_task_builder.cc b/ge/single_op/task/rts_kernel_task_builder.cc index aad78fd9..07bcbd19 100644 --- a/ge/single_op/task/rts_kernel_task_builder.cc +++ b/ge/single_op/task/rts_kernel_task_builder.cc @@ -15,7 +15,7 @@ */ #include "single_op/task/rts_kernel_task_builder.h" -#include "build_task_utils.h" +#include "single_op/task/build_task_utils.h" namespace ge { namespace { diff --git a/ge/single_op/task/tbe_task_builder.cc b/ge/single_op/task/tbe_task_builder.cc index c7ff13d1..017dac25 100644 --- a/ge/single_op/task/tbe_task_builder.cc +++ b/ge/single_op/task/tbe_task_builder.cc @@ -29,15 +29,8 @@ namespace ge { namespace { constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; constexpr char const *kAttrOpParamSize = "op_para_size"; +constexpr char const *kAttrAtomicOpParamSize = "atomic_op_para_size"; std::mutex g_reg_mutex; - -inline void GetKernelName(const OpDescPtr &op_desc, std::string &kernel_name) { - (void)AttrUtils::GetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); -} - -inline TBEKernelPtr GetTbeKernel(const OpDescPtr &op_desc) { - return op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); -} } // namespace KernelHolder::KernelHolder(const char *stub_func, std::shared_ptr kernel_bin) @@ -96,7 +89,15 @@ TbeTaskBuilder::TbeTaskBuilder(const std::string &model_name, const NodePtr &nod task_def_(task_def), kernel_def_(task_def.kernel()), kernel_def_with_handle_(task_def.kernel_with_handle()), - stub_name_(model_name + "/" + node->GetName() + "_tvmbin") {} + model_name_(model_name) {} + +TBEKernelPtr TbeTaskBuilder::GetTbeKernel(const OpDescPtr &op_desc) const { + return op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); +} + +void TbeTaskBuilder::GetKernelName(const OpDescPtr &op_desc, std::string &kernel_name) const { + (void)AttrUtils::GetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); +} Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam ¶m) const { @@ -104,7 +105,7 @@ Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bi binary.version = 0; binary.data = kernel_bin.GetBinData(); binary.length = kernel_bin.GetBinDataSize(); - binary.magic = param.core_type == 0 ? RT_DEV_BINARY_MAGIC_ELF : RT_DEV_BINARY_MAGIC_ELF_AIVEC; + GE_CHK_STATUS_RET_NOLOG(GetMagic(binary.magic)); Status ret = 0; if (task_def_.type() == RT_MODEL_TASK_ALL_KERNEL) { ret = rtRegisterAllKernel(&binary, bin_handle); @@ -124,7 +125,7 @@ Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bi Status TbeTaskBuilder::DoRegisterMeta(void *bin_handle) { std::string meta_data; - (void)AttrUtils::GetStr(op_desc_, TVM_ATTR_NAME_METADATA, meta_data); + (void)AttrUtils::GetStr(op_desc_, GetKeyForTvmMetaData(), meta_data); GELOGI("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str()); if (!meta_data.empty()) { auto rt_ret = rtMetadataRegister(bin_handle, meta_data.c_str()); @@ -307,6 +308,15 @@ Status TbeTaskBuilder::GetSmDesc(void **sm_desc, const SingleOpModelParam ¶m return SUCCESS; } +Status TbeTaskBuilder::InitKernelArgs(void *arg_addr, size_t arg_size, const SingleOpModelParam ¶m) { + // copy args + std::vector tensor_device_addr_vec = BuildTaskUtils::GetKernelArgs(op_desc_, param); + void *src_addr = reinterpret_cast(tensor_device_addr_vec.data()); + uint64_t src_len = sizeof(void *) * tensor_device_addr_vec.size(); + GE_CHK_RT_RET(rtMemcpy(arg_addr, arg_size, src_addr, src_len, RT_MEMCPY_HOST_TO_HOST)); + return SUCCESS; +} + Status TbeTaskBuilder::SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m, const OpDescPtr &op_desc) { auto task_type = static_cast(task_def_.type()); bool is_task_all_kernel = (task_type == RT_MODEL_TASK_ALL_KERNEL); @@ -331,12 +341,7 @@ Status TbeTaskBuilder::SetKernelArgs(TbeOpTask &task, const SingleOpModelParam & kernel_def_with_handle_.context() : kernel_def_.context(); const auto *args_offset_tmp = reinterpret_cast(context.args_offset().data()); uint16_t offset = *args_offset_tmp; - - // copy args - std::vector tensor_device_addr_vec = BuildTaskUtils::GetKernelArgs(op_desc_, param); - void *src_addr = reinterpret_cast(tensor_device_addr_vec.data()); - uint64_t src_len = sizeof(void *) * tensor_device_addr_vec.size(); - GE_CHK_RT_RET(rtMemcpy(args.get() + offset, arg_size - offset, src_addr, src_len, RT_MEMCPY_HOST_TO_HOST)); + GE_CHK_STATUS_RET_NOLOG(InitKernelArgs(args.get() + offset, arg_size - offset, param)); if (is_task_all_kernel) { task.SetKernelWithHandleArgs(std::move(args), arg_size, kernel_def_with_handle_.block_dim(), op_desc, @@ -367,8 +372,15 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ } auto task_type = static_cast(task_def_.type()); - ret = task_type == RT_MODEL_TASK_ALL_KERNEL ? RegisterKernelWithHandle(task, param) : - RegisterKernel(task, param); + if (task_type == RT_MODEL_TASK_ALL_KERNEL) { + stub_name_ = model_name_ + "/" + node_->GetName() + "_tvmbin"; + ret = RegisterKernelWithHandle(task, param); + } else { + const domi::KernelDef &kernel_def = task_def_.kernel(); + stub_name_ = model_name_ + "/" + kernel_def.stub_func() + "_tvmbin"; + ret = RegisterKernel(task, param); + } + task.SetHandle(handle_); if (ret != SUCCESS) { return ret; @@ -387,6 +399,9 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ } task.SetStubFunc(stub_name_, stub_func); } + GE_CHK_STATUS_RET(task.SetArgIndex(), "[Set][ArgTable] failed."); + task.input_num_ = op_desc_->GetInputsSize(); + task.output_num_ = op_desc_->GetOutputsSize(); return SUCCESS; } @@ -394,8 +409,8 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ Status TbeTaskBuilder::InitTilingInfo(TbeOpTask &task) { GELOGD("Start alloc tiling data of node %s.", op_desc_->GetName().c_str()); int64_t max_size = -1; - (void)AttrUtils::GetInt(op_desc_, kAttrOpParamSize, max_size); - GELOGD("Got op param size by key: %s, ret = %ld", kAttrOpParamSize, max_size); + (void)AttrUtils::GetInt(op_desc_, GetKeyForOpParamSize(), max_size); + GELOGD("Got op param size by key: %s, ret = %ld", GetKeyForOpParamSize().c_str(), max_size); if (max_size < 0) { GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Get][Int] %s Invalid op_param_size: %ld.", op_desc_->GetName().c_str(), max_size); @@ -413,4 +428,55 @@ Status TbeTaskBuilder::InitTilingInfo(TbeOpTask &task) { task.EnableDynamicSupport(node_, tiling_buffer, static_cast(max_size)); return SUCCESS; } + +Status TbeTaskBuilder::GetMagic(uint32_t &magic) const { + std::string json_string; + GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc_, TVM_ATTR_NAME_MAGIC, json_string), + GELOGD("Get original type of session_graph_id.")); + if (json_string == "RT_DEV_BINARY_MAGIC_ELF") { + magic = RT_DEV_BINARY_MAGIC_ELF; + } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") { + magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; + } else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") { + magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE; + } else { + REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%s check invalid", + TVM_ATTR_NAME_MAGIC.c_str(), op_desc_->GetName().c_str(), + op_desc_->GetType().c_str(), json_string.c_str()); + GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in op:%s(%s), value:%s check invalid", + TVM_ATTR_NAME_MAGIC.c_str(), op_desc_->GetName().c_str(), + op_desc_->GetType().c_str(), json_string.c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + +std::string TbeTaskBuilder::GetKeyForOpParamSize() const { + return kAttrOpParamSize; +} + +std::string TbeTaskBuilder::GetKeyForTvmMetaData() const { + return TVM_ATTR_NAME_METADATA; +} + +Status AtomicAddrCleanTaskBuilder::InitKernelArgs(void *args_addr, size_t arg_size, const SingleOpModelParam ¶m) { + return SUCCESS; +} + +std::string AtomicAddrCleanTaskBuilder::GetKeyForOpParamSize() const { + return kAttrAtomicOpParamSize; +} + +std::string AtomicAddrCleanTaskBuilder::GetKeyForTvmMetaData() const { + return ATOMIC_ATTR_TVM_METADATA; +} + +void AtomicAddrCleanTaskBuilder::GetKernelName(const OpDescPtr &op_desc, std::string &kernel_name) const { + (void)AttrUtils::GetStr(op_desc, op_desc->GetName() + "_atomic_kernelname", kernel_name); +} + +TBEKernelPtr AtomicAddrCleanTaskBuilder::GetTbeKernel(const OpDescPtr &op_desc) const { + return op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_TBE_KERNEL, TBEKernelPtr()); +} + } // namespace ge diff --git a/ge/single_op/task/tbe_task_builder.h b/ge/single_op/task/tbe_task_builder.h index a202cbf1..06d17901 100755 --- a/ge/single_op/task/tbe_task_builder.h +++ b/ge/single_op/task/tbe_task_builder.h @@ -90,10 +90,17 @@ class HandleRegistry { class TbeTaskBuilder { public: TbeTaskBuilder(const std::string &model_name, const NodePtr &node, const domi::TaskDef &task_def); - ~TbeTaskBuilder() = default; + virtual ~TbeTaskBuilder() = default; Status BuildTask(TbeOpTask &task, const SingleOpModelParam ¶m); + protected: + virtual std::string GetKeyForOpParamSize() const; + virtual std::string GetKeyForTvmMetaData() const; + virtual TBEKernelPtr GetTbeKernel(const OpDescPtr &op_desc) const; + virtual void GetKernelName(const OpDescPtr &op_desc, std::string &kernel_name) const; + virtual Status InitKernelArgs(void *args_addr, size_t arg_size, const SingleOpModelParam ¶m); + private: Status InitTilingInfo(TbeOpTask &task); Status SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m, const OpDescPtr &op_desc); @@ -105,6 +112,7 @@ class TbeTaskBuilder { const SingleOpModelParam ¶m); Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam ¶m) const; Status DoRegisterMeta(void *bin_handle); + Status GetMagic(uint32_t &magic) const; static Status DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name); @@ -113,9 +121,24 @@ class TbeTaskBuilder { const domi::TaskDef &task_def_; const domi::KernelDef &kernel_def_; const domi::KernelDefWithHandle &kernel_def_with_handle_; - const std::string stub_name_; + const std::string model_name_; + std::string stub_name_; void *handle_ = nullptr; }; + +class AtomicAddrCleanTaskBuilder : public TbeTaskBuilder { + public: + AtomicAddrCleanTaskBuilder(const std::string &model_name, const NodePtr &node, const domi::TaskDef &task_def) + : TbeTaskBuilder(model_name, node, task_def) {} + ~AtomicAddrCleanTaskBuilder() override = default; + + protected: + std::string GetKeyForOpParamSize() const override; + std::string GetKeyForTvmMetaData() const override; + TBEKernelPtr GetTbeKernel(const OpDescPtr &op_desc) const override; + void GetKernelName(const OpDescPtr &op_desc, std::string &kernel_name) const override; + Status InitKernelArgs(void *args_addr, size_t arg_size, const SingleOpModelParam ¶m) override; +}; } // namespace ge #endif // GE_SINGLE_OP_TASK_TBE_TASK_BUILDER_H_ diff --git a/inc/external/OWNERS b/inc/external/OWNERS new file mode 100644 index 00000000..934272a6 --- /dev/null +++ b/inc/external/OWNERS @@ -0,0 +1,10 @@ +approvers: +- gegenhua +reviewers: +- wqtshg +- ji_chen +- xchu42 +- sheng-nan +- wangxiaotian22 +- zhangxiaokun9 +- tangqunzhang diff --git a/inc/external/acl/acl.h b/inc/external/acl/acl.h index 8d261201..a5194472 100644 --- a/inc/external/acl/acl.h +++ b/inc/external/acl/acl.h @@ -25,9 +25,9 @@ extern "C" { #endif -// Current version is 1.0.0 +// Current version is 1.1.0 #define ACL_MAJOR_VERSION 1 -#define ACL_MINOR_VERSION 0 +#define ACL_MINOR_VERSION 1 #define ACL_PATCH_VERSION 0 /** diff --git a/inc/external/acl/acl_base.h b/inc/external/acl/acl_base.h index 64d4bd81..90da8b8f 100644 --- a/inc/external/acl/acl_base.h +++ b/inc/external/acl/acl_base.h @@ -150,6 +150,8 @@ typedef enum { ACL_DOUBLE = 11, ACL_BOOL = 12, ACL_STRING = 13, + ACL_COMPLEX64 = 16, + ACL_COMPLEX128 = 17 } aclDataType; typedef enum { diff --git a/inc/external/acl/acl_mdl.h b/inc/external/acl/acl_mdl.h index 2bf85e29..522dbd38 100644 --- a/inc/external/acl/acl_mdl.h +++ b/inc/external/acl/acl_mdl.h @@ -297,9 +297,21 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetDatasetTensorDesc(aclmdlDataset *dataset, /** * @ingroup AscendCL + * @brief Get aclTensorDesc from aclmdlDataset + * + * @param dataset [IN] aclmdlDataset pointer; + * @param index [IN] index of tensorDesc + * + * @retval Get address of aclTensorDesc when executed successfully. + * @retval Failure return NULL + */ +ACL_FUNC_VISIBILITY aclTensorDesc *aclmdlGetDatasetTensorDesc(const aclmdlDataset *dataset, size_t index); + +/** + * @ingroup AscendCL * @brief Get the number of aclDataBuffer in aclmdlDataset * - * @param dataset [IN] aclmdlDataset poiter + * @param dataset [IN] aclmdlDataset pointer * * @retval the number of aclDataBuffer */ @@ -309,7 +321,7 @@ ACL_FUNC_VISIBILITY size_t aclmdlGetDatasetNumBuffers(const aclmdlDataset *datas * @ingroup AscendCL * @brief Get the aclDataBuffer in aclmdlDataset by index * - * @param dataset [IN] aclmdlDataset poiter + * @param dataset [IN] aclmdlDataset pointer * @param index [IN] the index of aclDataBuffer * * @retval Get successfully, return the address of aclDataBuffer diff --git a/inc/external/acl/acl_op.h b/inc/external/acl/acl_op.h index d2e59bfb..f340b6bc 100644 --- a/inc/external/acl/acl_op.h +++ b/inc/external/acl/acl_op.h @@ -137,6 +137,34 @@ ACL_FUNC_VISIBILITY aclError aclopSetAttrString(aclopAttr *attr, const char *att /** * @ingroup AscendCL + * @brief set an attribute. the type of the attribute is aclDataType + * + * @param attr [OUT] pointer to the instance of aclopAttr + * @param attrName [IN] attribute name + * @param attrValue [IN] attribute value + * + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclopSetAttrDataType(aclopAttr *attr, const char *attrName, aclDataType attrValue); + +/** + * @ingroup AscendCL + * @brief set an attribute. the type of the attribute is list of aclDataType + * + * @param attr [OUT] pointer to the instance of aclopAttr + * @param attrName [IN] attribute name + * @param numValues [IN] number of values. false if attrValue is 0, true otherwise. + * @param values [IN] pointer to values + * + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclopSetAttrListDataType(aclopAttr *attr, const char *attrName, int numValues, + const aclDataType values[]); + +/** + * @ingroup AscendCL * @brief set an attribute. the type of the attribute is list of bools * * @param attr [OUT] pointer to the instance of aclopAttr diff --git a/inc/external/acl/acl_op_compiler.h b/inc/external/acl/acl_op_compiler.h index d9d1b3da..b64b2bad 100644 --- a/inc/external/acl/acl_op_compiler.h +++ b/inc/external/acl/acl_op_compiler.h @@ -86,9 +86,9 @@ ACL_FUNC_VISIBILITY aclError aclopCompile(const char *opType, int numInputs, con * @retval OtherValues Failure */ ACL_FUNC_VISIBILITY aclError aclopCompileAndExecute( - const char *opType, int numInputs, const aclTensorDesc *const inputDesc[], const aclDataBuffer *const inputs[], - int numOutputs, const aclTensorDesc *const outputDesc[], aclDataBuffer *const outputs[], const aclopAttr *attr, - aclopEngineType engineType, aclopCompileType compileFlag, const char *opPath, aclrtStream stream); + const char *opType, int numInputs, const aclTensorDesc *const inputDesc[], const aclDataBuffer *const inputs[], + int numOutputs, const aclTensorDesc *const outputDesc[], aclDataBuffer *const outputs[], const aclopAttr *attr, + aclopEngineType engineType, aclopCompileType compileFlag, const char *opPath, aclrtStream stream); /** * @ingroup AscendCL diff --git a/inc/external/acl/acl_prof.h b/inc/external/acl/acl_prof.h index 3784d8c6..a93374b0 100644 --- a/inc/external/acl/acl_prof.h +++ b/inc/external/acl/acl_prof.h @@ -40,13 +40,20 @@ typedef enum { ACL_AICORE_MEMORY_BANDWIDTH = 2, ACL_AICORE_L0B_AND_WIDTH = 3, ACL_AICORE_RESOURCE_CONFLICT_RATIO = 4, + ACL_AICORE_MEMORY_UB = 5, ACL_AICORE_NONE = 0xFF } aclprofAicoreMetrics; +typedef enum { + ACL_STEP_START = 0, // step start + ACL_STEP_END = 1 // step end +} aclprofStepTag; + typedef struct aclprofConfig aclprofConfig; typedef struct aclprofStopConfig aclprofStopConfig; typedef struct aclprofAicoreEvents aclprofAicoreEvents; typedef struct aclprofSubscribeConfig aclprofSubscribeConfig; +typedef struct aclprofStepInfo aclprofStepInfo; /** * @ingroup AscendCL @@ -322,6 +329,36 @@ ACL_FUNC_VISIBILITY uint64_t aclprofGetOpDuration(const void *opInfo, size_t opI */ ACL_FUNC_VISIBILITY size_t aclprofGetModelId(const void *opInfo, size_t opInfoLen, uint32_t index); +/** + * @ingroup AscendCL + * @brief + * + * @param stepInfo [IN] pointer to stepInfo data + * @param aclprofstepTag [IN] start or end flag + * @param stream [IN] steam info + * + * @retval 0 for failed + */ +ACL_FUNC_VISIBILITY aclError aclprofGetStepTimestamp(aclprofStepInfo *stepInfo, aclprofStepTag tag, aclrtStream stream); + +/** + * @ingroup AscendCL + * @brief create pointer to aclprofStepInfo data + * + * + * @retval aclprofStepInfo pointer + */ +ACL_FUNC_VISIBILITY aclprofStepInfo *aclprofCreateStepInfo(); + +/** + * @ingroup AscendCL + * @brief destroy aclprofStepInfo pointer + * + * + * @retval void + */ +ACL_FUNC_VISIBILITY void aclprofDestroyStepInfo(aclprofStepInfo *stepinfo); + #ifdef __cplusplus } #endif diff --git a/inc/external/acl/acl_rt.h b/inc/external/acl/acl_rt.h index 5ee70724..50dbc34d 100644 --- a/inc/external/acl/acl_rt.h +++ b/inc/external/acl/acl_rt.h @@ -44,6 +44,12 @@ typedef enum aclrtEventStatus { ACL_EVENT_STATUS_RESERVED = 2, } aclrtEventStatus; +typedef enum aclrtEventWaitStatus { + ACL_EVENT_WAIT_STATUS_COMPLETE = 0, + ACL_EVENT_WAIT_STATUS_NOT_READY = 1, + ACL_EVENT_WAIT_STATUS_RESERVED = 0xffff, +} aclrtEventWaitStatus; + typedef enum aclrtCallbackBlockType { ACL_CALLBACK_NO_BLOCK, ACL_CALLBACK_BLOCK, @@ -501,6 +507,18 @@ ACL_FUNC_VISIBILITY aclError aclrtQueryEvent(aclrtEvent event, aclrtEventStatus /** * @ingroup AscendCL + * @brief Queries an event's wait-status + * + * @param event [IN] event to query + * @param status [OUT] event wait-status + * + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclrtQueryEventWaitStatus(aclrtEvent event, aclrtEventWaitStatus *status); + +/** + * @ingroup AscendCL * @brief Block Host Running, wait event to be complete * * @param event [IN] event to wait diff --git a/inc/external/acl/ops/acl_dvpp.h b/inc/external/acl/ops/acl_dvpp.h index dcaa3936..5418ebd3 100644 --- a/inc/external/acl/ops/acl_dvpp.h +++ b/inc/external/acl/ops/acl_dvpp.h @@ -158,6 +158,20 @@ enum acldvppJpegFormat { ACL_JPEG_CSS_UNKNOWN = 1000 }; +enum acldvppChannelDescParamType { ACL_DVPP_CSC_MATRIX_UINT32 = 0 }; + +enum aclvdecChannelDescParamType { ACL_VDEC_CSC_MATRIX_UINT32 = 0 }; + +// Csc Matrix can be used both for acldvppChannelDescParamType and aclvdecChannelDescParamType +enum acldvppCscMatrix { + ACL_DVPP_CSC_MATRIX_BT601_WIDE = 0, + ACL_DVPP_CSC_MATRIX_BT601_NARROW, + ACL_DVPP_CSC_MATRIX_BT709_WIDE, + ACL_DVPP_CSC_MATRIX_BT709_NARROW, + ACL_DVPP_CSC_MATRIX_BT2020_WIDE, + ACL_DVPP_CSC_MATRIX_BT2020_NARROW +}; + /** * @ingroup AscendCL * @brief alloc device memory for dvpp. @@ -1910,9 +1924,9 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropAndPasteAsync(acldvppChannelDesc * @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | acldvppCreateResizeConfig */ ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropResizePasteAsync( - acldvppChannelDesc *channelDesc, acldvppBatchPicDesc *srcBatchPicDescs, uint32_t *roiNums, uint32_t size, - acldvppBatchPicDesc *dstBatchPicDescs, acldvppRoiConfig *cropAreas[], acldvppRoiConfig *pasteAreas[], - acldvppResizeConfig *resizeConfig, aclrtStream stream); + acldvppChannelDesc *channelDesc, acldvppBatchPicDesc *srcBatchPicDescs, uint32_t *roiNums, uint32_t size, + acldvppBatchPicDesc *dstBatchPicDescs, acldvppRoiConfig *cropAreas[], acldvppRoiConfig *pasteAreas[], + acldvppResizeConfig *resizeConfig, aclrtStream stream); /** * @ingroup AscendCL @@ -2557,10 +2571,93 @@ ACL_FUNC_VISIBILITY aclError acldvppClearHist(acldvppHist *hist); * @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | acldvppCreateResizeConfig */ ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropResizeMakeBorderAsync( - acldvppChannelDesc *channelDesc, acldvppBatchPicDesc *srcBatchPicDescs, uint32_t *roiNums, uint32_t size, - acldvppBatchPicDesc *dstBatchPicDescs, acldvppRoiConfig *cropAreas[], acldvppBorderConfig *borderCfgs[], - acldvppResizeConfig *resizeConfig, aclrtStream stream); + acldvppChannelDesc *channelDesc, acldvppBatchPicDesc *srcBatchPicDescs, uint32_t *roiNums, uint32_t size, + acldvppBatchPicDesc *dstBatchPicDescs, acldvppRoiConfig *cropAreas[], acldvppBorderConfig *borderCfgs[], + acldvppResizeConfig *resizeConfig, aclrtStream stream); +/** + * @ingroup AscendCL + * @brief set param for dvpp channel desc + * + * @par Function + * set attribution in dvpp channelDesc for specified type + * + * @param channelDesc [OUT] the channel destruction + * @param paramType [IN] specified param type + * @param length [IN] mem length of param + * @param param [IN] pointer to param + * + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + * + * @see acldvppGetChannelDescParam | acldvppCreateChannelDesc | acldvppDestroyChannelDesc + */ +ACL_FUNC_VISIBILITY aclError acldvppSetChannelDescParam(acldvppChannelDesc *channelDesc, + acldvppChannelDescParamType paramType, size_t length, + const void *param); + +/** + * @ingroup AscendCL + * @brief get param of dvpp channel desc + * + * @par Function + * get attribution value in dvpp channelDesc for specified type + * + * @param channelDesc [IN] the channel destruction + * @param paramType [IN] specified param type + * @param length [IN] mem length allocated for output param + * @param paramRetSize [OUT] mem length of output param + * @param param [OUT] pointer to output param + * + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + * + * @see acldvppSetChannelDescParam | acldvppCreateChannelDesc | acldvppDestroyChannelDesc + */ +ACL_FUNC_VISIBILITY aclError acldvppGetChannelDescParam(const acldvppChannelDesc *channelDesc, + acldvppChannelDescParamType paramType, size_t length, + size_t *paramRetSize, void *param); +/** + * @ingroup AscendCL + * @brief set param for vdec channel desc + * + * @par Function + * set attribution in channelDesc for specified type + * + * @param channelDesc [OUT] the vdec channel destruction + * @param paramType [IN] specified param type + * @param length [IN] mem length of param + * @param param [IN] pointer to param + * + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + * + * @see aclvdecGetChannelDescParam | aclvdecCreateChannelDesc | aclvdecDestroyChannelDesc + */ +ACL_FUNC_VISIBILITY aclError aclvdecSetChannelDescParam(aclvdecChannelDesc *channelDesc, + aclvdecChannelDescParamType paramType, size_t length, + const void *param); +/** + * @ingroup AscendCL + * @brief get param of vdec channel desc + * + * @par Function + * get attribution value in channelDesc for specified type + * + * @param channelDesc [IN] the vdec channel destruction + * @param paramType [IN] specified param type + * @param length [IN] mem length allocated for output param + * @param paramRetSize [OUT] mem length of output param + * @param param [OUT] pointer to output param + * + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + * + * @see aclvdecSetChannelDescParam | aclvdecCreateChannelDesc | aclvdecDestroyChannelDesc + */ +ACL_FUNC_VISIBILITY aclError aclvdecGetChannelDescParam(const aclvdecChannelDesc *channelDesc, + aclvdecChannelDescParamType paramType, size_t length, + size_t *paramRetSize, void *param); #ifdef __cplusplus } #endif diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index fbd6c020..6f5bbfbf 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -113,6 +113,7 @@ const char *const INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; const char *const OP_DEBUG_LEVEL = "ge.opDebugLevel"; const char *const PERFORMANCE_MODE = "ge.performance_mode"; const char *const MODIFY_MIXLIST = "ge.exec.modify_mixlist"; +const char *const OP_PRECISION_MODE = "ge.exec.op_precision_mode"; } // namespace configure_option // Configure stream num by Session constructor options param, // its value should be int32_t type, default value is "1" @@ -326,6 +327,8 @@ const std::string PERFORMANCE_MODE = "ge.performance_mode"; const std::string MODIFY_MIXLIST = "ge.exec.modify_mixlist"; +const std::string OP_PRECISION_MODE = "ge.exec.op_precision_mode"; + // Graph run mode enum GraphRunMode { PREDICTION = 0, TRAIN }; @@ -405,6 +408,7 @@ static const char *const OP_BANK_UPDATE = ge::OP_BANK_UPDATE_FLAG.c_str(); static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); static const char *const PERFORMANCE_MODE = ge::PERFORMANCE_MODE.c_str(); static const char *const MODIFY_MIXLIST = ge::MODIFY_MIXLIST.c_str(); +static const char *const OP_PRECISION_MODE = ge::OP_PRECISION_MODE.c_str(); // for interface: aclgrphBuildModel #ifdef __GNUC__ @@ -416,6 +420,7 @@ const std::set ir_builder_suppported_options = {INPUT_FORMAT, DYNAMIC_IMAGE_SIZE, DYNAMIC_DIMS, INSERT_OP_FILE, + OP_PRECISION_MODE, PRECISION_MODE, TUNE_DEVICE_IDS, EXEC_DISABLE_REUSED_MEMORY, diff --git a/inc/external/ge/ge_ir_build.h b/inc/external/ge/ge_ir_build.h index 04e059a1..729685a9 100644 --- a/inc/external/ge/ge_ir_build.h +++ b/inc/external/ge/ge_ir_build.h @@ -1,18 +1,18 @@ /** -* Copyright 2020 Huawei Technologies Co., Ltd - -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at - -* http://www.apache.org/licenses/LICENSE-2.0 - -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef INC_EXTERNAL_GE_IR_BUILD_H_ #define INC_EXTERNAL_GE_IR_BUILD_H_ diff --git a/inc/external/hccl/hccl.h b/inc/external/hccl/hccl.h index 8261adc4..c24b5374 100644 --- a/inc/external/hccl/hccl.h +++ b/inc/external/hccl/hccl.h @@ -145,6 +145,33 @@ extern HcclResult HcclGetRankId(HcclComm comm, uint32_t *rank); extern HcclResult HcclBarrier(HcclComm comm, aclrtStream stream); /** + * @brief AllGather operator. + * + * @param sendBuff A pointer identifying the input data address of the operator. + * @param count An integer(u64) identifying the number of the send data. + * @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32. + * @param destRank An integer identifying the destination rank. + * @param comm A pointer identifying the communication resource based on. + * @param stream A pointer identifying the stream information. + * @return HcclResult + */ +extern HcclResult HcclSend(void *sendBuf, uint64_t count, HcclDataType dataType, uint32_t destRank, HcclComm comm, + aclrtStream stream); +/** + * @brief AllGather operator. + * + * @param recvBuff A pointer identifying the output data address of the operator. + * @param count An integer(u64) identifying the number of the receive data. + * @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32. + * @param srcRank An integer identifying the source rank. + * @param comm A pointer identifying the communication resource based on. + * @param stream A pointer identifying the stream information. + * @return HcclResult + */ +extern HcclResult HcclRecv(void *recvBuf, uint64_t count, HcclDataType dataType, uint32_t srcRank, HcclComm comm, + aclrtStream stream); + +/** * @brief Destroy HCCL comm * * @param comm A pointer identifying the communication resource targetting diff --git a/inc/framework/common/debug/ge_log.h b/inc/framework/common/debug/ge_log.h index 754712f3..3e646440 100644 --- a/inc/framework/common/debug/ge_log.h +++ b/inc/framework/common/debug/ge_log.h @@ -84,9 +84,10 @@ inline bool IsLogEnable(int module_name, int log_level) { ##__VA_ARGS__); \ } while (0) -#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ - dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GeLog::GetTid(), __FUNCTION__, ERROR_CODE, \ - ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) +#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ + dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) %s" fmt, GeLog::GetTid(), __FUNCTION__, ERROR_CODE, \ + ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ErrorManager::GetInstance().GetLogHeader().c_str(), \ + ##__VA_ARGS__) // print memory when it is greater than 1KB. #define GE_PRINT_DYNAMIC_MEMORY(FUNC, PURPOSE, SIZE) \ diff --git a/inc/framework/common/helper/model_helper.h b/inc/framework/common/helper/model_helper.h index e25d5d6f..2a63291c 100644 --- a/inc/framework/common/helper/model_helper.h +++ b/inc/framework/common/helper/model_helper.h @@ -22,10 +22,10 @@ #include "common/fmk_types.h" #include "common/helper/om_file_helper.h" +#include "common/model/ge_model.h" +#include "common/model/ge_root_model.h" #include "common/types.h" #include "graph/model.h" -#include "model/ge_model.h" -#include "model/ge_root_model.h" namespace ge { class GE_FUNC_VISIBILITY ModelHelper { @@ -42,13 +42,21 @@ class GE_FUNC_VISIBILITY ModelHelper { Status LoadRootModel(const ge::ModelData &model_data); Status GetModelBufferData(ge::ModelBufferData &model); - const ModelFileHeader *GetFileHeader() const { return file_header_; } + const ModelFileHeader *GetFileHeader() const { + return file_header_; + } GeModelPtr GetGeModel(); GeRootModelPtr GetGeRootModel(); - void SetSaveMode(bool val) { is_offline_ = val; } - bool GetSaveMode(void) const { return is_offline_; } - bool GetModelType() const { return is_unknown_shape_model_; }; + void SetSaveMode(bool val) { + is_offline_ = val; + } + bool GetSaveMode(void) const { + return is_offline_; + } + bool GetModelType() const { + return is_unknown_shape_model_; + }; Status GetBaseNameFromFileName(const std::string &file_name, std::string &base_name); Status GetModelNameFromMergedGraphName(const std::string &graph_name, std::string &model_name); diff --git a/inc/framework/common/profiling/ge_profiling.h b/inc/framework/common/profiling/ge_profiling.h index a8de56a8..c87c082c 100644 --- a/inc/framework/common/profiling/ge_profiling.h +++ b/inc/framework/common/profiling/ge_profiling.h @@ -43,6 +43,13 @@ GE_FUNC_VISIBILITY ge::Status RegProfCtrlCallback(MsprofCtrlCallback func); GE_FUNC_VISIBILITY ge::Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func); GE_FUNC_VISIBILITY ge::Status RegProfReporterCallback(MsprofReporterCallback func); GE_FUNC_VISIBILITY ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len); + +/// +/// @brief Output the profiling data of single operator in Pytorch, and does not support multithreading +/// @return Status result +/// GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream); +GE_FUNC_VISIBILITY ge::Status ProfGetDeviceFormGraphId(uint32_t graph_id, uint32_t &device_id); + #endif // INC_FRAMEWORK_COMMON_GE_PROFILING_H_ diff --git a/inc/framework/common/profiling/ge_runner_profiling.h b/inc/framework/common/profiling/ge_runner_profiling.h index 011797a3..27e19bce 100644 --- a/inc/framework/common/profiling/ge_runner_profiling.h +++ b/inc/framework/common/profiling/ge_runner_profiling.h @@ -17,7 +17,7 @@ #ifndef INC_FRAMEWORK_COMMON_GE_RUNNER_PROFILING_H_ #define INC_FRAMEWORK_COMMON_GE_RUNNER_PROFILING_H_ -#include "profiling/ge_profiling.h" +#include "framework/common/profiling/ge_profiling.h" GE_FUNC_VISIBILITY bool IsInitialize(); diff --git a/inc/framework/executor/ge_executor.h b/inc/framework/executor/ge_executor.h index fcca561c..ce7c82ac 100644 --- a/inc/framework/executor/ge_executor.h +++ b/inc/framework/executor/ge_executor.h @@ -50,14 +50,30 @@ class GE_FUNC_VISIBILITY GeExecutor { public: GeExecutor(); ~GeExecutor() = default; - ge::Status Initialize(); - ge::Status Finalize(); - ge::Status UnloadModel(uint32_t modelId); + Status Initialize(); + Status Finalize(); + + /// + /// @ingroup ge + /// @brief Initialize global execute environment. + /// @param [in] options: environment variables. + /// @return init result + /// + static Status Initialize(const std::map &options); + + /// + /// @ingroup ge + /// @brief Finalize global execute environment. + /// @return execute result + /// + static Status FinalizeEx(); + + Status UnloadModel(uint32_t modelId); // Get input and output descriptor - ge::Status GetModelDescInfo(uint32_t model_id, std::vector &input_desc, - std::vector &output_desc, bool new_model_desc = false); + Status GetModelDescInfo(uint32_t model_id, std::vector &input_desc, std::vector &output_desc, + bool new_model_desc = false); /// /// @ingroup ge @@ -68,7 +84,7 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [in] batch_size: batch size entered by user in dynamic multi-batch scenario /// @return execute result /// - ge::Status SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t batch_size); + Status SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t batch_size); /// /// @ingroup ge @@ -80,8 +96,8 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [in] image_width: image width entered by user in dynamic multi-resolution scenario /// @return execute result /// - ge::Status SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height, - uint64_t image_width); + Status SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height, + uint64_t image_width); /// /// @ingroup ge @@ -93,8 +109,8 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [in] dynamic_dims: array of dynamic dimensions /// @return execute result /// - ge::Status SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, uint64_t length, - const std::vector &dynamic_dims); + Status SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, uint64_t length, + const std::vector &dynamic_dims); /// /// @ingroup ge @@ -104,8 +120,8 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [out] cur_dynamic_dims: current dynamic dims /// @return execute result /// - ge::Status GetCurDynamicDims(uint32_t model_id, const std::vector &dynamic_dims, - std::vector &cur_dynamic_dims); + Status GetCurDynamicDims(uint32_t model_id, const std::vector &dynamic_dims, + std::vector &cur_dynamic_dims); /// /// @ingroup ge @@ -115,8 +131,7 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [out] dynamic_type /// @return execute result /// - ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector> &batch_info, - int32_t &dynamic_type); + Status GetDynamicBatchInfo(uint32_t model_id, std::vector> &batch_info, int32_t &dynamic_type); /// /// @ingroup ge @@ -125,7 +140,7 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [out] batch_info /// @return execute result /// - ge::Status GetCombinedDynamicDims(uint32_t model_id, std::vector> &batch_info); + Status GetCombinedDynamicDims(uint32_t model_id, std::vector> &batch_info); /// /// @ingroup ge @@ -134,9 +149,9 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [out] user_designate_shape_order /// @return execute result /// - ge::Status GetUserDesignateShapeOrder(uint32_t model_id, std::vector &user_designate_shape_order); + Status GetUserDesignateShapeOrder(uint32_t model_id, std::vector &user_designate_shape_order); - ge::Status GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type); + Status GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type); /// /// @ingroup ge @@ -148,22 +163,22 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [in] aippParms: kAippDynamicPara by user in dynamic aipp /// @return execute result /// - ge::Status SetDynamicAippData(uint32_t model_id, void *dynamic_input_addr, uint64_t length, - const std::vector &aippBatchPara, - const kAippDynamicPara &aippParms); + Status SetDynamicAippData(uint32_t model_id, void *dynamic_input_addr, uint64_t length, + const std::vector &aipp_batch_para, + const kAippDynamicPara &aippParms); - ge::Status GetAIPPInfo(uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info); + Status GetAIPPInfo(uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info); - ge::Status GetOpAttr(uint32_t model_id, const std::string &op_name, const std::string &attr_name, - std::string &attr_value); + Status GetOpAttr(uint32_t model_id, const std::string &op_name, const std::string &attr_name, + std::string &attr_value); - ge::Status GetModelAttr(uint32_t model_id, std::vector &dynamic_output_shape_info); + Status GetModelAttr(uint32_t model_id, std::vector &dynamic_output_shape_info); - ge::Status GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index); + Status GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index); - ge::Status CommandHandle(const ge::Command &command); + Status CommandHandle(const Command &command); - ge::Status SetDump(const DumpConfig &dump_config); + Status SetDump(const DumpConfig &dump_config); /// /// @ingroup ge @@ -173,7 +188,7 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @return SUCCESS /// @return FAILED /// - ge::Status GetMaxUsedMemory(uint32_t model_id, uint32_t &max_size); + Status GetMaxUsedMemory(uint32_t model_id, uint32_t &max_size); /// /// @ingroup ge @@ -182,7 +197,7 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [out] ModelData &model_data: Offline model memory data /// @return SUCCESS handle successfully / others handle failed /// - ge::Status LoadDataFromFile(const std::string &path, ge::ModelData &model_data); + Status LoadDataFromFile(const std::string &path, ModelData &model_data); /// /// @ingroup ge @@ -195,8 +210,8 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [out] uint32_t &model_id: Corresponding identification after model loading /// @return SUCCESS handle successfully / others handle failed /// - ge::Status LoadModelFromData(uint32_t &model_id, const ge::ModelData &model_data, void *dev_ptr, size_t mem_size, - void *weight_ptr, size_t weight_size); + Status LoadModelFromData(uint32_t &model_id, const ModelData &model_data, void *dev_ptr, size_t mem_size, + void *weight_ptr, size_t weight_size); /// /// @ingroup ge @@ -207,9 +222,8 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [in] output_queue_ids: input queue ids create from user. /// @return: 0 for success / others for fail /// - ge::Status LoadModelWithQ(uint32_t &model_id, const ge::ModelData &model_data, - const std::vector &input_queue_ids, - const std::vector &output_queue_ids); + Status LoadModelWithQ(uint32_t &model_id, const ModelData &model_data, const std::vector &input_queue_ids, + const std::vector &output_queue_ids); /// /// @ingroup ge @@ -221,8 +235,8 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [out] domi::OutputData *output_data: Model output data /// @return SUCCESS handle successfully / others handle failed /// - ge::Status ExecModel(uint32_t model_id, void *stream, const ge::RunModelData &input_data, - ge::RunModelData &output_data, bool async_mode = false); + Status ExecModel(uint32_t model_id, void *stream, const RunModelData &input_data, RunModelData &output_data, + bool async_mode = false); /// /// @ingroup ge @@ -236,9 +250,9 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [out] std::vector &output_desc: description of model output data /// @return SUCCESS handle successfully / others handle failed /// - ge::Status ExecModel(uint32_t model_id, void *stream, const ge::RunModelData &run_input_data, - const std::vector &input_desc, ge::RunModelData &run_output_data, - std::vector &output_desc, bool async_mode = false); + Status ExecModel(uint32_t model_id, void *stream, const RunModelData &run_input_data, + const std::vector &input_desc, RunModelData &run_output_data, + std::vector &output_desc, bool async_mode = false); /// /// @ingroup ge @@ -248,7 +262,7 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [out] size_t &weight_size Weight memory space size /// @return SUCCESS handle successfully / others handle failed /// - ge::Status GetMemAndWeightSize(const std::string &path, size_t &mem_size, size_t &weight_size); + Status GetMemAndWeightSize(const std::string &path, size_t &mem_size, size_t &weight_size); /// /// @ingroup ge @@ -259,39 +273,39 @@ class GE_FUNC_VISIBILITY GeExecutor { /// @param [out] size_t &weight_size Weight memory space size /// @return SUCCESS handle successfully / others handle failed /// - ge::Status GetMemAndWeightSize(const void *model_data, size_t model_size, size_t &mem_size, size_t &weight_size); + Status GetMemAndWeightSize(const void *model_data, size_t model_size, size_t &mem_size, size_t &weight_size); - static ge::Status LoadSingleOp(const std::string &modelName, const ge::ModelData &modelData, void *stream, - SingleOp **single_op); + static Status LoadSingleOp(const std::string &modelName, const ModelData &modelData, void *stream, + SingleOp **single_op); - static ge::Status LoadSingleOpV2(const std::string &modelName, const ge::ModelData &modelData, void *stream, - SingleOp **single_op, const uint64_t model_id); + static Status LoadSingleOpV2(const std::string &modelName, const ModelData &modelData, void *stream, + SingleOp **single_op, const uint64_t model_id); - static ge::Status ExecuteAsync(SingleOp *executor, const std::vector &inputs, - std::vector &outputs); + static Status ExecuteAsync(SingleOp *executor, const std::vector &inputs, + std::vector &outputs); - static ge::Status LoadDynamicSingleOp(const std::string &model_name, const ge::ModelData &modelData, void *stream, - DynamicSingleOp **single_op); + static Status LoadDynamicSingleOp(const std::string &model_name, const ModelData &modelData, void *stream, + DynamicSingleOp **single_op); - static ge::Status LoadDynamicSingleOpV2(const std::string &model_name, const ge::ModelData &modelData, void *stream, - DynamicSingleOp **single_op, const uint64_t model_id); + static Status LoadDynamicSingleOpV2(const std::string &model_name, const ModelData &modelData, void *stream, + DynamicSingleOp **single_op, const uint64_t model_id); - static ge::Status ExecuteAsync(DynamicSingleOp *executor, const std::vector &input_desc, - const std::vector &inputs, std::vector &output_desc, - std::vector &outputs); + static Status ExecuteAsync(DynamicSingleOp *executor, const std::vector &input_desc, + const std::vector &inputs, std::vector &output_desc, + std::vector &outputs); - static ge::Status ReleaseSingleOpResource(void *stream); + static Status ReleaseSingleOpResource(void *stream); - static ge::Status GetDeviceIdByModelId(uint32_t model_id, uint32_t &device_id); + static Status GetDeviceIdByModelId(uint32_t model_id, uint32_t &device_id); - ge::Status GetBatchInfoSize(uint32_t model_id, size_t &shape_count); - ge::Status GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info); - ge::Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector &input_dims, - std::vector &output_dims); - ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); + Status GetBatchInfoSize(uint32_t model_id, size_t &shape_count); + Status GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info); + Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector &input_dims, + std::vector &output_dims); + Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); private: - static bool isInit_; + static std::atomic_bool is_inited_; }; } // namespace ge diff --git a/inc/framework/ge_runtime/task_info.h b/inc/framework/ge_runtime/task_info.h index 4530bff7..abc4783d 100644 --- a/inc/framework/ge_runtime/task_info.h +++ b/inc/framework/ge_runtime/task_info.h @@ -50,10 +50,18 @@ enum TaskInfoType { class TaskInfo { public: virtual ~TaskInfo() {} - uint32_t stream_id() const { return stream_id_; } - TaskInfoType type() const { return type_; } - std::string op_name() const { return op_name_; } - bool dump_flag() const { return dump_flag_; } + uint32_t stream_id() const { + return stream_id_; + } + TaskInfoType type() const { + return type_; + } + std::string op_name() const { + return op_name_; + } + bool dump_flag() const { + return dump_flag_; + } protected: TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) @@ -84,15 +92,33 @@ class CceTaskInfo : public TaskInfo { is_flowtable_(is_flowtable) {} ~CceTaskInfo() override {} - cce::ccOpContext cc_context() const { return ctx_; } - std::string stub_func() const { return stub_func_; } - uint32_t block_dim() const { return block_dim_; } - const std::vector &args() const { return args_; } - uint32_t args_size() const { return args_size_; } - const std::vector &sm_desc() const { return sm_desc_; } - const std::vector &flow_table() const { return flow_table_; } - const std::vector &args_offset() const { return args_offset_; } - bool is_flowtable() const { return is_flowtable_; } + cce::ccOpContext cc_context() const { + return ctx_; + } + std::string stub_func() const { + return stub_func_; + } + uint32_t block_dim() const { + return block_dim_; + } + const std::vector &args() const { + return args_; + } + uint32_t args_size() const { + return args_size_; + } + const std::vector &sm_desc() const { + return sm_desc_; + } + const std::vector &flow_table() const { + return flow_table_; + } + const std::vector &args_offset() const { + return args_offset_; + } + bool is_flowtable() const { + return is_flowtable_; + } private: cce::ccOpContext ctx_; @@ -126,17 +152,39 @@ class TbeTaskInfo : public TaskInfo { workspace_addrs_(workspace_addrs) {} ~TbeTaskInfo() override {} - const std::string &stub_func() const { return stub_func_; } - uint32_t block_dim() const { return block_dim_; } - const std::vector &args() const { return args_; } - uint32_t args_size() const { return args_size_; } - const std::vector &sm_desc() const { return sm_desc_; } - void *binary() const { return binary_; } - uint32_t binary_size() const { return binary_size_; } - const std::vector &meta_data() const { return meta_data_; } - const std::vector &input_data_addrs() const { return input_data_addrs_; } - const std::vector &output_data_addrs() const { return output_data_addrs_; } - const std::vector &workspace_addrs() const { return workspace_addrs_; } + const std::string &stub_func() const { + return stub_func_; + } + uint32_t block_dim() const { + return block_dim_; + } + const std::vector &args() const { + return args_; + } + uint32_t args_size() const { + return args_size_; + } + const std::vector &sm_desc() const { + return sm_desc_; + } + void *binary() const { + return binary_; + } + uint32_t binary_size() const { + return binary_size_; + } + const std::vector &meta_data() const { + return meta_data_; + } + const std::vector &input_data_addrs() const { + return input_data_addrs_; + } + const std::vector &output_data_addrs() const { + return output_data_addrs_; + } + const std::vector &workspace_addrs() const { + return workspace_addrs_; + } void SetBinary(void *binary, uint32_t binary_size) { binary_ = binary; @@ -171,12 +219,24 @@ class AicpuTaskInfo : public TaskInfo { output_data_addrs_(output_data_addrs) {} ~AicpuTaskInfo() override {} - const std::string &so_name() const { return so_name_; } - const std::string &kernel_name() const { return kernel_name_; } - const std::string &node_def() const { return node_def_; } - const std::vector &input_data_addrs() const { return input_data_addrs_; } - const std::vector &output_data_addrs() const { return output_data_addrs_; } - const std::string &ext_info() const { return ext_info_; } + const std::string &so_name() const { + return so_name_; + } + const std::string &kernel_name() const { + return kernel_name_; + } + const std::string &node_def() const { + return node_def_; + } + const std::vector &input_data_addrs() const { + return input_data_addrs_; + } + const std::vector &output_data_addrs() const { + return output_data_addrs_; + } + const std::string &ext_info() const { + return ext_info_; + } private: std::string so_name_; @@ -192,7 +252,9 @@ class LabelSetTaskInfo : public TaskInfo { LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} ~LabelSetTaskInfo() override {} - uint32_t label_id() const { return label_id_; } + uint32_t label_id() const { + return label_id_; + } private: uint32_t label_id_; @@ -203,7 +265,9 @@ class LabelGotoTaskInfo : public TaskInfo { LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} ~LabelGotoTaskInfo() override {} - uint32_t label_id() const { return label_id_; } + uint32_t label_id() const { + return label_id_; + } private: uint32_t label_id_; @@ -218,9 +282,15 @@ class LabelSwitchTaskInfo : public TaskInfo { label_list_(label_list), cond_(cond) {} ~LabelSwitchTaskInfo() override {} - uint32_t label_size() const { return label_size_; } - const std::vector &label_list() const { return label_list_; } - void *cond() const { return cond_; } + uint32_t label_size() const { + return label_size_; + } + const std::vector &label_list() const { + return label_list_; + } + void *cond() const { + return cond_; + } private: uint32_t label_size_; @@ -230,7 +300,9 @@ class LabelSwitchTaskInfo : public TaskInfo { class EventTaskInfo : public TaskInfo { public: - uint32_t event_id() const { return event_id_; } + uint32_t event_id() const { + return event_id_; + } protected: EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) @@ -271,14 +343,13 @@ class FusionEndTaskInfo : public TaskInfo { class HcclTaskInfo : public TaskInfo { public: HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, - void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, + void *output_data_addr, int64_t workspace_size, int64_t hccl_stream_num, const std::vector &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag) : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), hccl_type_(hccl_type), input_data_addr_(input_data_addr), output_data_addr_(output_data_addr), - workspace_addr_(workspace_addr), workspace_size_(workspace_size), hccl_stream_num_(hccl_stream_num), private_def_(private_def), @@ -290,25 +361,47 @@ class HcclTaskInfo : public TaskInfo { group_(group) {} ~HcclTaskInfo() override {} - const std::string &hccl_type() const { return hccl_type_; } - void *input_data_addr() const { return input_data_addr_; } - void *output_data_addr() const { return output_data_addr_; } - void *workspace_addr() const { return workspace_addr_; } - int64_t workspace_size() const { return workspace_size_; } - int64_t hccl_stream_num() const { return hccl_stream_num_; } - const std::vector &private_def() const { return private_def_; } - void *ops_kernel_store() const { return ops_kernel_store_; } - int32_t count() const { return count_; } - int64_t root_id() const { return root_id_; } - int64_t op_type() const { return op_type_; } - int64_t data_type() const { return data_type_; } - const std::string &group() const { return group_; } + const std::string &hccl_type() const { + return hccl_type_; + } + void *input_data_addr() const { + return input_data_addr_; + } + void *output_data_addr() const { + return output_data_addr_; + } + int64_t workspace_size() const { + return workspace_size_; + } + int64_t hccl_stream_num() const { + return hccl_stream_num_; + } + const std::vector &private_def() const { + return private_def_; + } + void *ops_kernel_store() const { + return ops_kernel_store_; + } + int32_t count() const { + return count_; + } + int64_t root_id() const { + return root_id_; + } + int64_t op_type() const { + return op_type_; + } + int64_t data_type() const { + return data_type_; + } + const std::string &group() const { + return group_; + } private: std::string hccl_type_; void *input_data_addr_; void *output_data_addr_; - void *workspace_addr_; int64_t workspace_size_; int64_t hccl_stream_num_; std::vector private_def_; @@ -329,9 +422,15 @@ class ProfilerTraceTaskInfo : public TaskInfo { flat_(flat) {} ~ProfilerTraceTaskInfo() override {} - uint64_t log_id() const { return log_id_; } - bool notify() const { return notify_; } - uint32_t flat() const { return flat_; } + uint64_t log_id() const { + return log_id_; + } + bool notify() const { + return notify_; + } + uint32_t flat() const { + return flat_; + } private: uint64_t log_id_; @@ -351,11 +450,21 @@ class MemcpyAsyncTaskInfo : public TaskInfo { kind_(kind) {} ~MemcpyAsyncTaskInfo() override {} - void *dst() const { return dst_; } - uint64_t dst_max() const { return dst_max_; } - void *src() const { return src_; } - uint64_t count() const { return count_; } - uint32_t kind() const { return kind_; } + void *dst() const { + return dst_; + } + uint64_t dst_max() const { + return dst_max_; + } + void *src() const { + return src_; + } + uint64_t count() const { + return count_; + } + uint32_t kind() const { + return kind_; + } private: void *dst_; @@ -377,11 +486,21 @@ class StreamSwitchTaskInfo : public TaskInfo { data_type_(data_type) {} ~StreamSwitchTaskInfo() override {} - int64_t true_stream_id() const { return true_stream_id_; } - void *input_addr() const { return input_addr_; } - void *value_addr() const { return value_addr_; } - int64_t cond() const { return cond_; } - int64_t data_type() const { return data_type_; } + int64_t true_stream_id() const { + return true_stream_id_; + } + void *input_addr() const { + return input_addr_; + } + void *value_addr() const { + return value_addr_; + } + int64_t cond() const { + return cond_; + } + int64_t data_type() const { + return data_type_; + } private: int64_t true_stream_id_; @@ -397,7 +516,9 @@ class StreamActiveTaskInfo : public TaskInfo { : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} ~StreamActiveTaskInfo() override {} - uint32_t active_stream_id() const { return active_stream_id_; } + uint32_t active_stream_id() const { + return active_stream_id_; + } private: uint32_t active_stream_id_; diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index ee51d29d..5da5a593 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -106,7 +106,7 @@ class GE_FUNC_VISIBILITY GeGenerator { bool CheckNoAicore(const ComputeGraphPtr &graph); void RemoveConst(const vector &inputs, vector &outputs); Status CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs); - Status InferFormatForSingleOp(OpDescPtr &op_desc); + Status InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph); using GeRootModelPtr = std::shared_ptr; Status SetModelNameForDump(const GeRootModelPtr &ge_root_model); diff --git a/metadef b/metadef index 00c0c12e..21178899 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 00c0c12eede6c7bce93a1eda5f0bb437ae80a7ec +Subproject commit 211788997dcc9aa63527541a44d511388c06bce5 diff --git a/parser b/parser index e75eda62..7a2daaa2 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit e75eda62de2b51a0bded5481ca81eb8fc7bf376e +Subproject commit 7a2daaa2625505e1a15e1faa46c90df1a23dd6fa diff --git a/scripts/env/Dockerfile b/scripts/env/Dockerfile index af02f7bb..923a1453 100755 --- a/scripts/env/Dockerfile +++ b/scripts/env/Dockerfile @@ -38,5 +38,20 @@ RUN wget https://github.com/ccup/lcov/archive/refs/tags/add_lcov.tar.gz -O add_l ENV PROJECT_HOME=/code/Turing/graphEngine +RUN mkdir /var/run/sshd +RUN echo "root:root" | chpasswd +RUN sed -i 's/\#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config +RUN sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd + +ENV NOTVISIBLE "in users profile" +RUN echo "export VISIBLE=now" >> /etc/profile + +EXPOSE 22 7777 + +RUN useradd -ms /bin/bash debugger +RUN echo "debugger:ge123" | chpasswd + +CMD ["/usr/sbin/sshd" "-D" "&"] + RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc diff --git a/scripts/env/ge_env.sh b/scripts/env/ge_env.sh index 18c6aa5d..10ca810f 100755 --- a/scripts/env/ge_env.sh +++ b/scripts/env/ge_env.sh @@ -21,7 +21,7 @@ MOUNT_PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} -DOCKER_IMAGE_TAG=ge_build_env.1.0.6 +DOCKER_IMAGE_TAG=ge_build_env.1.0.9 DOCKER_IAMGE_NAME=joycode2art/turing DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} @@ -61,7 +61,7 @@ function enter_docker_env(){ if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then echo "please run 'ge env --pull' to download images first!" elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then - $docker_cmd run -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} + $docker_cmd run -p 7002:22 -p 7003:7777 --privileged=true -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then $docker_cmd start ${DOCKER_BUILD_ENV_NAME} $docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} diff --git a/scripts/update/ge_update.sh b/scripts/update/ge_update.sh index d6bcd043..57266d06 100755 --- a/scripts/update/ge_update.sh +++ b/scripts/update/ge_update.sh @@ -38,7 +38,7 @@ function extract_deps_so_community() { echo "begin to extract .run file ........." chmod +x ./${DRIVER_RUN_NAME_C} - chmod +X ./${PACKAGE_NAME_C} + chmod +x ./${PACKAGE_NAME_C} [ -n "${DEP_TMP_DIR}" ] && rm -rf "${DEP_TMP_DIR}" ./${DRIVER_RUN_NAME_C} --noexec --extract=${DEP_TMP_DIR}/driver ./${PACKAGE_NAME_C} --noexec --extract=${DEP_TMP_DIR}/Packages_tmp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cfad36e1..f5dab366 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -15,13 +15,14 @@ project(tests CXX C) find_package(Threads) -add_subdirectory(depends/cce) + add_subdirectory(depends/slog) add_subdirectory(depends/mmpa) add_subdirectory(depends/runtime) add_subdirectory(depends/hccl) add_subdirectory(depends/profiler) add_subdirectory(depends/error_manager) +add_subdirectory(depends/opt_info) if (ENABLE_GE_COV OR ENABLE_GE_UT) add_subdirectory(ut) diff --git a/tests/depends/cce/CMakeLists.txt b/tests/depends/cce/CMakeLists.txt deleted file mode 100644 index 7550c63f..00000000 --- a/tests/depends/cce/CMakeLists.txt +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -#cmake_minimum_required(VERSION 2.8) - -project(STUB_CCE) - -set(CMAKE_CXX_STANDARD 11) - -include_directories(${GE_CODE_DIR}/inc) -include_directories(${GE_CODE_DIR}/inc/framework) -include_directories(${GE_CODE_DIR}/metadef/inc/graph) -include_directories(${GE_CODE_DIR}/inc/external) -include_directories(${GE_CODE_DIR}/metadef/inc/external) -include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) -include_directories(${GE_CODE_DIR}/metadef) -include_directories(${GE_CODE_DIR}/metadef/inc) -include_directories(${GE_CODE_DIR}/metadef/graph) -include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) -include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) -include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) -set(PROTO_LIST - "${GE_CODE_DIR}/metadef/proto/om.proto" - "${GE_CODE_DIR}/metadef/proto/ge_ir.proto" - "${GE_CODE_DIR}/metadef/proto/task.proto" -) - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) - -set(SRCS - "${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc" - "${GE_CODE_DIR}/metadef/graph/anchor.cc" - "${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc" - "${GE_CODE_DIR}/metadef/graph/buffer.cc" - "${GE_CODE_DIR}/metadef/graph/aligned_ptr.cc" - "${GE_CODE_DIR}/metadef/graph/compute_graph.cc" - "${GE_CODE_DIR}/metadef/graph/graph.cc" - "${GE_CODE_DIR}/metadef/graph/model.cc" - "${GE_CODE_DIR}/metadef/graph/model_serialize.cc" - "${GE_CODE_DIR}/metadef/graph/node.cc" - "${GE_CODE_DIR}/metadef/graph/op_desc.cc" - "${GE_CODE_DIR}/metadef/graph/operator.cc" - "${GE_CODE_DIR}/metadef/graph/operator_factory.cc" - "${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" - "${GE_CODE_DIR}/metadef/graph/tensor.cc" - "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" - "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" - "${GE_CODE_DIR}/metadef/ops/op_imp.cpp" - "${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" - "${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" - "${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc" -) -add_library(cce_ge_stub SHARED src/cce_stub.cc ${PROTO_SRCS} ${PROTO_HDRS}) - -target_compile_definitions(cce_ge_stub PRIVATE - google=ascend_private -) - -target_link_libraries(cce_ge_stub - $ - -Wl,--no-as-needed - ascend_protobuf - -Wl,--as-needed - c_sec -) - -add_library(cce_stub SHARED ${SRCS} ${PROTO_SRCS} ${PROTO_HDRS}) - -target_compile_definitions(cce_stub PRIVATE - google=ascend_private -) - -target_link_libraries(cce_stub PRIVATE - $ - -Wl,--no-as-needed - ascend_protobuf - -Wl,--as-needed - c_sec -) diff --git a/tests/depends/cce/src/cce_stub.cc b/tests/depends/cce/src/cce_stub.cc deleted file mode 100644 index 03df3d0c..00000000 --- a/tests/depends/cce/src/cce_stub.cc +++ /dev/null @@ -1,576 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -#include "cce/optimizer/fusion_engine.h" -#include "common/op/attr_value_util.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/graph_utils.h" - -using namespace cce; -using namespace std; -using namespace ge; -using namespace fusion; - -uint64_t global_mem_base = 0; - -namespace cce { -#define DIM_MAX_SIZE 8 -static const uint32_t C0 = 16; -struct tagCcPad {}; -struct tagCcConvolution {}; - -struct tagCcLRN {}; - -struct tagCcFasterRcnnProposal {}; -struct tagCcRoiAlign {}; -struct tagCcBatchNorm {}; -struct tagCcDetectpostprocess {}; - -struct tagCcSsdDetectionOutput {}; - -struct tagCcRefinedetDetectionOutput {}; - -struct tagCcMsrGenerateRpnProposals {}; - -struct tagCcFilter { - vector dims; -}; - -struct tagCcTensor { - ccTensorFormat_t format; - ccDataType_t data_type; - uint32_t dim_cnt; - int32_t real_dim_cnt; - uint32_t data_size; - int32_t dim_buf[DIM_MAX_SIZE]; - int32_t stride_buf[DIM_MAX_SIZE]; -}; - -typedef struct tagCcPooling { - ccPoolingMode_t mode; - ccPaddingMode_t pad_mode; - ccNanPropagation_t max_pooling_nan_opt; - uint32_t dim_cnt; - int32_t window_dim[6]; - int32_t padding[6]; - int32_t stride[6]; -} ccPooling_t; - -struct tagCcActivation {}; - -struct tagCcFasterRcnnDetectionOutput {}; -struct tagCcSpatialTransformer {}; - -struct tagCcPower {}; -struct tagCcResizeBilinear {}; -struct tagCcSsdNormalize {}; -struct tagCcSsdPostProcessor {}; -struct tagCcSsdPriorBox {}; -struct tagCcPsRoiPooling {}; - -struct tagMsrFastRcnnPredictions {}; -struct tagCcPRelu {}; -struct tagCcStridedSlice {}; - -struct tagCcStridedSliceAttrs {}; - -struct tagCcRnn {}; - -struct tagCcArgmaxmin {}; - -typedef struct tagCcLog { - ccDataType_t data_type; - uint32_t param_cnt; -} ccLog_t; -typedef struct tagCcLog *ccLogDescriptor_t; - -struct tagCcPadV2 {}; - -ccStatus_t ccGetPadV2OutputDim(const ccTensorDescriptor_t x_desc, const ccPadV2Descriptor_t pad_desc, int32_t *dim_cnt, - int32_t dim[], int32_t dim_len) { - *dim_cnt = 4; - dim[0] = 1; - dim[1] = 2; - dim[2] = 2; - dim[3] = 3; - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccPadV2Forward(ccHandle_t handle, const ccPadV2Descriptor_t pad_desc, const void *alpha, - const ccTensorDescriptor_t x_desc, const void *x, const void *beta, - const ccTensorDescriptor_t output_desc, void *output) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccCreatePadV2Descriptor(ccPadV2Descriptor_t *pad_desc) { return CC_STATUS_SUCCESS; } - -ccStatus_t ccDestroyPadV2Descriptor(ccPadV2Descriptor_t *pad_desc) { return CC_STATUS_SUCCESS; } - -ccStatus_t ccSetKernelOpMap(ccHandle_t handle) { return CC_STATUS_SUCCESS; } - -ccStatus_t ccDataDumpForward(ccHandle_t handle, const void *buffer, const uint64_t buf_len, const uint32_t task_index) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetPadV2Descriptor(ccPadV2Descriptor_t pad_desc, const int32_t pad_shape_cnt, - const int32_t pad_shape_low[], const int32_t pad_shape_high[], - const ccPadMode_t pad_mode, const void *pad_value, const ccDataType_t pad_value_type) { - return CC_STATUS_SUCCESS; -} - -struct tagCcYoloDetectionOutput { - ccYoloVersion_t yolo_version; - uint32_t net_h; - uint32_t net_w; - uint32_t post_top_k; - uint32_t classes; - float nms_threshold; - float iou_thre_decay; - float coor_scale_factor; - bool relative; - float obj_threshold; - float cls_threshold; - uint32_t bias_num; - float *bias; -}; - -struct tagCcYoloRegion {}; - -struct tagCcEltwise {}; - -struct tagCcHashTableLookup {}; - -struct tagCcEmbeddingAttnDecoder {}; -struct tagNonMaxSuppression {}; - -struct tagCcArcSinCos {}; -struct tagCcPow {}; -struct tagCcConcatFive2Four_t {}; -struct tagCcConcatFour2Five_t {}; - -ccStatus_t ccCreatePowDescriptor(ccPowDescriptor_t *pow_desc) { - *pow_desc = new tagCcPow(); - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetPowDescriptor(ccPowDescriptor_t pow_desc, ccDataType_t data_type, uint32_t param_cnt) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccDestroyPowDescriptor(ccPowDescriptor_t *pow_desc) { - if (nullptr == pow_desc) { - return CC_STATUS_BAD_PARAM; - } - - delete *pow_desc; - *pow_desc = 0; - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccPowForward(ccHandle_t handle, const ccPowDescriptor_t pow_desc, const void *pow_param, const void *alpha, - const ccTensorDescriptor_t x_desc, const void *x, const ccTensorDescriptor_t y_desc, - const void *y, const void *beta, const ccTensorDescriptor_t z_desc, void *z) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccLogicalOrForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t x_desc, const void *x, - const ccTensorDescriptor_t y_desc, const void *y, const void *beta, - const ccTensorDescriptor_t output_desc, void *output) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccCompareForward(ccHandle_t handle, ccCompareType_t compare_type, const void *alpha, - const ccTensorDescriptor_t x_desc, const void *x, const ccTensorDescriptor_t y_desc, - const void *y, const void *beta, const ccTensorDescriptor_t output_desc, void *output) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccGetCompareOutputDim(const ccTensorDescriptor_t x_desc, const ccTensorDescriptor_t y_desc, int32_t *dim_cnt, - int32_t *dim, int32_t dim_len) { - *dim_cnt = 4; - dim[0] = 1; - dim[1] = 1; - dim[2] = 1; - dim[3] = 1; - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccArcTanForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t x_desc, const void *x, - const void *beta, const ccTensorDescriptor_t y_desc, void *y) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccAtanhForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t x_desc, const void *x, - const void *beta, const ccTensorDescriptor_t y_desc, void *y) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccIsDepthwiseHighPerformance(int32_t input_n, int32_t input_c, int32_t input_h, int32_t input_w, - int32_t filter_n, int32_t filter_c, int32_t filter_h, int32_t filter_w, - int32_t dilation_h, int32_t dilation_w, int32_t pad_h_head, int32_t pad_h_tail, - int32_t pad_w_head, int32_t pad_w_tail, int32_t stride_h, int32_t stride_w, - int32_t group_num, bool &is_high_performance, bool is_quant, - ccDataType_t input_data_type, ccDataType_t output_data_type) { - is_high_performance = true; - return CC_STATUS_SUCCESS; -} - -struct tagCcSpaceToBatch {}; - -struct tagCcBatchToSpace {}; - -struct tagCcResizeNearestNeighbor {}; - -ccStatus_t ccGetStream(ccHandle_t handle, rtStream_t *stream_id) { return CC_STATUS_SUCCESS; } - -ccStatus_t ccGetRtVersion(uint32_t *count) { return CC_STATUS_SUCCESS; } - -ccStatus_t ccDestroyTensorDescriptor(ccTensorDescriptor_t *tensor_desc) { - if (nullptr == tensor_desc) { - return CC_STATUS_BAD_PARAM; - } - delete *tensor_desc; - *tensor_desc = 0; - return CC_STATUS_SUCCESS; -} -ccStatus_t ccDestroyFilterDescriptor(ccFilterDescriptor_t *filter_desc) { - delete *filter_desc; - *filter_desc = 0; - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccGetFilterSizeInBytes(const ccFilterDescriptor_t filter_desc, uint32_t *size) { - *size = filter_desc->dims[0] * filter_desc->dims[1] * filter_desc->dims[2] * filter_desc->dims[3] * sizeof(float); - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccTransFilter(const ccFilterDescriptor_t w_desc, const void *w, ccFilterDescriptor_t y_desc, void *y, - uint32_t y_size_in_bytes) { - y = const_cast(w); - - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccCreateTensorDescriptor(ccTensorDescriptor_t *tensor_desc) { - *tensor_desc = new tagCcTensor(); - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetTensor4dDescriptor(ccTensorDescriptor_t tensor_desc, ccTensorFormat_t format, ccDataType_t data_type, - int32_t n, int32_t c, int32_t h, int32_t w) { - if (CC_TENSOR_NHWC == format) { - tensor_desc->dim_buf[0] = n; - tensor_desc->dim_buf[1] = h; - tensor_desc->dim_buf[2] = w; - tensor_desc->dim_buf[3] = c; - } else { - tensor_desc->dim_buf[0] = n; - tensor_desc->dim_buf[1] = c; - tensor_desc->dim_buf[2] = h; - tensor_desc->dim_buf[3] = w; - } - tensor_desc->dim_cnt = 4; - tensor_desc->data_type = data_type; - tensor_desc->format = format; - tensor_desc->data_size = n * c * h * w * sizeof(data_type); - return CC_STATUS_SUCCESS; -} -ccStatus_t ccGetTensorSizeInBytes(const ccTensorDescriptor_t tensor_desc, uint32_t *size) { - if ((NULL == tensor_desc) || (NULL == size)) { - return CC_STATUS_BAD_PARAM; - } - *size = tensor_desc->data_size; - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccGetTensorMemorySizeInBytes(const ccTensorDescriptor_t tensor_desc, uint32_t *size) { - *size = tensor_desc->data_size; - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccCreateFilterDescriptor(ccFilterDescriptor_t *filter_desc) { - *filter_desc = new tagCcFilter(); - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetFilter4dDescriptor(ccFilterDescriptor_t filter_desc, ccTensorFormat_t format, ccDataType_t data_type, - int32_t k, int32_t c, int32_t h, int32_t w) { - filter_desc->dims.push_back(k); - filter_desc->dims.push_back(c); - filter_desc->dims.push_back(h); - filter_desc->dims.push_back(w); - - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetFilterFractalDescriptor(ccFilterDescriptor_t filter_desc, ccTensorFormat_t format, - ccDataType_t data_type, int32_t k, int32_t c, int32_t h, int32_t w) { - filter_desc->dims.push_back(k); - filter_desc->dims.push_back(c); - filter_desc->dims.push_back(h); - filter_desc->dims.push_back(w); - - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetStream(ccHandle_t handle, rtStream_t stream_id) { return CC_STATUS_SUCCESS; } -ccStatus_t ccCreatePoolingMaskDescriptor(ccTensorDescriptor_t *pooling_mask_desc) { - *pooling_mask_desc = new tagCcTensor(); - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetPoolingMaskTensorDescriptor(ccTensorDescriptor_t tensor_desc, ccTensorFormat_t format, - ccDataType_t data_type, int32_t n, int32_t c, int32_t h, int32_t w, - int32_t window_h, int32_t window_w) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetFilter6dDescriptor(ccTensorDescriptor_t filter_desc, ccTensorFormat_t format, ccDataType_t data_type, - int32_t c1, int32_t h, int32_t w, int32_t n, int32_t co, int32_t c0) { - return CC_STATUS_SUCCESS; -} - -/// @ingroup dnn -/// @brief get the format and dimcnt of GeTensor -/// @param [in] tensor_desc descriptor of tensor -/// @param [in|out] format point to format -/// @return ccStatus_t -ccStatus_t ccGetTensorFormat(const ccTensorDescriptor_t tensor_desc, ccTensorFormat_t *format) { - *format = tensor_desc->format; - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccTransTensor(const ccTensorDescriptor_t x_desc, const void *x, const ccTensorDescriptor_t y_desc, void *y, - uint32_t y_size_in_bytes) { - return CC_STATUS_SUCCESS; -} -void cceSysInit() {} - -bool compilerStubFree() { return true; } - -bool compilerStubInit() { return true; } - -ccStatus_t ccSetInt8Filter4dDescriptor(ccFilterDescriptor_t filter_desc, ccTensorFormat_t format, - ccDataType_t data_type, int32_t k, int32_t c, int32_t h, int32_t w, - ccDataType_t output_data_type) { - filter_desc->dims.push_back(k); - filter_desc->dims.push_back(c); - filter_desc->dims.push_back(h); - filter_desc->dims.push_back(w); - - return CC_STATUS_SUCCESS; -} -ccStatus_t ccSetTensorNdDescriptor(ccTensorDescriptor_t tensor_desc, ccDataType_t data_type, int32_t dim_cnt, - int32_t dimA[]) { - tensor_desc->data_type = data_type; - tensor_desc->data_size = sizeof(data_type); - for (int32_t i = 0; i < dim_cnt; i++) { - tensor_desc->data_size = tensor_desc->data_size * dimA[i]; - } - tensor_desc->format = CC_TENSOR_ND; - return CC_STATUS_SUCCESS; -} - -ccStatus_t CceProfilingConfig(const char *target, const char *job_ctx, uint32_t flag) { return CC_STATUS_SUCCESS; } -ccStatus_t ccSetTensorRealDimCnt(ccTensorDescriptor_t tensor_desc, int32_t real_dim_cnt) { - if (tensor_desc != NULL && tensor_desc != nullptr) { - tensor_desc->real_dim_cnt = real_dim_cnt; - } - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccGetTensorRealDimCnt(ccTensorDescriptor_t tensor_desc, int32_t *real_dim_cnt) { - *real_dim_cnt = tensor_desc->real_dim_cnt; - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetQuantizeFactors(ccQuantizeDescriptor_t quantize_info, ccScaleValueMode_t scale_val_mode, - const uint16_t *scale, const uint16_t *offset, const uint8_t *offset_pad) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetReQuantizeFactors(ccQuantizeDescriptor_t quantize_info, ccScaleValueMode_t scale_val_mode, - const uint16_t *scale_rq, const uint16_t *next_layer_offset, - const int32_t *offset_w) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetDeQuantizeFactors(ccQuantizeDescriptor_t quantize_info, ccScaleValueMode_t scale_val_mode, - const uint16_t *scale_dq, const int32_t *offset_w) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetQuantizeAlgoAndScaleType(ccQuantizeDescriptor_t quantize_info, ccQuantizeAlgo_t quant_algo, - ccScaleType_t scale_type, bool relu_flag) { - return CC_STATUS_SUCCESS; -} -ccStatus_t ccPrintTimeStat() { return CC_STATUS_SUCCESS; } -ccStatus_t ccSetModelId(ccHandle_t handle, uint32_t model_id) { return CC_STATUS_SUCCESS; } - -ccStatus_t ccGetKernelContext(rtStream_t stream_id, ccOpContext &op_context) { - if (stream_id == nullptr) { - op_context.kernelType = ccKernelType::TE; - } else { - op_context.kernelType = ccKernelType::CCE_AI_CORE; - op_context.opId = 1; - op_context.kernelFuncId = 1; - op_context.isFlowtable = true; - op_context.opCount = 1; - op_context.opIndex2[0] = 0; - } - - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccUpdateKernelArgs(ccOpContext &op_context, uint64_t data_base_addr, uint64_t weight_base_addr, - uint64_t variable_base_addr, void *args_addr, uint64_t args_size, void *l2ctrl_addr) { - return CC_STATUS_SUCCESS; -} -ccStatus_t ccGetKernelArgsAddrs(ccOpContext &op_context, void *args_addr, uint64_t args_size, void *l2ctrl_addr, - std::vector &op_addrs_info) { - // cce - ccOpAddrsInfo tmp_op_addrs_info; - uint64_t tmp_input = (uint64_t)global_mem_base; - tmp_op_addrs_info.addrPos = &tmp_input; - tmp_op_addrs_info.addrData = tmp_input; - op_addrs_info.push_back(tmp_op_addrs_info); - - uint64_t tmp_output = (uint64_t)(global_mem_base + 5476352); - tmp_op_addrs_info.addrPos = &tmp_output; - tmp_op_addrs_info.addrData = tmp_output; - op_addrs_info.push_back(tmp_op_addrs_info); - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccSetKernelArgs(std::vector &date_info) { return CC_STATUS_SUCCESS; } -} // namespace cce -// ccFusion no namespace -ccStatus_t ccFusionStart(ccHandle_t handle, uint32_t graph_id, uint32_t init_flag, CceFusionMemCfg_t mem_cfg) { - return CC_STATUS_SUCCESS; -} - -//???ccFusion ????namespace cce?? -ccStatus_t ccFusionStart(ccHandle_t handle, uint32_t graph_id, uint32_t init_flag, uint32_t addr_change_flag) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t ccFusionEnd(ccHandle_t handle, uint32_t graph_id) { return CC_STATUS_SUCCESS; } - -ccStatus_t ccFusionTaskEnd(ccHandle_t handle, uint32_t graph_id) { return CC_STATUS_SUCCESS; } - -ccStatus_t ccKernelLaunchRepeat(ccHandle_t handle) { return CC_STATUS_SUCCESS; } - -ccStatus_t ccKernelDelete(ccHandle_t handle) { return CC_STATUS_SUCCESS; } - -ccStatus_t cce::ccSetTensorFormat(cce::tagCcTensor *, cce::tagCcTensorFormat) { return CC_STATUS_SUCCESS; } - -namespace fusion { -uint32_t BufferFusion(std::shared_ptr, std::shared_ptr, bool) { return 0; } - -uint32_t BufferFusionTrain(std::shared_ptr, std::shared_ptr) { return 0; } - -uint32_t GraphFusionTrain(ge::ComputeGraphPtr orig_graph, ge::ComputeGraphPtr fusion_graph) { return 0; } -} // namespace fusion -namespace fusion { -using namespace ge; - -uint32_t Fusion(ComputeGraphPtr model_graph, ComputeGraphPtr fusion_graph, kScopeNodeMap_t &te_fusion_map) { - OpDescPtr op_def_a = std::make_shared(); - op_def_a->SetName("reduction_nd"); - op_def_a->SetType("reduction_nd"); - - GeTensorDescPtr v_input_desc = std::make_shared(); - op_def_a->AddInputDesc(*v_input_desc); - - vector v_input; - v_input.push_back(0); - op_def_a->SetInputOffset(v_input); - - GeTensorDesc input_desc = op_def_a->GetInputDesc(0); - input_desc.SetFormat(FORMAT_NCHW); - input_desc.SetDataType(DT_FLOAT); - input_desc.SetShape(GeShape({1, 3, 5, 5})); - ge::TensorUtils::SetSize(input_desc, 192); - ge::TensorUtils::SetRealDimCnt(input_desc, 4); - - GeTensorDescPtr output_desc = std::make_shared(); - op_def_a->AddOutputDesc(*output_desc); - - output_desc->SetFormat(FORMAT_NCHW); - output_desc->SetDataType(DT_FLOAT); - output_desc->SetShape(GeShape({1, 3, 5})); - ge::TensorUtils::SetSize(*output_desc, 96); - ge::TensorUtils::SetRealDimCnt(*output_desc, 3); - - OpDescPtr op_def_b = std::make_shared(); - op_def_b->SetName("transdata_1"); - op_def_b->SetType("TransData"); - - int stream_num = 1; - int flag = 0; - - NodePtr node_a = fusion_graph->AddNode(op_def_a); - NodePtr node_b = fusion_graph->AddNode(op_def_b); - - GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); - int32_t a = 1; - int32_t b = 2; - - AttrUtils::SetInt(op_def_a, "fusion_scope", a); - AttrUtils::SetInt(op_def_b, "fusion_scope", b); - - vector node_list1; - node_list1.push_back(node_a); - vector node_list2; - node_list2.push_back(node_b); - te_fusion_map[1] = node_list1; - te_fusion_map[2] = node_list2; - - return FUSION_STATUS_SUCCESS; -} - -uint32_t FusionTaskBuild(cce::ccHandle_t cc_handle, ge::ComputeGraphPtr fusion_graph, ge::Buffer &buffer, - ModelRes &model_res, std::vector &task_def_list_) { - TaskDef task_def_temp; - task_def_list_.push_back(task_def_temp); - - return FUSION_STATUS_SUCCESS; -} -uint32_t GraphFusion(ge::ComputeGraphPtr orig_graph, ge::ComputeGraphPtr fusion_graph) { - *fusion_graph = *orig_graph; - return FUSION_STATUS_SUCCESS; -} - -void FusionTaskBuildComplete(std::vector cc_handle_list) { return; } - -} // namespace fusion - -ccStatus_t cce::ccSetTensorDescriptorQuantizeParam(ccTensorDescriptor_t tensor_desc, - const ccVecQuantizePara_t *vec_quantize_para) { - return CC_STATUS_SUCCESS; -} - -ccStatus_t cce::ccSetAllOffsetQuantizeFactors(ccQuantizeDescriptor_t quantize_info, const uint8_t *offset_w, - const uint8_t *offset_d, const uint16_t *scale_req, - const uint16_t *offset_d_next) { - return CC_STATUS_SUCCESS; -} diff --git a/tests/depends/cce/src/op_kernel_registry.cc b/tests/depends/cce/src/op_kernel_registry.cc deleted file mode 100644 index 5ccd1391..00000000 --- a/tests/depends/cce/src/op_kernel_registry.cc +++ /dev/null @@ -1,29 +0,0 @@ -#include "register/op_kernel_registry.h" - -namespace ge { -class OpKernelRegistry::OpKernelRegistryImpl { - -}; - -OpKernelRegistry::OpKernelRegistry() { -} - -OpKernelRegistry::~OpKernelRegistry() { - -} - -bool OpKernelRegistry::IsRegistered(const std::string &op_type) { - return false; -} - -std::unique_ptr OpKernelRegistry::CreateHostCpuOp(const std::string &op_type) { - return nullptr; -} - -void OpKernelRegistry::RegisterHostCpuOp(const std::string &op_type, CreateFn create_fn) { -} - -HostCpuOpRegistrar::HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()) { - -} -} // namespace ge \ No newline at end of file diff --git a/tests/depends/mmpa/src/mmpa_stub.cc b/tests/depends/mmpa/src/mmpa_stub.cc index a82621ef..8801aacd 100644 --- a/tests/depends/mmpa/src/mmpa_stub.cc +++ b/tests/depends/mmpa/src/mmpa_stub.cc @@ -220,6 +220,13 @@ VOID mmScandirFree(mmDirent **entryList, INT32 count) INT32 mmAccess2(const CHAR *pathName, INT32 mode) { + if (pathName == NULL) { + return EN_INVALID_PARAM; + } + INT32 ret = access(pathName, mode); + if (ret != EN_OK) { + return EN_ERROR; + } return 0; } @@ -338,6 +345,10 @@ INT32 mmIsDir(const CHAR *fileName) INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) { + const char *env = getenv(name); + if (env != nullptr) { + strcpy(value, env); + } return 0; } @@ -370,3 +381,8 @@ INT32 mmGetPid() { return (INT32)getpid(); } + +INT32 mmSetCurrentThreadName(const CHAR *name) +{ + return EN_OK; +} \ No newline at end of file diff --git a/tests/depends/opt_info/CMakeLists.txt b/tests/depends/opt_info/CMakeLists.txt new file mode 100644 index 00000000..9148dfd8 --- /dev/null +++ b/tests/depends/opt_info/CMakeLists.txt @@ -0,0 +1,37 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +#cmake_minimum_required(VERSION 2.8) + +project(opt_feature_stub) + +file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "src/opt_info_stub.cc" +) + +include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info) + +add_library(opt_feature_stub SHARED ${SRCS}) + +target_compile_options(opt_feature_stub PRIVATE + -g +) + +target_link_libraries(opt_feature_stub PRIVATE + $ + c_sec +) + +target_include_directories(opt_feature_stub INTERFACE ${CMAKE_CURRENT_LIST_DIR}/src) diff --git a/tests/depends/opt_info/src/opt_info_stub.cc b/tests/depends/opt_info/src/opt_info_stub.cc new file mode 100644 index 00000000..df518c4b --- /dev/null +++ b/tests/depends/opt_info/src/opt_info_stub.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "opt_info.h" +#include +#include +#include +#include + +namespace gelc { +namespace { +const std::vector kSocVersions = {"Ascend910"}; +} + +void SetAllOptInfo(std::map &opt_infos) { + opt_infos.emplace("opt_module.fe", "all"); + opt_infos.emplace("opt_module.pass", "all"); + opt_infos.emplace("opt_module.op_tune", "all"); + opt_infos.emplace("opt_module.rl_tune", "all"); + opt_infos.emplace("opt_module.aoe", "all"); +} + +Status GetOptInfo(WorkMode mode, const std::string &soc_ver, + std::map &opt_infos) { + if (std::find(kSocVersions.begin(), kSocVersions.end(), soc_ver)== kSocVersions.end()) { + SetAllOptInfo(opt_infos); + return SUCCESS; + } + opt_infos.emplace("opt_module.fe", "all"); + opt_infos.emplace("opt_module.pass", "all"); + opt_infos.emplace("opt_module.op_tune", "all"); + return SUCCESS; +} +} // namespace gelc diff --git a/tests/depends/profiler/src/profiler_stub.cc b/tests/depends/profiler/src/profiler_stub.cc index 1ed49fd8..0b8eaa88 100644 --- a/tests/depends/profiler/src/profiler_stub.cc +++ b/tests/depends/profiler/src/profiler_stub.cc @@ -16,6 +16,7 @@ #include "toolchain/prof_engine.h" #include "toolchain/prof_mgr_core.h" +#include "runtime/base.h" void * ProfMgrStartUp(const ProfMgrCfg *cfg) { @@ -32,3 +33,10 @@ int Msprof::Engine::RegisterEngine(const std::string& module, const Msprof::Engi return 0; } +rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) { + return 0; +} + +rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback) { + return 0; +} diff --git a/tests/depends/runtime/src/runtime_stub.cc b/tests/depends/runtime/src/runtime_stub.cc index 2b1af23c..32df7552 100644 --- a/tests/depends/runtime/src/runtime_stub.cc +++ b/tests/depends/runtime/src/runtime_stub.cc @@ -16,12 +16,94 @@ #include #include +#include "runtime_stub.h" +#include "runtime/rt.h" + +#define ADD_STUB_RETURN_VALUE(FUNC, TYPE) std::vector g_Stub_##FUNC##_RETURN + +#define GET_STUB_RETURN_VALUE(FUNC, TYPE, DEFAULT) ({ \ + TYPE result = DEFAULT; \ + if (!g_Stub_##FUNC##_RETURN.empty()) { \ + result = g_Stub_##FUNC##_RETURN.back(); \ + g_Stub_##FUNC##_RETURN.pop_back(); \ + } \ + result; \ +}) + +#define DEL_STUB_RETURN_VALUE(FUNC, TYPE) \ +do { \ + extern std::vector g_Stub_##FUNC##_RETURN; \ + g_Stub_##FUNC##_RETURN.clear(); \ +} while (0) + + +#define ADD_STUB_OUTBOUND_VALUE(FUNC, TYPE, NAME) std::vector g_Stub_##FUNC##_OUT_##NAME + +#define GET_STUB_OUTBOUND_VALUE(FUNC, TYPE, NAME, DEFAULT) ({ \ + TYPE value; \ + if (!g_Stub_##FUNC##_OUT_##NAME.empty()) { \ + value = g_Stub_##FUNC##_OUT_##NAME.back(); \ + g_Stub_##FUNC##_OUT_##NAME.pop_back(); \ + } else { \ + value = DEFAULT; \ + } \ + value; \ +}) + +#define DEL_STUB_OUTBOUND_VALUE(FUNC, TYPE, NAME) \ +do { \ + extern std::vector g_Stub_##FUNC##_OUT_##NAME; \ + g_Stub_##FUNC##_OUT_##NAME.clear(); \ +} while (0) #ifdef __cplusplus extern "C" { #endif #define EVENT_LENTH 10 +void rtStubTearDown() { + DEL_STUB_RETURN_VALUE(rtGetDevice, rtError_t); + DEL_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t); + DEL_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t); + DEL_STUB_RETURN_VALUE(rtEventReset, rtError_t); + DEL_STUB_RETURN_VALUE(rtEventCreate, rtError_t); + DEL_STUB_RETURN_VALUE(rtGetEventID, rtError_t); +} + +ADD_STUB_RETURN_VALUE(rtGetDevice, rtError_t); +rtError_t rtGetDevice(int32_t *device) { + return GET_STUB_RETURN_VALUE(rtGetDevice, rtError_t, RT_ERROR_NONE); +} + +ADD_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t); +ADD_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value); +rtError_t rtGetDeviceCapability(int32_t device, int32_t moduleType, int32_t featureType, int32_t *value) { + *value = GET_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT); + return GET_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); +} + +ADD_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t); +rtError_t rtStreamWaitEvent(rtStream_t stream, rtEvent_t event) { + return GET_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, RT_ERROR_NONE); +} + +ADD_STUB_RETURN_VALUE(rtEventReset, rtError_t); +rtError_t rtEventReset(rtEvent_t event, rtStream_t stream) { + return GET_STUB_RETURN_VALUE(rtEventReset, rtError_t, RT_ERROR_NONE); +} + +ADD_STUB_RETURN_VALUE(rtEventCreate, rtError_t); +rtError_t rtEventCreate(rtEvent_t *event) { + *event = new int[EVENT_LENTH]; + return GET_STUB_RETURN_VALUE(rtEventCreate, rtError_t, RT_ERROR_NONE); +} + +ADD_STUB_RETURN_VALUE(rtGetEventID, rtError_t); +rtError_t rtGetEventID(rtEvent_t event, uint32_t *event_id) { + *event_id = 0; + return GET_STUB_RETURN_VALUE(rtEventCreate, rtError_t, RT_ERROR_NONE); +} + rtError_t rtCtxSetCurrent(rtContext_t ctx) { return RT_ERROR_NONE; } rtError_t rtGetStreamId(rtStream_t stream, int32_t *stream_id) { @@ -42,11 +124,6 @@ rtError_t rtEventGetTimeStamp(uint64_t *time, rtEvent_t event) { return RT_ERROR_NONE; } -rtError_t rtEventCreate(rtEvent_t *event) { - *event = new int[EVENT_LENTH]; - return RT_ERROR_NONE; -} - rtError_t rtEventCreateWithFlag(rtEvent_t *event, uint32_t flag) { return rtEventCreate(event); } @@ -112,8 +189,6 @@ rtError_t rtMemcpyAsync(void *dst, uint64_t dest_max, const void *src, uint64_t return RT_ERROR_NONE; } -rtError_t rtStreamWaitEvent(rtStream_t stream, rtEvent_t event) { return RT_ERROR_NONE; } - rtError_t rtSetTSDevice(uint32_t tsId) { return RT_ERROR_NONE; } @@ -193,6 +268,12 @@ rtError_t rtMemGetInfo(size_t *free, size_t *total) { return RT_ERROR_NONE; } +rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) { + *free = 512UL * 1024UL * 1024UL; + *total = 1024UL * 1024UL * 1024UL; + return RT_ERROR_NONE; +} + rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag) { return RT_ERROR_NONE; } rtError_t rtMemFreeManaged(void *ptr) { return RT_ERROR_NONE; } @@ -341,10 +422,6 @@ rtError_t rtStreamSwitchEx(void *ptr, rtCondition_t condition, void *value_ptr, rtError_t rtStreamActive(rtStream_t active_stream, rtStream_t stream) { return RT_ERROR_NONE; } -rtError_t rtEventReset(rtEvent_t event, rtStream_t stream) { return RT_ERROR_NONE; } - -rtError_t rtGetDevice(int32_t *device) { return RT_ERROR_NONE; } - rtError_t rtDatadumpInfoLoad(const void *dump_info, uint32_t length) { return RT_ERROR_NONE; } rtError_t rtKernelLaunchWithFlag(const void *stub_func, uint32_t block_dim, void *args, uint32_t args_size, @@ -456,6 +533,25 @@ rtError_t rtDebugRegisterForStream(rtStream_t stream, uint32_t flag, const void rtError_t rtDebugUnRegisterForStream(rtStream_t stream) { return RT_ERROR_NONE; } + +rtError_t rtFftsTaskLaunch(rtFftsTaskInfo_t *fftsTaskInfo, rtStream_t stream) { + return RT_ERROR_NONE; +} + +rtError_t rtKernelLaunchFwk(const char *opName, void *args, uint32_t argSize, uint32_t flags, rtStream_t rtStream) { + return RT_ERROR_NONE; +} + +rtError_t rtAicpuKernelLaunchWithFlag(const rtKernelLaunchNames_t *launchNames, uint32_t blockDim, const void *args, + uint32_t argSize, rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) { + return RT_ERROR_NONE; +} + +rtError_t rtAicpuKernelLaunch(const rtKernelLaunchNames_t *launchNames, uint32_t blockDim, const void *args, + uint32_t argSize, rtSmDesc_t *smDesc, rtStream_t stream) { + return RT_ERROR_NONE; +} + #ifdef __cplusplus } #endif diff --git a/tests/depends/runtime/src/runtime_stub.h b/tests/depends/runtime/src/runtime_stub.h new file mode 100644 index 00000000..b693b9ea --- /dev/null +++ b/tests/depends/runtime/src/runtime_stub.h @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __INC_LLT_RUNTIME_STUB_H +#define __INC_LLT_RUNTIME_STUB_H + +#include + +#ifdef __cplusplus +extern "C" { +#endif +void rtStubTearDown(); + +#define RTS_STUB_SETUP() \ +do { \ + rtStubTearDown(); \ +} while (0) + +#define RTS_STUB_TEARDOWN() \ +do { \ + rtStubTearDown(); \ +} while (0) + +#define RTS_STUB_RETURN_VALUE(FUNC, TYPE, VALUE) \ +do { \ + g_Stub_##FUNC##_RETURN.emplace(g_Stub_##FUNC##_RETURN.begin(), VALUE); \ +} while (0) + +#define RTS_STUB_OUTBOUND_VALUE(FUNC, TYPE, NAME, VALUE) \ +do { \ + g_Stub_##FUNC##_OUT_##NAME.emplace(g_Stub_##FUNC##_OUT_##NAME.begin(), VALUE); \ +} while (0) + + +#define RTS_STUB_RETURN_EXTERN(FUNC, TYPE) extern std::vector g_Stub_##FUNC##_RETURN; +#define RTS_STUB_OUTBOUND_EXTERN(FUNC, TYPE, NAME) extern std::vector g_Stub_##FUNC##_OUT_##NAME; + +RTS_STUB_RETURN_EXTERN(rtGetDevice, rtError_t); +RTS_STUB_OUTBOUND_EXTERN(rtGetDevice, int32_t, device) + +RTS_STUB_RETURN_EXTERN(rtGetDeviceCapability, rtError_t); +RTS_STUB_OUTBOUND_EXTERN(rtGetDeviceCapability, int32_t, value); + +RTS_STUB_RETURN_EXTERN(rtStreamWaitEvent, rtError_t); + +RTS_STUB_RETURN_EXTERN(rtEventReset, rtError_t); + +RTS_STUB_RETURN_EXTERN(rtEventCreate, rtError_t); +RTS_STUB_OUTBOUND_EXTERN(rtEventCreate, rtEvent_t, event); + +RTS_STUB_RETURN_EXTERN(rtGetEventID, rtError_t); +RTS_STUB_OUTBOUND_EXTERN(rtEventCreate, uint32_t, event_id); + +#ifdef __cplusplus +} +#endif +#endif // __INC_LLT_RUNTIME_STUB_H diff --git a/tests/depends/slog/src/slog_stub.cc b/tests/depends/slog/src/slog_stub.cc index d0eb49c5..238a6b37 100644 --- a/tests/depends/slog/src/slog_stub.cc +++ b/tests/depends/slog/src/slog_stub.cc @@ -23,13 +23,46 @@ void dav_log(int module_id, const char *fmt, ...) {} -void DlogErrorInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } +static int log_level = DLOG_ERROR; + +#define __DO_PRINT() \ + do { \ + const int FMT_BUFF_SIZE = 1024; \ + char fmt_buff[FMT_BUFF_SIZE] = {0}; \ + va_list valist; \ + va_start(valist, fmt); \ + vsnprintf(fmt_buff, FMT_BUFF_SIZE, fmt, valist); \ + va_end(valist); \ + printf("%s \n", fmt_buff); \ + } while (0) + +void DlogErrorInner(int module_id, const char *fmt, ...) { + if (log_level > DLOG_ERROR) { + return; + } + __DO_PRINT(); +} -void DlogWarnInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } +void DlogWarnInner(int module_id, const char *fmt, ...) { + if (log_level > DLOG_WARN) { + return; + } + __DO_PRINT(); +} -void DlogInfoInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } +void DlogInfoInner(int module_id, const char *fmt, ...) { + if (log_level > DLOG_INFO) { + return; + } + __DO_PRINT(); +} -void DlogDebugInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } +void DlogDebugInner(int module_id, const char *fmt, ...) { + if (log_level > DLOG_DEBUG) { + return; + } + __DO_PRINT(); +} void DlogEventInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } @@ -39,30 +72,25 @@ void DlogWithKVInner(int module_id, int level, KeyValue *pst_kv_array, int kv_nu dav_log(module_id, fmt); } -int dlog_setlevel(int module_id, int level, int enable_event) { return DLOG_DEBUG; } +int dlog_setlevel(int module_id, int level, int enable_event) { + log_level = level; + return log_level; +} -int dlog_getlevel(int module_id, int *enable_event) { return DLOG_DEBUG; } +int dlog_getlevel(int module_id, int *enable_event) { return log_level; } -int CheckLogLevel(int moduleId, int logLevel) -{ - return 1; -} +int CheckLogLevel(int moduleId, int log_level_check) { return log_level >= log_level_check; } /** * @ingroup plog * @brief DlogReportInitialize: init log in service process before all device setting. * @return: 0: SUCCEED, others: FAILED */ -int DlogReportInitialize() { - return 0; -} +int DlogReportInitialize() { return 0; } /** * @ingroup plog * @brief DlogReportFinalize: release log resource in service process after all device reset. * @return: 0: SUCCEED, others: FAILED */ -int DlogReportFinalize() { - return 0; -} - +int DlogReportFinalize() { return 0; } diff --git a/tests/framework/CMakeLists.txt b/tests/framework/CMakeLists.txt index d7c806a6..bbab454b 100644 --- a/tests/framework/CMakeLists.txt +++ b/tests/framework/CMakeLists.txt @@ -15,18 +15,5 @@ include(cmake/graphengine.cmake) add_subdirectory(easy_graph) -add_subdirectory(stub_engine) add_subdirectory(ge_graph_dsl) - -file(GLOB_RECURSE UTILS_SRC CONFIGURE_DEPENDS - "utils/*.cc" - ) - -add_library(framework STATIC ${UTILS_SRC}) - -target_include_directories(framework - PUBLIC utils/ -) - -set_target_properties(framework PROPERTIES CXX_STANDARD 11) -target_link_libraries(framework PUBLIC ge_graph_dsl graphengine fe) +add_subdirectory(ge_running_env) diff --git a/tests/framework/cmake/graphengine.cmake b/tests/framework/cmake/graphengine.cmake index 81aa00cc..d83203b4 100644 --- a/tests/framework/cmake/graphengine.cmake +++ b/tests/framework/cmake/graphengine.cmake @@ -103,6 +103,7 @@ list(APPEND INCLUDE_DIRECTORIES "${GE_CODE_DIR}/third_party/fwkacllib/inc/cce" "${GE_CODE_DIR}/third_party/fwkacllib/inc/ops" "${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain" + "${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info" "${GE_CODE_DIR}/tests/ut/ge" "${GE_CODE_DIR}/tests/ut/common" "${CMAKE_BINARY_DIR}" @@ -117,6 +118,7 @@ list(APPEND STUB_LIBS runtime_stub profiler_stub hccl_stub + opt_feature_stub error_manager_stub ascend_protobuf json @@ -150,7 +152,7 @@ set_target_properties(metadef_graph PROPERTIES CXX_STANDARD 11) # ---- Target : Local engine ---- -add_library(ge_local_engine SHARED ${LOCAL_ENGINE_SRC} ${METADEF_REGISTER_SRCS}) +add_library(ge_local_engine SHARED ${LOCAL_ENGINE_SRC}) target_include_directories(ge_local_engine PUBLIC @@ -169,38 +171,11 @@ target_compile_options(ge_local_engine PRIVATE target_link_libraries(ge_local_engine PUBLIC $ ${STUB_LIBS} - metadef_graph -lrt -ldl -lpthread -lgcov ) set_target_properties(ge_local_engine PROPERTIES CXX_STANDARD 11) -# ---- Target : Host engine ---- - -add_library(host_cpu_engine SHARED ${HOST_ENGINE_SRC}) - -target_include_directories(host_cpu_engine - PUBLIC - "${INCLUDE_DIRECTORIES}" - "${GE_CODE_DIR}/ge/host_cpu_engine" -) - -target_compile_definitions(host_cpu_engine PRIVATE - google=ascend_private - FMK_SUPPORT_DUMP -) - -target_compile_options(host_cpu_engine PRIVATE - -g --coverage -fprofile-arcs -ftest-coverage - -Werror=format -) - -target_link_libraries(host_cpu_engine PUBLIC - $ ${STUB_LIBS} metadef_graph -lrt -ldl -lpthread -lgcov -) - -set_target_properties(host_cpu_engine PROPERTIES CXX_STANDARD 11) - # ---- Target : engine plugin---- # @@ -273,4 +248,4 @@ target_link_libraries(graphengine PUBLIC ) set_target_properties(graphengine PROPERTIES CXX_STANDARD 11) -add_dependencies(graphengine host_cpu_engine ge_local_engine nnengine engine_conf.json optimizer_priority.pbtxt) +add_dependencies(graphengine ge_local_engine nnengine engine_conf.json optimizer_priority.pbtxt) diff --git a/tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h b/tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h index 4d430983..46bfe324 100644 --- a/tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h +++ b/tests/framework/easy_graph/include/easy_graph/builder/graph_dsl.h @@ -26,16 +26,32 @@ EG_NS_BEGIN //////////////////////////////////////////////////////////////// namespace detail { -template +template Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) { GraphBuilder builder(name); builderInDSL(builder); return std::move(*builder); } + +struct GraphDefiner { + GraphDefiner(const char *defaultName, const char *specifiedName = nullptr) { + name = specifiedName ? specifiedName : defaultName; + } + + template + auto operator|(USER_BUILDER &&userBuilder) { + GraphBuilder graphBuilder{name}; + std::forward(userBuilder)(graphBuilder); + return *graphBuilder; + } + + private: + const char *name; +}; + } // namespace detail -#define HAS_NAME(...) NOT_EMPTY_SELECT(__VA_ARGS__) -#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::BuildGraph(HAS_NAME(__VA_ARGS__)(__VA_ARGS__, #G), [&](::EG_NS::GraphBuilder& BUILDER) +#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::GraphDefiner(#G, ##__VA_ARGS__) | [&](auto &&BUILDER) #define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__ #define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__ #define CHAIN(...) DATA_CHAIN(__VA_ARGS__) diff --git a/tests/framework/easy_graph/src/layout/graph_layout.cc b/tests/framework/easy_graph/src/layout/graph_layout.cc index 340acf67..716bed8a 100644 --- a/tests/framework/easy_graph/src/layout/graph_layout.cc +++ b/tests/framework/easy_graph/src/layout/graph_layout.cc @@ -16,10 +16,15 @@ #include "easy_graph/layout/graph_layout.h" #include "easy_graph/layout/layout_executor.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" #include "easy_graph/graph/graph.h" EG_NS_BEGIN +namespace { +GraphEasyExecutor default_executor; +} + void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { this->executor_ = &executor; options_ = opts; @@ -27,8 +32,7 @@ void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) { const LayoutOption *options = opts ? opts : this->options_; - if (!executor_) - return EG_UNIMPLEMENTED; + if (!executor_) return static_cast(default_executor).Layout(graph, options); return executor_->Layout(graph, options); } diff --git a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/assert_error.h b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/assert_error.h new file mode 100644 index 00000000..7f5d5086 --- /dev/null +++ b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/assert_error.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef D52AA06185E34BBFB714FFBCDAB0D53A +#define D52AA06185E34BBFB714FFBCDAB0D53A + +#include "ge_graph_dsl/ge.h" +#include +#include + +GE_NS_BEGIN + +struct AssertError : std::exception { + AssertError(const char *file, int line, const std::string &info); + + private: + const char *what() const noexcept override; + + private: + std::string info; +}; + +GE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/check_utils.h b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/check_utils.h new file mode 100644 index 00000000..fa0ae783 --- /dev/null +++ b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/check_utils.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_31309AA0A4E44C009C22AD9351BF3410 +#define INC_31309AA0A4E44C009C22AD9351BF3410 + +#include "ge_graph_dsl/ge.h" +#include "graph/compute_graph.h" + +GE_NS_BEGIN + +using GraphCheckFun = std::function; +struct CheckUtils { + static bool CheckGraph(const std::string &phase_id, const GraphCheckFun &fun); + static void init(); +}; + +GE_NS_END + +#endif \ No newline at end of file diff --git a/ge/graph/common/local_context.h b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/filter_scope_guard.h similarity index 63% rename from ge/graph/common/local_context.h rename to tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/filter_scope_guard.h index 83367766..a208c02e 100644 --- a/ge/graph/common/local_context.h +++ b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/filter_scope_guard.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,14 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef C8B32320BD4943D588594B82FFBF2685 +#define C8B32320BD4943D588594B82FFBF2685 -#ifndef GE_GRAPH_COMMON_LOCAL_CONTEXT_H_ -#define GE_GRAPH_COMMON_LOCAL_CONTEXT_H_ +#include +#include +#include "ge_graph_dsl/ge.h" -#include "omg/omg_inner_types.h" +GE_NS_BEGIN -namespace ge { -void SetLocalOmgContext(OmgContext &context); -OmgContext &GetLocalOmgContext(); -} // namespace ge -#endif // GE_GRAPH_COMMON_LOCAL_CONTEXT_H_ +struct FilterScopeGuard { + FilterScopeGuard(const std::vector &); + ~FilterScopeGuard(); +}; + +GE_NS_END + +#endif diff --git a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/graph_assert.h b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/graph_assert.h new file mode 100644 index 00000000..663907a0 --- /dev/null +++ b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/assert/graph_assert.h @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AD954C4ADF5B44F5B1CC8BCD72EE9ED6 +#define AD954C4ADF5B44F5B1CC8BCD72EE9ED6 + +#include "ge_graph_dsl/ge.h" +#include "ge_graph_dsl/assert/check_utils.h" +#include "ge_graph_dsl/assert/assert_error.h" +#include "ge_graph_dsl/assert/filter_scope_guard.h" + +GE_NS_BEGIN + +#ifdef GTEST_MESSAGE_AT_ +#define GRAPH_CHECK_MESSAGE(file, line, message) \ + GTEST_MESSAGE_AT_(file, line, message, ::testing::TestPartResult::kFatalFailure) +#elif +#define GRAPH_CHECK_MESSAGE(file, line, message) throw AssertError(file, line, message) +#endif + +namespace detail { +struct GraphAssert { + GraphAssert(const char *file, unsigned int line, const std::string &phase_id) + : file_(file), line_(line), phase_id_(phase_id) {} + + void operator|(const ::GE_NS::GraphCheckFun &check_fun) { + bool ret = ::GE_NS::CheckUtils::CheckGraph(phase_id_, check_fun); + if (!ret) { + auto message = "expect dump graph in phase: [" + phase_id_ + "], while not find the dump graph! "; + GRAPH_CHECK_MESSAGE(file_, line_, message.c_str()); + } + } + + private: + const char *file_; + unsigned int line_; + const std::string phase_id_; +}; +} // namespace detail + +#define DUMP_GRAPH_WHEN(...) ::GE_NS::FilterScopeGuard guard__COUNTER__({__VA_ARGS__}); +#define CHECK_GRAPH(phase_id) \ + ::GE_NS::detail::GraphAssert(__FILE__, __LINE__, #phase_id) | [&](const ::GE_NS::ComputeGraphPtr &graph) + +GE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h index bb2326ec..99eafa7f 100644 --- a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h +++ b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg.h @@ -33,14 +33,12 @@ struct OpDescCfg { std::vector shape_; }; - OpDescCfg(const OpType &type, int in_cnt = 0, int out_cnt = 0, Format format = FORMAT_NCHW, + OpDescCfg(const OpType &type, int in_cnt = 1, int out_cnt = 1, Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, std::vector shape = {1, 1, 224, 224}) : type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {} protected: - OpType GetType() const { - return type_; - } + OpType GetType() const { return type_; } OpType type_; int in_cnt_; int out_cnt_; diff --git a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg_box.h b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg_box.h index af3a1971..2be05972 100644 --- a/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg_box.h +++ b/tests/framework/ge_graph_dsl/include/ge_graph_dsl/op_desc/op_desc_cfg_box.h @@ -21,6 +21,7 @@ #include "ge_graph_dsl/ge.h" #include "ge_graph_dsl/op_desc/op_box.h" #include "ge_graph_dsl/op_desc/op_desc_cfg.h" +#include "graph/ge_attr_value.h" #include "graph/op_desc.h" GE_NS_BEGIN @@ -29,19 +30,32 @@ struct OpDescCfgBox : OpBox, private OpDescCfg { OpDescCfgBox(const OpType &opType); OpDescCfgBox &InCnt(int in_cnt); OpDescCfgBox &OutCnt(int out_cnt); + OpDescCfgBox &ParentNodeIndex(int node_index); OpDescCfgBox &TensorDesc(Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, - std::vector shape = {1, 1, 224, 224}); - template - OpDescCfgBox& Attr(const std::string &name, Type value) { - auto attrvalue = ge::GeAttrValue::CreateFrom(value); - attrs_.emplace(std::make_pair(name, attrvalue)); - return *this; - } + std::vector shape = {1, 1, 224, 224}); + OpDescCfgBox &Weight(GeTensorPtr &); - private: + template + OpDescCfgBox &Attr(const std::string &name, Type &&value) { + auto attrvalue = ge::GeAttrValue::CreateFrom(std::forward(value)); + attrs_.emplace(std::make_pair(name, attrvalue)); + return *this; + } + + template + OpDescCfgBox &Attr(const std::string &name, Type &value) { + auto attrvalue = ge::GeAttrValue::CreateFrom(value); + attrs_.emplace(std::make_pair(name, attrvalue)); + return *this; + } + + OpDescCfgBox &Attr(const std::string &name, int value); + OpDescCfgBox &Attr(const std::string &name, const char *value); OpDescPtr Build(const ::EG_NS::NodeId &id) const override; - void UpdateAttrs(OpDescPtr&) const; - std::map attrs_; + + private: + void UpdateAttrs(OpDescPtr &) const; + std::map attrs_; }; #define OP_CFG(optype) ::GE_NS::OpDescCfgBox(optype) diff --git a/tests/framework/ge_graph_dsl/src/assert/assert_error.cc b/tests/framework/ge_graph_dsl/src/assert/assert_error.cc new file mode 100644 index 00000000..5b74d852 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/assert_error.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ge_graph_dsl/assert/assert_error.h" + +GE_NS_BEGIN + +AssertError::AssertError(const char *file, int line, const std::string &info) { + this->info = std::string(file) + ":" + std::to_string(line) + "\n" + info; +} + +const char *AssertError::what() const noexcept { return info.c_str(); } + +GE_NS_END diff --git a/tests/framework/ge_graph_dsl/src/assert/check_utils.cc b/tests/framework/ge_graph_dsl/src/assert/check_utils.cc new file mode 100644 index 00000000..56bc6e81 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/check_utils.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_graph_dsl/assert/check_utils.h" +#include "graph/utils/dumper/ge_graph_dumper.h" +#include "ge_graph_default_checker.h" +#include "ge_graph_check_dumper.h" + +GE_NS_BEGIN + +bool CheckUtils::CheckGraph(const std::string &phase_id, const GraphCheckFun &fun) { + auto &dumper = dynamic_cast(GraphDumperRegistry::GetDumper()); + return dumper.CheckFor(GeGraphDefaultChecker(phase_id, fun)); +} + +void CheckUtils::init() { + static GeGraphCheckDumper checkDumper; + GraphDumperRegistry::Register(checkDumper); +} + +GE_NS_END diff --git a/tests/framework/ge_graph_dsl/src/assert/filter_scope_guard.cc b/tests/framework/ge_graph_dsl/src/assert/filter_scope_guard.cc new file mode 100644 index 00000000..4aa4795d --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/filter_scope_guard.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_graph_dsl/assert/filter_scope_guard.h" +#include "graph/utils/dumper/ge_graph_dumper.h" +#include "ge_dump_filter.h" + +GE_NS_BEGIN + +namespace { +GeDumpFilter &GetDumpFilter() { return dynamic_cast(GraphDumperRegistry::GetDumper()); } +} // namespace + +FilterScopeGuard::FilterScopeGuard(const std::vector &filter) { GetDumpFilter().Update(filter); } + +FilterScopeGuard::~FilterScopeGuard() { GetDumpFilter().Reset(); } + +GE_NS_END \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_dump_filter.h b/tests/framework/ge_graph_dsl/src/assert/ge_dump_filter.h new file mode 100644 index 00000000..47967c91 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_dump_filter.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_4C6224E8F7474EF89B18CCB0E4B19FD6 +#define INC_4C6224E8F7474EF89B18CCB0E4B19FD6 + +#include +#include +#include "ge_graph_dsl/ge.h" +#include "easy_graph/infra/keywords.h" + +GE_NS_BEGIN + +INTERFACE(GeDumpFilter) { + ABSTRACT(void Update(const std::vector &)); + ABSTRACT(void Reset()); +}; + +GE_NS_END + +#endif diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.cc b/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.cc new file mode 100644 index 00000000..ba72cf86 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_graph_check_dumper.h" +#include "graph/model.h" +#include "graph/buffer.h" +#include "graph/utils/graph_utils.h" +#include "ge_graph_default_checker.h" + +GE_NS_BEGIN + +GeGraphCheckDumper::GeGraphCheckDumper() { Reset(); } + +bool GeGraphCheckDumper::IsNeedDump(const std::string &suffix) const { + auto iter = std::find(suffixes_.begin(), suffixes_.end(), suffix); + return (iter != suffixes_.end()); +} + +void GeGraphCheckDumper::Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) { + if (!IsNeedDump(suffix)) { + return; + } + auto iter = buffers_.find(suffix); + if (iter != buffers_.end()) { + DumpGraph(graph, iter->second); + } else { + buffers_[suffix] = Buffer(); + DumpGraph(graph, buffers_.at(suffix)); + } +} + +bool GeGraphCheckDumper::CheckFor(const GeGraphChecker &checker) { + auto iter = buffers_.find(checker.PhaseId()); + if (iter == buffers_.end()) { + return false; + } + DoCheck(checker, iter->second); + return true; +} + +void GeGraphCheckDumper::DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer) { + Model model("", ""); + Model::Load(buffer.GetData(), buffer.GetSize(), model); + auto load_graph = model.GetGraph(); + checker.Check(GraphUtils::GetComputeGraph(load_graph)); +} + +void GeGraphCheckDumper::DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer) { + Model model("", ""); + buffer.clear(); + model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); + model.Save(buffer, true); +} + +void GeGraphCheckDumper::Update(const std::vector &new_suffixes_) { + suffixes_ = new_suffixes_; + buffers_.clear(); +} + +void GeGraphCheckDumper::Reset() { + static std::vector default_suffixes_{"PreRunAfterBuild"}; + suffixes_ = default_suffixes_; + buffers_.clear(); +} + +GE_NS_END \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.h b/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.h new file mode 100644 index 00000000..5eda52ea --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_graph_check_dumper.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_8EFED0015C27464897BF64531355C810 +#define INC_8EFED0015C27464897BF64531355C810 + +#include "ge_graph_dsl/ge.h" +#include "graph/utils/dumper/ge_graph_dumper.h" +#include "ge_dump_filter.h" +#include + +GE_NS_BEGIN + +struct GeGraphChecker; + +struct GeGraphCheckDumper : GeGraphDumper, GeDumpFilter { + GeGraphCheckDumper(); + virtual void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix); + bool CheckFor(const GeGraphChecker &checker); + + private: + void DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer); + void DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer); + + private: + void Update(const std::vector &) override; + void Reset() override; + bool IsNeedDump(const std::string &suffix) const; + + private: + std::map buffers_; + std::vector suffixes_; +}; + +GE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_graph_checker.h b/tests/framework/ge_graph_dsl/src/assert/ge_graph_checker.h new file mode 100644 index 00000000..c6b25b65 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_graph_checker.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_5960A8F437324904BEE0690271258762 +#define INC_5960A8F437324904BEE0690271258762 + +#include "ge_graph_dsl/ge.h" +#include "easy_graph/infra/keywords.h" +#include "graph/compute_graph.h" + +GE_NS_BEGIN + +INTERFACE(GeGraphChecker) { + ABSTRACT(const std::string &PhaseId() const); + ABSTRACT(void Check(const ge::ComputeGraphPtr &graph) const); +}; + +GE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.cc b/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.cc new file mode 100644 index 00000000..4aa48ac6 --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_graph_default_checker.h" + +GE_NS_BEGIN + +GeGraphDefaultChecker::GeGraphDefaultChecker(const std::string &phase_id, const GraphCheckFun &check_fun) + : phase_id_(phase_id), check_fun_(check_fun) {} + +const std::string &GeGraphDefaultChecker::PhaseId() const { return phase_id_; } + +void GeGraphDefaultChecker::Check(const ge::ComputeGraphPtr &graph) const { return check_fun_(graph); } + +GE_NS_END \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.h b/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.h new file mode 100644 index 00000000..af8f3fbe --- /dev/null +++ b/tests/framework/ge_graph_dsl/src/assert/ge_graph_default_checker.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef BCF4D96BE9FC48938DE7B7E93B551C54 +#define BCF4D96BE9FC48938DE7B7E93B551C54 + +#include "ge_graph_dsl/ge.h" +#include "ge_graph_checker.h" +#include "graph/compute_graph.h" + +GE_NS_BEGIN + +using GraphCheckFun = std::function; + +struct GeGraphDefaultChecker : GeGraphChecker { + GeGraphDefaultChecker(const std::string &, const GraphCheckFun &); + + private: + const std::string &PhaseId() const override; + void Check(const ge::ComputeGraphPtr &graph) const override; + + private: + const std::string phase_id_; + const GraphCheckFun check_fun_; +}; + +GE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/ge_graph_dsl/src/op_desc_cfg_box.cc b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_box.cc similarity index 71% rename from tests/framework/ge_graph_dsl/src/op_desc_cfg_box.cc rename to tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_box.cc index fc2a6c1c..be7cd831 100644 --- a/tests/framework/ge_graph_dsl/src/op_desc_cfg_box.cc +++ b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_box.cc @@ -17,8 +17,8 @@ #include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" #include "easy_graph/infra/status.h" #include "ge_graph_dsl/op_desc/op_desc_cfg_repo.h" -#include "ge_graph_dsl/op_desc/op_desc_cfg.h" #include "external/graph/gnode.h" +#include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" using ::EG_NS::Status; @@ -44,6 +44,26 @@ OpDescCfgBox &OpDescCfgBox::OutCnt(int out_cnt) { return *this; } +OpDescCfgBox &OpDescCfgBox::ParentNodeIndex(int node_index) { + this->Attr(ATTR_NAME_PARENT_NODE_INDEX, node_index); + return *this; +} + +OpDescCfgBox &OpDescCfgBox::Attr(const std::string &name, int value) { + this->Attr(name, (int64_t)value); + return *this; +} + +OpDescCfgBox &OpDescCfgBox::Attr(const std::string &name, const char *value) { + this->Attr(name, std::string(value)); + return *this; +} + +OpDescCfgBox &OpDescCfgBox::Weight(GeTensorPtr &tensor_ptr) { + this->Attr(ATTR_NAME_WEIGHTS, tensor_ptr); + return *this; +} + OpDescCfgBox &OpDescCfgBox::TensorDesc(Format format, DataType data_type, std::vector shape) { default_tensor_.format_ = format; default_tensor_.data_type_ = data_type; @@ -51,10 +71,9 @@ OpDescCfgBox &OpDescCfgBox::TensorDesc(Format format, DataType data_type, std::v return *this; } -void OpDescCfgBox::UpdateAttrs(OpDescPtr& op_desc) const { - std::for_each(attrs_.begin(), attrs_.end(), [&op_desc](const auto &attr){ - op_desc->SetAttr(attr.first, attr.second); - }); +void OpDescCfgBox::UpdateAttrs(OpDescPtr &op_desc) const { + std::for_each(attrs_.begin(), attrs_.end(), + [&op_desc](const auto &attr) { op_desc->SetAttr(attr.first, attr.second); }); } OpDescPtr OpDescCfgBox::Build(const ::EG_NS::NodeId &id) const { diff --git a/tests/framework/ge_graph_dsl/src/op_desc_cfg_repo.cc b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc similarity index 53% rename from tests/framework/ge_graph_dsl/src/op_desc_cfg_repo.cc rename to tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc index e7fa018f..19dfa4a5 100644 --- a/tests/framework/ge_graph_dsl/src/op_desc_cfg_repo.cc +++ b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc @@ -23,15 +23,22 @@ GE_NS_BEGIN namespace { -#define OP_CFG(optype, ...) \ - { \ - optype, OpDescCfg { \ - optype, __VA_ARGS__ \ - } \ +#define OP_CFG(optype, ...) \ + { \ + optype, OpDescCfg { optype, __VA_ARGS__ } \ } static std::map cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(ENTER, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(MERGE, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(CONSTANT, 0, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(LESS, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL, {1, 1, 224, 224}), + OP_CFG(SWITCH, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(EXIT, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(NEXTITERATION, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(NETOUTPUT, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), OP_CFG(VARIABLE, 1, 1)}; } // namespace diff --git a/tests/framework/ge_graph_dsl/src/op_desc_ptr_box.cc b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_ptr_box.cc similarity index 97% rename from tests/framework/ge_graph_dsl/src/op_desc_ptr_box.cc rename to tests/framework/ge_graph_dsl/src/op_desc/op_desc_ptr_box.cc index 23d4773c..1564e019 100644 --- a/tests/framework/ge_graph_dsl/src/op_desc_ptr_box.cc +++ b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_ptr_box.cc @@ -19,6 +19,4 @@ USING_GE_NS -OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { - return op_; -} +OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { return op_; } diff --git a/tests/framework/ge_graph_dsl/src/ge_graph_vistor.cc b/tests/framework/ge_graph_dsl/src/vistor/ge_graph_visitor.cc similarity index 89% rename from tests/framework/ge_graph_dsl/src/ge_graph_vistor.cc rename to tests/framework/ge_graph_dsl/src/vistor/ge_graph_visitor.cc index d8bc2aab..c1dca646 100644 --- a/tests/framework/ge_graph_dsl/src/ge_graph_vistor.cc +++ b/tests/framework/ge_graph_dsl/src/vistor/ge_graph_visitor.cc @@ -36,17 +36,11 @@ GE_NS_BEGIN GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared("")) {} -void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { - build_graph_ = graph; -} +void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { build_graph_ = graph; } -Graph GeGraphVisitor::BuildGeGraph() const { - return GraphUtils::CreateGraphFromComputeGraph(build_graph_); -} +Graph GeGraphVisitor::BuildGeGraph() const { return GraphUtils::CreateGraphFromComputeGraph(build_graph_); } -ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { - return build_graph_; -} +ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { return build_graph_; } Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) { build_graph_->SetName(graph.GetName()); diff --git a/tests/framework/ge_graph_dsl/src/ge_subgraph_vistor.cc b/tests/framework/ge_graph_dsl/src/vistor/ge_subgraph_vistor.cc similarity index 100% rename from tests/framework/ge_graph_dsl/src/ge_subgraph_vistor.cc rename to tests/framework/ge_graph_dsl/src/vistor/ge_subgraph_vistor.cc diff --git a/tests/framework/ge_graph_dsl/src/graph_dsl.cc b/tests/framework/ge_graph_dsl/src/vistor/graph_dsl.cc similarity index 100% rename from tests/framework/ge_graph_dsl/src/graph_dsl.cc rename to tests/framework/ge_graph_dsl/src/vistor/graph_dsl.cc diff --git a/tests/framework/ge_graph_dsl/tests/CMakeLists.txt b/tests/framework/ge_graph_dsl/tests/CMakeLists.txt index 40097d8b..65482679 100644 --- a/tests/framework/ge_graph_dsl/tests/CMakeLists.txt +++ b/tests/framework/ge_graph_dsl/tests/CMakeLists.txt @@ -26,7 +26,7 @@ target_compile_options(ge_graph_dsl_test PRIVATE ) set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17) -target_link_libraries(ge_graph_dsl_test PUBLIC gtest gtest_main ge_graph_dsl) +target_link_libraries(ge_graph_dsl_test PUBLIC gtest ge_graph_dsl) include(CTest) enable_testing() diff --git a/tests/framework/ge_graph_dsl/tests/check_graph_test.cc b/tests/framework/ge_graph_dsl/tests/check_graph_test.cc new file mode 100644 index 00000000..731b7eed --- /dev/null +++ b/tests/framework/ge_graph_dsl/tests/check_graph_test.cc @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "gtest/gtest.h" +#include "easy_graph/layout/graph_layout.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" +#include "ge_graph_dsl/graph_dsl.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/dumper/ge_graph_dumper.h" +#include "framework/common/types.h" +#include "ge_graph_dsl/assert/graph_assert.h" +#include "graph/model.h" +#include "graph/buffer.h" + +USING_GE_NS + +class CheckGraphTest : public testing::Test { + private: + EG_NS::GraphEasyExecutor executor; + + protected: + void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } + void TearDown() {} +}; + +TEST_F(CheckGraphTest, test_ge_graph_dump_is_work) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + + DUMP_GRAPH_WHEN("after_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); + + CHECK_GRAPH(after_build) { + ASSERT_EQ(graph->GetName(), "g1"); + ASSERT_EQ(graph->GetAllNodesSize(), 2); + }; +} + +TEST_F(CheckGraphTest, test_ge_graph_dump_two_phase) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + DEF_GRAPH(g2) { + CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); + CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); + }; + + DUMP_GRAPH_WHEN("before_build", "after_build"); + + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "after_build"); + + CHECK_GRAPH(before_build) { + ASSERT_EQ(graph->GetName(), "g1"); + ASSERT_EQ(graph->GetAllNodesSize(), 2); + }; + + CHECK_GRAPH(after_build) { + ASSERT_EQ(graph->GetName(), "g2"); + ASSERT_EQ(graph->GetAllNodesSize(), 3); + }; +} + +TEST_F(CheckGraphTest, test_ge_graph_dump_one_phase_two_times) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + DEF_GRAPH(g2) { + CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); + CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); + }; + + DUMP_GRAPH_WHEN("before_build") + + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "before_build"); + + CHECK_GRAPH(before_build) { + ASSERT_EQ(graph->GetName(), "g2"); + ASSERT_EQ(graph->GetAllNodesSize(), 3); + }; +} + +TEST_F(CheckGraphTest, test_check_phases_is_work) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + + DUMP_GRAPH_WHEN("before_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); + auto ret = ::GE_NS::CheckUtils::CheckGraph("after_build", [&](const ::GE_NS::ComputeGraphPtr &graph) {}); + ASSERT_FALSE(ret); +} + +TEST_F(CheckGraphTest, test_check_one_phase_dump_another_not_dump) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + + DUMP_GRAPH_WHEN("before_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); + GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); + + CHECK_GRAPH(before_build) { + ASSERT_EQ(graph->GetName(), "g1"); + ASSERT_EQ(graph->GetAllNodesSize(), 2); + }; +} + +TEST_F(CheckGraphTest, test_model_serialize_and_unserialize_success) { + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; + auto ge_graph = ToGeGraph(g1); + + ge::Model model("", ""); + model.SetGraph(ge_graph); + Buffer buffer; + model.Save(buffer, true); + + ge::Model loadModel("", ""); + Model::Load(buffer.GetData(), buffer.GetSize(), loadModel); + auto load_graph = loadModel.GetGraph(); + + ASSERT_EQ(load_graph.GetName(), "g1"); + ASSERT_EQ(load_graph.GetAllNodes().size(), 2); +} diff --git a/tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc b/tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc index f7e55e3d..a8240b32 100644 --- a/tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc +++ b/tests/framework/ge_graph_dsl/tests/graph_dsl_test.cc @@ -37,17 +37,13 @@ class GraphDslTest : public testing::Test { EG_NS::GraphEasyExecutor executor; protected: - void SetUp() { - EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); - } + void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } void TearDown() {} }; TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { - DEF_GRAPH(g1) { - CHAIN(NODE("data1", DATA)->NODE("add", ADD)); - }); + DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; auto geGraph = ToGeGraph(g1); auto computeGraph = ToComputeGraph(g1); @@ -57,9 +53,7 @@ TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { } TEST_F(GraphDslTest, test_build_graph_with_name) { - DEF_GRAPH(g1, "sample_graph") { - CHAIN(NODE("data1", DATA)->NODE("add", ADD)); - }); + DEF_GRAPH(g1, "sample_graph") { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; auto geGraph = ToGeGraph(g1); @@ -72,7 +66,7 @@ TEST_F(GraphDslTest, test_build_from_from_op_desc_ptr) { auto data = std::make_shared("data1", DATA); auto add = std::make_shared("Add", ADD); CHAIN(NODE(data)->NODE(add)); - }); + }; auto geGraph = ToGeGraph(g1); @@ -84,7 +78,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); CHAIN(NODE("data1", datCfg)->NODE("add", addCfg)); - }); + }; auto geGraph = ToGeGraph(g1); @@ -92,9 +86,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { } TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { - DEF_GRAPH(g1) { - CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); - }); + DEF_GRAPH(g1) { CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); }; auto geGraph = ToGeGraph(g1); @@ -102,9 +94,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { } TEST_F(GraphDslTest, test_build_from_control_chain) { - DEF_GRAPH(g1) { - CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); - }); + DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; auto geGraph = ToGeGraph(g1); @@ -112,9 +102,7 @@ TEST_F(GraphDslTest, test_build_from_control_chain) { } TEST_F(GraphDslTest, test_build_from_data_chain) { - DEF_GRAPH(g1) { - DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); - }); + DEF_GRAPH(g1) { DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; auto geGraph = ToGeGraph(g1); @@ -125,7 +113,7 @@ TEST_F(GraphDslTest, test_build_from_data_chain_with_edge) { DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add")); - }); + }; auto geGraph = ToGeGraph(g1); @@ -136,7 +124,7 @@ TEST_F(GraphDslTest, test_build_graph_reused_before_node) { DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add")); - }); + }; auto geGraph = ToGeGraph(g1); @@ -147,7 +135,7 @@ TEST_F(GraphDslTest, test_build_graph_with_constant_folding) { DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data2", DATA)->NODE("add")); - }); + }; auto geGraph = ToGeGraph(g1); @@ -168,7 +156,7 @@ TEST_F(GraphDslTest, test_build_complex_normal_graph_build_suggested) { ->NODE("Add4") ->NODE("Add5") ->NODE("net_output", NETOUTPUT)); - }); + }; auto geGraph = ToGeGraph(g1); @@ -187,7 +175,7 @@ TEST_F(GraphDslTest, test_build_complex_mult_normal_graph_build) { CHAIN(NODE("add2")->NODE("net_output")); CHAIN(NODE("add3")->NODE("net_output")); CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3")); - }); + }; auto geGraph = ToGeGraph(g1); @@ -198,17 +186,17 @@ TEST_F(GraphDslTest, test_build_graph_with_sub_graph) { DEF_GRAPH(sub_1) { CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT)); CHAIN(NODE("const_5", CONSTANTOP)->NODE("less")); - }); + }; DEF_GRAPH(sub_2) { CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT)); CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul")); - }); + }; DEF_GRAPH(g1) { CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT)); CHAIN(NODE("data_i", DATA)->NODE("while")); - }); + }; sub_1.Layout(); sub_2.Layout(); diff --git a/tests/framework/ge_graph_dsl/tests/op_desc_config_test.cc b/tests/framework/ge_graph_dsl/tests/op_desc_config_test.cc new file mode 100644 index 00000000..eee5d7c2 --- /dev/null +++ b/tests/framework/ge_graph_dsl/tests/op_desc_config_test.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "gtest/gtest.h" +#include "framework/common/types.h" +#include "graph/debug/ge_attr_define.h" +#include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" +#include "graph/ge_tensor.h" +#include "graph/utils/attr_utils.h" +GE_NS_BEGIN + +class OpDescCfgTest : public testing::Test {}; + +TEST_F(OpDescCfgTest, test_attr_set_string_success) { + auto op_ptr = OP_CFG(DATA).Attr(ENTER_ATTR_FRAME_NAME, "1").Build("data1"); + + ge::GeAttrValue ret; + op_ptr->GetAttr(ENTER_ATTR_FRAME_NAME, ret); + std::string value; + ret.GetValue(value); + + ASSERT_EQ(value, "1"); +} + +TEST_F(OpDescCfgTest, test_attr_set_int_success) { + auto op_ptr = OP_CFG(DATA).Attr(ENTER_ATTR_FRAME_NAME, 2).Build("data1"); + + ge::GeAttrValue ret; + op_ptr->GetAttr(ENTER_ATTR_FRAME_NAME, ret); + int64_t value; + ret.GetValue(value); + + ASSERT_EQ(value, 2); +} + +TEST_F(OpDescCfgTest, test_attr_set_perent_node_index_success) { + auto op_ptr = OP_CFG(DATA).ParentNodeIndex(2).Build("data1"); + + ge::GeAttrValue ret; + op_ptr->GetAttr(ATTR_NAME_PARENT_NODE_INDEX, ret); + int64_t value; + ret.GetValue(value); + + ASSERT_EQ(value, 2); +} + +TEST_F(OpDescCfgTest, test_attr_set_weight_success) { + int64_t dims_size = 1; + vector data_vec = {5}; + for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); + vector data_value_vec(dims_size, 1); + GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); + GeTensorPtr data_tensor = std::make_shared(data_tensor_desc, (uint8_t *)data_value_vec.data(), + data_value_vec.size() * sizeof(int32_t)); + + auto op_ptr = OP_CFG(CONSTANT).Weight(data_tensor).Build("const1"); + + ConstGeTensorPtr tensor_value; + ASSERT_TRUE(AttrUtils::GetTensor(op_ptr, ge::ATTR_NAME_WEIGHTS, tensor_value)); + ASSERT_EQ(tensor_value->GetTensorDesc().GetDataType(), DT_INT32); +} + +GE_NS_END diff --git a/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc b/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc index b83d68fc..533e8198 100644 --- a/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc +++ b/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc @@ -23,11 +23,18 @@ GE_NS_BEGIN REGISTER_OPTYPE_DEFINE(DATA, "Data"); REGISTER_OPTYPE_DEFINE(HCOMALLGATHER, "HcomAllGather"); REGISTER_OPTYPE_DEFINE(VARIABLE, "Variable"); +REGISTER_OPTYPE_DEFINE(CONSTANT, "Const"); REGISTER_OPTYPE_DEFINE(CONSTANTOP, "Constant"); REGISTER_OPTYPE_DEFINE(LESS, "Less"); REGISTER_OPTYPE_DEFINE(MUL, "Mul"); REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput"); REGISTER_OPTYPE_DEFINE(ADD, "Add"); REGISTER_OPTYPE_DEFINE(WHILE, "While"); +REGISTER_OPTYPE_DEFINE(ENTER, "Enter"); +REGISTER_OPTYPE_DEFINE(MERGE, "Merge"); +REGISTER_OPTYPE_DEFINE(LOOPCOND, "Loopcond"); +REGISTER_OPTYPE_DEFINE(SWITCH, "Switch"); +REGISTER_OPTYPE_DEFINE(EXIT, "Exit"); +REGISTER_OPTYPE_DEFINE(NEXTITERATION, "Nextiteration"); GE_NS_END diff --git a/tests/framework/utils/builder/tensor_builder_utils.h b/tests/framework/ge_graph_dsl/tests/test_main.cc similarity index 73% rename from tests/framework/utils/builder/tensor_builder_utils.h rename to tests/framework/ge_graph_dsl/tests/test_main.cc index 73656e4a..eb6112f2 100644 --- a/tests/framework/utils/builder/tensor_builder_utils.h +++ b/tests/framework/ge_graph_dsl/tests/test_main.cc @@ -1,22 +1,25 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H -#define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H - -class tensor_builder_utils {}; - -#endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "ge_graph_dsl/assert/check_utils.h" + +int main(int argc, char **argv) { + ::GE_NS::CheckUtils::init(); + testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/tests/framework/ge_running_env/CMakeLists.txt b/tests/framework/ge_running_env/CMakeLists.txt new file mode 100644 index 00000000..deac4e03 --- /dev/null +++ b/tests/framework/ge_running_env/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +add_subdirectory(include) +add_subdirectory(src) +add_subdirectory(tests) \ No newline at end of file diff --git a/tests/framework/ge_running_env/include/CMakeLists.txt b/tests/framework/ge_running_env/include/CMakeLists.txt new file mode 100644 index 00000000..b71b0578 --- /dev/null +++ b/tests/framework/ge_running_env/include/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +add_library(ge_running_env_inc INTERFACE) +target_include_directories(ge_running_env_inc INTERFACE ./) diff --git a/tests/framework/ge_running_env/include/ge_running_env/env_installer.h b/tests/framework/ge_running_env/include/ge_running_env/env_installer.h new file mode 100644 index 00000000..79b65137 --- /dev/null +++ b/tests/framework/ge_running_env/include/ge_running_env/env_installer.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef H1D9F4FDE_BB21_4DE4_AC7E_751920B45039 +#define H1D9F4FDE_BB21_4DE4_AC7E_751920B45039 + +#include "fake_ns.h" +#include "opskernel_manager/ops_kernel_manager.h" +#include "register/ops_kernel_builder_registry.h" + +FAKE_NS_BEGIN + +struct EnvInstaller { + virtual void InstallTo(std::map&) const {} + virtual void InstallTo(std::map&) const {} + virtual void InstallTo(std::map&) const {} + virtual void Install() const {} +}; + +FAKE_NS_END + +#endif diff --git a/tests/framework/ge_running_env/include/ge_running_env/fake_engine.h b/tests/framework/ge_running_env/include/ge_running_env/fake_engine.h new file mode 100644 index 00000000..c4207223 --- /dev/null +++ b/tests/framework/ge_running_env/include/ge_running_env/fake_engine.h @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef HAF5E9BF2_752F_4E03_B0A5_E1B912A5FA24 +#define HAF5E9BF2_752F_4E03_B0A5_E1B912A5FA24 + +#include +#include "fake_ns.h" +#include "ge_running_env/env_installer.h" +#include "common/opskernel/ops_kernel_info_types.h" +#include "opskernel_manager/ops_kernel_manager.h" +#include "register/ops_kernel_builder_registry.h" +#include "fake_ops_kernel_builder.h" +#include "fake_ops_kernel_info_store.h" + +FAKE_NS_BEGIN + +using FakeOpsKernelBuilderPtr = std::shared_ptr; +using FakeOpsKernelInfoStorePtr = std::shared_ptr; + +struct FakeEngine : EnvInstaller { + FakeEngine(const std::string& engine_name); + FakeEngine& KernelBuilder(FakeOpsKernelBuilderPtr); + FakeEngine& KernelInfoStore(FakeOpsKernelInfoStorePtr); + FakeEngine& KernelInfoStore(const std::string&); + + private: + void InstallTo(std::map&) const override; + void InstallTo(std::map&) const override; + + private: + template + void InstallFor(std::map& maps, const std::map>&) const; + + private: + std::string engine_name_; + std::set info_store_names_; + std::map custom_builders_; + std::map custom_info_stores_; +}; + +FAKE_NS_END + +#endif diff --git a/tests/framework/ge_running_env/include/ge_running_env/fake_ns.h b/tests/framework/ge_running_env/include/ge_running_env/fake_ns.h new file mode 100644 index 00000000..c802e109 --- /dev/null +++ b/tests/framework/ge_running_env/include/ge_running_env/fake_ns.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef H7AEFF0EA_9FDE_487F_8562_2917A2D48EA2 +#define H7AEFF0EA_9FDE_487F_8562_2917A2D48EA2 + +#define FAKE_NS ge +#define FAKE_NS_BEGIN namespace FAKE_NS { +#define FAKE_NS_END } +#define USING_STUB_NS using namespace FAKE_NS; +#define FWD_DECL_STUB(type) \ + namespace FAKE_NS { \ + struct type; \ + } + +#endif diff --git a/tests/framework/ge_running_env/include/ge_running_env/fake_op.h b/tests/framework/ge_running_env/include/ge_running_env/fake_op.h new file mode 100644 index 00000000..cc442cdb --- /dev/null +++ b/tests/framework/ge_running_env/include/ge_running_env/fake_op.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef H737AD661_27C0_400F_8B08_29701308C5D0 +#define H737AD661_27C0_400F_8B08_29701308C5D0 + +#include +#include +#include "fake_ns.h" +#include "ge_running_env/env_installer.h" +#include "graph/operator_factory.h" + +FAKE_NS_BEGIN + +struct FakeOp : EnvInstaller { + FakeOp(const std::string& op_type); + + FakeOp& Inputs(const std::vector&); + FakeOp& Outputs(const std::vector&); + FakeOp& InferShape(InferShapeFunc); + FakeOp& InfoStoreAndBuilder(const std::string&); + + private: + void Install() const override; + void InstallTo(std::map&) const override; + + private: + const std::string op_type_; + std::vector inputs_; + std::vector outputs_; + InferShapeFunc info_fun_; + std::set info_store_names_; +}; + +FAKE_NS_END + +#endif /* H737AD661_27C0_400F_8B08_29701308C5D0 */ diff --git a/tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_builder.h b/tests/framework/ge_running_env/include/ge_running_env/fake_ops_kernel_builder.h similarity index 59% rename from tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_builder.h rename to tests/framework/ge_running_env/include/ge_running_env/fake_ops_kernel_builder.h index 62dab542..acfe5e41 100644 --- a/tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_builder.h +++ b/tests/framework/ge_running_env/include/ge_running_env/fake_ops_kernel_builder.h @@ -13,39 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef H39E4E719_91F4_4D0F_BA4F_6BA56CB1E20D +#define H39E4E719_91F4_4D0F_BA4F_6BA56CB1E20D -#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ -#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ +#include "fake_ns.h" +#include "common/opskernel/ops_kernel_builder.h" +#include "info_store_holder.h" -#if defined(_MSC_VER) -#ifdef FUNC_VISIBILITY -#define GE_FUNC_VISIBILITY _declspec(dllexport) -#else -#define GE_FUNC_VISIBILITY -#endif -#else -#ifdef FUNC_VISIBILITY -#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_VISIBILITY -#endif -#endif +FAKE_NS_BEGIN -#include "common/opskernel/ops_kernel_builder.h" +struct FakeOpsKernelBuilder : OpsKernelBuilder, InfoStoreHolder { + FakeOpsKernelBuilder(const std::string &kernel_lib_name); + FakeOpsKernelBuilder(); -namespace ge { -namespace st { -class GE_FUNC_VISIBILITY StubOpsKernelBuilder : public OpsKernelBuilder { - public: + private: Status Initialize(const map &options) override; - Status Finalize() override; - Status CalcOpRunningParam(Node &node) override; - Status GenerateTask(const Node &node, RunContext &context, std::vector &tasks) override; }; -} // namespace st -} // namespace ge -#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ +FAKE_NS_END + +#endif diff --git a/tests/framework/ge_running_env/include/ge_running_env/fake_ops_kernel_info_store.h b/tests/framework/ge_running_env/include/ge_running_env/fake_ops_kernel_info_store.h new file mode 100644 index 00000000..4a8ab9dc --- /dev/null +++ b/tests/framework/ge_running_env/include/ge_running_env/fake_ops_kernel_info_store.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef H1EBABA85_7056_48F0_B496_E4DB68E5FED3 +#define H1EBABA85_7056_48F0_B496_E4DB68E5FED3 + +#include "fake_ns.h" +#include "common/opskernel/ops_kernel_info_store.h" +#include "ge/ge_api_types.h" +#include "info_store_holder.h" + +FAKE_NS_BEGIN + +struct FakeOpsKernelInfoStore : OpsKernelInfoStore, InfoStoreHolder { + FakeOpsKernelInfoStore(const std::string &kernel_lib_name); + FakeOpsKernelInfoStore(); + + private: + Status Initialize(const std::map &options) override; + Status Finalize() override; + bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override; + void GetAllOpsKernelInfo(std::map &infos) const override; +}; + +FAKE_NS_END + +#endif diff --git a/tests/framework/ge_running_env/include/ge_running_env/ge_running_env_faker.h b/tests/framework/ge_running_env/include/ge_running_env/ge_running_env_faker.h new file mode 100644 index 00000000..6d325c6a --- /dev/null +++ b/tests/framework/ge_running_env/include/ge_running_env/ge_running_env_faker.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef H99C11FC4_700E_4D4D_B073_7808FA88BEBC +#define H99C11FC4_700E_4D4D_B073_7808FA88BEBC + +#include "ge_running_env/fake_engine.h" +#include "fake_ns.h" +#include "opskernel_manager/ops_kernel_manager.h" +#include "register/ops_kernel_builder_registry.h" + +FAKE_NS_BEGIN + +struct GeRunningEnvFaker { + GeRunningEnvFaker(); + GeRunningEnvFaker &Reset(); + GeRunningEnvFaker &Install(const EnvInstaller &); + GeRunningEnvFaker &InstallDefault(); + static void BackupEnv(); + + private: + void flush(); + + private: + std::map> &op_kernel_info_; + std::map &ops_kernel_info_stores_; + std::map &ops_kernel_optimizers_; + std::map &ops_kernel_builders_; +}; + +FAKE_NS_END + +#endif /* H99C11FC4_700E_4D4D_B073_7808FA88BEBC */ diff --git a/tests/framework/ge_running_env/include/ge_running_env/info_store_holder.h b/tests/framework/ge_running_env/include/ge_running_env/info_store_holder.h new file mode 100644 index 00000000..85b6c75f --- /dev/null +++ b/tests/framework/ge_running_env/include/ge_running_env/info_store_holder.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef H7992249B_058D_40A1_94EA_52BBCB76434E +#define H7992249B_058D_40A1_94EA_52BBCB76434E + +#include "fake_ns.h" +#include "common/opskernel/ops_kernel_info_types.h" + +FAKE_NS_BEGIN + +struct InfoStoreHolder { + InfoStoreHolder(); + InfoStoreHolder(const std::string&); + void EngineName(std::string engine_name); + void RegistOp(std::string op_type); + std::string GetLibName(); + + protected: + std::map op_info_map_; + std::string kernel_lib_name_; + std::string engine_name_; +}; + +FAKE_NS_END + +#endif diff --git a/tests/framework/ge_running_env/src/CMakeLists.txt b/tests/framework/ge_running_env/src/CMakeLists.txt new file mode 100644 index 00000000..ae068bd3 --- /dev/null +++ b/tests/framework/ge_running_env/src/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") + +# ---- Target : stub Host engine ---- +add_library(ge_with_env STATIC ${SOURCES}) + +target_include_directories(ge_with_env + PUBLIC + include + ) + +target_include_directories(ge_with_env + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ) + +target_compile_definitions(ge_with_env PRIVATE + google=ascend_private + FMK_SUPPORT_DUMP + ) + +target_compile_options(ge_with_env PRIVATE + -g --coverage -fprofile-arcs -ftest-coverage + -Werror=format + ) + +target_link_libraries(ge_with_env PUBLIC + $ ge_running_env_inc graphengine -lrt -ldl -lpthread -lgcov + ) + +set_target_properties(ge_with_env PROPERTIES CXX_STANDARD 17) diff --git a/tests/framework/ge_running_env/src/engine/fake_engine.cc b/tests/framework/ge_running_env/src/engine/fake_engine.cc new file mode 100644 index 00000000..4b8fedbc --- /dev/null +++ b/tests/framework/ge_running_env/src/engine/fake_engine.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_running_env/fake_engine.h" +#include "ge_running_env/fake_ops_kernel_builder.h" +#include "ge_running_env/fake_ops_kernel_info_store.h" +#include "opskernel_manager/ops_kernel_manager.h" + +FAKE_NS_BEGIN + +FakeEngine::FakeEngine(const std::string &engine_name) : engine_name_(engine_name) {} + +FakeEngine &FakeEngine::KernelInfoStore(const std::string &info_store) { + info_store_names_.insert(info_store); + return *this; +} + +FakeEngine &FakeEngine::KernelInfoStore(FakeOpsKernelInfoStorePtr ptr) { + info_store_names_.insert(ptr->GetLibName()); + custom_info_stores_.insert(std::make_pair(ptr->GetLibName(), ptr)); + return *this; +} + +FakeEngine &FakeEngine::KernelBuilder(FakeOpsKernelBuilderPtr builder) { + info_store_names_.insert(builder->GetLibName()); + custom_builders_.insert(std::make_pair(builder->GetLibName(), builder)); + return *this; +} + +namespace { +template +void InstallDefault(std::map &maps, const std::string &info_store_name, + const std::string &engine_name) { + auto parent_obj = std::make_shared(info_store_name); + if (parent_obj == nullptr) { + return; + } + parent_obj->EngineName(engine_name); + maps.insert(std::make_pair(parent_obj->GetLibName(), parent_obj)); +} +} // namespace + +template +void FakeEngine::InstallFor(std::map &maps, + const std::map> &child_maps) const { + if (info_store_names_.empty()) { + InstallDefault(maps, engine_name_, engine_name_); + } else { + for (auto &info_store_name : info_store_names_) { + auto iter = child_maps.find(info_store_name); + if (iter == child_maps.end()) { + InstallDefault(maps, info_store_name, engine_name_); + } else { + maps.insert(std::make_pair(iter->second->GetLibName(), iter->second)); + } + } + } +} + +void FakeEngine::InstallTo(std::map &ops_kernel_info_stores) const { + InstallFor(ops_kernel_info_stores, custom_info_stores_); +} + +void FakeEngine::InstallTo(std::map &ops_kernel_builders) const { + InstallFor(ops_kernel_builders, custom_builders_); +} + +FAKE_NS_END diff --git a/tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_builder.cc b/tests/framework/ge_running_env/src/engine/fake_ops_kernel_builder.cc similarity index 73% rename from tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_builder.cc rename to tests/framework/ge_running_env/src/engine/fake_ops_kernel_builder.cc index 2de8691f..77472249 100644 --- a/tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_builder.cc +++ b/tests/framework/ge_running_env/src/engine/fake_ops_kernel_builder.cc @@ -14,40 +14,25 @@ * limitations under the License. */ -#include "stub_ops_kernel_builder.h" -#include +#include "ge_running_env/fake_ops_kernel_builder.h" +#include "graph/utils/node_utils.h" #include "common/ge_inner_error_codes.h" #include "ge/ge_api_types.h" #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -#include #include "framework/common/debug/ge_log.h" -#include "host_cpu_engine/common/constant/constant.h" -#include "register/ops_kernel_builder_registry.h" -#include "inc/st_types.h" +FAKE_NS_BEGIN -namespace ge { -namespace st { -REGISTER_OPS_KERNEL_BUILDER(kAicoreLibName, StubOpsKernelBuilder); -REGISTER_OPS_KERNEL_BUILDER(kVectorLibName, StubOpsKernelBuilder); -REGISTER_OPS_KERNEL_BUILDER(kAicpuLibName, StubOpsKernelBuilder); -REGISTER_OPS_KERNEL_BUILDER(kAicpuAscendLibName, StubOpsKernelBuilder); -REGISTER_OPS_KERNEL_BUILDER(kHcclLibName, StubOpsKernelBuilder); -REGISTER_OPS_KERNEL_BUILDER(kRTSLibName, StubOpsKernelBuilder); +FakeOpsKernelBuilder::FakeOpsKernelBuilder(const std::string &info_store_name) : InfoStoreHolder(info_store_name) {} +FakeOpsKernelBuilder::FakeOpsKernelBuilder() : InfoStoreHolder() {} -Status StubOpsKernelBuilder::Finalize() { - return SUCCESS; -} -Status StubOpsKernelBuilder::Initialize(const map &options) { - return SUCCESS; -} +Status FakeOpsKernelBuilder::Finalize() { return SUCCESS; } +Status FakeOpsKernelBuilder::Initialize(const map &options) { return SUCCESS; } -Status StubOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { +Status FakeOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { OpDescPtr op_desc = ge_node.GetOpDesc(); if (op_desc == nullptr) { - GELOGE(FAILED, "[Get][OpDesc]CalcOpRunningParam failed, as op desc is null"); - REPORT_INNER_ERROR("E19999", "GetOpDesc failed."); return FAILED; } @@ -86,9 +71,9 @@ Status StubOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); REPORT_CALL_ERROR( - "E19999", "CalcTensorMemSize failed for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", - name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), - TypeUtils::DataTypeToSerialString(data_type).c_str()); + "E19999", "CalcTensorMemSize failed for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", + name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); return FAILED; } GELOGI("Calc op[%s:%s] out[%zu] mem size is %ld, format=%s, data_type=%s.", name.c_str(), type.c_str(), i, @@ -111,9 +96,9 @@ Status StubOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { return SUCCESS; } -Status StubOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, vector &tasks) { +Status FakeOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, vector &tasks) { // no need to generate device task return SUCCESS; } -} // namespace st -} // namespace ge \ No newline at end of file + +FAKE_NS_END diff --git a/tests/framework/ge_running_env/src/engine/fake_ops_kernel_info_store.cc b/tests/framework/ge_running_env/src/engine/fake_ops_kernel_info_store.cc new file mode 100644 index 00000000..348e1334 --- /dev/null +++ b/tests/framework/ge_running_env/src/engine/fake_ops_kernel_info_store.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/ge/ge_api_error_codes.h" +#include "ge_running_env/fake_ops_kernel_info_store.h" + +FAKE_NS_BEGIN + +FakeOpsKernelInfoStore::FakeOpsKernelInfoStore(const std::string &info_store_name) : InfoStoreHolder(info_store_name) {} + +FakeOpsKernelInfoStore::FakeOpsKernelInfoStore() : InfoStoreHolder() {} + +Status FakeOpsKernelInfoStore::Finalize() { + op_info_map_.clear(); + return SUCCESS; +} + +Status FakeOpsKernelInfoStore::Initialize(const std::map &options) { return SUCCESS; } + +void FakeOpsKernelInfoStore::GetAllOpsKernelInfo(map &infos) const { infos = op_info_map_; } + +bool FakeOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const { + if (op_desc == nullptr) { + return false; + } + return op_info_map_.count(op_desc->GetType()) > 0; +} + +FAKE_NS_END diff --git a/tests/framework/ge_running_env/src/engine/info_store_holder.cc b/tests/framework/ge_running_env/src/engine/info_store_holder.cc new file mode 100644 index 00000000..231af93a --- /dev/null +++ b/tests/framework/ge_running_env/src/engine/info_store_holder.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_running_env/info_store_holder.h" +FAKE_NS_BEGIN + +namespace { +std::string GenStoreName() { + static int store_id = 0; + return "store_" + std::to_string(store_id++); +} +} // namespace + +InfoStoreHolder::InfoStoreHolder(const std::string& kernel_lib_name) : kernel_lib_name_(kernel_lib_name) {} + +InfoStoreHolder::InfoStoreHolder() : kernel_lib_name_(GenStoreName()) {} + +void InfoStoreHolder::RegistOp(std::string op_type) { + OpInfo default_op_info = {.engine = engine_name_, + .opKernelLib = kernel_lib_name_, + .computeCost = 0, + .flagPartial = false, + .flagAsync = false, + .isAtomic = false}; + + auto iter = op_info_map_.find(op_type); + if (iter == op_info_map_.end()) { + op_info_map_.emplace(op_type, default_op_info); + } +} + +void InfoStoreHolder::EngineName(std::string engine_name) { engine_name_ = engine_name; } + +std::string InfoStoreHolder::GetLibName() { return kernel_lib_name_; } + +FAKE_NS_END diff --git a/tests/framework/ge_running_env/src/env/ge_default_running_env.cc b/tests/framework/ge_running_env/src/env/ge_default_running_env.cc new file mode 100644 index 00000000..ab705f55 --- /dev/null +++ b/tests/framework/ge_running_env/src/env/ge_default_running_env.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_default_running_env.h" +#include "ge_running_env/ge_running_env_faker.h" +#include "ge_running_env/fake_op.h" + +FAKE_NS_BEGIN +namespace { +std::vector default_engines = {FakeEngine("AIcoreEngine").KernelInfoStore("AiCoreLib"), + FakeEngine("VectorEngine").KernelInfoStore("VectorLib"), + FakeEngine("DNN_VM_AICPU").KernelInfoStore("AicpuLib"), + FakeEngine("DNN_VM_AICPU_ASCEND").KernelInfoStore("AicpuAscendLib"), + FakeEngine("DNN_HCCL").KernelInfoStore("HcclLib"), + FakeEngine("DNN_VM_RTS").KernelInfoStore("RTSLib")}; + +std::vector fake_ops = { + FakeOp(ENTER).InfoStoreAndBuilder("RTSLib"), FakeOp(MERGE).InfoStoreAndBuilder("RTSLib"), + FakeOp(SWITCH).InfoStoreAndBuilder("RTSLib"), FakeOp(LOOPCOND).InfoStoreAndBuilder("RTSLib"), + FakeOp(STREAMMERGE).InfoStoreAndBuilder("RTSLib"), FakeOp(STREAMSWITCH).InfoStoreAndBuilder("RTSLib"), + FakeOp(STREAMACTIVE).InfoStoreAndBuilder("RTSLib"), FakeOp(EXIT).InfoStoreAndBuilder("RTSLib"), + + FakeOp(LESS).InfoStoreAndBuilder("AiCoreLib"), FakeOp(NEXTITERATION).InfoStoreAndBuilder("AiCoreLib"), + FakeOp(CAST).InfoStoreAndBuilder("AiCoreLib"), FakeOp(TRANSDATA).InfoStoreAndBuilder("AiCoreLib"), + FakeOp(NOOP).InfoStoreAndBuilder("AiCoreLib"), FakeOp(VARIABLE).InfoStoreAndBuilder("AiCoreLib"), + FakeOp(CONSTANT).InfoStoreAndBuilder("AiCoreLib"), FakeOp(ASSIGN).InfoStoreAndBuilder("AiCoreLib"), + FakeOp(ADD).InfoStoreAndBuilder("AiCoreLib"), FakeOp(MUL).InfoStoreAndBuilder("AiCoreLib"), + FakeOp(DATA).InfoStoreAndBuilder("AiCoreLib"), FakeOp(NETOUTPUT).InfoStoreAndBuilder("AiCoreLib"), + +}; +} // namespace + +void GeDefaultRunningEnv::InstallTo(GeRunningEnvFaker& ge_env) { + for (auto& fake_engine : default_engines) { + ge_env.Install(fake_engine); + } + + for (auto& fake_op : fake_ops) { + ge_env.Install(fake_op); + } +} + +FAKE_NS_END \ No newline at end of file diff --git a/tests/framework/utils/builder/tensor_builder_utils.cc b/tests/framework/ge_running_env/src/env/ge_default_running_env.h similarity index 69% rename from tests/framework/utils/builder/tensor_builder_utils.cc rename to tests/framework/ge_running_env/src/env/ge_default_running_env.h index f99b9107..b93c528a 100644 --- a/tests/framework/utils/builder/tensor_builder_utils.cc +++ b/tests/framework/ge_running_env/src/env/ge_default_running_env.h @@ -1,17 +1,32 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensor_builder_utils.h" +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_5D044B8760CB41ABA108AE2E37E8EBDE +#define INC_5D044B8760CB41ABA108AE2E37E8EBDE + +#include "ge_running_env/fake_ns.h" + +FAKE_NS_BEGIN + +struct GeRunningEnvFaker; + +struct GeDefaultRunningEnv { + static void InstallTo(GeRunningEnvFaker&); +}; + +FAKE_NS_END + +#endif \ No newline at end of file diff --git a/tests/framework/ge_running_env/src/env/ge_running_env_faker.cc b/tests/framework/ge_running_env/src/env/ge_running_env_faker.cc new file mode 100644 index 00000000..2977f6b2 --- /dev/null +++ b/tests/framework/ge_running_env/src/env/ge_running_env_faker.cc @@ -0,0 +1,109 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "external/ge/ge_api.h" +#include "opskernel_manager/ops_kernel_builder_manager.h" +#include "init/gelib.h" +#include "utility" +#include "ge_running_env/ge_running_env_faker.h" +#include "ge_default_running_env.h" +#include "ge_running_env/env_installer.h" +#include "op/fake_op_repo.h" + +FAKE_NS_BEGIN + +namespace { +OpsKernelManager& getKernelManger() { + std::shared_ptr instancePtr = ge::GELib::GetInstance(); + return instancePtr->OpsKernelManagerObj(); +} + +struct InitEnv { + static InitEnv& GetInstance() { + static InitEnv instance; + return instance; + } + + void reset(std::map& ops_kernel_info_stores, + std::map& builders) { + std::set remove_info_names; + for (auto iter : ops_kernel_info_stores) { + if (kernel_info_names.find(iter.first) == kernel_info_names.end()) { + remove_info_names.insert(iter.first); + } + } + for (auto info_name : remove_info_names) { + ops_kernel_info_stores.erase(info_name); + builders.erase(info_name); + } + } + + private: + InitEnv() { + for (auto iter : getKernelManger().GetAllOpsKernelInfoStores()) { + kernel_info_names.insert(iter.first); + } + } + + private: + std::set kernel_info_names; +}; +} // namespace + +GeRunningEnvFaker::GeRunningEnvFaker() + : op_kernel_info_(const_cast>&>(getKernelManger().GetAllOpsKernelInfo())), + ops_kernel_info_stores_( + const_cast&>(getKernelManger().GetAllOpsKernelInfoStores())), + ops_kernel_optimizers_( + const_cast&>(getKernelManger().GetAllGraphOptimizerObjs())), + ops_kernel_builders_(const_cast&>( + OpsKernelBuilderManager::Instance().GetAllOpsKernelBuilders())) { + Reset(); +} + +GeRunningEnvFaker& GeRunningEnvFaker::Reset() { + InitEnv& init_env = InitEnv::GetInstance(); + FakeOpRepo::Reset(); + init_env.reset(ops_kernel_info_stores_, ops_kernel_builders_); + flush(); + return *this; +} + +void GeRunningEnvFaker::BackupEnv() { InitEnv::GetInstance(); } + +GeRunningEnvFaker& GeRunningEnvFaker::Install(const EnvInstaller& installer) { + installer.Install(); + installer.InstallTo(ops_kernel_info_stores_); + installer.InstallTo(ops_kernel_optimizers_); + installer.InstallTo(ops_kernel_builders_); + flush(); + return *this; +} + +void GeRunningEnvFaker::flush() { + op_kernel_info_.clear(); + getKernelManger().GetOpsKernelInfo(""); +} + +GeRunningEnvFaker& GeRunningEnvFaker::InstallDefault() { + Reset(); + GeDefaultRunningEnv::InstallTo(*this); + return *this; +} + +FAKE_NS_END diff --git a/tests/framework/ge_running_env/src/op/fake_op.cc b/tests/framework/ge_running_env/src/op/fake_op.cc new file mode 100644 index 00000000..52bbee8d --- /dev/null +++ b/tests/framework/ge_running_env/src/op/fake_op.cc @@ -0,0 +1,95 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_running_env/fake_op.h" +#include "fake_op_repo.h" +#include "ge_running_env/info_store_holder.h" +#include "graph/operator_factory.h" + +FAKE_NS_BEGIN + +FakeOp::FakeOp(const std::string& op_type) : op_type_(op_type) {} + +FakeOp& FakeOp::Inputs(const std::vector& inputs) { + inputs_ = inputs; + return *this; +} + +FakeOp& FakeOp::Outputs(const std::vector& outputs) { + outputs_ = outputs; + return *this; +} + +FakeOp& FakeOp::InferShape(InferShapeFunc infer_fun) { + info_fun_ = infer_fun; + return *this; +} + +FakeOp& FakeOp::InfoStoreAndBuilder(const std::string& name) { + info_store_names_.insert(name); + return *this; +} + +namespace { + +void RegistOpToInfoStore(OpsKernelInfoStorePtr& info_store, const std::string& op_type) { + if (info_store == nullptr) { + return; + } + auto holder = dynamic_cast(info_store.get()); + holder->RegistOp(op_type); +} + +struct FakeOperator : Operator { + FakeOperator(const std::string& op_type) : Operator(op_type) {} + + FakeOperator& RegistInputs(const std::vector& inputs) { + for (auto& input : inputs) { + Operator::InputRegister(input); + } + return *this; + } + + FakeOperator& RegistOutputs(const std::vector& outputs) { + for (auto& output : outputs) { + Operator::OutputRegister(output); + } + return *this; + } +}; +} // namespace + +void FakeOp::InstallTo(std::map& info_stores) const { + std::for_each(info_store_names_.begin(), info_store_names_.end(), [=, &info_stores](auto& info_store_name) { + auto iter = info_stores.find(info_store_name); + if (iter != info_stores.end()) { + RegistOpToInfoStore(iter->second, op_type_); + } + }); +} + +void FakeOp::Install() const { + FakeOpRepo::Regist( + op_type_, + [op_type = this->op_type_, inputs = this->inputs_, outputs = this->outputs_](const std::string&) -> Operator { + return FakeOperator(op_type).RegistInputs(inputs).RegistOutputs(outputs); + }); + if (info_fun_) { + FakeOpRepo::Regist(op_type_, info_fun_); + } +} + +FAKE_NS_END diff --git a/tests/framework/ge_running_env/src/op/fake_op_repo.cc b/tests/framework/ge_running_env/src/op/fake_op_repo.cc new file mode 100644 index 00000000..7d571b8b --- /dev/null +++ b/tests/framework/ge_running_env/src/op/fake_op_repo.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/operator_factory_impl.h" +#include "ge_running_env/fake_op.h" +#include "fake_op_repo.h" + +FAKE_NS_BEGIN + +void FakeOpRepo::Reset() { + if (OperatorFactoryImpl::operator_creators_) { + OperatorFactoryImpl::operator_creators_->clear(); + } + if (OperatorFactoryImpl::operator_infershape_funcs_) { + OperatorFactoryImpl::operator_infershape_funcs_->clear(); + } +} + +void FakeOpRepo::Regist(const std::string &operator_type, const OpCreator creator) { + OperatorFactoryImpl::RegisterOperatorCreator(operator_type, creator); +} +void FakeOpRepo::Regist(const std::string &operator_type, const InferShapeFunc infer_fun) { + OperatorFactoryImpl::RegisterInferShapeFunc(operator_type, infer_fun); +} + +FAKE_NS_END \ No newline at end of file diff --git a/tests/framework/ge_running_env/src/op/fake_op_repo.h b/tests/framework/ge_running_env/src/op/fake_op_repo.h new file mode 100644 index 00000000..345515e4 --- /dev/null +++ b/tests/framework/ge_running_env/src/op/fake_op_repo.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DBF6CE7CD4AC4A83BA4ED4B372FC66E4 +#define DBF6CE7CD4AC4A83BA4ED4B372FC66E4 + +#include "ge_running_env/fake_ns.h" +#include "graph/operator_factory.h" + +FAKE_NS_BEGIN + +struct FakeOpRepo { + static void Reset(); + static void Regist(const std::string &operator_type, const OpCreator); + static void Regist(const std::string &operator_type, const InferShapeFunc); +}; + +FAKE_NS_END +#endif \ No newline at end of file diff --git a/tests/framework/ge_running_env/tests/CMakeLists.txt b/tests/framework/ge_running_env/tests/CMakeLists.txt new file mode 100644 index 00000000..67a9bd70 --- /dev/null +++ b/tests/framework/ge_running_env/tests/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP") + +add_executable(ge_running_env_test ${SOURCES}) + +target_include_directories(ge_running_env_test + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_compile_options(ge_running_env_test PRIVATE + -g +) +set_target_properties(ge_running_env_test PROPERTIES CXX_STANDARD 17) + +target_link_libraries(ge_running_env_test PUBLIC gtest ge_with_env) + +include(CTest) +enable_testing() +add_test(NAME test COMMAND ge_running_env_test) \ No newline at end of file diff --git a/tests/framework/ge_running_env/tests/test_ge_running_env_faker.cc b/tests/framework/ge_running_env/tests/test_ge_running_env_faker.cc new file mode 100644 index 00000000..4429f4a7 --- /dev/null +++ b/tests/framework/ge_running_env/tests/test_ge_running_env_faker.cc @@ -0,0 +1,148 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "graph/operator_factory_impl.h" +#include "init/gelib.h" +#include "external/ge/ge_api.h" +#include "opskernel_manager/ops_kernel_builder_manager.h" +#include "ge_running_env/fake_ops_kernel_builder.h" +#include "ge_running_env/fake_ns.h" +#include "ge_running_env/ge_running_env_faker.h" +#include "ge_running_env/fake_op.h" +FAKE_NS_BEGIN + +#define ASSERT_OPS_LIST_SIZE(list_size) \ + std::vector ops_list; \ + OperatorFactory::GetOpsTypeList(ops_list);\ + ASSERT_EQ(ops_list.size(), list_size); + +class GeRunningEvnFakerTest : public testing::Test { + protected: + void SetUp() {} + OpsKernelManager &kernel_manager = ge::GELib::GetInstance()->OpsKernelManagerObj(); + OpsKernelBuilderManager &builder_manager = OpsKernelBuilderManager::Instance(); +}; + +TEST_F(GeRunningEvnFakerTest, test_reset_running_env_is_success) { + GeRunningEnvFaker ge_env; + ge_env.Reset(); + ASSERT_OPS_LIST_SIZE(0); + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 1); + ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 1); + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 52); + ASSERT_EQ(kernel_manager.GetOpsKernelInfo(SWITCH).size(), 1); +} + +TEST_F(GeRunningEvnFakerTest, test_install_fake_op_success) { + GeRunningEnvFaker ge_env; + ge_env.Install(FakeOp(DATA)).Install(FakeOp(SWITCH)); + ASSERT_OPS_LIST_SIZE(2); + ASSERT_TRUE(OperatorFactory::IsExistOp(DATA)); + ASSERT_TRUE(OperatorFactory::IsExistOp(SWITCH)); +} + +TEST_F(GeRunningEvnFakerTest, test_install_fake_op_with_inputs_and_outputs_success) { + GeRunningEnvFaker ge_env; + ge_env.Install(FakeOp(ADD).Inputs({"x1", "x2"}).Outputs({"y"})); + + auto add1 = OperatorFactory::CreateOperator("add1", ADD); + + ASSERT_EQ(add1.GetInputsSize(), 2); + ASSERT_EQ(add1.GetOutputsSize(), 1); + ASSERT_OPS_LIST_SIZE(1); +} + +TEST_F(GeRunningEvnFakerTest, test_install_fake_op_with_infer_shape_success) { + GeRunningEnvFaker ge_env; + auto infer_fun = [](Operator &op) -> graphStatus { + TensorDesc input_desc = op.GetInputDescByName("data"); + return GRAPH_SUCCESS; + }; + ASSERT_TRUE(OperatorFactoryImpl::GetInferShapeFunc(DATA) == nullptr); + + ge_env.Install(FakeOp(DATA).Inputs({"data"}).InferShape(infer_fun)); + + ASSERT_TRUE(OperatorFactoryImpl::GetInferShapeFunc(DATA) != nullptr); +} + +TEST_F(GeRunningEvnFakerTest, test_install_engine_with_default_info_store) { + GeRunningEnvFaker ge_env; + ge_env.Install(FakeEngine("DNN_HCCL")); + + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 2); + ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 2); + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 52); + ASSERT_EQ(kernel_manager.GetOpsKernelInfo(SWITCH).size(), 1); +} + +TEST_F(GeRunningEvnFakerTest, test_install_engine_with_info_store_name) { + GeRunningEnvFaker ge_env; + ge_env.Install(FakeEngine("DNN_HCCL").KernelInfoStore("AiCoreLib2")) + .Install(FakeOp(SWITCH).InfoStoreAndBuilder("AiCoreLib2")); + + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 2); + ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 2); + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 52); + ASSERT_EQ(kernel_manager.GetOpsKernelInfo(SWITCH).size(), 2); +} + +TEST_F(GeRunningEvnFakerTest, test_install_custom_kernel_builder_success) { + struct FakeKernelBuilder : FakeOpsKernelBuilder { + Status CalcOpRunningParam(Node &node) override { + OpDescPtr op_desc = node.GetOpDesc(); + if (op_desc == nullptr) { + return FAILED; + } + return SUCCESS; + } + }; + + GeRunningEnvFaker ge_env; + auto ai_core_kernel = FakeEngine("DNN_HCCL").KernelBuilder(std::make_shared()); + ge_env.Reset().Install(ai_core_kernel); + + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 2); + ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 2); + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 52); +} + +TEST_F(GeRunningEvnFakerTest, test_install_custom_kernel_info_store_success) { + struct FakeKernelBuilder : FakeOpsKernelInfoStore { + FakeKernelBuilder(const std::string &kernel_lib_name) : FakeOpsKernelInfoStore(kernel_lib_name) {} + + bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override { return FAILED; } + }; + + GeRunningEnvFaker ge_env; + auto ai_core_kernel = FakeEngine("DNN_HCCL").KernelInfoStore(std::make_shared("AiCoreLib2")); + ge_env.Reset().Install(ai_core_kernel); + + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 2); + ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 2); + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 52); +} + +TEST_F(GeRunningEvnFakerTest, test_install_default_fake_engine_success) { + GeRunningEnvFaker ge_env; + ge_env.InstallDefault(); + + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 7); + ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 7); + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 66); +} + +FAKE_NS_END diff --git a/tests/framework/ge_running_env/tests/test_main.cc b/tests/framework/ge_running_env/tests/test_main.cc new file mode 100644 index 00000000..ede79c75 --- /dev/null +++ b/tests/framework/ge_running_env/tests/test_main.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "common/debug/log.h" +#include "external/ge/ge_api.h" +#include "ge_running_env/ge_running_env_faker.h" + +using namespace std; +using namespace ge; + +int main(int argc, char **argv) { + map options; + ge::GEInitialize(options); + GeRunningEnvFaker::BackupEnv(); + testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + + return ret; +} diff --git a/tests/framework/stub_engine/CMakeLists.txt b/tests/framework/stub_engine/CMakeLists.txt deleted file mode 100644 index c86313c7..00000000 --- a/tests/framework/stub_engine/CMakeLists.txt +++ /dev/null @@ -1,58 +0,0 @@ -list(APPEND INCLUDE_DIRECTORIES - "${CMAKE_CURRENT_SOURCE_DIR}" - "${GE_CODE_DIR}" - "${GE_CODE_DIR}/inc" - "${GE_CODE_DIR}/metadef/inc" - "${GE_CODE_DIR}/ge" - "${GE_CODE_DIR}/ge/inc" - "${GE_CODE_DIR}/ge/ir_build" - "${GE_CODE_DIR}/metadef" - "${GE_CODE_DIR}/metadef/graph" - "${GE_CODE_DIR}/inc/external" - "${GE_CODE_DIR}/inc/framework/common" - "${GE_CODE_DIR}/metadef/inc/external" - "${GE_CODE_DIR}/metadef/inc/external/graph" - "${GE_CODE_DIR}/metadef/inc/graph" - "${GE_CODE_DIR}/inc/framework" - "${GE_CODE_DIR}/metadef/inc/common" - "${GE_CODE_DIR}/metadef/third_party" - "${GE_CODE_DIR}/metadef/third_party/transformer/inc" - "${GE_CODE_DIR}/parser" - "${GE_CODE_DIR}/parser/parser" - "${GE_CODE_DIR}/third_party/fwkacllib/inc" - "${GE_CODE_DIR}/third_party/fwkacllib/inc/cce" - "${GE_CODE_DIR}/third_party/fwkacllib/inc/ops" - "${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain" - "${GE_CODE_DIR}/tests/ut/ge" - "${GE_CODE_DIR}/tests/ut/common" - "${CMAKE_BINARY_DIR}" - "${CMAKE_BINARY_DIR}/proto/ge" - "${CMAKE_BINARY_DIR}/proto/ge/proto" - ) - -file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") - -# ---- Target : stub Host engine ---- -add_library(fe SHARED ${SOURCES}) - -target_include_directories(fe - PUBLIC - ${INCLUDE_DIRECTORIES} - ${CMAKE_CURRENT_SOURCE_DIR} - ) - -target_compile_definitions(fe PRIVATE - google=ascend_private - FMK_SUPPORT_DUMP - ) - -target_compile_options(fe PRIVATE - -g --coverage -fprofile-arcs -ftest-coverage - -Werror=format - ) - -target_link_libraries(fe PUBLIC - $ ${STUB_LIBS} metadef_graph -lmmpa -L${GE_CODE_DIR}/third_party/prebuild/x86_64 -lrt -ldl -lpthread -lgcov - ) - -set_target_properties(fe PROPERTIES CXX_STANDARD 11) diff --git a/tests/framework/stub_engine/engine/stub_engine.cc b/tests/framework/stub_engine/engine/stub_engine.cc deleted file mode 100644 index 622e8c4e..00000000 --- a/tests/framework/stub_engine/engine/stub_engine.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "stub_engine.h" -#include -#include -#include -#include -#include "framework/common/debug/ge_log.h" -#include "common/ge/ge_util.h" -#include "inc/st_types.h" - -namespace ge { -namespace st { -StubEngine &StubEngine::Instance() { - static StubEngine instance; - return instance; -} - -Status StubEngine::Initialize(const std::map &options) { - for (const auto engine_2_lib : kStubEngine2KernelLib) { - auto ops_kernel_store = MakeShared(engine_2_lib.second); - if (ops_kernel_store == nullptr) { - return FAILED; - } - ops_kernel_store_map_.insert(make_pair(engine_2_lib.second, ops_kernel_store)); - } - return SUCCESS; -} - -void StubEngine::GetOpsKernelInfoStores(std::map &ops_kernel_map) { - for (const auto name_2_ops_kernel_store : ops_kernel_store_map_) { - ops_kernel_map[name_2_ops_kernel_store.first] = name_2_ops_kernel_store.second; - } -} - -void StubEngine::GetGraphOptimizerObjs(std::map &) { - // no optimizer for host cpu engine -} - -Status StubEngine::Finalize() { - return SUCCESS; -} -} // namespace st -} // namespace ge - -ge::Status Initialize(const std::map &options) { - return ge::st::StubEngine::Instance().Initialize(options); -} - -void GetOpsKernelInfoStores(std::map &ops_kernel_map) { - ge::st::StubEngine::Instance().GetOpsKernelInfoStores(ops_kernel_map); -} - -void GetGraphOptimizerObjs(std::map &graph_optimizers) { - ge::st::StubEngine::Instance().GetGraphOptimizerObjs(graph_optimizers); -} - -ge::Status Finalize() { - return ge::st::StubEngine::Instance().Finalize(); -} diff --git a/tests/framework/stub_engine/engine/stub_engine.h b/tests/framework/stub_engine/engine/stub_engine.h deleted file mode 100644 index d3909115..00000000 --- a/tests/framework/stub_engine/engine/stub_engine.h +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GRAPH_ENGINE_LLT_STUB_ENGINE_H_ -#define GRAPH_ENGINE_LLT_STUB_ENGINE_H_ - -#if defined(_MSC_VER) -#ifdef FUNC_VISIBILITY -#define GE_FUNC_VISIBILITY _declspec(dllexport) -#else -#define GE_FUNC_VISIBILITY -#endif -#else -#ifdef FUNC_VISIBILITY -#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_VISIBILITY -#endif -#endif - -#include -#include -#include -#include "inc/st_types.h" -#include "common/opskernel/ops_kernel_info_store.h" -#include "common/optimizer/graph_optimizer.h" -#include "stub_engine/ops_kernel_store/stub_ops_kernel_store.h" - -using OpsKernelInfoStorePtr = std::shared_ptr; -using StubOpsKernelInfoStorePtr = std::shared_ptr; -using GraphOptimizerPtr = std::shared_ptr; - -namespace ge { -namespace st { -/** - * host cpu engine. - * Used for the ops which executes on host. - */ -class GE_FUNC_VISIBILITY StubEngine { - public: - /** - * get StubEngine instance. - * @return StubEngine instance. - */ - static StubEngine &Instance(); - - virtual ~StubEngine() = default; - - /** - * When Ge start, GE will invoke this interface - * @return The status whether initialize successfully - */ - Status Initialize(const std::map &options); - - /** - * After the initialize, GE will invoke this interface - * to get the Ops kernel Store. - * @param ops_kernel_map The host cpu's ops kernel info - */ - void GetOpsKernelInfoStores(std::map &ops_kernel_map); - - /** - * After the initialize, GE will invoke this interface - * to get the Graph Optimizer. - * @param graph_optimizers The host cpu's Graph Optimizer objs - */ - void GetGraphOptimizerObjs(std::map &graph_optimizers); - - /** - * When the graph finished, GE will invoke this interface - * @return The status whether initialize successfully - */ - Status Finalize(); - - StubEngine(const StubEngine &StubEngine) = delete; - StubEngine(const StubEngine &&StubEngine) = delete; - StubEngine &operator=(const StubEngine &StubEngine) = delete; - StubEngine &operator=(StubEngine &&StubEngine) = delete; - - private: - StubEngine() = default; - map ops_kernel_store_map_; -}; -} // namespace st -} // namespace ge - -extern "C" { - -/** - * When Ge start, GE will invoke this interface - * @return The status whether initialize successfully - */ -GE_FUNC_VISIBILITY ge::Status Initialize(const map &options); - -/** - * After the initialize, GE will invoke this interface to get the Ops kernel Store - * @param ops_kernel_map The host cpu's ops kernel info - */ -GE_FUNC_VISIBILITY void GetOpsKernelInfoStores(std::map &ops_kernel_map); - -/** - * After the initialize, GE will invoke this interface to get the Graph Optimizer - * @param graph_optimizers The host cpu's Graph Optimizer objs - */ -GE_FUNC_VISIBILITY void GetGraphOptimizerObjs(std::map &graph_optimizers); - -/** - * When the graph finished, GE will invoke this interface - * @return The status whether initialize successfully - */ -GE_FUNC_VISIBILITY ge::Status Finalize(); -} - -#endif // GRAPH_ENGINE_LLT_STUB_ENGINE_H_ diff --git a/tests/framework/stub_engine/inc/st_types.h b/tests/framework/stub_engine/inc/st_types.h deleted file mode 100644 index 92aa00d9..00000000 --- a/tests/framework/stub_engine/inc/st_types.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef GRAPHENGINE_ST_TYPES_H -#define GRAPHENGINE_ST_TYPES_H -#include -namespace ge { -namespace st { -const std::string kAicoreLibName = "AiCoreLib"; -const std::string kVectorLibName = "VectorLib"; -const std::string kAicpuLibName = "AicpuLib"; -const std::string kAicpuAscendLibName = "AicpuAscendLib"; -const std::string kHcclLibName = "HcclLib"; -const std::string kRTSLibName = "RTSLib"; -const std::map kStubEngine2KernelLib = { - {"AIcoreEngine", "AiCoreLib"}, {"VectorEngine", "VectorLib"}, - {"DNN_VM_AICPU", "AicpuLib"}, {"DNN_VM_AICPU_ASCEND", "AicpuAscendLib"}, - {"DNN_HCCL", "HcclLib"}, {"DNN_VM_RTS", "RTSLib"}}; -} // namespace st -} // namespace ge -#endif // GRAPHENGINE_ST_TYPES_H diff --git a/tests/framework/stub_engine/ops_kernel_store/op/host_op.cc b/tests/framework/stub_engine/ops_kernel_store/op/host_op.cc deleted file mode 100644 index 42678148..00000000 --- a/tests/framework/stub_engine/ops_kernel_store/op/host_op.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "inc/st_types.h" -#include "stub_engine/ops_kernel_store/op/host_op.h" -#include "framework/common/util.h" -#include "stub_engine/ops_kernel_store/op/stub_op_factory.h" - -namespace ge { -namespace st { -Status HostOp::Run() { - // no need to generate device task - return SUCCESS; -} -REGISTER_OP_CREATOR(Enter, RTSLib, HostOp); -REGISTER_OP_CREATOR(Merge, RTSLib, HostOp); -REGISTER_OP_CREATOR(Switch, RTSLib, HostOp); -REGISTER_OP_CREATOR(Less, AiCoreLib, HostOp); -REGISTER_OP_CREATOR(NextIteration, AiCoreLib, HostOp); -REGISTER_OP_CREATOR(LoopCond, RTSLib, HostOp); -REGISTER_OP_CREATOR(Exit, RTSLib, HostOp); -REGISTER_OP_CREATOR(StreamMerge, RTSLib, HostOp); -REGISTER_OP_CREATOR(StreamSwitch, RTSLib, HostOp); -REGISTER_OP_CREATOR(StreamActive, RTSLib, HostOp); -REGISTER_OP_CREATOR(Cast, AiCoreLib, HostOp); -REGISTER_OP_CREATOR(Transdata, AiCoreLib, HostOp); -} // namespace st -} // namespace ge diff --git a/tests/framework/stub_engine/ops_kernel_store/op/op.h b/tests/framework/stub_engine/ops_kernel_store/op/op.h deleted file mode 100644 index 3741567a..00000000 --- a/tests/framework/stub_engine/ops_kernel_store/op/op.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ -#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ - -#include -#include -#include -#include "common/ge_inner_error_codes.h" -#include "common/opskernel/ops_kernel_info_types.h" -#include "graph/node.h" - -namespace ge { -namespace st { -/** - * The base class for all op. - */ -class GE_FUNC_VISIBILITY Op { - public: - Op(const Node &node, RunContext &run_context) : run_context_(run_context), node_(node) {} - virtual ~Op() = default; - virtual Status Run() = 0; - - protected: - const RunContext &run_context_; - const Node &node_; -}; -} // namespace st -} // namespace ge - -#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ diff --git a/tests/framework/stub_engine/ops_kernel_store/op/stub_op_factory.cc b/tests/framework/stub_engine/ops_kernel_store/op/stub_op_factory.cc deleted file mode 100644 index 601bca4d..00000000 --- a/tests/framework/stub_engine/ops_kernel_store/op/stub_op_factory.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "stub_op_factory.h" -#include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" -#include "graph/op_desc.h" - -namespace ge { -namespace st { -OpFactory &OpFactory::Instance() { - static OpFactory instance; - return instance; -} - -std::shared_ptr OpFactory::CreateOp(const Node &node, RunContext &run_context) { - auto iter = op_creator_map_.find(node.GetType()); - if (iter != op_creator_map_.end()) { - return iter->second(node, run_context); - } - GELOGE(FAILED, "Not supported OP, type = %s, name = %s", node.GetType().c_str(), node.GetName().c_str()); - return nullptr; -} - -void OpFactory::RegisterCreator(const std::string &type, const std::string &kernel_lib, const OP_CREATOR_FUNC &func) { - if (func == nullptr) { - GELOGW("Func is NULL."); - return; - } - - if (all_store_ops_.find(kernel_lib) != all_store_ops_.end()) { - all_store_ops_[kernel_lib].emplace_back(type); - } else { - all_store_ops_[kernel_lib] = {type}; - } -} -} // namespace st -} // namespace ge diff --git a/tests/framework/stub_engine/ops_kernel_store/op/stub_op_factory.h b/tests/framework/stub_engine/ops_kernel_store/op/stub_op_factory.h deleted file mode 100644 index f41fd07e..00000000 --- a/tests/framework/stub_engine/ops_kernel_store/op/stub_op_factory.h +++ /dev/null @@ -1,109 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ -#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ - -#include -#include -#include -#include -#include -#include "common/ge/ge_util.h" -#include "stub_engine/ops_kernel_store/op/op.h" -#include "inc/st_types.h" - -namespace ge { -namespace st { -using OP_CREATOR_FUNC = std::function(const Node &, RunContext &)>; - -/** - * manage all the op, support create op. - */ -class GE_FUNC_VISIBILITY OpFactory { - public: - static OpFactory &Instance(); - - /** - * @brief create Op. - * @param [in] node share ptr of node - * @param [in] run_context run context - * @return not nullptr success - * @return nullptr fail - */ - std::shared_ptr CreateOp(const Node &node, RunContext &run_context); - - /** - * @brief Register Op create function. - * @param [in] type Op type - * @param [in] func Op create func - */ - void RegisterCreator(const std::string &type, const std::string &lib_name, const OP_CREATOR_FUNC &func); - - const std::vector &GetAllOps() const { - return all_ops_; - } - - const std::vector &GetAllOps(std::string lib_name) const { - auto iter = all_store_ops_.find(lib_name); - if (iter == all_store_ops_.end()) { - return all_ops_; - } - return iter->second; - } - - bool CheckSupported(const std::string &type) { - return op_creator_map_.find(type) != op_creator_map_.end(); - } - - OpFactory(const OpFactory &) = delete; - OpFactory &operator=(const OpFactory &) = delete; - OpFactory(OpFactory &&) = delete; - OpFactory &operator=(OpFactory &&) = delete; - - private: - OpFactory() = default; - ~OpFactory() = default; - - // the op creator function map - std::map op_creator_map_; - std::map> lib_op_creator_map_; - std::vector all_ops_; - std::map> all_store_ops_; -}; - -class GE_FUNC_VISIBILITY OpRegistrar { - public: - OpRegistrar(const std::string &type, const std::string &kernel_lib, const OP_CREATOR_FUNC &func) { - OpFactory::Instance().RegisterCreator(type, kernel_lib, func); - } - ~OpRegistrar() = default; - - OpRegistrar(const OpRegistrar &) = delete; - OpRegistrar &operator=(const OpRegistrar &) = delete; - OpRegistrar(OpRegistrar &&) = delete; - OpRegistrar &operator=(OpRegistrar &&) = delete; -}; - -#define REGISTER_OP_CREATOR(type, lib_name, clazz) \ - std::shared_ptr Creator_##type##Op(const Node &node, RunContext &run_context) { \ - return MakeShared(node, run_context); \ - } \ - OpRegistrar g_##type##Op_creator(#type, #lib_name, Creator_##type##Op) -} // namespace st -} // namespace ge - -#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ diff --git a/tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_store.cc b/tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_store.cc deleted file mode 100644 index d43fee88..00000000 --- a/tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_store.cc +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "stub_ops_kernel_store.h" -#include -#include "ge/ge_api_types.h" -#include "framework/common/debug/ge_log.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/type_utils.h" -#include "op/stub_op_factory.h" - -namespace ge { -namespace st { -using domi::TaskDef; -using std::map; -using std::string; -using std::vector; - -Status StubOpsKernelInfoStore::Initialize(const map &options) { - GELOGI("StubOpsKernelInfoStore init start."); - string engine_name; - for (const auto &engine_2_lib : kStubEngine2KernelLib) { - if (engine_2_lib.second == store_name_) { - engine_name = engine_2_lib.first; - } - } - if (engine_name.empty()) { - return FAILED; - } - - OpInfo default_op_info = {.engine = engine_name, - .opKernelLib = store_name_, - .computeCost = 0, - .flagPartial = false, - .flagAsync = false, - .isAtomic = false}; - // Init op_info_map_ - auto all_ops_in_store = OpFactory::Instance().GetAllOps(store_name_); - for (auto &op : all_ops_in_store) { - op_info_map_[op] = default_op_info; - } - - GELOGI("StubOpsKernelInfoStore inited success. op num=%zu", op_info_map_.size()); - return SUCCESS; -} - -Status StubOpsKernelInfoStore::Finalize() { - op_info_map_.clear(); - return SUCCESS; -} - -void StubOpsKernelInfoStore::GetAllOpsKernelInfo(map &infos) const { - infos = op_info_map_; -} - -bool StubOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const { - if (op_desc == nullptr) { - return false; - } - return op_info_map_.count(op_desc->GetType()) > 0; -} -} // namespace st -} // namespace ge diff --git a/tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_store.h b/tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_store.h deleted file mode 100644 index ea7f712b..00000000 --- a/tests/framework/stub_engine/ops_kernel_store/stub_ops_kernel_store.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ -#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ - -#if defined(_MSC_VER) -#ifdef FUNC_VISIBILITY -#define GE_FUNC_VISIBILITY _declspec(dllexport) -#else -#define GE_FUNC_VISIBILITY -#endif -#else -#ifdef FUNC_VISIBILITY -#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_VISIBILITY -#endif -#endif - -#include -#include -#include - -#include "common/opskernel/ops_kernel_info_store.h" - -namespace ge { -namespace st { -/*const vector kStubOpKernelLibNameVec = { - "AiCoreLib", - "AicpuLib", - "HcclLib", - "RTSLib" -};*/ -class GE_FUNC_VISIBILITY StubOpsKernelInfoStore : public OpsKernelInfoStore { - public: - StubOpsKernelInfoStore(std::string store_name) : store_name_(store_name) {} - ~StubOpsKernelInfoStore() override = default; - Status Initialize(const std::map &options) override; - Status Finalize() override; - bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override; - void GetAllOpsKernelInfo(std::map &infos) const override; - std::string GetOpsKernelStoreName() const { - return store_name_; - } - - StubOpsKernelInfoStore(const StubOpsKernelInfoStore &ops_kernel_store) = delete; - StubOpsKernelInfoStore(const StubOpsKernelInfoStore &&ops_kernel_store) = delete; - StubOpsKernelInfoStore &operator=(const StubOpsKernelInfoStore &ops_kernel_store) = delete; - StubOpsKernelInfoStore &operator=(StubOpsKernelInfoStore &&ops_kernel_store) = delete; - - private: - // store op name and OpInfo key-value pair - std::map op_info_map_; - std::string store_name_; -}; -} // namespace st -} // namespace ge - -#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ diff --git a/tests/framework/utils/builder/graph_builder_utils.cc b/tests/framework/utils/builder/graph_builder_utils.cc deleted file mode 100644 index c5555235..00000000 --- a/tests/framework/utils/builder/graph_builder_utils.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph_builder_utils.h" -#include "inc/external/graph/operator.h" -#include "inc/external/graph/operator_factory.h" -#include "graph/utils/graph_utils.h" - -namespace ge { -namespace st { -NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, - Format format, DataType data_type, std::vector shape) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape(std::move(shape))); - tensor_desc->SetFormat(format); - tensor_desc->SetDataType(data_type); - - auto op_desc = std::make_shared(name, type); - for (int i = 0; i < in_cnt; ++i) { - op_desc->AddInputDesc(tensor_desc->Clone()); - } - for (int i = 0; i < out_cnt; ++i) { - op_desc->AddOutputDesc(tensor_desc->Clone()); - } - - return graph_->AddNode(op_desc); -} -void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { - GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); -} -void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { - GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); -} -} // namespace st -} // namespace ge diff --git a/tests/framework/utils/builder/graph_builder_utils.h b/tests/framework/utils/builder/graph_builder_utils.h deleted file mode 100644 index 4627f082..00000000 --- a/tests/framework/utils/builder/graph_builder_utils.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H -#define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H - -#include -#include - -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" -#include "graph/graph.h" -#include "graph/node.h" - -namespace ge { -namespace st { -class ComputeGraphBuilder { - public: - explicit ComputeGraphBuilder(const std::string &name) { - graph_ = std::make_shared(name); - } - NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, - Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, - std::vector shape = {1, 1, 224, 224}); - void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); - void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); - ComputeGraphPtr GetComputeGraph() { - graph_->TopologicalSorting(); - return graph_; - } - Graph GetGraph() { - graph_->TopologicalSorting(); - return GraphUtils::CreateGraphFromComputeGraph(graph_); - } - - private: - ComputeGraphPtr graph_; -}; -} // namespace st -} // namespace ge - -#endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H diff --git a/tests/st/testcase/CMakeLists.txt b/tests/st/testcase/CMakeLists.txt index 9d1d5a0e..56b3f41b 100644 --- a/tests/st/testcase/CMakeLists.txt +++ b/tests/st/testcase/CMakeLists.txt @@ -8,7 +8,7 @@ target_include_directories(graph_engine_test set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) -target_link_libraries(graph_engine_test PRIVATE gtest gtest_main framework) +target_link_libraries(graph_engine_test PRIVATE gtest ge_graph_dsl ge_with_env) include(CTest) enable_testing() diff --git a/tests/st/testcase/test_framework_dummy.cc b/tests/st/testcase/test_framework_dummy.cc index 951e6b2b..8f13bb78 100644 --- a/tests/st/testcase/test_framework_dummy.cc +++ b/tests/st/testcase/test_framework_dummy.cc @@ -15,19 +15,12 @@ */ #include -#include #include "external/ge/ge_api.h" #include "graph/debug/ge_attr_define.h" #include "framework/common/types.h" -#include "builder/graph_builder_utils.h" -#include "graph/operator_reg.h" -#include "graph/operator.h" -#define protected public -#define private public -#include "graph/utils/op_desc_utils.h" +#include "ge_running_env/ge_running_env_faker.h" #include "ge_graph_dsl/graph_dsl.h" -#undef protected -#undef private +#include "ge_graph_dsl/assert/graph_assert.h" using namespace std; using namespace ge; @@ -53,79 +46,57 @@ namespace { * **/ Graph BuildV1ControlFlowGraph() { - // build graph - st::ComputeGraphBuilder graphBuilder("g1"); - auto data_i = graphBuilder.AddNode("data_i", DATA, 1, 1); - auto enter_i = graphBuilder.AddNode("enter_i", ENTER, 1, 1); - ge::AttrUtils::SetStr(enter_i->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); - auto merge_i = graphBuilder.AddNode("merge_i", MERGE, 2, 1); - auto const_5 = graphBuilder.AddNode("const_5", CONSTANT, 0, 1); - auto less = graphBuilder.AddNode("less", LESS, 2, 1); - auto loopcond = graphBuilder.AddNode("loopcond", LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL); - auto switch_i = graphBuilder.AddNode("switch_i", SWITCH, 2, 2); - auto exit_i = graphBuilder.AddNode("switch_i", EXIT, 1, 1); - auto const_1 = graphBuilder.AddNode("const_1", CONSTANT, 0, 1); - auto add = graphBuilder.AddNode("add", ADD, 2, 1); - auto next_iteration_i = graphBuilder.AddNode("next_iteration_i", NEXTITERATION, 1, 1); - - auto data_a = graphBuilder.AddNode("data_a", DATA, 1, 1); - auto enter_a = graphBuilder.AddNode("enter_a", ENTER, 1, 1); - ge::AttrUtils::SetStr(enter_a->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); - auto merge_a = graphBuilder.AddNode("merge_a", MERGE, 2, 1); - auto switch_a = graphBuilder.AddNode("switch_a", SWITCH, 2, 2); - auto exit_a = graphBuilder.AddNode("exit_a", EXIT, 1, 1); - auto mul = graphBuilder.AddNode("mul", MUL, 2, 1); - auto const_2 = graphBuilder.AddNode("const_2", CONSTANT, 0, 1); - auto next_iteration_a = graphBuilder.AddNode("next_iteration_a", NEXTITERATION, 1, 1); - auto netoutput = graphBuilder.AddNode("netoutput", NETOUTPUT, 2, 2); - // i = i+1 - graphBuilder.AddDataEdge(data_i, 0, enter_i, 0); - graphBuilder.AddDataEdge(enter_i, 0, merge_i, 0); - graphBuilder.AddDataEdge(next_iteration_i, 0, merge_i, 1); - graphBuilder.AddDataEdge(merge_i, 0, less, 0); - graphBuilder.AddDataEdge(const_5, 0, less, 1); - graphBuilder.AddDataEdge(less, 0, loopcond, 0); - graphBuilder.AddDataEdge(loopcond, 0, switch_i, 1); - graphBuilder.AddDataEdge(merge_i, 0, switch_i, 0); - graphBuilder.AddDataEdge(switch_i, 0, exit_i, 0); - graphBuilder.AddDataEdge(switch_i, 1, add, 0); - graphBuilder.AddDataEdge(const_1, 0, add, 1); - graphBuilder.AddDataEdge(add, 0, next_iteration_i, 0); - graphBuilder.AddDataEdge(exit_i, 0, netoutput, 1); - // a=a*2 - graphBuilder.AddDataEdge(data_a, 0, enter_a, 0); - graphBuilder.AddDataEdge(enter_a, 0, merge_a, 0); - graphBuilder.AddDataEdge(next_iteration_a, 0, merge_a, 1); - graphBuilder.AddDataEdge(loopcond, 0, switch_a, 1); - graphBuilder.AddDataEdge(merge_a, 0, switch_a, 0); - graphBuilder.AddDataEdge(switch_a, 0, exit_a, 0); - graphBuilder.AddDataEdge(switch_a, 1, mul, 0); - graphBuilder.AddDataEdge(const_2, 0, mul, 1); - graphBuilder.AddDataEdge(mul, 0, next_iteration_a, 0); - graphBuilder.AddDataEdge(exit_a, 0, netoutput, 0); - // set const weight int64_t dims_size = 1; vector data_vec = {5}; for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); vector data_value_vec(dims_size, 1); GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); - GeTensorPtr data_tensor = make_shared(data_tensor_desc, (uint8_t *) data_value_vec.data(), + GeTensorPtr data_tensor = make_shared(data_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t)); - OpDescUtils::SetWeights(const_5->GetOpDesc(), data_tensor); - OpDescUtils::SetWeights(const_2->GetOpDesc(), data_tensor); - OpDescUtils::SetWeights(const_1->GetOpDesc(), data_tensor); - return graphBuilder.GetGraph(); + auto enter = OP_CFG(ENTER).Attr(ENTER_ATTR_FRAME_NAME, "1"); + auto const_op = OP_CFG(CONSTANT).Weight(data_tensor); + + DEF_GRAPH(g1) { + CHAIN(NODE("data_i", DATA) + ->NODE("enter_i", enter) + ->EDGE(0, 0) + ->NODE("merge_i", MERGE) + ->NODE("less", LESS) + ->NODE("loopcond", LOOPCOND)); + CHAIN(NODE("const_1", const_op) + ->EDGE(0, 1) + ->NODE("add", ADD) + ->NODE("iteration_i", NEXTITERATION) + ->EDGE(0, 1) + ->NODE("merge_i")); + CHAIN(NODE("const_5", const_op)->EDGE(0, 1)->NODE("less")); + CHAIN(NODE("loopcond") + ->EDGE(0, 1) + ->NODE("switch_i", SWITCH) + ->EDGE(0, 0) + ->NODE("exit_i", EXIT) + ->EDGE(0, 1) + ->NODE("netoutput", NETOUTPUT)); + CHAIN(NODE("merge_i")->EDGE(0, 0)->NODE("switch_i")->EDGE(1, 0)->NODE("add")); + CHAIN(NODE("data_a", DATA) + ->NODE("enter_a", enter) + ->NODE("merge_a", MERGE) + ->NODE("switch_a", SWITCH) + ->NODE("exit_a", EXIT) + ->EDGE(0, 0) + ->NODE("netoutput")); + CHAIN(NODE("iteration_a", NEXTITERATION)->EDGE(0, 1)->NODE("merge_a")); + CHAIN(NODE("loopcond")->EDGE(0, 1)->NODE("switch_a")->EDGE(1, 0)->NODE("mul", MUL)); + CHAIN(NODE("const_2", const_op)->EDGE(0, 1)->NODE("mul")->EDGE(0, 0)->NODE("iteration_a")); + }; + return ToGeGraph(g1); } } // namespace class FrameworkTest : public testing::Test { protected: - void SetUp() { - // ge initialize - map options; - auto ret = ge::GEInitialize(options); - EXPECT_EQ(ret, SUCCESS); - } + GeRunningEnvFaker ge_env; + void SetUp() { ge_env.InstallDefault(); } void TearDown() {} }; @@ -136,19 +107,19 @@ TEST_F(FrameworkTest, test_framework_add) { DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); CHAIN(NODE("data2", DATA)->NODE("add")); - }); + }; - auto graph = ToGeGraph(g1); - // new session & add graph map options; Session session(options); - auto ret = session.AddGraph(1, graph, options); - EXPECT_EQ(ret, SUCCESS); - // build input tensor + session.AddGraph(1, ToGeGraph(g1), options); std::vector inputs; - // build_graph through session - ret = session.BuildGraph(1, inputs); + auto ret = session.BuildGraph(1, inputs); + EXPECT_EQ(ret, SUCCESS); + CHECK_GRAPH(PreRunAfterBuild) { + ASSERT_EQ(graph->GetName(), "g1_1"); + ASSERT_EQ(graph->GetAllNodesSize(), 4); + }; } /** data a = 2; diff --git a/tests/st/testcase/test_ge_opt_info.cc b/tests/st/testcase/test_ge_opt_info.cc new file mode 100644 index 00000000..2e8e5382 --- /dev/null +++ b/tests/st/testcase/test_ge_opt_info.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "external/ge/ge_api.h" +#include "easy_graph/builder/graph_dsl.h" +#include "graph/compute_graph.h" +#include "framework/common/types.h" +#include "graph/ge_local_context.h" +#include "ge_graph_dsl/graph_dsl.h" + +namespace ge { +class STEST_opt_info : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(STEST_opt_info, get_opt_info_all) { + std::map options = {{ge::SOC_VERSION, "Ascend310"}}; + GetThreadLocalContext().SetGlobalOption(options); + + /// data1 data2 + /// \ / + /// add + // build graph + DEF_GRAPH(g1) { + CHAIN(NODE("data1", DATA)->NODE("add", ADD)); + CHAIN(NODE("data2", DATA)->NODE("add")); + }; + + auto graph = ToGeGraph(g1); + + // new session & add graph + Session session(options); + auto ret = session.AddGraph(1, graph, options); + EXPECT_EQ(ret, SUCCESS); + // build input tensor + std::vector inputs; + // build_graph through session + ret = session.BuildGraph(1, inputs); + EXPECT_EQ(ret, SUCCESS); + + std::map graph_options = GetThreadLocalContext().GetAllGraphOptions(); + auto itr = graph_options.find("opt_module.fe"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.pass"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.op_tune"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.rl_tune"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.aoe"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); +} + +TEST_F(STEST_opt_info, get_opt_info_success) { + std::map options = {{ge::SOC_VERSION, "Ascend910"}}; + GetThreadLocalContext().SetGlobalOption(options); + + /// data1 data2 + /// \ / + /// add + // build graph + DEF_GRAPH(g1) { + CHAIN(NODE("data1", DATA)->NODE("add", ADD)); + CHAIN(NODE("data2", DATA)->NODE("add")); + }; + + auto graph = ToGeGraph(g1); + + // new session & add graph + Session session(options); + auto ret = session.AddGraph(1, graph, options); + EXPECT_EQ(ret, SUCCESS); + // build input tensor + std::vector inputs; + // build_graph through session + ret = session.BuildGraph(1, inputs); + EXPECT_EQ(ret, SUCCESS); + + std::map graph_options = GetThreadLocalContext().GetAllGraphOptions(); + auto itr = graph_options.find("opt_module.fe"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.pass"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.op_tune"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); +} +} // namespace ge diff --git a/tests/st/testcase/test_main.cc b/tests/st/testcase/test_main.cc new file mode 100644 index 00000000..a7a71954 --- /dev/null +++ b/tests/st/testcase/test_main.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "external/ge/ge_api.h" +#include "ge_graph_dsl/assert/check_utils.h" +#include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" + +using namespace std; +using namespace ge; + +int main(int argc, char **argv) { + // init the logging + map options; + auto init_status = ge::GEInitialize(options); + if (init_status != SUCCESS) { + std::cout << "ge init failed , ret code:" << init_status << endl; + } + GeRunningEnvFaker::BackupEnv(); + CheckUtils::init(); + testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + return ret; +} diff --git a/tests/ut/common/graph/CMakeLists.txt b/tests/ut/common/graph/CMakeLists.txt index 73780967..8da69c14 100644 --- a/tests/ut/common/graph/CMakeLists.txt +++ b/tests/ut/common/graph/CMakeLists.txt @@ -61,52 +61,21 @@ set(UT_FILES "testcase/ge_graph/ge_model_unittest.cc" ) -set(SRC_FILES - "${GE_CODE_DIR}/metadef/graph/option/ge_local_context.cc" - "${GE_CODE_DIR}/metadef/graph/option/ge_context.cc" - "${GE_CODE_DIR}/metadef/graph/anchor.cc" - "${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc" - "${GE_CODE_DIR}/metadef/graph/attr_value.cc" - "${GE_CODE_DIR}/metadef/graph/buffer.cc" - "${GE_CODE_DIR}/metadef/graph/aligned_ptr.cc" - "${GE_CODE_DIR}/metadef/graph/compute_graph.cc" - "${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc" - "${GE_CODE_DIR}/metadef/graph/graph.cc" - "${GE_CODE_DIR}/metadef/graph/gnode.cc" - "${GE_CODE_DIR}/metadef/graph/ascend_string.cc" - "${GE_CODE_DIR}/metadef/graph/model.cc" - "${GE_CODE_DIR}/metadef/graph/model_serialize.cc" - "${GE_CODE_DIR}/metadef/graph/node.cc" - "${GE_CODE_DIR}/metadef/graph/op_desc.cc" - "${GE_CODE_DIR}/metadef/graph/operator.cc" - "${GE_CODE_DIR}/metadef/graph/operator_factory.cc" - "${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" - "${GE_CODE_DIR}/metadef/graph/tensor.cc" - "${GE_CODE_DIR}/metadef/graph/types.cc" - "${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" - "${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" - "${GE_CODE_DIR}/metadef/graph/format_refiner.cc" - "${GE_CODE_DIR}/metadef/graph/inference_context.cc" - "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" - "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/tensor_utils.cc" - "${GE_CODE_DIR}/metadef/ops/op_imp.cpp" - "${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc" - "${GE_CODE_DIR}/metadef/graph/runtime_inference_context.cc" - "${GE_CODE_DIR}/metadef/graph/ref_relation.cc" - "${GE_CODE_DIR}/metadef/third_party/transformer/src/transfer_shape_according_to_format.cc" - "${GE_CODE_DIR}/metadef/third_party/transformer/src/axis_util.cc" - "${GE_CODE_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" - "${GE_CODE_DIR}/metadef/graph/utils/transformer_utils.cc" -) +FILE(GLOB_RECURSE GRAPH_SRC_FILES_DEPTH0 ${GE_CODE_DIR}/metadef/graph/*.cc) +FILE(GLOB_RECURSE GRAPH_SRC_FILES_DEPTH1 ${GE_CODE_DIR}/metadef/graph/*/*.cc) +FILE(GLOB_RECURSE GRAPH_SRC_FILES_DEPTH2 ${GE_CODE_DIR}/metadef/graph/*/*/*.cc) + +AUX_SOURCE_DIRECTORY(${GE_CODE_DIR}/metadef/ops GRAPH_OPS_SRC_FILES) +AUX_SOURCE_DIRECTORY(${GE_CODE_DIR}/metadef/third_party/transformer/src TRANSFORMER_SRC_FILES) -#add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) -add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) +add_executable(ut_libgraph ${UT_FILES} + ${GRAPH_SRC_FILES_DEPTH0} + ${GRAPH_SRC_FILES_DEPTH1} + ${GRAPH_SRC_FILES_DEPTH2} + ${GRAPH_OPS_SRC_FILES} + ${TRANSFORMER_SRC_FILES} + ${PROTO_SRCS} ${PROTO_HDRS} +) target_compile_options(ut_libgraph PRIVATE -g --coverage -fprofile-arcs -ftest-coverage diff --git a/tests/ut/common/graph/testcase/ge_graph/ge_graph_anchor_unittest.cc b/tests/ut/common/graph/testcase/ge_graph/ge_graph_anchor_unittest.cc index 5cf7569b..85328b27 100644 --- a/tests/ut/common/graph/testcase/ge_graph/ge_graph_anchor_unittest.cc +++ b/tests/ut/common/graph/testcase/ge_graph/ge_graph_anchor_unittest.cc @@ -272,115 +272,3 @@ TEST_F(UtestGeAnchor, graph_utils_test) { EXPECT_EQ(GraphUtils::RemoveEdge(conv_node->GetOutDataAnchor(0), bn_node->GetInControlAnchor()), GRAPH_SUCCESS); EXPECT_EQ(GraphUtils::RemoveEdge(conv_node->GetOutDataAnchor(0), bn_node->GetInControlAnchor()), GRAPH_FAILED); } - -TEST_F(UtestGeAnchor, data_anchor_replace_peer) { - ComputeGraphPtr graph_ptr = std::make_shared("graph"); - OpDescPtr in_op_ptr = std::make_shared("in_op_1", "float"); - in_op_ptr->AddInputDesc("x1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddInputDesc("x2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddInputDesc("x3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - NodePtr node1 = graph_ptr->AddNode(in_op_ptr); - NodePtr node2 = graph_ptr->AddNode(in_op_ptr); - NodePtr node3 = graph_ptr->AddNode(in_op_ptr); - - OutDataAnchorPtr out_data_anchor = node1->GetOutDataAnchor(1); - InDataAnchorPtr in_data_anchor = node2->GetInDataAnchor(1); - EXPECT_EQ(out_data_anchor != nullptr, true); - EXPECT_EQ(in_data_anchor != nullptr, true); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(0)), GRAPH_SUCCESS); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(1)), GRAPH_SUCCESS); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(2)), GRAPH_SUCCESS); - - size_t out_idx = 0; - for (; out_idx < out_data_anchor->peer_anchors_.size(); out_idx++) { - if (out_data_anchor->peer_anchors_[out_idx].lock() == in_data_anchor) { - break; - } - } - EXPECT_EQ(out_idx, 1); - - size_t in_idx = 0; - for (; in_idx < in_data_anchor->peer_anchors_.size(); in_idx++) { - if (in_data_anchor->peer_anchors_[in_idx].lock() == out_data_anchor) { - break; - } - } - EXPECT_EQ(in_idx, 0); - - out_data_anchor->ReplacePeer(in_data_anchor, node3->GetInDataAnchor(1), node3->GetOutDataAnchor(1)); - - size_t out_idx1 = 0; - for (; out_idx1 < out_data_anchor->peer_anchors_.size(); out_idx1++) { - if (out_data_anchor->peer_anchors_[out_idx1].lock() == node3->GetInDataAnchor(1)) { - break; - } - } - EXPECT_EQ(out_idx1, out_idx); - - size_t in_idx1 = 0; - for (; in_idx1 < in_data_anchor->peer_anchors_.size(); in_idx1++) { - if (in_data_anchor->peer_anchors_[in_idx1].lock() == node3->GetOutDataAnchor(1)) { - break; - } - } - EXPECT_EQ(in_idx1, in_idx); -} - -TEST_F(UtestGeAnchor, graph_utils_insert_node) { - ComputeGraphPtr graph_ptr = std::make_shared("graph"); - OpDescPtr in_op_ptr = std::make_shared("in_op_1", "float"); - in_op_ptr->AddInputDesc("x1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddInputDesc("x2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddInputDesc("x3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - NodePtr node1 = graph_ptr->AddNode(in_op_ptr); - NodePtr node2 = graph_ptr->AddNode(in_op_ptr); - NodePtr node3 = graph_ptr->AddNode(in_op_ptr); - - OutDataAnchorPtr out_data_anchor = node1->GetOutDataAnchor(1); - InDataAnchorPtr in_data_anchor = node2->GetInDataAnchor(1); - EXPECT_EQ(out_data_anchor != nullptr, true); - EXPECT_EQ(in_data_anchor != nullptr, true); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(0)), GRAPH_SUCCESS); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(1)), GRAPH_SUCCESS); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(2)), GRAPH_SUCCESS); - - size_t out_idx = 0; - for (; out_idx < out_data_anchor->peer_anchors_.size(); out_idx++) { - if (out_data_anchor->peer_anchors_[out_idx].lock() == in_data_anchor) { - break; - } - } - EXPECT_EQ(out_idx, 1); - - size_t in_idx = 0; - for (; in_idx < in_data_anchor->peer_anchors_.size(); in_idx++) { - if (in_data_anchor->peer_anchors_[in_idx].lock() == out_data_anchor) { - break; - } - } - EXPECT_EQ(in_idx, 0); - - GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, in_data_anchor, node3); - - size_t out_idx1 = 0; - for (; out_idx1 < out_data_anchor->peer_anchors_.size(); out_idx1++) { - if (out_data_anchor->peer_anchors_[out_idx1].lock() == node3->GetInDataAnchor(0)) { - break; - } - } - EXPECT_EQ(out_idx1, out_idx); - - size_t in_idx1 = 0; - for (; in_idx1 < in_data_anchor->peer_anchors_.size(); in_idx1++) { - if (in_data_anchor->peer_anchors_[in_idx1].lock() == node3->GetOutDataAnchor(0)) { - break; - } - } - EXPECT_EQ(in_idx1, in_idx); -} diff --git a/tests/ut/common/graph/testcase/ge_graph/ge_model_serialize_unittest.cc b/tests/ut/common/graph/testcase/ge_graph/ge_model_serialize_unittest.cc index 0366446c..c91f68df 100644 --- a/tests/ut/common/graph/testcase/ge_graph/ge_model_serialize_unittest.cc +++ b/tests/ut/common/graph/testcase/ge_graph/ge_model_serialize_unittest.cc @@ -30,6 +30,7 @@ #include "graph/model_serialize.h" #include "graph/detail/model_serialize_imp.h" +#include "graph/node_impl.h" #include "graph/ge_attr_value.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" @@ -1062,7 +1063,7 @@ TEST(UtestGeModelSerialize, test_model_serialize_imp_invalid_param) { auto graph = std::make_shared("test_graph"); auto node = graph->AddNode(std::make_shared()); - node->op_ = nullptr; + node->impl_->op_ = nullptr; ge::proto::ModelDef model_def; Model model; model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); diff --git a/tests/ut/common/graph/testcase/ge_graph/ge_tensor_unittest.cc b/tests/ut/common/graph/testcase/ge_graph/ge_tensor_unittest.cc index aa43ac99..838df735 100644 --- a/tests/ut/common/graph/testcase/ge_graph/ge_tensor_unittest.cc +++ b/tests/ut/common/graph/testcase/ge_graph/ge_tensor_unittest.cc @@ -25,6 +25,7 @@ #include "graph/ge_attr_value.h" #include "graph/tensor.h" #include "graph/utils/tensor_utils.h" +#include "graph/ge_tensor_impl.h" #undef private #undef protected @@ -196,23 +197,6 @@ TEST_F(UtestGeTensor, test_shape_copy_move) { EXPECT_EQ(shape4.GetDimNum(), 3); } -TEST_F(UtestGeTensor, test_tensor_desc_invalid_null) { - GeTensorDesc tensor_desc(nullptr, nullptr); - EXPECT_EQ(tensor_desc.GetDataType(), DT_UNDEFINED); - EXPECT_EQ(tensor_desc.GetFormat(), FORMAT_RESERVED); - EXPECT_EQ(tensor_desc.MutableShape().shape_def_.GetProtoMsg(), nullptr); - - GeTensorDesc tensor_desc2; - EXPECT_EQ(tensor_desc2.GetDataType(), DT_FLOAT); - EXPECT_EQ(tensor_desc2.GetFormat(), FORMAT_ND); - - tensor_desc2.SetDataType(DT_DUAL_SUB_INT8); - EXPECT_EQ(tensor_desc2.GetDataType(), DT_DUAL_SUB_INT8); - - TensorUtils::SetWeightSize(tensor_desc, 100); - EXPECT_EQ(TensorUtils::GetWeightSize(tensor_desc), 0); -} - TEST_F(UtestGeTensor, test_tensor_invalid_null) { ProtoMsgOwner msg_owner; GeTensor tensor(msg_owner, nullptr); diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 63579109..a7afee3f 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -20,6 +20,7 @@ set(CMAKE_CXX_STANDARD 11) set(PROTO_LIST "${GE_CODE_DIR}/metadef/proto/om.proto" "${GE_CODE_DIR}/metadef/proto/ge_ir.proto" + "${GE_CODE_DIR}/metadef/proto/task.proto" "${GE_CODE_DIR}/metadef/proto/ge_api.proto" "${GE_CODE_DIR}/metadef/proto/insert_op.proto" "${GE_CODE_DIR}/metadef/proto/dump_task.proto" @@ -62,66 +63,23 @@ include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain) +include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info) include_directories(${GE_CODE_DIR}/tests/ut/ge) include_directories(${GE_CODE_DIR}/tests/ut/common) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) -set(GRAPH_SRC_FILES - "${GE_CODE_DIR}/metadef/graph/option/ge_local_context.cc" - "${GE_CODE_DIR}/metadef/graph/option/ge_context.cc" - "${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc" - "${GE_CODE_DIR}/metadef/graph/anchor.cc" - "${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc" - "${GE_CODE_DIR}/metadef/graph/attr_value.cc" - "${GE_CODE_DIR}/metadef/graph/buffer.cc" - "${GE_CODE_DIR}/metadef/graph/aligned_ptr.cc" - "${GE_CODE_DIR}/metadef/graph/compute_graph.cc" - "${GE_CODE_DIR}/metadef/graph/graph.cc" - "${GE_CODE_DIR}/metadef/graph/gnode.cc" - "${GE_CODE_DIR}/metadef/graph/ascend_string.cc" - "${GE_CODE_DIR}/metadef/graph/inference_context.cc" - "${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" - "${GE_CODE_DIR}/metadef/graph/model.cc" - "${GE_CODE_DIR}/metadef/graph/model_serialize.cc" - "${GE_CODE_DIR}/metadef/graph/node.cc" - "${GE_CODE_DIR}/metadef/graph/runtime_inference_context.cc" - "${GE_CODE_DIR}/metadef/graph/op_desc.cc" - "${GE_CODE_DIR}/metadef/third_party/transformer/src/transfer_shape_according_to_format.cc" - "${GE_CODE_DIR}/metadef/third_party/transformer/src/axis_util.cc" - "${GE_CODE_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" - "${GE_CODE_DIR}/metadef/graph/operator.cc" - "${GE_CODE_DIR}/metadef/graph/operator_factory.cc" - "${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" - "${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" - "${GE_CODE_DIR}/metadef/graph/ref_relation.cc" - "${GE_CODE_DIR}/metadef/graph/tensor.cc" - "${GE_CODE_DIR}/metadef/graph/types.cc" - "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" - "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/tensor_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" - "${GE_CODE_DIR}/metadef/graph/utils/transformer_utils.cc" - "${GE_CODE_DIR}/metadef/graph/debug/graph_debug.cc" - "${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc" - "${GE_CODE_DIR}/metadef/ops/op_imp.cpp" - "${GE_CODE_DIR}/metadef/register/register.cpp" - "${GE_CODE_DIR}/metadef/register/register_pass.cpp" - "${GE_CODE_DIR}/metadef/register/op_kernel_registry.cpp" - "${GE_CODE_DIR}/metadef/register/auto_mapping_util.cpp" - "${GE_CODE_DIR}/metadef/register/tensor_assign.cpp" - "${GE_CODE_DIR}/metadef/register/register_format_transfer.cc" - "${GE_CODE_DIR}/metadef/graph/format_refiner.cc" - "${GE_CODE_DIR}/metadef/register/ops_kernel_builder_registry.cc" - "${GE_CODE_DIR}/metadef/register/op_tiling.cpp" - "${GE_CODE_DIR}/metadef/graph/utils/tuning_utils.cc" - "${GE_CODE_DIR}/metadef/register/op_tiling_registry.cpp" -) + +#### GRAPH_SRC_FILES #### +FILE(GLOB_RECURSE GRAPH_SRC_FILES_DEPTH0 ${GE_CODE_DIR}/metadef/graph/*.cc) +FILE(GLOB_RECURSE GRAPH_SRC_FILES_DEPTH1 ${GE_CODE_DIR}/metadef/graph/*/*.cc) +FILE(GLOB_RECURSE GRAPH_SRC_FILES_DEPTH2 ${GE_CODE_DIR}/metadef/graph/*/*/*.cc) + +AUX_SOURCE_DIRECTORY(${GE_CODE_DIR}/metadef/ops GRAPH_OPS_SRC_FILES) +AUX_SOURCE_DIRECTORY(${GE_CODE_DIR}/metadef/register GRAPH_REGISTER_SRC_FILES) +AUX_SOURCE_DIRECTORY(${GE_CODE_DIR}/metadef/third_party/transformer/src TRANSFORMER_SRC_FILES) + set(PARSER_SRC_FILES "${GE_CODE_DIR}/parser/parser/common/op_map.cc" @@ -131,6 +89,7 @@ set(PARSER_SRC_FILES "${GE_CODE_DIR}/parser/parser/common/model_saver.cc" "${GE_CODE_DIR}/parser/parser/common/parser_types.cc" "${GE_CODE_DIR}/parser/parser/common/parser_inner_ctx.cc" + "${GE_CODE_DIR}/parser/parser/tensorflow/iterator_fusion_pass.cc" ) set(COMMON_SRC_FILES @@ -145,26 +104,19 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/opskernel_manager/ops_kernel_manager.cc" "${GE_CODE_DIR}/ge/generator/ge_generator.cc" "${GE_CODE_DIR}/ge/generator/generator_api.cc" - "${GE_CODE_DIR}/ge/graph/common/omg_util.cc" - "${GE_CODE_DIR}/ge/graph/common/bcast.cc" + "${GE_CODE_DIR}/ge/common/omg_util.cc" + "${GE_CODE_DIR}/ge/common/bcast.cc" "${GE_CODE_DIR}/ge/common/util.cc" "${GE_CODE_DIR}/ge/common/ge/op_tiling_manager.cc" "${GE_CODE_DIR}/ge/init/gelib.cc" "${GE_CODE_DIR}/ge/engine_manager/dnnengine_manager.cc" "${GE_CODE_DIR}/ge/opskernel_manager/ops_kernel_manager.cc" - "${GE_CODE_DIR}/ge/session/session_manager.cc" "${GE_CODE_DIR}/ge/opskernel_manager/ops_kernel_builder_manager.cc" - "${GE_CODE_DIR}/ge/graph/load/model_manager/model_manager.cc" "${GE_CODE_DIR}/ge/common/profiling/profiling_manager.cc" + "${GE_CODE_DIR}/ge/common/profiling/ge_profiling.cc" "${GE_CODE_DIR}/ge/graph/manager/host_mem_manager.cc" - "${GE_CODE_DIR}/ge/graph/manager/memory_api.cc" - "${GE_CODE_DIR}/ge/session/inner_session.cc" + "${GE_CODE_DIR}/ge/graph/manager/memory_api.cc" "${GE_CODE_DIR}/ge/graph/manager/util/rt_context_util.cc" - "${GE_CODE_DIR}/ge/graph/execute/graph_execute.cc" - "${GE_CODE_DIR}/ge/graph/preprocess/graph_preprocess.cc" - "${GE_CODE_DIR}/ge/hybrid/hybrid_davinci_model_stub.cc" - "${GE_CODE_DIR}/ge/graph/load/model_manager/davinci_model.cc" - "${GE_CODE_DIR}/ge/graph/load/model_manager/data_inputer.cc" "${GE_CODE_DIR}/ge/common/dump/dump_properties.cc" "${GE_CODE_DIR}/ge/common/helper/model_helper.cc" "${GE_CODE_DIR}/ge/common/dump/dump_manager.cc" @@ -172,133 +124,22 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/common/dump/opdebug_register.cc" "${GE_CODE_DIR}/ge/common/dump/dump_op.cc" "${GE_CODE_DIR}/ge/common/helper/om_file_helper.cc" - "${GE_CODE_DIR}/ge/model/ge_root_model.cc" + "${GE_CODE_DIR}/ge/common/model/ge_root_model.cc" "${GE_CODE_DIR}/ge/common/model_parser/model_parser.cc" - "${GE_CODE_DIR}/ge/graph/load/model_manager/data_dumper.cc" - "${GE_CODE_DIR}/ge/graph/manager/graph_manager.cc" "${GE_CODE_DIR}/ge/common/dump/dump_server.cc" - "${GE_CODE_DIR}/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc" "${GE_CODE_DIR}/ge/graph/preprocess/multi_batch_copy_graph.cc" "${GE_CODE_DIR}/ge/graph/optimize/mem_rw_conflict_optimize.cc" - "${GE_CODE_DIR}/ge/graph/passes/pass_manager.cc" - "${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/pass_utils.cc" - "${GE_CODE_DIR}/ge/graph/passes/base_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/bitcast_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/constant_folding_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/aicpu_constant_folding_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/reshape_recovery_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/same_transdata_breadth_fusion_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/variable_prepare_op_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/variable_ref_delete_op_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/subgraph_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/data_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/replace_transshape_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/constant_fuse_same_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/print_op_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/iterator_op_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/input_output_connection_identify_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/atomic_addr_clean_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/mark_same_addr_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/mark_graph_unknown_status_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/mark_agnostic_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/dimension_compute_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/get_original_format_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/end_of_sequence_add_control_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/stop_gradient_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/prevent_gradient_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/identity_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/ref_identity_delete_op_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/placeholder_with_default_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/snapshot_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/guarantee_const_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/var_is_initialized_op_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/parallel_concat_start_op_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/folding_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/cast_translate_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/prune_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/merge_to_stream_merge_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/merge_input_memcpy_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/switch_to_stream_switch_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/mark_force_unknown_for_cond_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/attach_stream_label_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/multi_batch_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/multi_batch_clone_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/subexpression_migration_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/subgraph_const_migration_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/unused_args_clean_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/next_iteration_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/control_trigger_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/cond_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/cond_remove_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/for_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/enter_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/assign_remove_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/inplace_support_check_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/addn_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/common_subexpression_elimination_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/transop_symmetry_elimination_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/switch_dead_branch_elimination.cc" - "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" - "${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/cast_remove_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/hccl_memcpy_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/hccl_continuous_memcpy_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/global_step_insert_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/link_gen_mask_nodes_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/replace_with_empty_const_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/hccl_tailing_optimization_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/buffer_pool_memory_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/mark_node_unknown_shape_pass.cc" - "${GE_CODE_DIR}/ge/model/ge_model.cc" + "${GE_CODE_DIR}/ge/common/model/ge_model.cc" "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" - "${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc" - "${GE_CODE_DIR}/ge/graph/load/model_manager/zero_copy_offset.cc" - "${GE_CODE_DIR}/ge/graph/load/model_manager/zero_copy_task.cc" - "${GE_CODE_DIR}/ge/graph/load/model_manager/cpu_queue_schedule.cc" - "${GE_CODE_DIR}/ge/graph/load/model_manager/aipp_utils.cc" - "${GE_CODE_DIR}/ge/graph/load/model_manager/tbe_handle_store.cc" "${GE_CODE_DIR}/ge/common/kernel_store.cc" "${GE_CODE_DIR}/ge/common/tbe_kernel_store.cc" "${GE_CODE_DIR}/ge/common/auth/file_saver.cc" "${GE_CODE_DIR}/ge/graph/manager/util/debug.cc" "${GE_CODE_DIR}/ge/common/debug/memory_dumper.cc" - "${GE_CODE_DIR}/ge/graph/manager/graph_context.cc" "${GE_CODE_DIR}/ge/graph/load/graph_loader.cc" "${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" "${GE_CODE_DIR}/ge/graph/build/graph_builder.cc" "${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" - "${GE_CODE_DIR}/ge/common/helper/model_cache_helper.cc" "${GE_CODE_DIR}/ge/ir_build/ge_ir_build.cc" "${GE_CODE_DIR}/ge/ir_build/attr_options/utils.cc" "${GE_CODE_DIR}/ge/ir_build/attr_options/keep_dtype_option.cc" @@ -308,13 +149,10 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/partition/dynamic_shape_partition.cc" "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" "${GE_CODE_DIR}/ge/ir_build/option_utils.cc" - "${GE_CODE_DIR}/ge/graph/preprocess/insert_op/ge_aipp_op.cc" - "${GE_CODE_DIR}/ge/graph/preprocess/multi_batch_options.cc" "${GE_CODE_DIR}/ge/graph/build/model_builder.cc" "${GE_CODE_DIR}/ge/graph/build/run_context.cc" "${GE_CODE_DIR}/ge/graph/build/stream_graph_optimizer.cc" "${GE_CODE_DIR}/ge/graph/build/task_generator.cc" - "${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" "${GE_CODE_DIR}/ge/graph/partition/engine_place.cc" "${GE_CODE_DIR}/ge/graph/build/stream_allocator.cc" "${GE_CODE_DIR}/ge/graph/build/memory/memory_assigner.cc" @@ -330,10 +168,10 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/manager/graph_var_manager.cc" "${GE_CODE_DIR}/ge/analyzer/analyzer.cc" "${GE_CODE_DIR}/ge/common/thread_pool.cc" - "${GE_CODE_DIR}/ge/graph/common/transop_util.cc" + "${GE_CODE_DIR}/ge/common/transop_util.cc" "${GE_CODE_DIR}/ge/graph/manager/graph_manager_utils.cc" "${GE_CODE_DIR}/ge/graph/manager/trans_var_data_utils.cc" - "${GE_CODE_DIR}/ge/graph/common/local_context.cc" + "${GE_CODE_DIR}/ge/common/local_context.cc" "${GE_CODE_DIR}/ge/graph/manager/graph_caching_allocator.cc" "${GE_CODE_DIR}/ge/graph/manager/session_scope_mem_allocator.cc" "${GE_CODE_DIR}/ge/graph/manager/rdma_pool_allocator.cc" @@ -341,10 +179,11 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/manager/graph_mem_manager.cc" "${GE_CODE_DIR}/ge/common/dump/dump_op.cc" "${GE_CODE_DIR}/ge/common/model_saver.cc" - "${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc" "${GE_CODE_DIR}/ge/common/ge/datatype_util.cc" "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" "${GE_CODE_DIR}/ge/session/omg.cc" + "${GE_CODE_DIR}/ge/common/thread_pool.cc" + "${GE_CODE_DIR}/ge/ge_opt_info/ge_opt_info.cc" ) set(COMMON_FORMAT_SRC_FILES @@ -367,57 +206,25 @@ set(COMMON_FORMAT_SRC_FILES "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc" "${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc" "${GE_CODE_DIR}/ge/graph/manager/util/hcom_util.cc" - "${GE_CODE_DIR}/ge/common/dump/dump_manager.cc" ) -set(GRAPH_OPTIMIZE_COMMON_SRC_FILES - "${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" - "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" - "${GE_CODE_DIR}/ge/graph/optimize/mem_rw_conflict_optimize.cc" -) - - set(GRAPH_PREPARE_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/preprocess/graph_preprocess.cc" "${GE_CODE_DIR}/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc" "${GE_CODE_DIR}/ge/graph/preprocess/insert_op/ge_aipp_op.cc" - #"${GE_CODE_DIR}/ge/graph/preprocess/insert_op/base_insert_op.cc" -) - -set(GRAPH_PARTITION_COMMON_SRC_FILES - "${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" - "${GE_CODE_DIR}/ge/plugin/engine/dnnengines.cc" - "${GE_CODE_DIR}/ge/graph/partition/engine_place.cc" -) - -set(GRAPH_LOAD_COMMON_SRC_FILES - "${GE_CODE_DIR}/ge/graph/load/graph_loader.cc" - "${GE_CODE_DIR}/ge/graph/manager/graph_manager_utils.cc" - "${GE_CODE_DIR}/ge/graph/manager/graph_mem_allocator.cc" - "${GE_CODE_DIR}/ge/graph/manager/graph_var_manager.cc" - "${GE_CODE_DIR}/ge/graph/manager/trans_var_data_utils.cc" - "${GE_CODE_DIR}/ge/graph/manager/graph_caching_allocator.cc" - "${GE_CODE_DIR}/ge/graph/manager/session_scope_mem_allocator.cc" - "${GE_CODE_DIR}/ge/graph/manager/rdma_pool_allocator.cc" - "${GE_CODE_DIR}/ge/graph/manager/host_mem_allocator.cc" - "${GE_CODE_DIR}/ge/graph/manager/graph_mem_manager.cc" - "${GE_CODE_DIR}/ge/common/thread_pool.cc" + "${GE_CODE_DIR}/ge/graph/preprocess/multi_batch_options.cc" ) -set(DISTINCT_GRAPH_LOAD_SRC_FILES - "${GE_CODE_DIR}/ge/graph/manager/util/hcom_util.cc" - "${GE_CODE_DIR}/ge/graph/manager/util/debug.cc" - "${GE_CODE_DIR}/ge/common/properties_manager.cc" - "${GE_CODE_DIR}/ge/common/profiling/profiling_manager.cc" - "${GE_CODE_DIR}/ge/common/model_parser/model_parser.cc" - "${GE_CODE_DIR}/ge/common/tbe_kernel_store.cc" - "${GE_CODE_DIR}/ge/common/util.cc" +set(GRAPH_DAVINCI_MODEL_SRC_FILES + "${GE_CODE_DIR}/ge/graph/load/model_manager/aipp_utils.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/cpu_queue_schedule.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/data_dumper.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/data_inputer.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/davinci_model.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/model_manager.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc" + "${GE_CODE_DIR}/ge/graph/load/model_manager/zero_copy_offset.cc" + "${GE_CODE_DIR}/ge/graph/load/model_manager/zero_copy_task.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/tbe_handle_store.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/task_info.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/event_record_task_info.cc" @@ -436,45 +243,26 @@ set(DISTINCT_GRAPH_LOAD_SRC_FILES "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/stream_active_task_info.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/end_graph_task_info.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/model_exit_task_info.cc" + "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/ffts_task_info.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc" - "${GE_CODE_DIR}/ge/model/ge_model.cc" - "${GE_CODE_DIR}/ge/common/helper/om_file_helper.cc" - "${GE_CODE_DIR}/ge/common/debug/memory_dumper.cc" - "${GE_CODE_DIR}/ge/executor/ge_executor.cc" - "${GE_CODE_DIR}/ge/common/auth/file_saver.cc" - "${GE_CODE_DIR}/ge/graph/manager/model_manager/event_manager.cc" + "${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc" ) set(GRAPH_EXECUTE_COMMON_SRC_FILES - "${GE_CODE_DIR}/ge/graph/execute/graph_execute.cc" - "${GE_CODE_DIR}/ge/graph/manager/graph_manager.cc" - "${GE_CODE_DIR}/ge/graph/manager/graph_context.cc" - "${GE_CODE_DIR}/ge/graph/manager/util/rt_context_util.cc" - "${GE_CODE_DIR}/ge/graph/manager/graph_context.h" + "${GE_CODE_DIR}/ge/hybrid/hybrid_davinci_model_stub.cc" ) set(GRAPH_BUILD_COMMON_SRC_FILES - "${GE_CODE_DIR}/ge/graph/build/graph_builder.cc" - "${GE_CODE_DIR}/ge/graph/build/task_generator.cc" + "${GE_CODE_DIR}/ge/graph/manager/graph_manager.cc" "${GE_CODE_DIR}/ge/client/ge_api.cc" "${GE_CODE_DIR}/ge/session/inner_session.cc" "${GE_CODE_DIR}/ge/session/session_manager.cc" - "${GE_CODE_DIR}/ge/engine_manager/dnnengine_manager.cc" + "${GE_CODE_DIR}/ge/graph/execute/model_executor.cc" + "${GE_CODE_DIR}/ge/graph/execute/graph_execute.cc" + "${GE_CODE_DIR}/ge/plugin/engine/dnnengines.cc" "${GE_CODE_DIR}/ge/plugin/engine/engine_manage.cc" - "${GE_CODE_DIR}/ge/graph/build/logical_stream_allocator.cc" - "${GE_CODE_DIR}/ge/graph/build/stream_allocator.cc" - "${GE_CODE_DIR}/ge/graph/build/memory/block_mem_assigner.cc" - "${GE_CODE_DIR}/ge/graph/build/memory/binary_block_mem_assigner.cc" - "${GE_CODE_DIR}/ge/graph/build/memory/hybrid_mem_assigner.cc" - "${GE_CODE_DIR}/ge/graph/build/memory/max_block_mem_assigner.cc" - "${GE_CODE_DIR}/ge/model/ge_model.cc" - "${GE_CODE_DIR}/ge/common/helper/om_file_helper.cc" - "${GE_CODE_DIR}/ge/common/tbe_kernel_store.cc" - "${GE_CODE_DIR}/ge/common/thread_pool.cc" - "${GE_CODE_DIR}/ge/common/model_parser/model_parser.cc" - "${GE_CODE_DIR}/ge/graph/build/run_context.cc" - "${GE_CODE_DIR}/ge/graph/common/local_context.cc" + "${GE_CODE_DIR}/ge/graph/manager/graph_context.cc" ) set(GRAPH_PASS_COMMON_SRC_FILES @@ -484,7 +272,6 @@ set(GRAPH_PASS_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/variable_ref_delete_op_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/atomic_addr_clean_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/constant_folding_pass.cc" - "${GE_CODE_DIR}/parser/parser/tensorflow/iterator_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/iterator_op_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/print_op_pass.cc" @@ -523,20 +310,116 @@ set(GRAPH_PASS_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/same_transdata_breadth_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" - "${GE_CODE_DIR}/ge/graph/common/transop_util.cc" "${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" - #"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/folding_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/hccl_memcpy_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" - "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" - "${GE_CODE_DIR}/ge/analyzer/analyzer.cc" + "${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/pass_utils.cc" + "${GE_CODE_DIR}/ge/graph/passes/base_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/bitcast_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/constant_folding_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/aicpu_constant_folding_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/reshape_recovery_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/same_transdata_breadth_fusion_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/variable_prepare_op_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/variable_ref_delete_op_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/subgraph_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/data_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" - "${GE_CODE_DIR}/ge/graph/common/local_context.cc" + "${GE_CODE_DIR}/ge/graph/passes/replace_transshape_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/constant_fuse_same_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/print_op_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/iterator_op_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/input_output_connection_identify_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/atomic_addr_clean_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/mark_same_addr_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/mark_graph_unknown_status_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/mark_agnostic_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/dimension_compute_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/get_original_format_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/shape_operate_op_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/assert_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/dropout_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/infer_base_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/infer_value_range_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/unused_const_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/permute_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/ctrl_edge_transfer_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/end_of_sequence_add_control_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/stop_gradient_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/prevent_gradient_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/identity_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/ref_identity_delete_op_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/placeholder_with_default_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/snapshot_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/guarantee_const_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/var_is_initialized_op_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/parallel_concat_start_op_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/folding_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/cast_translate_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/prune_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/merge_to_stream_merge_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/merge_input_memcpy_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/switch_to_stream_switch_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/mark_force_unknown_for_cond_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/attach_stream_label_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/multi_batch_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/multi_batch_clone_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/subexpression_migration_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/subgraph_const_migration_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/unused_args_clean_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/next_iteration_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/control_trigger_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/cond_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/cond_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/for_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/enter_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/assign_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/inplace_support_check_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/addn_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/common_subexpression_elimination_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/transop_symmetry_elimination_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/save_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/switch_dead_branch_elimination.cc" + "${GE_CODE_DIR}/ge/graph/passes/switch_logic_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/switch_data_edges_bypass.cc" + "${GE_CODE_DIR}/ge/graph/passes/merge_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/cast_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/hccl_continuous_memcpy_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/global_step_insert_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/link_gen_mask_nodes_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/replace_with_empty_const_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/hccl_tailing_optimization_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/buffer_pool_memory_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/mark_node_unknown_shape_pass.cc" ) set(KERNEL_SRC_FILES @@ -577,6 +460,7 @@ set(KERNEL_SRC_FILES ) set(SINGLE_OP_SRC_FILES + "${GE_CODE_DIR}/ge/executor/ge_executor.cc" "${GE_CODE_DIR}/ge/single_op/task/build_task_utils.cc" "${GE_CODE_DIR}/ge/single_op/task/op_task.cc" "${GE_CODE_DIR}/ge/single_op/task/tbe_task_builder.cc" @@ -610,7 +494,6 @@ set(SINGLE_OP_SRC_FILES "${GE_CODE_DIR}/ge/hybrid/node_executor/aicore/aicore_op_task.cc" "${GE_CODE_DIR}/ge/hybrid/node_executor/aicore/aicore_task_builder.cc" "${GE_CODE_DIR}/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc" - "${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc" "${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc" "${GE_CODE_DIR}/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc" "${GE_CODE_DIR}/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc" @@ -635,29 +518,26 @@ set(COMMON_TEST_FILES set(DISTINCT_GRAPH_LOAD_TEST_FILES "graph/load/data_dumper_unittest.cc" - #"graph/load/new_model_manager_data_inputer_unittest.cc" - #"graph/load/new_model_manager_davinci_model_unittest.cc" "graph/load/model_manager_unittest.cc" - #"graph/load/new_model_manager_task_build_unittest.cc" "graph/load/new_model_manager_model_manager_aicpu_unittest.cc" "graph/load/end_graph_task_unittest.cc" - "graph/load/new_model_manager_event_manager_unittest.cc" - #"graph/load/output_net_output_unittest.cc" "graph/load/davinci_model_unittest.cc" "graph/load/tbe_handle_store_unittest.cc" "graph/load/hccl_task_info_unittest.cc" "graph/load/kernel_ex_task_info_unittest.cc" "graph/load/kernel_task_info_unittest.cc" + "graph/load/ffts_task_info_unittest.cc" "graph/load/memcpy_addr_async_task_info_unittest.cc" "graph/load/memcpy_async_task_info_unittest.cc" "graph/load/cpu_queue_schedule_unittest.cc" - #"graph/graph_load_unittest.cc" "graph/ge_executor_unittest.cc" "graph/load/model_helper_unittest.cc" "graph/load/model_utils_unittest.cc" ) set(PASS_TEST_FILES + "graph/passes/infer_value_range_pass_unittest.cc" + "graph/passes/infer_base_pass_unittest.cc" "graph/passes/prune_pass_unittest.cc" "graph/passes/enter_pass_unittest.cc" "graph/passes/switch_op_pass_unittest.cc" @@ -716,7 +596,6 @@ set(PASS_TEST_FILES "graph/passes/memcpy_addr_async_unittest.cc" "graph/passes/hccl_continuous_pass_unittest.cc" "graph/passes/hccl_memcpy_pass_unittest.cc" - ) set(KERNEL_TEST_FILES @@ -761,8 +640,11 @@ set(MULTI_PARTS_TEST_FILES "graph_ir/ge_ir_build_unittest.cc" "graph/transop_util_unittest.cc" "common/datatype_transfer_unittest.cc" + "common/util_unittest.cc" + "common/fp16_unittest.cc" "common/dump_manager_unittest.cc" "common/dump_op_unittest.cc" + "common/dump_properties_unittest.cc" "common/dump_exception_unittest.cc" "common/opdebug_register_unittest.cc" "common/format_transfer_unittest.cc" @@ -788,18 +670,21 @@ set(MULTI_PARTS_TEST_FILES "graph/build/stream_allocator_unittest.cc" "graph/build/model_builder_unittest.cc" "graph/build/mem_assigner_unittest.cc" + "graph/build/graph_mem_assigner_unittest.cc" "graph/build/task_generator_unittest.cc" "graph/build/buffer_pool_mem_assigner_unittest.cc" "graph/execute/graph_execute_unittest.cc" + "graph/execute/model_executor_unittest.cc" "graph/preprocess/graph_preprocess_unittest.cc" "graph/manager/hcom_util_unittest.cc" "graph/manager/graph_caching_allocator_unittest.cc" "graph/manager/host_mem_allocator_unittest.cc" - "graph/manager/memory_api_unittest.cc" + "graph/manager/memory_api_unittest.cc" "graph/manager/session_scope_mem_allocator_unittest.cc" "graph/manager/run_graph_unittest.cc" "graph/partition/dynamic_shape_partition_unittest.cc" "graph/manager/graph_manager_unittest.cc" + "graph/manager/graph_var_manager_unittest.cc" "graph/optimize/mem_rw_conflict_optimize_unittest.cc" "graph/optimize/graph_optimize_unittest.cc" "session/omg_omg_unittest.cc" @@ -807,6 +692,11 @@ set(MULTI_PARTS_TEST_FILES "session/inner_session_unittest.cc" "session/session_manager_unittest.cc" "common/host_cpu_engine_unittest.cc" + "common/tbe_plugin_manager_unittest.cc" +) + +set(GE_OPT_INFO_TEST_FILES + "ge_opt_info/ge_opt_info_unittest.cc" ) set(GENERATOR_TEST_FILES @@ -836,10 +726,12 @@ set(HYBRID_TEST_FILES "hybrid/executor/subgraph_executor_unittest.cc" "hybrid/executor/worker/execution_engine_unittest.cc" "hybrid/model/hybrid_model_builder_unittest.cc" + "hybrid/node_executor/node_executor_unittest.cc" "hybrid/node_executor/rts/rts_node_task_unittest.cc" "hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc" "hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" "hybrid/node_executor/hccl/hccl_node_executor_unittest.cc" + "hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc" "hybrid/executor/hybrid_model_async_executor_unittest.cc" "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" @@ -852,23 +744,30 @@ set(OTHERS_TEST_FILES list(APPEND COMMON_SHARED_LIBRARIES c_sec slog_stub - cce_ge_stub runtime_stub profiler_stub mmpa_stub hccl_stub error_manager_stub + opt_feature_stub ascend_protobuf json ) # build graph add_library(ge_ut_graph STATIC - ${GRAPH_SRC_FILES} ${PARSER_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS} + ${GRAPH_SRC_FILES_DEPTH0} + ${GRAPH_SRC_FILES_DEPTH1} + ${GRAPH_SRC_FILES_DEPTH2} + ${GRAPH_OPS_SRC_FILES} + ${GRAPH_REGISTER_SRC_FILES} + ${TRANSFORMER_SRC_FILES} + ${PARSER_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS} ) target_compile_definitions(ge_ut_graph PRIVATE google=ascend_private + FMK_SUPPORT_DUMP ) target_compile_options(ge_ut_graph PRIVATE @@ -943,57 +842,19 @@ target_link_libraries(ge_prepare_common PRIVATE json ) -# build graph optimize common -add_library(ge_optimize_common STATIC ${GRAPH_OPTIMIZE_COMMON_SRC_FILES} ${PROTO_HDRS}) - -target_compile_definitions(ge_optimize_common PRIVATE - google=ascend_private -) - -target_compile_options(ge_optimize_common PRIVATE - -g --coverage -fprofile-arcs -ftest-coverage - -Werror=format -) - -target_link_libraries(ge_optimize_common PRIVATE - $ - ascend_protobuf - c_sec - json -) - -# build graph partition common -add_library(ge_partition_common STATIC ${GRAPH_PARTITION_COMMON_SRC_FILES} ${PROTO_HDRS}) - -target_compile_definitions(ge_partition_common PRIVATE - google=ascend_private -) - -target_compile_options(ge_partition_common PRIVATE - -g --coverage -fprofile-arcs -ftest-coverage - -Werror=format -) - -target_link_libraries(ge_partition_common PRIVATE - $ - ascend_protobuf - c_sec - json -) - # build build graph load common -add_library(ge_load_common STATIC ${GRAPH_LOAD_COMMON_SRC_FILES} ${PROTO_HDRS}) +add_library(ge_davinci_model STATIC ${GRAPH_DAVINCI_MODEL_SRC_FILES} ${PROTO_HDRS}) -target_compile_definitions(ge_load_common PRIVATE +target_compile_definitions(ge_davinci_model PRIVATE google=ascend_private ) -target_compile_options(ge_load_common PRIVATE +target_compile_options(ge_davinci_model PRIVATE -g --coverage -fprofile-arcs -ftest-coverage -Werror=format ) -target_link_libraries(ge_load_common PRIVATE +target_link_libraries(ge_davinci_model PRIVATE $ c_sec ascend_protobuf @@ -1075,6 +936,7 @@ target_link_libraries(ge_single_op PRIVATE ascend_protobuf json c_sec + runtime_stub ) # ut binary @@ -1096,8 +958,9 @@ target_compile_definitions(ut_libge_multiparts_utest PRIVATE target_link_libraries(ut_libge_multiparts_utest $ - ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common - ge_single_op ge_ut_common_format ge_ut_common + -Wl,--whole-archive + ge_davinci_model ge_build_common ge_prepare_common ge_execute_common ge_pass_common ge_ut_common_format ge_ut_common + -Wl,--no-whole-archive gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov ) @@ -1107,6 +970,7 @@ add_executable(ut_libge_others_utest ${PASS_TEST_FILES} ${EXECUTE_TEST_FILES} ${OTHERS_TEST_FILES} + ${GE_OPT_INFO_TEST_FILES} ) target_compile_options(ut_libge_others_utest PRIVATE @@ -1116,7 +980,9 @@ target_compile_options(ut_libge_others_utest PRIVATE target_link_libraries(ut_libge_others_utest $ - ge_load_common ge_execute_common ge_ut_common ge_ut_common_format + -Wl,--whole-archive + ge_davinci_model ge_build_common ge_prepare_common ge_pass_common ge_execute_common ge_ut_common ge_ut_common_format + -Wl,--no-whole-archive gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov ) @@ -1134,7 +1000,9 @@ target_compile_options(ut_libge_kernel_utest PRIVATE target_link_libraries(ut_libge_kernel_utest $ - ge_load_common ge_ut_common ge_ut_common_format + -Wl,--whole-archive + ge_davinci_model ge_build_common ge_prepare_common ge_pass_common ge_execute_common ge_ut_common ge_ut_common_format + -Wl,--no-whole-archive gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lgcov ) @@ -1144,7 +1012,6 @@ add_executable(ut_libge_distinct_load_utest ${GENERATOR_TEST_FILES} ${EXECUTOR_TEST_FILES} ${DISTINCT_GRAPH_LOAD_TEST_FILES} - ${DISTINCT_GRAPH_LOAD_SRC_FILES} ${SINGLE_OP_TEST_FILES} ${PROFILING_MNG_TEST_FILES} ${HYBRID_TEST_FILES} @@ -1163,9 +1030,7 @@ target_compile_definitions(ut_libge_distinct_load_utest PRIVATE target_link_libraries(ut_libge_distinct_load_utest $ -Wl,--whole-archive - ge_single_op + ge_single_op ge_davinci_model ge_build_common ge_prepare_common ge_pass_common ge_ut_common ge_ut_common_format -Wl,--no-whole-archive - ge_execute_common ge_load_common - ge_prepare_common ge_optimize_common ge_build_common ge_partition_common ge_ut_common ge_ut_common_format gtest gtest_main gmock gmock_main ${COMMON_SHARED_LIBRARIES} -lrt -ldl -lpthread -lgcov ) diff --git a/tests/ut/ge/common/datatype_transfer_unittest.cc b/tests/ut/ge/common/datatype_transfer_unittest.cc index c311a7cf..ea131b2c 100644 --- a/tests/ut/ge/common/datatype_transfer_unittest.cc +++ b/tests/ut/ge/common/datatype_transfer_unittest.cc @@ -47,7 +47,7 @@ TEST_F(UtestDataTypeTransfer, fp16_fp32) { EXPECT_EQ(transfer.TransDataType(args, result), SUCCESS); EXPECT_EQ(result.length, sizeof(ret)); bool is_equal = true; - for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { + for (size_t i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { if (abs((reinterpret_cast(result.data.get()))[i] - ret[i]) > 1.0e-6) { is_equal = false; break; @@ -60,7 +60,7 @@ TEST_F(UtestDataTypeTransfer, fp16_fp32) { CastArgs args2{reinterpret_cast(ret), sizeof(ret) / sizeof(ret[0]), DT_FLOAT, DT_FLOAT16}; EXPECT_EQ(transfer2.TransDataType(args2, result2), SUCCESS); EXPECT_EQ(result2.length, sizeof(data)); - for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { + for (size_t i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { EXPECT_FLOAT_EQ((reinterpret_cast(result2.data.get()))[i].val, data[i].val); } EXPECT_EQ(TransDataType(args2, result2), SUCCESS); @@ -81,7 +81,7 @@ TEST_F(UtestDataTypeTransfer, int32_fp16) { CastArgs args{reinterpret_cast(data), sizeof(ret) / sizeof(ret[0]), DT_INT32, DT_FLOAT16}; EXPECT_EQ(transfer.TransDataType(args, result), SUCCESS); EXPECT_EQ(result.length, sizeof(ret)); - for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { + for (size_t i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { EXPECT_FLOAT_EQ((reinterpret_cast(result.data.get()))[i].val, ret[i].val); } @@ -91,7 +91,7 @@ TEST_F(UtestDataTypeTransfer, int32_fp16) { EXPECT_EQ(transfer2.TransDataType(args2, result2), SUCCESS); EXPECT_EQ(result2.length, sizeof(data)); bool is_equal = true; - for (int i = 0; i < sizeof(data) / sizeof(data[0]); ++i) { + for (size_t i = 0; i < sizeof(data) / sizeof(data[0]); ++i) { if (abs((reinterpret_cast(result2.data.get()))[i] - data[i]) / abs(data[i]) > 0.05) { is_equal = false; break; @@ -154,7 +154,7 @@ TEST_F(UtestDataTypeTransfer, fp32_fp16) { EXPECT_EQ(transfer.TransDataType(args, result), SUCCESS); EXPECT_EQ(result.length, sizeof(ret)); bool is_equal = true; - for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { + for (size_t i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { if (abs((reinterpret_cast(result.data.get()))[i] - ret[i]) > 1.0e-6) { is_equal = false; break; @@ -167,7 +167,7 @@ TEST_F(UtestDataTypeTransfer, fp32_fp16) { CastArgs args2{reinterpret_cast(ret), sizeof(data) / sizeof(data[0]), DT_FLOAT, DT_FLOAT16}; EXPECT_EQ(transfer2.TransDataType(args2, result2), SUCCESS); EXPECT_EQ(result2.length, sizeof(data)); - for (int i = 0; i < sizeof(data) / sizeof(data[0]); ++i) { + for (size_t i = 0; i < sizeof(data) / sizeof(data[0]); ++i) { EXPECT_FLOAT_EQ((reinterpret_cast(result2.data.get()))[i].val, data[i].val); } } @@ -238,7 +238,7 @@ TEST_F(UtestDataTypeTransfer, uint8_fp32) { DataTypeTransfer transfer; EXPECT_EQ(transfer.TransDataType(args, result), SUCCESS); EXPECT_EQ(result.length, sizeof(ret)); - for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { + for (size_t i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { EXPECT_EQ((reinterpret_cast(result.data.get()))[i], ret[i]); } } @@ -259,7 +259,7 @@ TEST_F(UtestDataTypeTransfer, uint8_int32) { DataTypeTransfer transfer; EXPECT_EQ(transfer.TransDataType(args, result), SUCCESS); EXPECT_EQ(result.length, sizeof(ret)); - for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { + for (size_t i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { EXPECT_EQ((reinterpret_cast(result.data.get()))[i], ret[i]); } } @@ -282,7 +282,7 @@ TEST_F(UtestDataTypeTransfer, fp32_int32) { DataTypeTransfer transfer; EXPECT_EQ(transfer.TransDataType(args, result), SUCCESS); EXPECT_EQ(result.length, sizeof(ret)); - for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { + for (size_t i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { EXPECT_FLOAT_EQ((reinterpret_cast(result.data.get()))[i], ret[i]); } } @@ -304,7 +304,7 @@ TEST_F(UtestDataTypeTransfer, int32_fp32) { DataTypeTransfer transfer; EXPECT_EQ(transfer.TransDataType(args, result), SUCCESS); EXPECT_EQ(result.length, sizeof(ret)); - for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { + for (size_t i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { EXPECT_FLOAT_EQ((reinterpret_cast(result.data.get()))[i], ret[i]); } } @@ -329,7 +329,7 @@ TEST_F(UtestDataTypeTransfer, int32_uint8) { DataTypeTransfer transfer; EXPECT_EQ(transfer.TransDataType(args, result), SUCCESS); EXPECT_EQ(result.length, sizeof(ret)); - for (int i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { + for (size_t i = 0; i < sizeof(ret) / sizeof(ret[0]); ++i) { EXPECT_FLOAT_EQ((reinterpret_cast(result.data.get()))[i], ret[i]); } } diff --git a/tests/ut/ge/common/dump_manager_unittest.cc b/tests/ut/ge/common/dump_manager_unittest.cc index 50eabc4a..7a242997 100644 --- a/tests/ut/ge/common/dump_manager_unittest.cc +++ b/tests/ut/ge/common/dump_manager_unittest.cc @@ -16,6 +16,8 @@ #include +#define protected public +#define private public #include "common/dump/dump_manager.h" #include "common/debug/log.h" #include "common/ge_inner_error_codes.h" @@ -102,4 +104,13 @@ TEST_F(UTEST_dump_manager, is_dump_single_op_close_success) { auto dump = DumpManager::GetInstance().GetDumpProperties(0); DumpManager::GetInstance().RemoveDumpProperties(0); } + + TEST_F(UTEST_dump_manager, not_need_do_dump) { + DumpConfig dump_config; + dump_config.dump_status = "off"; + dump_config.dump_debug = "off"; + DumpProperties dump_properties; + bool ret = DumpManager::GetInstance().NeedDoDump(dump_config, dump_properties); + EXPECT_EQ(ret, false); + } } // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/common/dump_properties_unittest.cc b/tests/ut/ge/common/dump_properties_unittest.cc new file mode 100644 index 00000000..3623bc6d --- /dev/null +++ b/tests/ut/ge/common/dump_properties_unittest.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define protected public +#define private public + +#include "common/dump/dump_properties.h" +#include "ge_local_context.h" +#include "ge/ge_api_types.h" +#include "common/debug/log.h" +#include "common/ge_inner_error_codes.h" + +namespace ge { +class UTEST_dump_properties : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UTEST_dump_properties, check_dump_step) { + DumpProperties dp; + std::string dump_step{"0|3-5|10"}; + std::string unsupport_input1{"0|5-3|10"}; + std::string unsupport_input2{"one"}; + std::string unsupport_input3; + for (int i = 0; i < 200; ++i) { + unsupport_input3 += std::to_string(i) + "|"; + } + unsupport_input3.pop_back(); + Status st = dp.CheckDumpStep(dump_step); + EXPECT_EQ(st, SUCCESS); + st = dp.CheckDumpStep(unsupport_input1); + EXPECT_NE(st, SUCCESS); + st = dp.CheckDumpStep(unsupport_input2); + EXPECT_NE(st, SUCCESS); + st = dp.CheckDumpStep(unsupport_input3); + EXPECT_NE(st, SUCCESS); +} + +TEST_F(UTEST_dump_properties, check_dump_mode) { + DumpProperties dp; + std::string dump_mode_1{"input"}; + std::string dump_mode_2{"output"}; + std::string dump_mode_3{"all"}; + std::string unsupport_input1{"mode1"}; + Status st = dp.CheckDumpMode(dump_mode_1); + EXPECT_EQ(st, SUCCESS); + st = dp.CheckDumpMode(dump_mode_2); + EXPECT_EQ(st, SUCCESS); + st = dp.CheckDumpMode(dump_mode_3); + EXPECT_EQ(st, SUCCESS); + st = dp.CheckDumpMode(unsupport_input1); + EXPECT_NE(st, SUCCESS); +} + +TEST_F(UTEST_dump_properties, check_dump_path) { + DumpProperties dp; + std::string dump_path{"/tmp/"}; + std::string unsupport_input1{" \\unsupported"}; + Status st = dp.CheckDumpPath(dump_path); + EXPECT_EQ(st, SUCCESS); + st = dp.CheckDumpPath(unsupport_input1); + EXPECT_NE(st, SUCCESS); +} + +TEST_F(UTEST_dump_properties, check_enable_dump) { + DumpProperties dp; + std::string enable_dump_t{"1"}; + std::string enable_dump_f{"0"}; + std::string unsupport_input1{"true"}; + std::string unsupport_input2{"false"}; + Status st = dp.CheckEnableDump(enable_dump_t); + EXPECT_EQ(st, SUCCESS); + st = dp.CheckEnableDump(enable_dump_f); + EXPECT_EQ(st, SUCCESS); + st = dp.CheckEnableDump(unsupport_input1); + EXPECT_NE(st, SUCCESS); + st = dp.CheckEnableDump(unsupport_input2); + EXPECT_NE(st, SUCCESS); +} + +TEST_F(UTEST_dump_properties, init_by_options_success_1) { + DumpProperties dp; + std::map options {{OPTION_EXEC_ENABLE_DUMP, "1"}, + {OPTION_EXEC_DUMP_PATH, "/tmp/"}, + {OPTION_EXEC_DUMP_STEP, "0|1-3|10"}, + {OPTION_EXEC_DUMP_MODE, "all"}}; + GetThreadLocalContext().SetGlobalOption(options); + Status st = dp.InitByOptions(); + EXPECT_EQ(st, SUCCESS); +} + +TEST_F(UTEST_dump_properties, init_by_options_success_2) { + DumpProperties dp; + std::map options {{OPTION_EXEC_ENABLE_DUMP_DEBUG, "1"}, + {OPTION_EXEC_DUMP_PATH, "/tmp/"}, + {OPTION_EXEC_DUMP_DEBUG_MODE, "aicore_overflow"}}; + GetThreadLocalContext().SetGlobalOption(options); + Status st = dp.InitByOptions(); + EXPECT_EQ(st, SUCCESS); +} + +TEST_F(UTEST_dump_properties, init_by_options_success_3) { + DumpProperties dp; + std::map options {{OPTION_EXEC_ENABLE_DUMP_DEBUG, "1"}, + {OPTION_EXEC_DUMP_PATH, "/tmp/"}}; + GetThreadLocalContext().SetGlobalOption(options); + Status st = dp.InitByOptions(); + EXPECT_EQ(st, SUCCESS); +} +} // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/common/fp16_unittest.cc b/tests/ut/ge/common/fp16_unittest.cc new file mode 100644 index 00000000..a9590fe2 --- /dev/null +++ b/tests/ut/ge/common/fp16_unittest.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "common/fp16_t.h" + +namespace ge { +namespace formats { +class UtestFP16 : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UtestFP16, fp16_to_other) { + fp16_t test; + float num = test.ToFloat(); + EXPECT_EQ(num, 0.0); + + double num2 = test.ToDouble(); + EXPECT_EQ(num2, 0); + + int16_t num3 = test.ToInt16(); + EXPECT_EQ(num3, 0); + + int32_t num4 = test.ToInt32(); + EXPECT_EQ(num4, 0); + + int8_t num5 = test.ToInt8(); + EXPECT_EQ(num5, 0); + + uint16_t num6 = test.ToUInt16(); + EXPECT_EQ(num6, 0); + + uint32_t num7 = test.ToUInt16(); + EXPECT_EQ(num7, 0); + + uint8_t num8 = test.ToUInt8(); + EXPECT_EQ(num8, 0); +} +} // namespace formats +} // namespace ge diff --git a/tests/framework/stub_engine/ops_kernel_store/op/host_op.h b/tests/ut/ge/common/tbe_plugin_manager_unittest.cc similarity index 51% rename from tests/framework/stub_engine/ops_kernel_store/op/host_op.h rename to tests/ut/ge/common/tbe_plugin_manager_unittest.cc index 464df47a..16c1650b 100644 --- a/tests/framework/stub_engine/ops_kernel_store/op/host_op.h +++ b/tests/ut/ge/common/tbe_plugin_manager_unittest.cc @@ -1,36 +1,40 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ -#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ - -#include "stub_engine/ops_kernel_store/op/op.h" - -namespace ge { -namespace st { -class GE_FUNC_VISIBILITY HostOp : public Op { - public: - HostOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} - ~HostOp() override = default; - HostOp &operator=(const HostOp &op) = delete; - HostOp(const HostOp &op) = delete; - - Status Run() override; -}; -} // namespace st -} // namespace ge - -#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define protected public +#define private public +#include "common/ge/tbe_plugin_manager.h" +#undef private +#undef protected + +namespace ge { +class UtestTBEPluginManager: public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UtestTBEPluginManager, CheckFindParserSo) { + string path = ""; + vector file_list = {}; + string caffe_parser_path = ""; + TBEPluginManager::Instance().FindParserSo(path, file_list, caffe_parser_path); + path = "/lib64"; + TBEPluginManager::Instance().FindParserSo(path, file_list, caffe_parser_path); +} +} // namespace ge diff --git a/tests/ut/ge/common/util_unittest.cc b/tests/ut/ge/common/util_unittest.cc new file mode 100644 index 00000000..6df3db96 --- /dev/null +++ b/tests/ut/ge/common/util_unittest.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "common/util.h" + +namespace ge { +namespace formats { +class UtestUtilTransfer : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + + +INT32 mmAccess2(const CHAR *pathName, INT32 mode) +{ + return -1; +} + +TEST_F(UtestUtilTransfer, CheckOutputPathValid) { + EXPECT_EQ(CheckOutputPathValid("", ""), false); + EXPECT_EQ(CheckOutputPathValid("", "model"), false); + + char max_file_path[14097] = {0}; + memset(max_file_path, 1, 14097); + EXPECT_EQ(CheckOutputPathValid(max_file_path, "model"), false); + + EXPECT_EQ(CheckOutputPathValid("$#%", ""), false); + + // system("touch test_util"); + // system("chmod 555 test_util"); + // EXPECT_EQ(CheckOutputPathValid("./test_util", ""), false); + // system("rm -r test_util"); +} + +TEST_F(UtestUtilTransfer, CheckInputPathValid) { + EXPECT_EQ(CheckInputPathValid("", ""), false); + EXPECT_EQ(CheckInputPathValid("", "model"), false); + + EXPECT_EQ(CheckInputPathValid("$#%", ""), false); + + EXPECT_EQ(CheckInputPathValid("./test_util", ""), false); + +} + +} +} + diff --git a/tests/ut/ge/ge_opt_info/ge_opt_info_unittest.cc b/tests/ut/ge/ge_opt_info/ge_opt_info_unittest.cc new file mode 100644 index 00000000..20c123e9 --- /dev/null +++ b/tests/ut/ge/ge_opt_info/ge_opt_info_unittest.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#define protected public +#define private public +#include "ge_opt_info/ge_opt_info.h" +#include "graph/ge_local_context.h" +#include "external/ge/ge_api_types.h" +#undef private +#undef protected + +namespace ge { +class UTEST_opt_info : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UTEST_opt_info, get_opt_info_success) { + std::map options = {{ge::SOC_VERSION, "Ascend910"}}; + GetThreadLocalContext().SetGlobalOption(options); + auto ret = GeOptInfo::SetOptInfo(); + EXPECT_EQ(ret, ge::SUCCESS); + std::map graph_options = GetThreadLocalContext().GetAllGraphOptions(); + auto itr = graph_options.find("opt_module.fe"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.pass"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.op_tune"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); +} + +TEST_F(UTEST_opt_info, get_opt_info_all) { + std::map global_options = {{ge::SOC_VERSION, "Ascend310"}}; + GetThreadLocalContext().SetGlobalOption(global_options); + auto ret = GeOptInfo::SetOptInfo(); + EXPECT_EQ(ret, ge::SUCCESS); + std::map graph_options = GetThreadLocalContext().GetAllGraphOptions(); + auto itr = graph_options.find("opt_module.fe"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.pass"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.op_tune"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.rl_tune"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); + itr = graph_options.find("opt_module.aoe"); + EXPECT_NE(itr, graph_options.end()); + EXPECT_EQ(itr->second, "all"); +} + +TEST_F(UTEST_opt_info, get_opt_info_failed) { + std::map options; + GetThreadLocalContext().SetGlobalOption(options); + auto ret = GeOptInfo::SetOptInfo(); + EXPECT_EQ(ret, ge::FAILED); +} + +} // namespace ge diff --git a/tests/ut/ge/generator/ge_generator_unittest.cc b/tests/ut/ge/generator/ge_generator_unittest.cc index 1bb4430f..b3abb2f9 100644 --- a/tests/ut/ge/generator/ge_generator_unittest.cc +++ b/tests/ut/ge/generator/ge_generator_unittest.cc @@ -83,12 +83,16 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) { graphStatus TestFunc(Operator &op) { return 0; } graphStatus TestFunc1(Operator &op) { return 1; } TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { + ComputeGraphPtr compute_graph = MakeShared("graph_name"); + auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); OperatorFactoryImpl::RegisterInferFormatFunc("Add", TestFunc); shared_ptr op_desc = make_shared("add", "add"); + compute_graph->AddNode(op_desc); GeGenerator generator; - EXPECT_EQ(generator.InferFormatForSingleOp(op_desc), SUCCESS); + EXPECT_EQ(generator.InferFormatForSingleOp(op_desc, graph), SUCCESS); shared_ptr op_desc1 = make_shared("Add", "Add"); - EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1), SUCCESS); + compute_graph->AddNode(op_desc1); + EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1, graph), SUCCESS); OperatorFactoryImpl::RegisterInferFormatFunc("MatMulV2", TestFunc1); shared_ptr op_desc2 = make_shared("MatMulV2", "MatMulV2"); GeTensorDesc tensor_desc; @@ -99,7 +103,8 @@ TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); - EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2), FAILED); + compute_graph->AddNode(op_desc2); + EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2, graph), FAILED); } TEST_F(UtestGeGenerator, test_build_single_op_online) { diff --git a/tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc b/tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc index 96283250..05141785 100644 --- a/tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc +++ b/tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc @@ -29,6 +29,7 @@ #include "graph/build/memory/buffer_pool_mem_assigner.h" #include "graph/build/memory/graph_mem_assigner.h" #include "graph/build/stream_allocator.h" +#include "graph/ge_local_context.h" #undef protected #undef private @@ -260,6 +261,10 @@ TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_serial_graph_assign_success) } TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_subgraph_with_inner_dependency_assign_success) { + std::string build_mode; + std::map options_map; + options_map.insert({ge::OPTION_GRAPH_RUN_MODE, "1"}); + ge::GetThreadLocalContext().SetGraphOption(options_map); ut::BufferPoolGraphBuilder builder("SubgraphWithInnerDependency"); ge::ComputeGraphPtr graph = builder.BuildSubgraphWithInnerDependency(); BufferPoolMemoryPass buffer_pool_mem_pass; diff --git a/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc b/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc new file mode 100644 index 00000000..703ac3b4 --- /dev/null +++ b/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc @@ -0,0 +1,90 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "graph/anchor.h" +#include "graph/attr_value.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "omg/omg_inner_types.h" +#include "../passes/graph_builder_utils.h" + +#define protected public +#define private public +#include "graph/build/memory/binary_block_mem_assigner.h" +#include "graph/build/memory/graph_mem_assigner.h" +#include "graph/build/memory/hybrid_mem_assigner.h" +#include "graph/build/memory/max_block_mem_assigner.h" +#include "graph/manager/graph_var_manager.h" +#include "graph/manager/graph_mem_manager.h" +#undef protected +#undef private + +using namespace std; +using namespace testing; +using namespace ge; +using domi::GetContext; + +class UtestGraphMemAssigner : public testing::Test { + public: + ge::ComputeGraphPtr BuildGraphWithVar(int64_t session_id) { + // init + MemManager::Instance().Initialize(std::vector({RT_MEMORY_HBM})); + VarManager::Instance(session_id)->Init(0, 0, 0, 0); + ge::ut::GraphBuilder builder("graph"); + auto var_input = builder.AddNode("var", "Variable", 1, 1); + auto const_input = builder.AddNode("const", "Const", 1, 1); + auto assign = builder.AddNode("assgin", "Assign", 2, 1); + // add link + builder.AddDataEdge(var_input, 0, assign, 0); + builder.AddDataEdge(const_input, 0, assign, 1); + // set offset + var_input->GetOpDesc()->SetOutputOffset({10000}); + const_input->GetOpDesc()->SetOutputOffset({1000}); + assign->GetOpDesc()->SetInputOffset({10100, 1000}); + assign->GetOpDesc()->SetOutputOffset({10100}); + // set inner offset + int64_t inner_offset = 100; + ge::AttrUtils::SetInt(assign->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_INNER_OFFSET, inner_offset); + ge::AttrUtils::SetInt(assign->GetOpDesc()->MutableOutputDesc(0), ATTR_NAME_INNER_OFFSET, inner_offset); + // add var addr + VarManager::Instance(session_id)->var_resource_->var_offset_map_.emplace(10000, RT_MEMORY_HBM); + + return builder.GetGraph(); + } + +protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UtestGraphMemAssigner, graph_memory_assign_fail_case) { + ge::ComputeGraphPtr compute_graph = make_shared(""); + GraphMemoryAssigner graph_mem_assigner(compute_graph); + MemoryOffset mem_offset(2, 10000); + graph_mem_assigner.memory_offset_.insert({2, mem_offset}); + VarManager::Instance(0)->graph_mem_max_size_ = 0; + + map mem_type_to_offset = {}; + Status ret = graph_mem_assigner.ReAssignMemory(false, mem_type_to_offset); + EXPECT_EQ(ret, ACL_ERROR_GE_MEMORY_ALLOCATION); +} + diff --git a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc index 218bfd0d..352984fa 100644 --- a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc +++ b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc @@ -349,7 +349,7 @@ class UtestLogicalStreamAllocator : public testing::Test { /// B --> C(AllReduce) --- D /// / /// stream id: 0 A - /// \ + /// \. /// E --> F(AllReduce) --- G /// stream id: 2 2 2 /// @@ -599,7 +599,7 @@ TEST_F(UtestLogicalStreamAllocator, test_label_not_reusable2) { /// case of multi-output, then unuse stream /// sub1 -/// / | \ +/// / | \. /// sub2 sub3 sub4 TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { SubGraphInfoPtr data = CreateDataSubgraph(); @@ -624,7 +624,7 @@ TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { /// if paralle id 1, then use stream /// sub1 -/// / | | \ +/// / | | \. /// sub2 sub3 sub4 sub5 TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { SubGraphInfoPtr data = CreateDataSubgraph(); @@ -653,7 +653,7 @@ TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { /// if the param of engine independent is true, then set independent stream /// sub1 -/// / | | \ +/// / | | \. /// sub2 sub3 sub4 sub5 TEST_F(UtestLogicalStreamAllocator, test_independent) { SubGraphInfoPtr data = CreateDataSubgraph(); @@ -692,7 +692,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) { /// set stream based on stream label, and then based on independent /// sub1 -/// / | | \ +/// / | | \. /// sub2 sub3 sub4 sub5 TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { SubGraphInfoPtr data = CreateDataSubgraph(); diff --git a/tests/ut/ge/graph/build/model_builder_unittest.cc b/tests/ut/ge/graph/build/model_builder_unittest.cc index d544e1a3..4f061e27 100644 --- a/tests/ut/ge/graph/build/model_builder_unittest.cc +++ b/tests/ut/ge/graph/build/model_builder_unittest.cc @@ -17,7 +17,7 @@ #include #include -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/anchor.h" #include "graph/attr_value.h" #include "graph/debug/ge_attr_define.h" diff --git a/tests/ut/ge/graph/build/stream_allocator_unittest.cc b/tests/ut/ge/graph/build/stream_allocator_unittest.cc index 019e75d1..4ae871af 100644 --- a/tests/ut/ge/graph/build/stream_allocator_unittest.cc +++ b/tests/ut/ge/graph/build/stream_allocator_unittest.cc @@ -36,7 +36,7 @@ class UtestStreamAllocator : public testing::Test { /// /// A - /// / \ + /// / \. /// B C /// | | /// D 400 diff --git a/tests/ut/ge/graph/build/task_generator_unittest.cc b/tests/ut/ge/graph/build/task_generator_unittest.cc index f869f1e0..7be20fa1 100644 --- a/tests/ut/ge/graph/build/task_generator_unittest.cc +++ b/tests/ut/ge/graph/build/task_generator_unittest.cc @@ -29,6 +29,8 @@ #define protected public #define private public +#include "init/gelib.h" +#include "ge/opskernel_manager/ops_kernel_builder_manager.h" #include "graph/build/task_generator.h" #include "graph/manager/graph_mem_manager.h" #include "graph/manager/graph_var_manager.h" @@ -41,9 +43,46 @@ using namespace ge; namespace { const char *const kIsInputVar = "INPUT_IS_VAR"; const char *const kIsOutputVar = "OUTPUT_IS_VAR"; -} +const char *const kKernelInfoNameHccl = "ops_kernel_info_hccl"; +} // namespace class UtestTaskGeneratorTest : public testing::Test { public: + struct FakeOpsKernelBuilder : OpsKernelBuilder { + FakeOpsKernelBuilder(){}; + + private: + Status Initialize(const map &options) override { + return SUCCESS; + }; + Status Finalize() override { + return SUCCESS; + }; + Status CalcOpRunningParam(Node &node) override { + return SUCCESS; + }; + Status GenerateTask(const Node &node, RunContext &context, std::vector &tasks) override { + domi::TaskDef task_def; + tasks.push_back(task_def); + return SUCCESS; + }; + }; + + struct FakeOpsKernelInfoStore : OpsKernelInfoStore { + FakeOpsKernelInfoStore() = default; + + private: + Status Initialize(const std::map &options) override { + return SUCCESS; + }; + Status Finalize() override { + return SUCCESS; + }; + bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override { + return true; + }; + void GetAllOpsKernelInfo(std::map &infos) const override{}; + }; + ge::ComputeGraphPtr BuildGraphFpProfiling() { ge::ut::GraphBuilder builder("graph"); auto data = builder.AddNode("data", "phony", 1, 1); @@ -95,6 +134,14 @@ class UtestTaskGeneratorTest : public testing::Test { return builder.GetGraph(); } + ge::ComputeGraphPtr BuildHcclGraph() { + ge::ut::GraphBuilder builder("graph"); + auto hccl_node = builder.AddNode("hccl_phony_node", "HCCL_PHONY", 0, 0); + auto op_desc = hccl_node->GetOpDesc(); + op_desc->SetOpKernelLibName(kKernelInfoNameHccl); + op_desc->SetStreamId(0); + return builder.GetGraph(); + } protected: void SetUp() {} @@ -116,7 +163,9 @@ TEST_F(UtestTaskGeneratorTest, FindLastBpFromBpNode) { TaskGenerator task_generator(nullptr, 0); auto net_output = graph->FindNode("Node_Output"); // netoutput has no data input, return default value 0 - EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output), 0); + uint32_t bp_index = 0; + EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output, bp_index), 0); + EXPECT_EQ(bp_index, 2); } TEST_F(UtestTaskGeneratorTest, UpdateOpIsVarAttr) { @@ -154,3 +203,31 @@ TEST_F(UtestTaskGeneratorTest, AutoFindBpOpIndex) { output_desc->SetName("hcom"); EXPECT_EQ(task_generator.AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes), SUCCESS); } + +TEST_F(UtestTaskGeneratorTest, GenerateTask) { + map options; + Status ret = ge::GELib::Initialize(options); + EXPECT_EQ(ret, SUCCESS); + + shared_ptr instance_ptr = ge::GELib::GetInstance(); + EXPECT_NE(instance_ptr, nullptr); + + OpsKernelInfoStorePtr ops_kernel_info_store_ptr = MakeShared(); + instance_ptr->opsManager_.ops_kernel_store_.insert(make_pair(kKernelInfoNameHccl, ops_kernel_info_store_ptr)); + + OpsKernelBuilderManager &builder_manager_instance_ptr = ge::OpsKernelBuilderManager::Instance(); + OpsKernelBuilderPtr fake_builder = MakeShared(); + builder_manager_instance_ptr.ops_kernel_builders_[kKernelInfoNameHccl] = fake_builder; + + auto graph = BuildHcclGraph(); + TaskGenerator task_generator(nullptr, 0); + RunContext run_context; + run_context.graphStreamList.push_back(static_cast(ops_kernel_info_store_ptr.get())); + vector all_reduce_nodes; + vector task_def_list; + map op_name_map; + + EXPECT_EQ(task_generator.GenerateTask(run_context, graph, task_def_list, op_name_map), SUCCESS); + EXPECT_EQ(task_def_list.size(), 1); + EXPECT_EQ(task_def_list[0].ops_kernel_store_ptr(), reinterpret_cast(ops_kernel_info_store_ptr.get())); +} \ No newline at end of file diff --git a/tests/ut/ge/graph/execute/graph_execute_unittest.cc b/tests/ut/ge/graph/execute/graph_execute_unittest.cc index 6d982454..3e32405b 100644 --- a/tests/ut/ge/graph/execute/graph_execute_unittest.cc +++ b/tests/ut/ge/graph/execute/graph_execute_unittest.cc @@ -17,6 +17,8 @@ #include #include +#include "common/profiling/profiling_manager.h" + #define protected public #define private public #include "graph/execute/graph_execute.h" @@ -125,4 +127,46 @@ TEST_F(UtestGraphExecuteTest, test_set_callback) { auto status = executor.SetCallback(1, ge_root_model, callback); EXPECT_EQ(status, SUCCESS); } + +TEST_F(UtestGraphExecuteTest, test_without_subscribe) { + GraphExecutor executor; + auto ret = executor.ModelSubscribe(1); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestGraphExecuteTest, test_with_subscribe_failed1) { + GraphExecutor executor; + uint32_t graph_id = 1; + auto &profiling_manager = ProfilingManager::Instance(); + profiling_manager.SetSubscribeInfo(0, 1, true); + auto ret = executor.ModelSubscribe(graph_id); + profiling_manager.CleanSubscribeInfo(); + EXPECT_NE(ret, SUCCESS); +} + +TEST_F(UtestGraphExecuteTest, test_with_subscribe_failed2) { + GraphExecutor executor; + uint32_t graph_id = 1; + uint32_t model_id = 1; + auto &profiling_manager = ProfilingManager::Instance(); + profiling_manager.SetSubscribeInfo(0, 1, true); + profiling_manager.SetGraphIdToModelMap(2, model_id); + auto ret = executor.ModelSubscribe(graph_id); + profiling_manager.CleanSubscribeInfo(); + EXPECT_NE(ret, SUCCESS); +} + +TEST_F(UtestGraphExecuteTest, test_with_subscribe_success) { + GraphExecutor executor; + uint32_t graph_id = 1; + uint32_t model_id = 1; + GraphNodePtr graph_node = std::make_shared(graph_id); + DavinciModel model(model_id, nullptr); + auto &profiling_manager = ProfilingManager::Instance(); + profiling_manager.SetSubscribeInfo(0, 1, true); + profiling_manager.SetGraphIdToModelMap(graph_id, model_id); + auto ret = executor.ModelSubscribe(graph_id); + profiling_manager.CleanSubscribeInfo(); + EXPECT_EQ(ret, SUCCESS); +} } // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/graph/execute/model_executor_unittest.cc b/tests/ut/ge/graph/execute/model_executor_unittest.cc new file mode 100644 index 00000000..cd907e99 --- /dev/null +++ b/tests/ut/ge/graph/execute/model_executor_unittest.cc @@ -0,0 +1,328 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define protected public +#define private public +#include "graph/execute/model_executor.h" +#include "graph/manager/graph_manager.h" +#include "graph/manager/graph_var_manager.h" +#include "graph/load/model_manager/model_manager.h" +#include "graph/load/model_manager/davinci_model.h" + +using namespace std; + +namespace ge { +class UtestModelExecutorTest : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { + OpDescPtr op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); + TensorUtils::SetSize(tensor, 64); + vector input_offset; + for (int i = 0; i < in_num; i++) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(index * 64 + i * 64); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; i++) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(index * 64 + in_num * 64 + i * 64); + } + op_desc->SetOutputOffset(output_offset); + + op_desc->SetWorkspace({}); + op_desc->SetWorkspaceBytes({}); + op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); + + return graph.AddNode(op_desc); +} + +TEST_F(UtestModelExecutorTest, test_load_graph_sync) { + ModelExecutor model_executor; + EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS); + + auto compute_graph = MakeShared("test_graph"); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + + GeModelPtr ge_model = MakeShared(); + ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(compute_graph)); + ge_root_model->SetSubgraphInstanceNameToModel(compute_graph->GetName(), ge_model); + + GraphId graph_id = 1; + GraphNodePtr graph_node = MakeShared(graph_id); + graph_node->SetGeRootModel(ge_root_model); + graph_node->SetLoadFlag(true); + graph_node->SetAsync(false); + + EXPECT_EQ(model_executor.LoadGraph(ge_root_model, graph_node), SUCCESS); + EXPECT_EQ(model_executor.UnloadGraph(ge_root_model, graph_id), SUCCESS); + + EXPECT_EQ(model_executor.Finalize(), SUCCESS); +} + +TEST_F(UtestModelExecutorTest, test_load_graph_async) { + ModelExecutor model_executor; + EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS); + + Graph graph("test_graph"); + auto compute_graph = MakeShared("test_graph"); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + + GeModelPtr ge_model = MakeShared(); + ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(compute_graph)); + ge_root_model->SetSubgraphInstanceNameToModel(compute_graph->GetName(), ge_model); + + GraphId graph_id = 1; + GraphNodePtr graph_node = MakeShared(graph_id); + graph_node->SetGeRootModel(ge_root_model); + graph_node->SetLoadFlag(true); + graph_node->SetAsync(true); + + EXPECT_EQ(model_executor.LoadGraph(ge_root_model, graph_node), SUCCESS); + + EXPECT_EQ(model_executor.UnloadGraph(ge_root_model, graph_id), SUCCESS); + + EXPECT_EQ(model_executor.Finalize(), SUCCESS); +} + +TEST_F(UtestModelExecutorTest, test_load_graph_failed) { + ModelExecutor model_executor; + EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS); + + Graph graph("test_graph"); + auto compute_graph = MakeShared("test_graph"); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + + GraphId graph_id = 1; + GraphNodePtr graph_node = MakeShared(graph_id); + graph_node->SetGeRootModel(ge_root_model); + graph_node->SetLoadFlag(true); + graph_node->SetAsync(true); + + // GeModel is null, DavinciModel::Assign will return FAILED + setenv(kEnvGeuseStaticMemory, "1", true); + EXPECT_EQ(model_executor.LoadGraph(ge_root_model, graph_node), FAILED); + EXPECT_EQ(model_executor.UnloadGraph(ge_root_model, graph_id), SUCCESS); + + EXPECT_EQ(model_executor.Finalize(), SUCCESS); + unsetenv(kEnvGeuseStaticMemory); +} + +TEST_F(UtestModelExecutorTest, test_check_and_release_memory) { + { + auto listener = MakeShared(); + shared_ptr davinci_model1 = MakeShared(1, listener); + davinci_model1->SetId(1); + ModelManager::GetInstance()->InsertModel(1, davinci_model1); + shared_ptr davinci_model2 = MakeShared(2, listener); + davinci_model1->SetId(2); + ModelManager::GetInstance()->InsertModel(2, davinci_model2); + } + + ModelExecutor model_executor; + EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS); + + GeModelPtr ge_model = make_shared(); + int64_t memory_size = 25 * 1024UL * 1024UL * 1024UL; + int64_t weight_size = 25 * 1024UL * 1024UL * 1024UL; + uint64_t session_id = 0; + EXPECT_TRUE(AttrUtils::SetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, memory_size)); + EXPECT_TRUE(AttrUtils::SetInt(ge_model, ATTR_MODEL_WEIGHT_SIZE, weight_size)); + EXPECT_TRUE(AttrUtils::SetInt(ge_model, MODEL_ATTR_SESSION_ID, session_id)); + + GraphId graph_id = 1; + GraphNodePtr graph_node = MakeShared(graph_id); + model_executor.AddGraphNode(graph_id, graph_node); + + ComputeGraphPtr compute_graph = MakeShared("test_graph"); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + ge_root_model->SetModelId(1); + ge_root_model->SetModelId(2); + graph_node->SetGeRootModel(ge_root_model); + graph_node->SetLoadFlag(true); + + EXPECT_EQ(model_executor.CheckAndReleaseMemory(ge_model, graph_node), SUCCESS); + EXPECT_EQ(model_executor.Finalize(), SUCCESS); +} + +TEST_F(UtestModelExecutorTest, parse_inputs_dims_data) { + ModelExecutor model_executor; + EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS); + + OmeContext context; + SetLocalOmeContext(context); + ComputeGraphPtr compute_graph = MakeShared("test_graph"); + const auto data1 = CreateNode(*compute_graph, DATA, "data1", 1, 1); + const auto next1 = CreateNode(*compute_graph, GETNEXT, "data1", 1, 1); + + Tensor tensor; + std::vector input_tensors; + input_tensors.emplace_back(tensor); + EXPECT_EQ(model_executor.ParseInputsDims(input_tensors), SUCCESS); // dynamic_node_type is empty, just return + + context.dynamic_node_type = DATA; + EXPECT_EQ(model_executor.ParseInputsDims(input_tensors), SUCCESS); // ParseInputsDimsForData + + context.getnext_nosink_nodes.emplace_back(next1); + EXPECT_EQ(model_executor.ParseInputsDims(input_tensors), SUCCESS); // ParseInputsDimsForGetNexNosinkAndData + + EXPECT_EQ(model_executor.Finalize(), SUCCESS); +} + +TEST_F(UtestModelExecutorTest, parse_inputs_dims_getnext) { + ModelExecutor model_executor; + EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS); + + OmeContext context; + SetLocalOmeContext(context); + ComputeGraphPtr compute_graph = MakeShared("test_graph"); + const auto data1 = CreateNode(*compute_graph, DATA, "data1", 1, 1); + const auto next1 = CreateNode(*compute_graph, GETNEXT, "data1", 1, 1); + + Tensor tensor; + std::vector input_tensors; + input_tensors.emplace_back(tensor); + + context.dynamic_node_type = GETNEXT; + EXPECT_EQ(model_executor.ParseInputsDims(input_tensors), SUCCESS); // just getnext_sink + + context.getnext_nosink_nodes.emplace_back(next1); + EXPECT_EQ(model_executor.ParseInputsDims(input_tensors), SUCCESS); // ParseInputsDimsForData + + context.data_nodes.emplace_back(data1); + EXPECT_EQ(model_executor.ParseInputsDims(input_tensors), PARAM_INVALID); // ParseInputsDimsForGetNexNosinkAndData + AttrUtils::SetInt(next1->GetOpDesc(), ATTR_NAME_INDEX, 0); + EXPECT_EQ(model_executor.ParseInputsDims(input_tensors), SUCCESS); // ParseInputsDimsForGetNexNosinkAndData + + EXPECT_EQ(model_executor.Finalize(), SUCCESS); +} + +TEST_F(UtestModelExecutorTest, test_run_thread) { + ModelExecutor model_executor; + EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS); + + GraphId graph_id = 1; + uint64_t session_id = 0; + error_message::Context error_context; + GEThreadLocalContext context; + const auto callback = [](Status status, std::vector &outputs) { }; + + auto compute_graph = MakeShared("test_graph"); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + + GeModelPtr ge_model = MakeShared(); + ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(compute_graph)); + ge_root_model->SetSubgraphInstanceNameToModel(compute_graph->GetName(), ge_model); + + GraphNodePtr graph_node = MakeShared(graph_id); + graph_node->SetGeRootModel(ge_root_model); + graph_node->SetLoadFlag(false); + graph_node->SetAsync(true); + graph_node->IncreaseLoadCount(); + graph_node->Lock(); + + Tensor tensor; + std::vector input_tensors; + input_tensors.emplace_back(tensor); + + RunArgs run_args{graph_node, graph_id, session_id, error_context, input_tensors, ge_root_model, context, callback}; + EXPECT_EQ(model_executor.PushGraph(run_args), SUCCESS); + + while (model_executor.run_args_q_.Size() > 0) { + usleep(10); // 0.01ms, Wait for RunThread. + } + EXPECT_EQ(model_executor.Finalize(), SUCCESS); +} + +static void test_run_graph(ModelExecutor &model_executor) { + auto compute_graph = MakeShared("test_graph"); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + + GeModelPtr ge_model = MakeShared(); + ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(compute_graph)); + ge_root_model->SetSubgraphInstanceNameToModel(compute_graph->GetName(), ge_model); + + GraphId graph_id = 1; + GraphNodePtr graph_node = MakeShared(graph_id); + graph_node->SetGeRootModel(ge_root_model); + graph_node->SetLoadFlag(false); + graph_node->SetAsync(false); // RunGraph is Synchronization. + EXPECT_EQ(model_executor.LoadGraph(ge_root_model, graph_node), SUCCESS); + + std::vector inputs; + std::vector outputs; + EXPECT_EQ(model_executor.RunGraph(graph_node, graph_id, inputs, outputs), SUCCESS); +} + +TEST_F(UtestModelExecutorTest, test_run_graph_train) { + GetThreadLocalContext().SetGlobalOption({{OPTION_GRAPH_RUN_MODE, "1"}}); + ModelExecutor model_executor; + EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS); + test_run_graph(model_executor); + EXPECT_EQ(model_executor.Finalize(), SUCCESS); +} + +TEST_F(UtestModelExecutorTest, test_run_graph_infer) { + GetThreadLocalContext().SetGlobalOption({}); + GetThreadLocalContext().SetSessionOption({}); + GetThreadLocalContext().SetGraphOption({}); + ModelExecutor model_executor; + EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS); + test_run_graph(model_executor); + EXPECT_EQ(model_executor.Finalize(), SUCCESS); +} + +TEST_F(UtestModelExecutorTest, test_run_graph_with_stream) { + ModelExecutor model_executor; + EXPECT_EQ(model_executor.Initialize({}, 0), SUCCESS); + + GraphId graph_id = 1; + auto compute_graph = MakeShared("test_graph"); + GeRootModelPtr ge_root_model = MakeShared(compute_graph); + + GeModelPtr ge_model = MakeShared(); + ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(compute_graph)); + ge_root_model->SetSubgraphInstanceNameToModel(compute_graph->GetName(), ge_model); + + GraphNodePtr graph_node = MakeShared(graph_id); + graph_node->SetGeRootModel(ge_root_model); + graph_node->SetLoadFlag(false); + graph_node->SetAsync(true); + + GeTensor tensor; + std::vector inputs{tensor}; + std::vector outputs; + + rtStream_t stream = nullptr; + rtStreamCreate(&stream, 0); + EXPECT_EQ(model_executor.RunGraphWithStream(graph_node, graph_id, stream, inputs, outputs), 145003); + + EXPECT_EQ(model_executor.Finalize(), SUCCESS); + rtStreamDestroy(stream); +} +} // namespace ge diff --git a/tests/ut/ge/graph/graph_load_unittest.cc b/tests/ut/ge/graph/graph_load_unittest.cc deleted file mode 100644 index cbcefd03..00000000 --- a/tests/ut/ge/graph/graph_load_unittest.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -#include "common/debug/log.h" -#include "common/helper/model_helper.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "graph/op_desc.h" -#include "graph/types.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/op_desc_utils.h" - -#define protected public -#define private public -#include "graph/load/graph_loader.h" - -#include "framework/common/ge_inner_error_codes.h" -#include "graph/load/model_manager/model_manager.h" -#include "graph/manager/graph_manager_utils.h" -#include "model/ge_model.h" -#undef private -#undef protected - -using namespace testing; -namespace ge { - -class UtestGraphGraphLoad : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestGraphGraphLoad, load_graph_param_invalid1) { - std::shared_ptr graph_run_listener = nullptr; - SubGraphInfo sub_graph1; - ge::SubGraphInfoPtr sub_graph_ptr1 = std::make_shared(sub_graph1); - ModelIdInfo model_Id_info; - model_Id_info.model_id = 1; - - GeModelPtr ge_model_ptr = std::make_shared(); - sub_graph_ptr1->SetGeModelPtr(ge_model_ptr); - - std::vector input_flag; - input_flag.push_back(false); - sub_graph_ptr1->SetInputFlag(input_flag); - - ge::GraphLoader graph_load; - EXPECT_EQ(GE_GRAPH_PARAM_NULLPTR, graph_load.LoadGraph(sub_graph_ptr1->ge_model_ptr_, graph_run_listener, model_Id_info)); - sub_graph_ptr1->SetModelIdInfo(model_Id_info); -} - -TEST_F(UtestGraphGraphLoad, load_graph_param_invalid2) { - std::mutex sync_run_mutex; - std::condition_variable condition; - std::shared_ptr listener = std::make_shared(sync_run_mutex, condition); - - SubGraphInfo sub_graph1; - ge::SubGraphInfoPtr sub_graph_ptr1 = std::make_shared(sub_graph1); - ModelIdInfo model_Id_info; - model_Id_info.model_id = 1; - - GeModelPtr ge_model_ptr = std::make_shared(); - sub_graph_ptr1->SetGeModelPtr(ge_model_ptr); - - std::vector input_flag; - input_flag.push_back(false); - sub_graph_ptr1->SetInputFlag(input_flag); - - ge::GraphLoader graph_load; - EXPECT_EQ(GE_GRAPH_PARAM_NULLPTR, graph_load.LoadGraph(sub_graph_ptr1->ge_model_ptr_, listener, model_Id_info)); - sub_graph_ptr1->SetModelIdInfo(model_Id_info); -} -} // namespace ge diff --git a/tests/ut/ge/graph/load/davinci_model_unittest.cc b/tests/ut/ge/graph/load/davinci_model_unittest.cc index 3f9cc850..62204f6c 100644 --- a/tests/ut/ge/graph/load/davinci_model_unittest.cc +++ b/tests/ut/ge/graph/load/davinci_model_unittest.cc @@ -1035,6 +1035,9 @@ TEST_F(UtestDavinciModel, NnExecute) { ProfilingManager::Instance().device_id_.emplace_back(0); model.task_list_.resize(1); EXPECT_EQ(model.NnExecute(stream, false, input_data, output_data), SUCCESS); + + input_data.blobs[0].length = 128; + EXPECT_NE(model.NnExecute(stream, false, input_data, output_data), SUCCESS); } TEST_F(UtestDavinciModel, update_io_addr_success) { DavinciModel model(0, nullptr); @@ -1059,4 +1062,144 @@ TEST_F(UtestDavinciModel, get_total_memsize_exclude_zero_copy) { EXPECT_EQ(model.GetTotalMemSizeExcludeZeroCopy(total_useful_size), SUCCESS); EXPECT_EQ(total_useful_size, 512); } + +// test InitTbeHandle +TEST_F(UtestDavinciModel, init_tbe_handle) { + DavinciModel model(0, nullptr); + OpDescPtr op_desc = CreateOpDesc("data", DATA); + model.ge_model_ = make_shared(); + // without kernel + EXPECT_EQ(model.InitTbeHandle(op_desc), INTERNAL_ERROR); + vector buffer; + string key = op_desc->GetName(); + TBEKernelPtr tbe_kernel_ptr = std::make_shared(key, std::move(buffer)); + op_desc->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); + string attr_kernel_name = op_desc->GetName() + "_kernelname"; + string kernel_name = "kernel_name"; + AttrUtils::SetStr(op_desc, attr_kernel_name, kernel_name); + EXPECT_EQ(model.InitTbeHandle(op_desc), SUCCESS); + // rtQueryFunctionRegistered(bin_file_key) failed + EXPECT_EQ(model.used_tbe_handle_map_.size(), 0); +} + +// test InitTbeHandleWithFfts +TEST_F(UtestDavinciModel, init_tbe_handle_with_ffts) { + DavinciModel model(0, nullptr); + OpDescPtr op_desc = CreateOpDesc("data", DATA); + model.ge_model_ = make_shared(); + // without tbe_kernel + EXPECT_EQ(model.InitTbeHandleWithFfts(op_desc), INTERNAL_ERROR); + + std::vector tbe_kernel; + vector buffer; + string key = op_desc->GetName(); + OpKernelBinPtr tbe_kernel_ptr0 = std::make_shared(key, std::move(buffer)); + OpKernelBinPtr tbe_kernel_ptr1 = std::make_shared(key, std::move(buffer)); + tbe_kernel.push_back(tbe_kernel_ptr0); + tbe_kernel.push_back(tbe_kernel_ptr1); + op_desc->SetExtAttr(OP_EXTATTR_NAME_THREAD_TBE_KERNEL, tbe_kernel); + // without _register_stub_func + EXPECT_EQ(model.InitTbeHandleWithFfts(op_desc), INTERNAL_ERROR); + + vector bin_file_keys; + bin_file_keys.emplace_back(op_desc->GetName() + "_0"); + bin_file_keys.emplace_back(op_desc->GetName() + "_1"); + AttrUtils::SetListStr(op_desc, "_register_stub_func", bin_file_keys); + + EXPECT_EQ(model.InitTbeHandleWithFfts(op_desc), SUCCESS); + // rtQueryFunctionRegistered(bin_file_key) failed + EXPECT_EQ(model.used_tbe_handle_map_.size(), 0); +} + +// test InitBinaryMagic +TEST_F(UtestDavinciModel, init_binary_magic) { + DavinciModel model(0, nullptr); + rtDevBinary_t binary; + OpDescPtr op_desc = CreateOpDesc("data", DATA); + bool is_ffts = true; + vector json_list; + AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); + // without tvm_magic + EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), INTERNAL_ERROR); + json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_AICPU"); + json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF"); + op_desc->DelAttr(TVM_ATTR_NAME_THREAD_MAGIC); + AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); + EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), SUCCESS); + EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AICPU); + EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 1, binary), SUCCESS); + EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF); + + json_list.clear(); + json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_AIVEC"); + json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_AICUBE"); + op_desc->DelAttr(TVM_ATTR_NAME_THREAD_MAGIC); + AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); + EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), SUCCESS); + EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AIVEC); + EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 1, binary), SUCCESS); + EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AICUBE); + + // with invalid json type + json_list.clear(); + json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_INVALID"); + json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_INVALID"); + op_desc->DelAttr(TVM_ATTR_NAME_THREAD_MAGIC); + AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); + EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), PARAM_INVALID); + + // test unffts + is_ffts = false; + string json_string = "RT_DEV_BINARY_MAGIC_ELF_AIVEC"; + AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, json_string); + EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), SUCCESS); + EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AIVEC); +} + +// test InitMetaData +TEST_F(UtestDavinciModel, init_meta_data) { + DavinciModel model(0, nullptr); + void *bin_handle; + OpDescPtr op_desc = CreateOpDesc("data", DATA); + bool is_ffts = true; + vector meta_data_list; + // with empty meta_data + EXPECT_EQ(model.InitMetaData(op_desc, is_ffts, 0, bin_handle), INTERNAL_ERROR); + meta_data_list.emplace_back("meta_data_0"); + meta_data_list.emplace_back("meta_data_1"); + AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_METADATA, meta_data_list); + EXPECT_EQ(model.InitMetaData(op_desc, is_ffts, 0, bin_handle), SUCCESS); + + is_ffts = false; + string meta_data = "meta_data"; + AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_METADATA, meta_data); + EXPECT_EQ(model.InitMetaData(op_desc, is_ffts, 0, bin_handle), SUCCESS); +} + +// test InitKernelName +TEST_F(UtestDavinciModel, init_kernel_name) { + DavinciModel model(0, nullptr); + string kernel_name; + OpDescPtr op_desc = CreateOpDesc("data", DATA); + bool is_ffts = true; + // failed when name is invalid + EXPECT_EQ(model.InitKernelName(op_desc, is_ffts, 0, kernel_name), INTERNAL_ERROR); + OpDescPtr op_desc1 = CreateOpDesc("sgt_graph_nodes/loss_scale", SCALE); + string attr_kernel_name = "loss_scale_thread_kernelname"; + vector kernel_name_list; + AttrUtils::SetListStr(op_desc, attr_kernel_name, kernel_name_list); + // failed without kernel_name + EXPECT_EQ(model.InitKernelName(op_desc, is_ffts, 0, kernel_name), INTERNAL_ERROR); + kernel_name_list.emplace_back("kernel_name_0"); + kernel_name_list.emplace_back("kernel_name_1"); + AttrUtils::SetListStr(op_desc1, attr_kernel_name, kernel_name_list); + EXPECT_EQ(model.InitKernelName(op_desc1, is_ffts, 0, kernel_name), SUCCESS); + + // without ffts + is_ffts = false; + attr_kernel_name = "data_kernelname"; + kernel_name = "kernel_name"; + AttrUtils::SetStr(op_desc, attr_kernel_name, kernel_name); + EXPECT_EQ(model.InitKernelName(op_desc, is_ffts, 0, kernel_name), SUCCESS); +} } // namespace ge diff --git a/tests/ut/ge/graph/load/ffts_task_info_unittest.cc b/tests/ut/ge/graph/load/ffts_task_info_unittest.cc new file mode 100644 index 00000000..25838f7e --- /dev/null +++ b/tests/ut/ge/graph/load/ffts_task_info_unittest.cc @@ -0,0 +1,212 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define private public +#define protected public + +#include "graph/load/model_manager/task_info/ffts_task_info.h" +#include "cce/aicpu_engine_struct.h" +#include "common/ge/ge_util.h" +#include "common/properties_manager.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/fmk_error_codes.h" +#include "graph/attr_value.h" +#include "graph/load/model_manager/davinci_model.h" +#include "graph/load/model_manager/model_manager.h" +#include "runtime/rt_ffts.h" + +namespace ge { +extern OpDescPtr CreateOpDesc(string name, string type); + +class UtestFftsTaskInfo : public testing::Test { +protected: + void SetUp() {} + + void TearDown() {} + +public: + void CreateFftsTaskInfo(DavinciModel &davinci_model, domi::TaskDef &task_def, FftsTaskInfo &ffts_task_info) { + rtStream_t stream = nullptr; + rtStreamCreate(&stream, 0); + davinci_model.stream_list_ = { stream }; + task_def.set_stream_id(0); + + domi::FftsTaskDef *ffts_task_def = task_def.mutable_ffts_task(); + davinci_model.op_list_[0] = CreateOpDesc("test", PARTITIONEDCALL); + ffts_task_def->set_op_index(0); + ffts_task_def->set_addr_size(2); + domi::FftsDescInfoDef *ffts_desc = ffts_task_def->mutable_ffts_desc(); + ffts_desc->set_tm(0); + rtFftsTaskInfo_t sub_task_info; + ffts_task_info.sub_task_info_ = sub_task_info; + ffts_task_def->set_ffts_type(RT_FFTS_TYPE_AUTO_THREAD); + } +}; + +// test FftsTaskInfo Init with no subtask and no ticket cache +TEST_F(UtestFftsTaskInfo, success_ffts_task_info_without_subtask) { + DavinciModel davinci_model(0, nullptr); + rtStream_t stream = nullptr; + rtStreamCreate(&stream, 0); + davinci_model.stream_list_ = { stream }; + domi::TaskDef task_def; + task_def.set_stream_id(0); + + domi::FftsTaskDef *ffts_task_def = task_def.mutable_ffts_task(); + FftsTaskInfo ffts_task_info; + // init failed when model without op_desc + EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), PARAM_INVALID); + + davinci_model.op_list_[0] = CreateOpDesc("test", PARTITIONEDCALL); + ffts_task_def->set_op_index(0); + ffts_task_def->set_addr_size(2); + domi::FftsDescInfoDef *ffts_desc = ffts_task_def->mutable_ffts_desc(); + ffts_desc->set_tm(0); + rtFftsTaskInfo_t sub_task_info; + ffts_task_info.sub_task_info_ = sub_task_info; + ffts_task_def->set_ffts_type(RT_FFTS_TYPE_AUTO_THREAD); + ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; + EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); +} + +// test FftsTaskInfo Init with subtask and no ticket cache: AutoThreadAicAivDef +TEST_F(UtestFftsTaskInfo, success_ffts_task_info_with_auto_thread_subgraph) { + DavinciModel davinci_model(0, nullptr); + domi::TaskDef task_def; + FftsTaskInfo ffts_task_info; + CreateFftsTaskInfo(davinci_model, task_def, ffts_task_info); + domi::FftsSubTaskDef *ffts_sub_task_def = task_def.mutable_ffts_task()->add_sub_task(); + ffts_sub_task_def->set_thread_dim(static_cast(1)); + //sub_task_def.has_auto_thread_aic_aiv() == sub_task_def.has_manual_thread_aic_aiv() + EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), FAILED); + + domi::AutoThreadAicAivDef *auto_thread_aic_aiv_def = ffts_sub_task_def->mutable_auto_thread_aic_aiv(); + domi::AutoThreadPrefetchDef *src_prefetch = auto_thread_aic_aiv_def->add_src_prefetch(); + // without InitIoAddrs + ffts_task_info.thread_dim_ = 0; + RuntimeParam runtime_param; + ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; + EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); +} + +// test FftsTaskInfo Init with subtask and no ticket cache: ManualThreadAicAivDef +TEST_F(UtestFftsTaskInfo, success_ffts_task_info_with_manual_thread_subgraph) { + DavinciModel davinci_model(0, nullptr); + domi::TaskDef task_def; + FftsTaskInfo ffts_task_info; + CreateFftsTaskInfo(davinci_model, task_def, ffts_task_info); + domi::FftsSubTaskDef *ffts_sub_task_def = task_def.mutable_ffts_task()->add_sub_task(); + ffts_sub_task_def->set_thread_dim(static_cast(1)); + //sub_task_def.has_auto_thread_aic_aiv() == sub_task_def.has_manual_thread_aic_aiv() + + domi::ManualThreadAicAivDef *manual_thread_aic_aiv_def = ffts_sub_task_def->mutable_manual_thread_aic_aiv(); + manual_thread_aic_aiv_def->add_thread_prefetch_dmu_idx(static_cast(0)); + manual_thread_aic_aiv_def->add_thread_blk_dim(static_cast(0)); + manual_thread_aic_aiv_def->add_thread_task_func_stub("ffts"); + domi::ManualThreadDmuDef *prefetch_list = manual_thread_aic_aiv_def->add_prefetch_list(); + prefetch_list->set_data_addr(static_cast(0)); + // without InitIoAddrs + ffts_task_info.thread_dim_ = 0; + RuntimeParam runtime_param; + ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; + EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); +} + +// test FftsTaskInfo Init with subtask and no ticket cache: ManualThreadNopDef +TEST_F(UtestFftsTaskInfo, success_ffts_task_info_with_manual_thread_nop_subgraph) { + DavinciModel davinci_model(0, nullptr); + domi::TaskDef task_def; + FftsTaskInfo ffts_task_info; + CreateFftsTaskInfo(davinci_model, task_def, ffts_task_info); + + domi::FftsSubTaskDef *ffts_sub_task_def = task_def.mutable_ffts_task()->add_sub_task(); + ffts_sub_task_def->set_thread_dim(static_cast(1)); + domi::AutoThreadAicAivDef *auto_thread_aic_aiv_def = ffts_sub_task_def->mutable_auto_thread_aic_aiv(); + domi::ManualThreadNopDef *manual_thread_nop = ffts_sub_task_def->mutable_manual_thread_nop(); + domi::ManualThreadDependencyDef *src_dep_tbl = manual_thread_nop->add_src_dep_tbl(); + src_dep_tbl->add_dependency(static_cast(0)); + + // without InitIoAddrs + ffts_task_info.thread_dim_ = 0; + RuntimeParam runtime_param; + ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; + EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); +} + +// test FftsTaskInfo Init with no subtask and ticket cache:AutoThreadCacheDef +TEST_F(UtestFftsTaskInfo, success_ffts_task_info_with_auto_thread_ticket_cache) { + DavinciModel davinci_model(0, nullptr); + domi::TaskDef task_def; + FftsTaskInfo ffts_task_info; + CreateFftsTaskInfo(davinci_model, task_def, ffts_task_info); + + domi::TicketCacheDef *ticket_cache_def = task_def.mutable_ffts_task()->add_ticket_cache(); + //ticket_cache_def.has_auto_thread_cache() == ticket_cache_def.has_manual_thread_cache() + EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), FAILED); + domi::AutoThreadCacheDef *auto_thread_cache = ticket_cache_def->mutable_auto_thread_cache(); + + ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; + EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); +} + +// test FftsTaskInfo Init with no subtask and ticket cache:ManualThreadCacheDef +TEST_F(UtestFftsTaskInfo, success_ffts_task_info_with_manual_thread_ticket_cache) { + DavinciModel davinci_model(0, nullptr); + domi::TaskDef task_def; + FftsTaskInfo ffts_task_info; + CreateFftsTaskInfo(davinci_model, task_def, ffts_task_info); + + domi::TicketCacheDef *ticket_cache_def = task_def.mutable_ffts_task()->add_ticket_cache(); + domi::ManualThreadCacheDef *manual_thread_cache = ticket_cache_def->mutable_manual_thread_cache(); + manual_thread_cache->add_slice_dmu_idx(static_cast(0)); + manual_thread_cache->add_ticket_cache_ref_cnt_tbl(static_cast(0)); + domi::ManualThreadDmuDef *dmu_list = manual_thread_cache->add_dmu_list(); + + ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; + EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); +} + +// test FftsTaskInfo UpdateArgs +TEST_F(UtestFftsTaskInfo, success_ffts_task_info_update_args) { + DavinciModel davinci_model(0, nullptr); + FftsTaskInfo ffts_task_info; + ffts_task_info.davinci_model_ = &davinci_model; + ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; + EXPECT_EQ(ffts_task_info.UpdateArgs(), SUCCESS); +} + +// test FftsTaskInfo CalculateArgs +TEST_F(UtestFftsTaskInfo, success_ffts_task_info_calculate_args) { + DavinciModel davinci_model(0, nullptr); + domi::TaskDef task_def; + FftsTaskInfo ffts_task_info; + EXPECT_EQ(ffts_task_info.CalculateArgs(task_def, &davinci_model), SUCCESS); +} + +// test FftsTaskInfo Distribute +TEST_F(UtestFftsTaskInfo, success_ffts_task_info_distribute) { + DavinciModel davinci_model(0, nullptr); + FftsTaskInfo ffts_task_info; + rtFftsTaskInfo_t sub_task_info; + ffts_task_info.sub_task_info_ = sub_task_info; + rtStream_t stream = nullptr; + rtStreamCreate(&stream, 0); + ffts_task_info.stream_ = stream; + EXPECT_EQ(ffts_task_info.Distribute(), SUCCESS); +} +} // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc b/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc index 327dd248..86569789 100644 --- a/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc +++ b/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc @@ -23,15 +23,20 @@ #include "graph/load/model_manager/task_info/kernel_ex_task_info.h" #include "cce/aicpu_engine_struct.h" +#include "tests/depends/runtime/src/runtime_stub.h" namespace ge { extern OpDescPtr CreateOpDesc(string name, string type); class UtestKernelExTaskInfo : public testing::Test { protected: - void SetUp() {} + void SetUp() { + RTS_STUB_SETUP(); + } - void TearDown() {} + void TearDown() { + RTS_STUB_TEARDOWN(); + } }; // test kernel_ex_task_Release @@ -209,4 +214,136 @@ TEST_F(UtestKernelExTaskInfo, parse_topic_type_failed_2) { KernelExTaskInfo kernel_ex_task_info; EXPECT_NE(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); } + +TEST_F(UtestKernelExTaskInfo, blocking_aicpu_op) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::TaskDef task_def; + domi::KernelExDef kernel_ex_def; + kernel_ex_def.set_kernel_ext_info(buf, len); + kernel_ex_def.set_kernel_ext_info_size(len); + domi::KernelExDef *kernel_ex_def_tmp = task_def.mutable_kernel_ex(); + *kernel_ex_def_tmp = kernel_ex_def; + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + + KernelExTaskInfo kernel_ex_task_info; + kernel_ex_task_info.op_desc_ = op_desc; + DavinciModel davinci_model(0, nullptr); + kernel_ex_task_info.davinci_model_ = &davinci_model; + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + EXPECT_EQ(kernel_ex_task_info.Distribute(), SUCCESS); + kernel_ex_task_info.op_desc_ = op_desc; + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + EXPECT_EQ(kernel_ex_task_info.Distribute(), SUCCESS); +} + +TEST_F(UtestKernelExTaskInfo, blocking_aicpu_op_fail_01) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::TaskDef task_def; + domi::KernelExDef kernel_ex_def; + kernel_ex_def.set_kernel_ext_info(buf, len); + kernel_ex_def.set_kernel_ext_info_size(len); + domi::KernelExDef *kernel_ex_def_tmp = task_def.mutable_kernel_ex(); + *kernel_ex_def_tmp = kernel_ex_def; + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + + KernelExTaskInfo kernel_ex_task_info; + kernel_ex_task_info.op_desc_ = op_desc; + DavinciModel davinci_model(0, nullptr); + kernel_ex_task_info.davinci_model_ = &davinci_model; + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + + kernel_ex_task_info.is_blocking_aicpu_op_ = true; + EXPECT_EQ(kernel_ex_task_info.Distribute(), FAILED); +} + +TEST_F(UtestKernelExTaskInfo, blocking_aicpu_op_fail_02) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::TaskDef task_def; + domi::KernelExDef kernel_ex_def; + kernel_ex_def.set_kernel_ext_info(buf, len); + kernel_ex_def.set_kernel_ext_info_size(len); + domi::KernelExDef *kernel_ex_def_tmp = task_def.mutable_kernel_ex(); + *kernel_ex_def_tmp = kernel_ex_def; + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + KernelExTaskInfo kernel_ex_task_info; + kernel_ex_task_info.op_desc_ = op_desc; + DavinciModel davinci_model(0, nullptr); + kernel_ex_task_info.davinci_model_ = &davinci_model; + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT + 1); + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.Distribute(), FAILED); + + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + RTS_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.Distribute(), FAILED); + + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + RTS_STUB_RETURN_VALUE(rtEventReset, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.Distribute(), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(kernel_ex_task_info.Distribute(), SUCCESS); +} + } // namespace ge diff --git a/tests/ut/ge/graph/load/kernel_task_info_unittest.cc b/tests/ut/ge/graph/load/kernel_task_info_unittest.cc index 0c8da4b5..45ae7853 100644 --- a/tests/ut/ge/graph/load/kernel_task_info_unittest.cc +++ b/tests/ut/ge/graph/load/kernel_task_info_unittest.cc @@ -22,15 +22,20 @@ #include "graph/load/model_manager/davinci_model.h" #include "graph/load/model_manager/task_info/kernel_task_info.h" #include "graph/load/model_manager/task_info/hccl_task_info.h" +#include "tests/depends/runtime/src/runtime_stub.h" namespace ge { extern OpDescPtr CreateOpDesc(string name, string type); class UtestKernelTaskInfo : public testing::Test { protected: - void SetUp() {} + void SetUp() { + RTS_STUB_SETUP(); + } - void TearDown() {} + void TearDown() { + RTS_STUB_TEARDOWN(); + } }; // test KernelTaskInfo Init. @@ -1240,4 +1245,135 @@ TEST_F(UtestKernelTaskInfo, kernel_task_info_super_kernel_info) { EXPECT_EQ(kernel_task_info.SKTFinalize(), SUCCESS); } +TEST_F(UtestKernelTaskInfo, blocking_aicpu_op) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::TaskDef task_def; + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + op_desc->SetId(0); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + DavinciModel davinci_model(0, nullptr); + davinci_model.op_list_.emplace(0, op_desc); + + KernelTaskInfo kernel_task_info; + kernel_task_info.op_desc_ = op_desc; + kernel_task_info.davinci_model_ = &davinci_model; + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + EXPECT_EQ(kernel_task_info.Distribute(), SUCCESS); + kernel_task_info.op_desc_ = op_desc; + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + EXPECT_EQ(kernel_task_info.Distribute(), SUCCESS); +} + +TEST_F(UtestKernelTaskInfo, blocking_aicpu_op_fail_01) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + op_desc->SetId(0); + DavinciModel davinci_model(0, nullptr); + davinci_model.op_list_.emplace(0, op_desc); + + KernelTaskInfo kernel_task_info; + kernel_task_info.davinci_model_ = &davinci_model; + kernel_task_info.op_desc_ = op_desc; + + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + + kernel_task_info.is_blocking_aicpu_op_ = true; + EXPECT_EQ(kernel_task_info.Distribute(), FAILED); +} + +TEST_F(UtestKernelTaskInfo, blocking_aicpu_op_fail_02) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + op_desc->SetId(0); + DavinciModel davinci_model(0, nullptr); + davinci_model.op_list_.emplace(0, op_desc); + + KernelTaskInfo kernel_task_info; + kernel_task_info.davinci_model_ = &davinci_model; + kernel_task_info.op_desc_ = op_desc; + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT + 1); + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.Distribute(), FAILED); + + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + RTS_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.Distribute(), FAILED); + + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + RTS_STUB_RETURN_VALUE(rtEventReset, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.Distribute(), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(kernel_task_info.Distribute(), SUCCESS); +} + } // namespace ge diff --git a/tests/ut/ge/graph/load/model_helper_unittest.cc b/tests/ut/ge/graph/load/model_helper_unittest.cc index 8fd8f014..8af329ed 100644 --- a/tests/ut/ge/graph/load/model_helper_unittest.cc +++ b/tests/ut/ge/graph/load/model_helper_unittest.cc @@ -20,7 +20,7 @@ #include "framework/common/helper/model_helper.h" #include "framework/omg/model_tool.h" #include "framework/omg/ge_init.h" -#include "ge/model/ge_model.h" +#include "ge/common/model/ge_model.h" #undef private #undef protected diff --git a/tests/ut/ge/graph/load/model_manager_unittest.cc b/tests/ut/ge/graph/load/model_manager_unittest.cc index a3545b33..65b70a24 100644 --- a/tests/ut/ge/graph/load/model_manager_unittest.cc +++ b/tests/ut/ge/graph/load/model_manager_unittest.cc @@ -26,6 +26,7 @@ #include "graph/load/graph_loader.h" #include "graph/load/model_manager/davinci_model.h" #include "graph/ops_stub.h" +#include "common/profiling/profiling_manager.h" using namespace std; using namespace testing; @@ -54,31 +55,13 @@ class UtestModelManagerModelManager : public testing::Test { } void SetUp() {} - void TearDown() {} - void CreateGraph(Graph &graph) { - TensorDesc desc(ge::Shape({1, 3, 224, 224})); - uint32_t size = desc.GetShape().GetShapeSize(); - desc.SetSize(size); - auto data = op::Data("Data").set_attr_index(0); - data.update_input_desc_data(desc); - data.update_output_desc_out(desc); - - auto flatten = op::Flatten("Flatten").set_input_x(data, data.name_out_out()); - - std::vector inputs{data}; - std::vector outputs{flatten}; - std::vector targets{flatten}; - // Graph graph("test_graph"); - graph.SetInputs(inputs).SetOutputs(outputs).SetTargets(targets); - } - void GenUnencryptModelData(ModelData &data) { const int model_len = 10; data.model_len = sizeof(ModelFileHeader) + model_len; data.model_data = new uint8_t[data.model_len]; - memset((uint8_t *)data.model_data + sizeof(ModelFileHeader), 10, model_len); + memset(data.model_data, 0, data.model_len); ModelFileHeader *header = (ModelFileHeader *)data.model_data; header->magic = MODEL_FILE_MAGIC_NUM; @@ -88,19 +71,6 @@ class UtestModelManagerModelManager : public testing::Test { header->is_checksum = ModelCheckType::CHECK; } - void GenEncryptModelData(ModelData &data) { - const int model_len = 10; - data.key = ENC_KEY; - data.model_data = new uint8_t[data.model_len]; - uint8_t data_ori[model_len]; - memset(data_ori, 10, model_len); - ModelFileHeader *header = (ModelFileHeader *)data.model_data; - header->magic = MODEL_FILE_MAGIC_NUM; - header->version = MODEL_VERSION; - header->is_encrypt = ModelEncryptType::ENCRYPTED; - header->length = 10; // encrypt_len; - } - void LoadStandardModelData(ModelData &data) { data.model_len = 512; data.model_data = new uint8_t[data.model_len]; @@ -166,7 +136,8 @@ class UtestModelManagerModelManager : public testing::Test { class DModelListener : public ModelListener { public: DModelListener(){}; - uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; } + uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, + uint32_t resultCode, std::vector &outputs) { return 0; } }; TEST_F(UtestModelManagerModelManager, case_is_need_hybrid_load) { @@ -224,6 +195,7 @@ TEST_F(UtestModelManagerModelManager, case_load_model_encypt_type_unsupported) { ModelFileHeader *header = (ModelFileHeader *)data.model_data; header->is_encrypt = 255; uint32_t model_id = 1; + // Error for: LoadModelPartitionTable: Invalid partition_table->num:0 EXPECT_EQ(mm.LoadModelOffline(model_id, data, nullptr, nullptr), ACL_ERROR_GE_PARAM_INVALID); delete[](uint8_t *) data.model_data; } @@ -438,4 +410,48 @@ TEST_F(UtestModelManagerModelManager, test_data_input_tensor) { auto ret = mm.DataInputTensor(model_id,inputs); EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. } + +TEST_F(UtestModelManagerModelManager, test_launch_kernel_cust_aicpu) { + ModelManager mm; + + // cust_aicpu_so_ is empty. + EXPECT_EQ(mm.LaunchKernelCustAicpuSo("empty_cust_aicpu"), SUCCESS); + + // deleteCustOp after Launch will deleted. + uintptr_t resource_id = 1; // for rtCtxGetCurrent stub + std::vector kernel_bin(256); + auto &cust_resource_001 = mm.cust_aicpu_so_[resource_id]; + auto tbe_kernel = std::shared_ptr(new OpKernelBin("deleteCustOp", std::move(kernel_bin))); + auto &cust_opkernel_001 = cust_resource_001["deleteCustOp"] = tbe_kernel; + + EXPECT_FALSE(mm.cust_aicpu_so_.empty()); + EXPECT_EQ(mm.LaunchKernelCustAicpuSo("deleteCustOp"), SUCCESS); + EXPECT_TRUE(mm.cust_aicpu_so_.empty()); +} + +shared_ptr listerner(new DModelListener()); +TEST_F(UtestModelManagerModelManager, test_load_model_online) { + ModelManager mm; + uint32_t model_id = 1; + ComputeGraphPtr graph = std::make_shared("test"); + GeRootModelPtr ge_root_model = make_shared(graph); + auto &profiling_manager = ge::ProfilingManager::Instance(); + profiling_manager.SetSubscribeInfo(0, model_id, true); + Status ret = mm.LoadModelOnline(model_id, ge_root_model, listerner); + profiling_manager.CleanSubscribeInfo(); +} + +TEST_F(UtestModelManagerModelManager, command_profiling) { + ModelManager manager; + uint32_t model_id = 1; + Command cmd; + auto model = std::make_shared(1, listerner); + model->SetId(model_id); + cmd.cmd_params.push_back("modelId"); + cmd.cmd_params.push_back(to_string(model_id)); + auto &profiling_manager = ge::ProfilingManager::Instance(); + profiling_manager.SetSubscribeInfo(0, model_id, true); + Status ret = manager.HandleProfModelUnsubscribeCommand(cmd); + profiling_manager.CleanSubscribeInfo(); +} } // namespace ge diff --git a/tests/ut/ge/graph/load/new_model_manager_data_inputer_unittest.cc b/tests/ut/ge/graph/load/new_model_manager_data_inputer_unittest.cc deleted file mode 100644 index 43c2ad15..00000000 --- a/tests/ut/ge/graph/load/new_model_manager_data_inputer_unittest.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -#include - -#include "graph/load/model_manager/data_inputer.h" - -#include "common/debug/log.h" -#include "common/debug/memory_dumper.h" -#include "common/types.h" -#include "new_op_test_utils.h" - -using namespace std; -using namespace testing; - -namespace ge { - -class UtestModelManagerDataInputer : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -/// InputDataWrapper -/// constructor -/// GetInput -TEST_F(UtestModelManagerDataInputer, inputdatawrapper_construct) { - InputDataWrapper *input_data_wrapper = new InputDataWrapper(); - - input_data_wrapper->GetInput(); - - delete input_data_wrapper; -} - -/// InputDataWrapper -/// Init func with correct input -TEST_F(UtestModelManagerDataInputer, success_inputdatawrapper_init) { - InputDataWrapper *input_data_wrapper = new InputDataWrapper(); - ge::InputData input_data; - ge::OutputData output_data; - Status ret = input_data_wrapper->Init(input_data, output_data); - - EXPECT_EQ(ret, SUCCESS); - - delete input_data_wrapper; - input_data_wrapper = NULL; -} - -} // namespace ge diff --git a/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc b/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc deleted file mode 100644 index 38a250ad..00000000 --- a/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc +++ /dev/null @@ -1,1433 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "common/debug/log.h" -#include "common/debug/memory_dumper.h" -#include "common/types.h" - -#define private public -#define protected public -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" -#include "graph/model_serialize.h" -#include "graph/load/model_manager/davinci_model.h" -#include "common/properties_manager.h" -#include "common/op/ge_op_utils.h" -#include -#include "runtime/dev.h" -#include "runtime/kernel.h" -#include "cce/fwk_adpt_struct.h" -#include "graph/load/model_manager/task_info/task_info_factory.h" -#include "graph/load/model_manager/task_info/task_info.h" -#include "graph/load/model_manager/task_info/stream_active_task_info.h" -#include "graph/load/model_manager/task_info/stream_switch_task_info.h" -#include "graph/load/model_manager/task_info/profiler_trace_task_info.h" -#include "graph/load/model_manager/task_info/memcpy_async_task_info.h" -#include "graph/load/model_manager/task_info/label_set_task_info.h" -#include "graph/load/model_manager/task_info/kernel_ex_task_info.h" -#include "graph/load/model_manager/task_info/kernel_task_info.h" -#include "graph/load/model_manager/task_info/hccl_task_info.h" -#include "graph/load/model_manager/task_info/fusion_start_task_info.h" -#include "graph/load/model_manager/task_info/fusion_stop_task_info.h" -#include "graph/load/model_manager/task_info/event_record_task_info.h" -#include "graph/load/model_manager/task_info/event_wait_task_info.h" -#include "graph/manager/graph_var_manager.h" -#include "graph/load/model_manager/model_manager.h" -#undef private -#undef protected - -#include "new_op_test_utils.h" -#include "graph/debug/ge_attr_define.h" -using namespace std; -using namespace testing; -using domi::EventExDef; -using domi::KernelContext; -using domi::KernelDef; -using domi::LogTimeStampDef; -using domi::ModelTaskDef; -using domi::StreamActiveDef; -using domi::TaskDef; - -namespace ge { -class UtestModelManagerDavinciModel : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -class DModelListener : public ge::ModelListener { - public: - DModelListener(){}; - uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { - GELOGI("In Call back. OnComputeDone"); - return 0; - } -}; - -shared_ptr g_label_call_back(new DModelListener()); - -static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { - auto op_desc = std::make_shared(name, type); - op_desc->SetStreamId(0); - op_desc->SetId(0); - - ge::AttrUtils::SetFloat(op_desc, ge::ATTR_NAME_ALPHA, 0); - ge::AttrUtils::SetFloat(op_desc, ge::ATTR_NAME_BETA, 0); - - op_desc->SetWorkspace({}); - ; - op_desc->SetWorkspaceBytes({}); - op_desc->SetInputOffset({}); - op_desc->SetOutputOffset({}); - - ge::AttrUtils::SetListStr(op_desc, ge::ATTR_NAME_WEIGHT_NAME, {}); - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_MODE, 0); - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_PAD_MODE, 0); - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_DATA_MODE, 0); - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_CEIL_MODE, 0); - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_NAN_OPT, 0); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_WINDOW, {}); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_PAD, {}); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_STRIDE, {}); - ge::AttrUtils::SetListInt(op_desc, ge::ATTR_NAME_ACTIVE_STREAM_LIST, {1, 1}); - ge::AttrUtils::SetInt(op_desc, ge::ATTR_NAME_STREAM_SWITCH_COND, 0); - ge::AttrUtils::SetInt(op_desc, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, FMK_TYPE_T); - return op_desc; -} - -// tset failed_rt_free_host -TEST_F(UtestModelManagerDavinciModel, failed_rt_free_host) { - DavinciModel model(0, g_label_call_back); - - OutputData output_data; - - auto op_desc = CreateOpDesc("Pooling", "Pooling"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 1, 1})); - ge::TensorUtils::SetSize(in_desc, 16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 1, 1})); - ge::TensorUtils::SetSize(out_desc, 16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_PAD_MODE, cce::CC_PADDING_DIRECTASSIGN); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_PAD, vector({1, 1, 1, 1})); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_WINDOW, vector({1, 1})); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_STRIDE, vector({1, 1})); - - auto compute_graph = make_shared("g"); - auto node = compute_graph->AddNode(op_desc); - - OmeTestOpUtils::InitModel(model); - - model.data_op_list_.push_back(op_desc); - - EXPECT_EQ(ge::INTERNAL_ERROR, model.ReturnResult(1, false, false, &output_data)); -} - -// test modeldef_fail -TEST_F(UtestModelManagerDavinciModel, contruct_modeldef_createfail) { - DavinciModel model(0, g_label_call_back); - - OmeTestOpUtils::InitModel(model); - - auto op_desc = CreateOpDesc("Pooling", "Pooling"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 1, 1})); - ge::TensorUtils::SetSize(in_desc, 16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 1, 1})); - ge::TensorUtils::SetSize(out_desc, 16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_PAD_MODE, cce::CC_PADDING_DIRECTASSIGN); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_PAD, vector({1, 1, 1, 1})); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_WINDOW, vector({1, 1})); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_STRIDE, vector({1, 1})); - - model.GetEventList(); -} - -// test CopyInputDataToModel -TEST_F(UtestModelManagerDavinciModel, copy_input_data_to_model_fail) { - DavinciModel model(0, g_label_call_back); - - ge::InputData input_data; - ge::DataBuffer data_buffer; - data_buffer.data = new char[16]; - data_buffer.length = 16; - input_data.index = 0; - input_data.model_id = 1; - input_data.blobs.push_back(data_buffer); - - model.op_list_.clear(); - - delete[](char *) data_buffer.data; -} - -// test StreamNum -TEST_F(UtestModelManagerDavinciModel, streamnum_success) { - DavinciModel *model = new DavinciModel(0, g_label_call_back); - - OmeTestOpUtils::InitModel(*model); - - EXPECT_EQ(0, model->StreamNum()); - EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); - - EXPECT_EQ(ge::SUCCESS, model->ModelRunStop()); - - delete model; -} - -// test EventNum -TEST_F(UtestModelManagerDavinciModel, eventnum_success) { - DavinciModel *model = new DavinciModel(0, g_label_call_back); - - OmeTestOpUtils::InitModel(*model); - - EXPECT_EQ(0, model->EventNum()); - EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); - - EXPECT_EQ(ge::SUCCESS, model->ModelRunStop()); - - delete model; -} - -TEST_F(UtestModelManagerDavinciModel, handlelist_success) { - DavinciModel *model = new DavinciModel(0, g_label_call_back); - - OmeTestOpUtils::InitModel(*model); - - EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); - - EXPECT_EQ(ge::SUCCESS, model->ModelRunStop()); - - delete model; -} - -// test GetEventList -TEST_F(UtestModelManagerDavinciModel, eventlist_success) { - DavinciModel *model = new DavinciModel(0, g_label_call_back); - - OmeTestOpUtils::InitModel(*model); - - EXPECT_EQ(true, model->GetEventList().empty()); - EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); - - EXPECT_EQ(ge::SUCCESS, model->ModelRunStop()); - - delete model; -} - -// test Shrink -TEST_F(UtestModelManagerDavinciModel, shrink_success) { - DavinciModel model(0, g_label_call_back); - OpDescPtr op_desc_ptr = make_shared("Cast", "Cast"); - void *addr = nullptr; - rtMalloc(&addr, 128, RT_MEMORY_HBM); - model.saved_task_addrs_.emplace(op_desc_ptr, addr); - model.Shrink(); - EXPECT_EQ(model.saved_task_addrs_.isEmpty(), true); -} - -// test rtMalloc -TEST_F(UtestModelManagerDavinciModel, failed_reset_device) { - DavinciModel model(0, g_label_call_back); - ge::OutputData output_data; - ge::DataBuffer buf_data; - rtMalloc(&buf_data.data, 128, RT_MEMORY_HBM); - buf_data.length = 128; - output_data.blobs.push_back(buf_data); - EXPECT_EQ(ge::INTERNAL_ERROR, model.ReturnResult(1, true, false, &output_data)); - rtFree(buf_data.data); -} - -// test priority -TEST_F(UtestModelManagerDavinciModel, init_not_support_priority) { - int32_t priority = 8; - DavinciModel model(priority, g_label_call_back); -} - -// test GetInputOutputDescInfo -TEST_F(UtestModelManagerDavinciModel, success_GetInputOutputDescInfo_without_netoutput) { - DavinciModel model(0, g_label_call_back); - - auto op_desc = CreateOpDesc("Data", "Data"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - op_desc->SetStreamId(0); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - - op_desc->SetSrcName({"Pooling1", "Pooling0"}); - op_desc->SetSrcIndex({0, 1}); - - auto compute_graph = make_shared("g"); - auto node = compute_graph->AddNode(op_desc); - - model.data_op_list_.push_back(op_desc); - model.output_size_list_.push_back(32); - - model.op_list_[0] = op_desc; - - model.output_op_list_.push_back(op_desc); - - vector input_shapes; - vector output_shapes; - EXPECT_EQ(ge::SUCCESS, model.GetInputOutputDescInfo(input_shapes, output_shapes)); -} - -TEST_F(UtestModelManagerDavinciModel, CopyTensorFromSrcVarNode_input_is_nullptr) { - NodePtr src_node = nullptr; - NodePtr dst_node = nullptr; - DavinciModel model(0, g_label_call_back); - Status ret = model.CopyTensorFromSrcVarNode(src_node, dst_node); - EXPECT_EQ(FAILED, ret); -} - -TEST_F(UtestModelManagerDavinciModel, CopyTensorFromSrcVarNode_success) { - ge::ComputeGraphPtr graph = std::make_shared("default"); - OpDescPtr op_desc_ptr = make_shared("Cast", "Cast"); - GeTensorDesc dims_tensor_desc(GeShape({1, 1, 1, 1}), FORMAT_NCHW, DT_FLOAT16); - GeTensorDesc dims_tensor_desc_in(GeShape({1, 1, 1, 1}), FORMAT_NCHW, DT_FLOAT); - op_desc_ptr->AddInputDesc(dims_tensor_desc_in); - op_desc_ptr->AddOutputDesc(dims_tensor_desc); - - NodePtr src_node = graph->AddNode(op_desc_ptr); - NodePtr dst_node = graph->AddNode(op_desc_ptr); - DavinciModel model(0, g_label_call_back); - Status ret = model.CopyTensorFromSrcVarNode(src_node, dst_node); -} - -TEST_F(UtestModelManagerDavinciModel, CopyVarData_graph_is_nullptr) { - ge::ComputeGraphPtr graph = nullptr; - DavinciModel model(0, g_label_call_back); - Status ret = model.CopyVarData(graph); - EXPECT_EQ(FAILED, ret); -} - -TEST_F(UtestModelManagerDavinciModel, copy_var_data_success) { - ge::ComputeGraphPtr graph = std::make_shared("default"); - OpDescPtr op_desc_ptr = make_shared("Variable", "Variable"); - GeTensorDesc dims_tensor_desc(GeShape({1, 1, 1, 1}), FORMAT_NCHW, DT_FLOAT16); - GeTensorDesc dims_tensor_desc_in(GeShape({1, 1, 1, 1}), FORMAT_NCHW, DT_FLOAT16); - op_desc_ptr->AddInputDesc(dims_tensor_desc_in); - op_desc_ptr->AddOutputDesc(dims_tensor_desc); - - NodePtr src_node = graph->AddNode(op_desc_ptr); - (void)ge::AttrUtils::SetStr(src_node->GetOpDesc(), "_copy_from_var_node", "abc"); - (void)ge::AttrUtils::SetBool(src_node->GetOpDesc(), "_copy_value", false); - - DavinciModel model(0, g_label_call_back); - Status ret = model.CopyVarData(graph); -} - -TEST_F(UtestModelManagerDavinciModel, get_input_output_desc_info_without_data_op_list) { - DavinciModel model(0, g_label_call_back); - vector input_list; - vector output_list; - Status ret = model.GetInputOutputDescInfo(input_list, output_list); - EXPECT_EQ(SUCCESS, ret); -} - -// test GetInputOutputDescInfo -TEST_F(UtestModelManagerDavinciModel, success_get_input_output_descInfo_with_net_output) { - DavinciModel model(0, g_label_call_back); - - auto op_desc = CreateOpDesc("Data", "Data"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - op_desc->SetStreamId(0); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - op_desc->SetSrcName({"Pooling1", "Pooling0"}); - op_desc->SetSrcIndex({0, 1}); - - auto compute_graph = make_shared("g"); - auto data_node = compute_graph->AddNode(op_desc); - - model.data_op_list_.push_back(op_desc); - - op_desc->SetType("NetOutput"); - - auto no_node = compute_graph->AddNode(op_desc); - - model.op_list_[0] = op_desc; - - model.output_op_list_.push_back(op_desc); - model.output_size_list_.push_back(32); - - vector input_shapes; - vector output_shapes; - EXPECT_EQ(ge::SUCCESS, model.GetInputOutputDescInfo(input_shapes, output_shapes)); -} - -TEST_F(UtestModelManagerDavinciModel, success_get_input_output_desc_info_for_zero_copy_with_net_output) { - DavinciModel model(0, g_label_call_back); - - auto op_desc = CreateOpDesc("Data", "Data"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - op_desc->SetStreamId(0); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - - op_desc->SetSrcName({"Pooling1", "Pooling0"}); - op_desc->SetSrcIndex({0, 1}); - - auto compute_graph = make_shared("g"); - auto data_node = compute_graph->AddNode(op_desc); - - model.data_op_list_.push_back(op_desc); - - op_desc->SetType("NetOutput"); - - auto net_out_node = compute_graph->AddNode(op_desc); - model.op_list_[0] = op_desc; - - model.output_op_list_.push_back(op_desc); - model.output_size_list_.push_back(32); - model.output_memory_size_list_.push_back(64); - - vector input_shapes; - vector output_shapes; - EXPECT_EQ(ge::SUCCESS, model.GetInputOutputDescInfoForZeroCopy(input_shapes, output_shapes)); -} - -TEST_F(UtestModelManagerDavinciModel, success_get_input_output_desc_info_dim_size_not4) { - DavinciModel model(0, g_label_call_back); - - auto op_desc = CreateOpDesc("Data", "Data"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - op_desc->SetStreamId(0); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - - op_desc->SetSrcName({"Pooling1", "Pooling0"}); - op_desc->SetSrcIndex({0, 1}); - - auto compute_graph = make_shared("g"); - auto data_node = compute_graph->AddNode(op_desc); - - model.data_op_list_.push_back(op_desc); - - op_desc->SetType("NetOutput"); - - auto net_out_node = compute_graph->AddNode(op_desc); - model.op_list_[0] = op_desc; - - model.output_op_list_.push_back(op_desc); - model.output_size_list_.push_back(32); - - vector input_shapes; - vector output_shapes; - EXPECT_EQ(ge::SUCCESS, model.GetInputOutputDescInfo(input_shapes, output_shapes)); -} - -// test GetLabelList -TEST_F(UtestModelManagerDavinciModel, get_label_list_success) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - vector label_list; - model.label_list_ = label_list; - EXPECT_EQ(label_list, model.GetLabelList()); -} - -// test GetInputListSize -TEST_F(UtestModelManagerDavinciModel, get_label_list_size_success) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - vector data_op_list; - data_op_list.push_back(std::make_shared()); - model.data_op_list_ = data_op_list; -} - -// test GetFlowctrlOpList -TEST_F(UtestModelManagerDavinciModel, get_flow_ctrl_op_list_success) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - std::map flowctrl_op_index_internal_map; - flowctrl_op_index_internal_map.insert(pair(1, 1)); - model.flowctrl_op_index_internal_map_ = flowctrl_op_index_internal_map; -} - -// test SetFlowctrlOpList -TEST_F(UtestModelManagerDavinciModel, get_flow_ctrl_index_success) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - EXPECT_EQ(0, model.GetFlowctrlIndex(0)); - EXPECT_EQ(1, model.GetFlowctrlIndex(0)); - EXPECT_EQ(0, model.GetFlowctrlIndex(1)); - EXPECT_EQ(1, model.GetFlowctrlIndex(1)); - EXPECT_EQ(2, model.GetFlowctrlIndex(0)); -} - -// test GetRegisterStub -TEST_F(UtestModelManagerDavinciModel, success_get_register_stub) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - std::string binfile = "tvmbin"; - string ret = model.GetRegisterStub(binfile); - EXPECT_EQ("tvmbin", ret); - model.tvm_bin_kernel_.insert("tvmbin"); - ret = model.GetRegisterStub(binfile); - EXPECT_EQ("tvmbin", ret); -} - -// test InitTbeHandle -TEST_F(UtestModelManagerDavinciModel, success_init_tbe_handle) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - std::shared_ptr op_desc = std::make_shared(); - Status ret = model.InitTbeHandle(op_desc); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); -} - -// test InitTVMTask failed -TEST_F(UtestModelManagerDavinciModel, init_tvm_task_failed1) { - DavinciModel model(0, g_label_call_back); - uint16_t offset = 0; - TaskDef *task_def = new TaskDef(); - KernelDef *kernel_def = task_def->mutable_kernel(); - map op_list; - model.op_list_ = op_list; - - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - Status ret = kernel_task_info->InitTVMTask(&model, offset, kernel_def[0]); - EXPECT_EQ(INTERNAL_ERROR, ret); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -TEST_F(UtestModelManagerDavinciModel, kernel_taskInfo_init_cce_task_failed1) { - DavinciModel model(0, g_label_call_back); - - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - KernelDef *kernel_def = task_def->mutable_kernel(); - Status ret = kernel_task_info->InitCceTask(&model, kernel_def[0]); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -// test SetContext success -TEST_F(UtestModelManagerDavinciModel, success_kernel_taskInfo_init_set_context) { - DavinciModel model(0, g_label_call_back); - - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - KernelDef *kernel_def = task_def->mutable_kernel(); - KernelContext *context = kernel_def->mutable_context(); - context->set_op_id(1); - context->set_kernel_func_id(1); - context->set_is_flowtable(true); - context->set_args_count(1); - context->set_args_offset("args111111", 10); - - Status ret = kernel_task_info->SetContext(kernel_def[0]); - EXPECT_EQ(ge::SUCCESS, ret); - - ret = kernel_task_info->Release(); - EXPECT_EQ(ge::SUCCESS, ret); - kernel_def->clear_context(); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -// test SetContext failed -TEST_F(UtestModelManagerDavinciModel, kernel_taskInfo_init_set_context_failed1) { - DavinciModel model(0, g_label_call_back); - - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - KernelDef *kernel_def = task_def->mutable_kernel(); - KernelContext *context = kernel_def->mutable_context(); - context->set_op_id(1); - context->set_kernel_func_id(1); - context->set_is_flowtable(true); - context->set_args_count(0); - Status ret = kernel_task_info->SetContext(kernel_def[0]); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); - - kernel_def->clear_context(); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -TEST_F(UtestModelManagerDavinciModel, kernel_taskInfo_init_set_context_failed2) { - DavinciModel model(0, g_label_call_back); - - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - KernelDef *kernel_def = task_def->mutable_kernel(); - KernelContext *context = kernel_def->mutable_context(); - context->set_op_id(1); - context->set_kernel_func_id(1); - context->set_is_flowtable(true); - context->set_args_count(5); - context->set_args_offset("\0\0"); // args_offset = 0 - - Status ret = kernel_task_info->SetContext(kernel_def[0]); - EXPECT_EQ(ge::PARAM_INVALID, ret); - - kernel_def->clear_context(); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -// test success DistributeDumpTask -TEST_F(UtestModelManagerDavinciModel, success_distribute_dump_task) { - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - KernelDef *kernel_def = task_def->mutable_kernel(); - - kernel_def->set_stub_func("kerneltaskinfo"); - kernel_def->set_block_dim(10); - kernel_def->set_args("args111111", 10); - kernel_def->set_args_size(10); - rtSmDesc_t l2CtrlInfo; - l2CtrlInfo.data[0].L2_mirror_addr = 1024; - kernel_def->set_sm_desc((void *)&l2CtrlInfo, sizeof(rtSmDesc_t)); - - // for SetStream - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - std::vector stream_list; - stream_list.push_back(stream); - Status ret = kernel_task_info->SetStream(0, stream_list); - EXPECT_EQ(SUCCESS, ret); - - ret = kernel_task_info->Release(); - EXPECT_EQ(SUCCESS, ret); - rtStreamDestroy(stream); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -// test success GetTaskID -TEST_F(UtestModelManagerDavinciModel, success_get_task_id) { - ModelTaskDef *model_task_def = new ModelTaskDef(); - TaskDef *task = model_task_def->add_task(); - task->set_type(RT_MODEL_TASK_KERNEL); - TaskInfoPtr task_info = TaskInfoFactory::Instance().Create(static_cast(task->type())); - - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - uint32_t ret = task_info->GetTaskID(); - EXPECT_EQ(0, ret); - ret = kernel_task_info->GetTaskID(); - EXPECT_EQ(0, ret); - HcclTaskInfo *hccl_task_info = new HcclTaskInfo(); - ret = hccl_task_info->GetTaskID(); - EXPECT_EQ(0, ret); - - delete hccl_task_info; - delete kernel_task_info; - delete model_task_def; -} - -// test StoreInputOutputTensor success -TEST_F(UtestModelManagerDavinciModel, success_store_input_output_tensor) { - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - - std::vector input_data_addrs; - std::vector output_data_addrs; - std::vector<::tagCcAICPUTensor> input_descs; - std::vector<::tagCcAICPUTensor> output_descs; - - int test = 1; - int *addr = &test; - void *input; - void *output; - input = addr; - output = addr; - input_data_addrs.push_back(&input); - output_data_addrs.push_back(output); - - tagCcAICPUTensor input_desc; - tagCcAICPUTensor output_desc; - input_descs.push_back(input_desc); - output_descs.push_back(output_desc); - - Status ret = kernel_task_info->StoreInputOutputTensor(input_data_addrs, output_data_addrs, input_descs, output_descs); - EXPECT_EQ(SUCCESS, ret); - ret = kernel_task_info->Release(); - EXPECT_EQ(SUCCESS, ret); - delete kernel_task_info; - delete task_def; -} - -// test init EventRecordTaskInfo -TEST_F(UtestModelManagerDavinciModel, success_event_record_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - EventRecordTaskInfo *eventRecordTaskInfo1 = new EventRecordTaskInfo(); - Status ret1 = eventRecordTaskInfo1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - - delete eventRecordTaskInfo1; - delete task_def1; - delete model1; - DavinciModel model(0, g_label_call_back); - - ModelTaskDef *model_task_info = new ModelTaskDef(); - TaskDef *task = model_task_info->add_task(); - task->set_type(RT_MODEL_TASK_EVENT_RECORD); - TaskInfoPtr task_info = TaskInfoFactory::Instance().Create(static_cast(task->type())); - - task->stream_id_ = 0; - rtStream_t rt_stream; - rtStreamCreate(&rt_stream, 1); - vector stream_list; - stream_list.push_back(rt_stream); - model.stream_list_ = stream_list; - - task->set_event_id(1); - model.runtime_param_.event_num = 1; - Status ret = task_info->Init(task[0], &model); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); - - model.runtime_param_.event_num = 2; - rtEvent_t event1; - rtEvent_t event2; - rtEventCreate(&event1); - rtEventCreate(&event2); - model.event_list_.push_back(event1); - model.event_list_.push_back(event2); - - EventExDef *event_ex_def = task->mutable_event_ex(); - event_ex_def->set_event_type(1); - - ret = task_info->Init(task[0], &model); - EXPECT_EQ(SUCCESS, ret); - - task->clear_event_ex(); - task_info->Release(); - delete model_task_info; -} - -// test init EventWaitTaskInfo -TEST_F(UtestModelManagerDavinciModel, success_event_wait_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - EventWaitTaskInfo *event_wait_task_info1 = new EventWaitTaskInfo(); - Status ret1 = event_wait_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - - delete event_wait_task_info1; - delete task_def1; - delete model1; - DavinciModel model(0, g_label_call_back); - - ModelTaskDef *model_task_info = new ModelTaskDef(); - TaskDef *task = model_task_info->add_task(); - task->set_type(RT_MODEL_TASK_EVENT_WAIT); - TaskInfoPtr task_info = TaskInfoFactory::Instance().Create(static_cast(task->type())); - - task->stream_id_ = 0; - rtStream_t rt_stream; - rtStreamCreate(&rt_stream, 1); - vector stream_list; - stream_list.push_back(rt_stream); - model.stream_list_ = stream_list; - - task->set_event_id(1); - model.runtime_param_.event_num = 1; - Status ret = task_info->Init(task[0], &model); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); - - model.runtime_param_.event_num = 2; - rtEvent_t event1; - rtEvent_t event2; - rtEventCreate(&event1); - rtEventCreate(&event2); - model.event_list_.push_back(event1); - model.event_list_.push_back(event2); - - EventExDef *event_ex_def = task->mutable_event_ex(); - event_ex_def->set_event_type(1); - - ret = task_info->Init(task[0], &model); - EXPECT_EQ(SUCCESS, ret); - - task->clear_event_ex(); - task_info->Release(); - delete model_task_info; -} - -// test fusion_start_task Init -TEST_F(UtestModelManagerDavinciModel, success_fusion_start_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - FusionStartTaskInfo *fusion_start_task_info1 = new FusionStartTaskInfo(); - Status ret1 = fusion_start_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - - delete fusion_start_task_info1; - delete task_def1; - delete model1; - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - FusionStartTaskInfo *fusion_start_task_info = new FusionStartTaskInfo(); - task_def->set_stream_id(0); - rtStream_t stream; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - - Status ret = fusion_start_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - delete fusion_start_task_info; - delete task_def; -} - -// test fusion_end_task Init -TEST_F(UtestModelManagerDavinciModel, success_fusion_end_task_rinit) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - FusionStopTaskInfo *fusion_stop_task_info1 = new FusionStopTaskInfo(); - Status ret1 = fusion_stop_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - - delete fusion_stop_task_info1; - delete task_def1; - delete model1; - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - FusionStopTaskInfo *fusion_stop_task_info = new FusionStopTaskInfo(); - task_def->set_stream_id(0); - rtStream_t stream; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - - Status ret = fusion_stop_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - delete fusion_stop_task_info; - delete task_def; -} - -// test kernel_ex_task_Release -TEST_F(UtestModelManagerDavinciModel, success_kernel_ex_task_release) { - KernelExTaskInfo *kernel_ex_task_info = new KernelExTaskInfo(); - Status ret = kernel_ex_task_info->Release(); - EXPECT_EQ(SUCCESS, ret); - - delete kernel_ex_task_info; -} - -// test hccl_Distribute -TEST_F(UtestModelManagerDavinciModel, success_Distribute7) { - DavinciModel model(0, g_label_call_back); - - ModelTaskDef *model_task_def = new ModelTaskDef(); - TaskDef *task7 = model_task_def->add_task(); - task7->set_type(RT_MODEL_TASK_HCCL); - TaskInfoPtr task_info7 = TaskInfoFactory::Instance().Create(static_cast(task7->type())); - Status ret = task_info7->Init(task7[0], &model); - EXPECT_EQ(FAILED, ret); - - std::vector task_list; - task_list.push_back(task_info7); - model.task_list_ = task_list; - - task_info7->Release(); - delete model_task_def; -} - -// test hccl_GetPrivateDefByTaskDef -TEST_F(UtestModelManagerDavinciModel, success_hccl_get_private_def_by_task_def) { - DavinciModel model(0, g_label_call_back); - - ModelTaskDef *model_task_def = new ModelTaskDef(); - TaskDef *task7 = model_task_def->add_task(); - task7->set_type(RT_MODEL_TASK_HCCL); - // for SetStream - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - // for GetPrivateDefByTaskDef - task7->set_ops_kernel_store_ptr(10); - std::string value = "hccl_task"; - task7->set_private_def(value); - - TaskInfoPtr task_info7 = TaskInfoFactory::Instance().Create(static_cast(task7->type())); - // for Distribute - Status ret = task_info7->Init(task7[0], &model); - EXPECT_EQ(ge::PARAM_INVALID, ret); - - task_info7->Release(); - delete model_task_def; -} - -// test hccl_task_TransToGETaskInfo -TEST_F(UtestModelManagerDavinciModel, success_hccl_trans_to_ge_task_info) { - DavinciModel model(0, g_label_call_back); - - ModelTaskDef *model_task_def = new ModelTaskDef(); - TaskDef *task7 = model_task_def->add_task(); - // for type - task7->set_type(RT_MODEL_TASK_HCCL); - TaskInfoPtr task_info7 = TaskInfoFactory::Instance().Create(static_cast(task7->type())); - - GETaskInfo ge_task; - HcclTaskInfo *hccl_task_info = new HcclTaskInfo(); - hccl_task_info->TransToGETaskInfo(ge_task); - - delete hccl_task_info; - delete model_task_def; -} - -// test stream_active_task Init -TEST_F(UtestModelManagerDavinciModel, success_stream_active_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - StreamActiveTaskInfo *stream_active_task_info1 = new StreamActiveTaskInfo(); - Status ret1 = stream_active_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - delete stream_active_task_info1; - delete task_def1; - delete model1; - - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - task_def->set_stream_id(0); - rtStream_t stream1, stream2; - rtStreamCreate(&stream1, 0); - rtStreamCreate(&stream2, 0); - model.stream_list_.push_back(stream1); - - StreamActiveTaskInfo *stream_active_task_info = new StreamActiveTaskInfo(); - - StreamActiveDef *stream_active_def = task_def->mutable_stream_active(); - stream_active_def->set_op_index(0); - stream_active_def->set_active_stream_id(0); - - std::map flowctrl; - flowctrl.insert(pair(1, 1)); - model.flowctrl_op_index_internal_map_ = flowctrl; - - auto opDef = CreateOpDesc("", ""); - model.op_list_[0] = opDef; - - Status ret = stream_active_task_info->Init(task_def[0], &model); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); // line 51 - - model.stream_list_.push_back(stream2); - ret = stream_active_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - - task_def->clear_stream_active(); - delete stream_active_task_info; - delete task_def; -} - -// test label_set_task Init -TEST_F(UtestModelManagerDavinciModel, success_label_set_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - LabelSetTaskInfo *label_set_task_info1 = new LabelSetTaskInfo(); - Status ret1 = label_set_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - delete label_set_task_info1; - delete task_def1; - delete model1; - - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - LabelSetTaskInfo *label_set_task_info = new LabelSetTaskInfo(); - task_def->set_stream_id(0); - rtStream_t stream; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - - task_def->set_label_id(1); - model.runtime_param_.batch_num = 0; - Status ret = label_set_task_info->Init(task_def[0], &model); - EXPECT_EQ(PARAM_INVALID, ret); - - task_def->clear_label_id(); - task_def->set_label_id(0); - model.runtime_param_.batch_num = 1; - rtLabel_t label; - rtLabelCreate(&label); - model.label_list_.push_back(label); - - ret = label_set_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - delete label_set_task_info; - delete task_def; -} - -// test label_goto_task init -TEST_F(UtestModelManagerDavinciModel, success_label_goto_task_init) { - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - LabelGotoTaskInfo *label_goto_task_info = new LabelGotoTaskInfo(); - task_def->set_stream_id(0); - - rtStream_t stream; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - - rtLabel_t label; - rtLabelCreate(&label); - model.label_list_.push_back(label); - - Status ret = label_goto_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - - delete label_goto_task_info; - delete task_def; -} - -// test profiler_trace_task init -TEST_F(UtestModelManagerDavinciModel, success_profiler_trace_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - ProfilerTraceTaskInfo *profiler_trace_task_info1 = new ProfilerTraceTaskInfo(); - Status ret1 = profiler_trace_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - - delete profiler_trace_task_info1; - delete task_def1; - delete model1; - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - task_def->set_stream_id(0); - rtStream_t stream; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - LogTimeStampDef *logTimeStampDef = task_def->mutable_log_timestamp(); - logTimeStampDef->set_logid(1); - logTimeStampDef->set_notify(1); - logTimeStampDef->set_flat(1); - ProfilerTraceTaskInfo *profiler_trace_task_info = new ProfilerTraceTaskInfo(); - Status ret = profiler_trace_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - - task_def->clear_log_timestamp(); - delete profiler_trace_task_info; - delete task_def; -} - -TEST_F(UtestModelManagerDavinciModel, profiling_model_success) { - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - - DavinciModel model(0, g_label_call_back); - model.model_id_ = 1; - model.name_ = "test"; - model.version_ = 0x01; - - model.stream_list_.push_back(stream); - - ge::ModelData data; - rtMallocHost(&data.model_data, 128); - data.model_len = 128; - - ModelDef *model_def = new ModelDef(); - auto op_def = CreateOpDesc("", "Data"); - op_def->SetInputOffset({1}); - op_def->SetOutputOffset({100}); - - ge::GeTensorDesc descin(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetSize(descin, 4); - op_def->AddInputDesc(descin); - ge::GeTensorDesc desc_out(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetSize(desc_out, 32); - op_def->AddInputDesc(desc_out); - op_def->SetId(0); - - model.data_op_list_.push_back(op_def); - model.op_list_[0] = op_def; - - auto opdef1 = CreateOpDesc("", "Relu"); - opdef1->SetInputOffset({1}); - opdef1->SetOutputOffset({100}); - - ge::GeTensorDesc desc_in1(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetSize(desc_in1, 4); - opdef1->AddInputDesc(desc_in1); - ge::GeTensorDesc desc_out1(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetSize(desc_out1, 32); - opdef1->AddInputDesc(desc_out1); - op_def->SetId(1); - - model.op_list_[1] = opdef1; - - auto opdef2 = CreateOpDesc("", "Relu"); - opdef2->SetInputOffset({1}); - opdef2->SetOutputOffset({100}); - - ge::GeTensorDesc desc_in2(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetSize(desc_in2, 4); - opdef2->AddInputDesc(desc_in2); - ge::GeTensorDesc desc_out2(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetSize(desc_out2, 32); - opdef2->AddInputDesc(desc_out2); - op_def->SetId(2); - - model.op_list_[2] = opdef2; - - auto opdef3 = CreateOpDesc("", "Relu"); - opdef3->SetInputOffset({1}); - opdef3->SetOutputOffset({100}); - - ge::GeTensorDesc desc_in3(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetSize(desc_in3, 4); - opdef3->AddInputDesc(desc_in3); - ge::GeTensorDesc desc_out3(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetSize(desc_out3, 32); - opdef3->AddInputDesc(desc_out3); - op_def->SetId(3); - - model.op_list_[3] = opdef3; - - auto opdef4 = CreateOpDesc("", "Relu"); - opdef4->SetInputOffset({1}); - opdef4->SetOutputOffset({100}); - - ge::GeTensorDesc desc_in4(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetSize(desc_in4, 4); - opdef4->AddInputDesc(desc_in4); - ge::GeTensorDesc desc_out4(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetSize(desc_out4, 32); - opdef4->AddInputDesc(desc_out4); - op_def->SetId(4); - - model.op_list_[4] = opdef4; - - ge::InputData input_data; - ge::DataBuffer data_buffer; - data_buffer.data = new char[4]; - data_buffer.length = 4; - input_data.index = 0; - input_data.model_id = 1; - input_data.blobs.push_back(data_buffer); - - rtFreeHost(data.model_data); - delete[](char *) data_buffer.data; - delete model_def; -} - -TEST_F(UtestModelManagerDavinciModel, success_output_list_0) { - DavinciModel model(0, g_label_call_back); - - uint32_t version = 0; - uint64_t session_id = 0; - uint32_t device_id = 0; - uint64_t job_id = 0; - Status ret = VarManager::Instance(session_id)->Init(version, session_id, device_id, job_id); - EXPECT_EQ(ret, ge::SUCCESS); - - ret = model.ReturnNoOutput(1); - EXPECT_EQ(ret, ge::SUCCESS); - - VarManagerPool::Instance().Destroy(); -} - -// test dyncbatch_distributeTask_SUCCESS -TEST_F(UtestModelManagerDavinciModel, dyncbatch_distribute_task_success) { - DavinciModel model(0, g_label_call_back); - - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - - rtLabel_t label = nullptr; - rtLabelCreate(&label); - model.label_list_.push_back(label); - rtLabelCreate(&label); - model.label_list_.push_back(label); - rtLabelCreate(&label); - model.label_list_.push_back(label); - - rtLabelDestroy(label); - rtStreamDestroy(stream); -} - -// test GetOutputDescInfo -TEST_F(UtestModelManagerDavinciModel, success_get_output_desc_info_with_netoutput) { - setenv("GE_TRAIN", "1", true); - DavinciModel model(0, g_label_call_back); - - auto op_desc = CreateOpDesc("Data", "Data"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - op_desc->SetStreamId(0); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_FRACTAL_Z, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - - op_desc->SetSrcName({"Pooling1", "Pooling0"}); - op_desc->SetSrcIndex({0, 1}); - - auto compute_graph = make_shared("g"); - - op_desc->SetType("NetOutput"); - - auto net_out_node = compute_graph->AddNode(op_desc); - model.op_list_[0] = op_desc; - - model.output_op_list_.push_back(op_desc); - model.output_size_list_.push_back(32); - model.output_memory_size_list_.push_back(64); - - vector output_shapes; - vector formats; - EXPECT_EQ(ge::SUCCESS, model.GetOutputDescInfo(output_shapes, formats)); - - setenv("GE_TRAIN", "0", true); -} - -TEST_F(UtestModelManagerDavinciModel, device_runtime_success_Run) { - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - - DavinciModel model(0, g_label_call_back); - - model.stream_list_.push_back(stream); - auto model_def = make_shared(); - - auto op_def = CreateOpDesc("", "Data"); - - auto compute_graph = make_shared("g"); - compute_graph->AddNode(op_def); - - model_def->SetGraph(ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph)); - - model.data_op_list_.push_back(op_def); - - model.data_inputer_ = new DataInputer(); - - model.ModelRunStart(); - - OutputData output_data; - ge::InputData input_data; - - ge::DataBuffer data_buffer; - data_buffer.data = new char[16]; - data_buffer.length = 16; - - input_data.index = 0; - input_data.model_id = 1; - input_data.blobs.push_back(data_buffer); - - model.ModelRunStop(); - - delete[](char *) data_buffer.data; -} - -TEST_F(UtestModelManagerDavinciModel, run_failed) { - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - - DavinciModel model(0, g_label_call_back); - - model.stream_list_.push_back(stream); - auto model_def = make_shared(); - - auto op_def = CreateOpDesc("", "Data"); - - auto compute_graph = make_shared("g"); - compute_graph->AddNode(op_def); - - model_def->SetGraph(ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph)); - - model.data_op_list_.push_back(op_def); - - model.data_inputer_ = new DataInputer(); - - model.ModelRunStart(); - - OutputData output_data; - ge::InputData input_data; - - ge::DataBuffer data_buffer; - data_buffer.data = new char[16]; - data_buffer.length = 16; - - input_data.index = 0; - input_data.model_id = 1; - input_data.blobs.push_back(data_buffer); - - model.ModelRunStop(); - delete[](char *) data_buffer.data; -} - -TEST_F(UtestModelManagerDavinciModel, run_failed01) { - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - - DavinciModel model(0, g_label_call_back); - - model.stream_list_.push_back(stream); - auto model_def = make_shared(); - - auto op_def = CreateOpDesc("", "Data"); - - auto compute_graph = make_shared("g"); - compute_graph->AddNode(op_def); - - model_def->SetGraph(ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph)); - - model.data_op_list_.push_back(op_def); - - model.data_inputer_ = nullptr; - model.ModelRunStart(); - - model.ModelRunStop(); -} - -TEST_F(UtestModelManagerDavinciModel, init_tbe_handle_fe_registered) { - DavinciModel::tvm_bin_kernel_.clear(); - DavinciModel model(0, g_label_call_back); - OpDescPtr op_desc = CreateOpDesc("MatMul", "MatMul"); - - std::vector kernelBin; - TBEKernelPtr tbe_kernel = std::make_shared("name/MatMul", std::move(kernelBin)); - op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); - - std::string kernel_name("kernel/MatMul"); - AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); - - EXPECT_EQ(model.InitTbeHandle(op_desc), SUCCESS); - EXPECT_EQ(model.InitTbeHandle(op_desc), SUCCESS); - - EXPECT_EQ(model.used_tbe_handle_map_.size(), 0); - DavinciModel::tvm_bin_kernel_.clear(); -} - -TEST_F(UtestModelManagerDavinciModel, init_tbe_handle_ge_registered) { - DavinciModel::tvm_bin_kernel_.clear(); - DavinciModel model(0, g_label_call_back); - OpDescPtr op_desc = CreateOpDesc("MatMul", "MatMul"); - - std::vector kernelBin; - TBEKernelPtr tbe_kernel = std::make_shared("name/MatMul", std::move(kernelBin)); - op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); - - std::string kernel_name("kernel/MatMul"); - AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); - - string session_graph_id; - AttrUtils::GetStr(op_desc, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); - const char *bin_file_key = DavinciModel::GetRegisterStub(op_desc->GetName(), session_graph_id); - model.used_tbe_handle_map_[bin_file_key] = 1; // test first register. - - EXPECT_EQ(model.InitTbeHandle(op_desc), SUCCESS); - EXPECT_EQ(model.InitTbeHandle(op_desc), SUCCESS); - - EXPECT_EQ(model.used_tbe_handle_map_.size(), 1); - - auto it = model.used_tbe_handle_map_.find(bin_file_key); - EXPECT_NE(it, model.used_tbe_handle_map_.end()); - EXPECT_EQ(it->second, 3); - DavinciModel::tvm_bin_kernel_.clear(); -} -} // namespace ge diff --git a/tests/ut/ge/graph/load/new_model_manager_event_manager_unittest.cc b/tests/ut/ge/graph/load/new_model_manager_event_manager_unittest.cc deleted file mode 100644 index ee708501..00000000 --- a/tests/ut/ge/graph/load/new_model_manager_event_manager_unittest.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "common/debug/log.h" -#include "common/debug/memory_dumper.h" -#include "common/types.h" - -#define private public -#include "graph/manager/model_manager/event_manager.h" -#undef private - -using namespace ge; -using namespace std; -using namespace testing; - -class UtestModelManagerEventManager : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -// test repeat initialize -TEST_F(UtestModelManagerEventManager, repeat_initialization) { - ge::EventManager event_manager; - size_t event_num = 1; - event_manager.Init(event_num); - Status ret = event_manager.Init(event_num); - EXPECT_EQ(ret, SUCCESS); -} - -TEST_F(UtestModelManagerEventManager, call_event_record_normal) { - ge::EventManager event_manager; - size_t event_num = 1; - Status ret = event_manager.Init(event_num); - EXPECT_EQ(SUCCESS, ret); - EXPECT_NE(event_manager.event_list_.size(), 0); - - ret = event_manager.EventRecord(0, NULL); - EXPECT_EQ(SUCCESS, ret); -} - -// test load EventRecore when uninited -TEST_F(UtestModelManagerEventManager, call_event_record_while_uninited) { - ge::EventManager event_manager; - Status ret = event_manager.EventRecord(1, NULL); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); -} - -// test with invalid param when load EventRecord -TEST_F(UtestModelManagerEventManager, call_event_record_with_invalid_param) { - ge::EventManager event_manager; - Status ret = event_manager.Init(1); - EXPECT_EQ(SUCCESS, ret); - ret = event_manager.EventRecord(1, NULL); - EXPECT_EQ(ge::PARAM_INVALID, ret); -} - -// test load EventElapsedTime when uninited -TEST_F(UtestModelManagerEventManager, call_event_elapsed_time_while_uninited) { - ge::EventManager event_manager; - float time = .0f; - Status ret = event_manager.EventElapsedTime(1, 2, time); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); -} - -// test with invalid param when load EventElapsedTime -TEST_F(UtestModelManagerEventManager, call_event_elapsed_time_with_invalid_param) { - ge::EventManager *event_manager = new ge::EventManager; - size_t event_num = 2; - Status ret = event_manager->Init(event_num); - EXPECT_EQ(SUCCESS, ret); - float time = .0f; - - // normal load - ret = event_manager->EventElapsedTime(0, 1, time); - EXPECT_EQ(SUCCESS, ret); - - // startevent_idx overstep boundary - ret = event_manager->EventElapsedTime(2, 1, time); - EXPECT_EQ(ge::PARAM_INVALID, ret); - - // stopevent_idx overstep boundary - ret = event_manager->EventElapsedTime(1, 2, time); - EXPECT_EQ(ge::PARAM_INVALID, ret); - - // startevent_idx > stopevent_idx - ret = event_manager->EventElapsedTime(1, 0, time); - EXPECT_EQ(ge::PARAM_INVALID, ret); - - delete event_manager; -} -TEST_F(UtestModelManagerEventManager, call_get_event) { - ge::EventManager event_manager; - size_t event_num = 1; - event_manager.Init(event_num); - rtEvent_t event = nullptr; - Status ret = event_manager.GetEvent(2, event); - EXPECT_EQ(ge::PARAM_INVALID, ret); - ret = event_manager.GetEvent(0, event); - EXPECT_EQ(SUCCESS, ret); -} diff --git a/tests/ut/ge/graph/load/new_model_manager_task_build_unittest.cc b/tests/ut/ge/graph/load/new_model_manager_task_build_unittest.cc deleted file mode 100644 index f10ccd7f..00000000 --- a/tests/ut/ge/graph/load/new_model_manager_task_build_unittest.cc +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "common/debug/log.h" -#include "common/debug/memory_dumper.h" -#include "common/types.h" -#include "new_op_test_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/attr_utils.h" -#include "graph/detail/model_serialize_imp.h" -#include "proto/ge_ir.pb.h" - -#define private public -#define protected public -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" -#include "graph/model_serialize.h" -#include "graph/load/model_manager/davinci_model.h" -#include "common/properties_manager.h" -#include "common/op/ge_op_utils.h" -#include -#include "runtime/dev.h" -#include "runtime/kernel.h" -#include "cce/fwk_adpt_struct.h" -#undef private -#undef protected - -using namespace std; -using namespace testing; - -namespace ge { -class UtestModelManagerTaskBuilder : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} - - /// data weight - /// | | | | - /// |-conv-| | | - /// | | | - /// conv2d | - /// | | - /// |-resApply - - void BuildGraph(ComputeGraphPtr graph) { - OpDescPtr data = std::make_shared("DATA1", "data"); - OpDescPtr weight = std::make_shared("WEIGHT", "weight"); - OpDescPtr conv_op = std::make_shared("conv", "conv"); - OpDescPtr conv_2D = std::make_shared("conv_2D", "conv2d"); - OpDescPtr res_apply_op = std::make_shared("res_apply_op", "resapply"); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - int32_t blockSize = 4096; - - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 1); - data->AddOutputDesc(out_desc); - - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 2); - weight->AddOutputDesc(out_desc); - - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 1); - conv_op->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 2); - conv_op->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 3); - conv_op->AddOutputDesc(out_desc); - - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 3); - conv_2D->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 2); - conv_2D->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 4); - conv_2D->AddOutputDesc(out_desc); - - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 4); - res_apply_op->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 1); - res_apply_op->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 5); - res_apply_op->AddOutputDesc(out_desc); - - NodePtr data_node = graph->AddNode(data); - NodePtr weigth_node = graph->AddNode(weight); - NodePtr conv_node = graph->AddNode(conv_op); - NodePtr conv_2D_node = graph->AddNode(conv_2D); - NodePtr res_node = graph->AddNode(res_apply_op); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(weigth_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv_2D_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(weigth_node->GetOutDataAnchor(0), conv_2D_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_2D_node->GetOutDataAnchor(0), res_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(weigth_node->GetOutDataAnchor(0), res_node->GetInDataAnchor(1)); - return; - } -}; -} // namespace ge diff --git a/tests/ut/ge/graph/load/output_net_output_unittest.cc b/tests/ut/ge/graph/load/output_net_output_unittest.cc deleted file mode 100644 index 97246dad..00000000 --- a/tests/ut/ge/graph/load/output_net_output_unittest.cc +++ /dev/null @@ -1,300 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "securec.h" - -#define protected public -#define private public -#include "common/debug/memory_dumper.h" -#include "common/op/ge_op_utils.h" -#include "graph/load/model_manager/davinci_model.h" -#include "graph/load/model_manager/model_utils.h" -#include "graph/manager/graph_var_manager.h" -#include "new_op_test_utils.h" -#include "proto/om.pb.h" - -using namespace std; - -namespace ge { -class UtestNetOutput : public testing::Test { - protected: - void TearDown() {} - shared_ptr GenOpdef(OpDescPtr &op_desc, int flag) { - shared_ptr builder = make_shared(op_desc); - builder->SetStreamId(0); - builder->AddInput(1); - builder->SetType("NetOutput"); - - if (flag == 1) { - auto input_desc_1 = builder->AddInputDesc({1, 1, 10, 10}, FORMAT_NCHW, DT_FLOAT16); - } - auto input_desc_1 = builder->AddInputDesc({1, 1, 10, 10}, FORMAT_NCHW, DT_FLOAT16); - - if (flag == 2) { - auto input_desc_2 = builder->AddInputDesc({1, 1, 10, 10}, FORMAT_NCHW, DT_FLOAT16); - } - if (flag == 3) { - builder->AddInput(10); - } - - return builder; - } - shared_ptr GenOpdef2(OpDescPtr &op_desc) { - shared_ptr builder = make_shared(op_desc); - builder->SetStreamId(0); - builder->SetType("NetOutput"); - builder->AddInput(10); - - auto input_desc_1 = builder->AddInputDesc({64, 32, 5, 5}, FORMAT_FRACTAL_Z, DT_FLOAT); - - builder->AddInput(1000000); - auto input_desc_2 = builder->AddInputDesc({1, 10, 10, 1}, FORMAT_NHWC, DT_FLOAT); - - builder->AddOutput(2000000); - auto output_desc_1 = builder->AddOutputDesc({64, 32, 5, 5}, FORMAT_NCHW, DT_FLOAT); - - builder->AddOutput(2100000); - output_desc_1 = builder->AddOutputDesc({1, 10, 10, 1}, FORMAT_NHWC, DT_FLOAT); - - return builder; - } - - public: - shared_ptr dav_model_; -}; - -TEST_F(UtestNetOutput, test_get_input_size) { - shared_ptr custom_op_desc = make_shared(); - OmeTestOpDescBuilder builder(custom_op_desc); - builder.SetName("netoutput"); - builder.SetStreamId(0); - builder.SetType("NetOutput"); - - auto input_desc_1 = builder.AddInputDesc({1, 1, 1, 1}, FORMAT_FRACTAL_Z, DT_FLOAT); - builder.AddInput(1); - auto output_desc = builder.AddOutputDesc({1, 1, 1, 1}, FORMAT_NCHW, DT_FLOAT); - builder.AddOutput(1); - builder.Finish(); - - vector v_output_size = ModelUtils::GetInputSize(custom_op_desc); - EXPECT_EQ(v_output_size.size(), 1); -} - -// test ModelUtils::IsOutput -TEST_F(UtestNetOutput, success_is_output) { - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - OmeTestOpDescBuilder builder(op_desc); - builder.SetType("NetOutput"); - vector outputs_desc; - std::shared_ptr desc = std::make_shared(); - outputs_desc.push_back(desc); - op_desc->outputs_desc_ = outputs_desc; - bool ret = model_utils->IsOutput(op_desc); - EXPECT_EQ(false, ret); - - delete model_utils; -} - -// test ModelUtils::IsOutput -TEST_F(UtestNetOutput, true_is_output) { - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - OmeTestOpDescBuilder builder(op_desc); - builder.SetType("NetOutput"); - vector outputs_desc; - std::shared_ptr desc = std::make_shared(); - outputs_desc.push_back(desc); - op_desc->outputs_desc_ = outputs_desc; - ge::TensorUtils::SetOutputTensor(*(outputs_desc[0].get()), true); - bool ret = model_utils->IsOutput(op_desc); - EXPECT_EQ(true, ret); - - delete model_utils; -} - -// test ModelUtils::IsInputTensorNeedTrans -TEST_F(UtestNetOutput, success_is_output_tensor_need_trans) { - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - OmeTestOpDescBuilder builder(op_desc); - builder.SetType("NetOutput"); - size_t tensor_index = 1; - vector outputs_desc; - std::shared_ptr desc = std::make_shared(); - outputs_desc.push_back(desc); - op_desc->outputs_desc_ = outputs_desc; - op_desc->inputs_desc_ = outputs_desc; - - bool ret = model_utils->IsInputTensorNeedTrans(op_desc, tensor_index); - EXPECT_EQ(false, ret); - - delete model_utils; -} - -// test ModelUtils::GetOutputSize -TEST_F(UtestNetOutput, success_get_output_size) { - vector v_output_size; - - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - vector outputs_desc; - std::shared_ptr desc = std::make_shared(); - outputs_desc.push_back(desc); - op_desc->outputs_desc_ = outputs_desc; - EXPECT_EQ(v_output_size, model_utils->GetOutputSize(op_desc)); - - vector output = {1}; - op_desc->SetOutputOffset(output); - uint32_t tensor_size = 0; - v_output_size.push_back(tensor_size); - EXPECT_EQ(v_output_size, model_utils->GetOutputSize(op_desc)); - delete model_utils; -} - -// test ModelUtils::GetWorkspaceSize -TEST_F(UtestNetOutput, success_get_workspace_size) { - vector v_workspace_size; - - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - vector workspace = {1}; - op_desc->SetWorkspace(workspace); - EXPECT_EQ(v_workspace_size, model_utils->GetWorkspaceSize(op_desc)); - - op_desc->SetWorkspaceBytes(workspace); - v_workspace_size.push_back(1); - EXPECT_EQ(v_workspace_size, model_utils->GetWorkspaceSize(op_desc)); - delete model_utils; -} - -// test ModelUtils::GetWeightSize -TEST_F(UtestNetOutput, success_get_weight_size) { - vector v_weight_size; - - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - op_desc->SetType("Const"); - EXPECT_EQ(v_weight_size, model_utils->GetWeightSize(op_desc)); - - op_desc->SetType("NetOutput"); - vector inputs_desc; - std::shared_ptr desc = std::make_shared(); - inputs_desc.push_back(desc); - op_desc->inputs_desc_ = inputs_desc; - - vector is_input_const = {true}; - op_desc->SetIsInputConst(is_input_const); - v_weight_size.push_back(0); - EXPECT_EQ(v_weight_size, model_utils->GetWeightSize(op_desc)); - - delete model_utils; -} - -// test ModelUtils::GetWeights -TEST_F(UtestNetOutput, success_get_weights) { - vector v_weights; - - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - op_desc->SetType("Const"); - EXPECT_EQ(v_weights, model_utils->GetWeights(op_desc)); - - op_desc->SetType("NetOutput"); - vector inputs_desc; - std::shared_ptr desc = std::make_shared(); - inputs_desc.push_back(desc); - op_desc->inputs_desc_ = inputs_desc; - - vector is_input_const = {true}; - op_desc->SetIsInputConst(is_input_const); - GeTensorDesc tensor_desc; - EXPECT_EQ(v_weights, model_utils->GetWeights(op_desc)); - - delete model_utils; -} - -// test ModelUtils::GetInputDescs -TEST_F(UtestNetOutput, success_get_input_descs) { - vector<::opTensor_t> v_input_descs; - vector<::tagCcAICPUTensor> ret; - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - ret = model_utils->GetInputDescs(op_desc); - EXPECT_EQ(v_input_descs.size(), ret.size()); - - vector inputs_desc; - std::shared_ptr desc = std::make_shared(); - inputs_desc.push_back(desc); - op_desc->inputs_desc_ = inputs_desc; - vector is_input_const = {false}; - op_desc->SetIsInputConst(is_input_const); - - opTensor_t tmp; - tmp.format = OP_TENSOR_FORMAT_NC1HWC0; - tmp.dim_cnt = 0; - tmp.data_type = OP_DATA_FLOAT; - v_input_descs.push_back(tmp); - ret = model_utils->GetInputDescs(op_desc); - EXPECT_EQ(v_input_descs.size(), ret.size()); - - delete model_utils; -} - -// test ModelUtils::GetOutputDescs -TEST_F(UtestNetOutput, success_get_output_descs) { - vector<::opTensor_t> v_output_descs; - vector<::tagCcAICPUTensor> ret; - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - ret = model_utils->GetOutputDescs(op_desc); - EXPECT_EQ(v_output_descs.size(), ret.size()); - - vector outputs_desc; - std::shared_ptr desc = std::make_shared(); - outputs_desc.push_back(desc); - op_desc->outputs_desc_ = outputs_desc; - - opTensor_t tmp; - tmp.format = OP_TENSOR_FORMAT_NC1HWC0; - tmp.dim_cnt = 0; - tmp.data_type = OP_DATA_FLOAT; - v_output_descs.push_back(tmp); - ret = model_utils->GetOutputDescs(op_desc); - EXPECT_EQ(v_output_descs.size(), ret.size()); - - delete model_utils; -} - -// test Output::GetOutputData -TEST_F(UtestNetOutput, success_get_output_data) { - Output *output = new Output(nullptr, nullptr); - output->v_input_data_addr_.push_back((void *)1); - output->v_input_size_.push_back(1); - output->input_num_ = 1; - - vector v_data_addr; - vector v_data_size; - output->GetOutputData(v_data_addr, v_data_size); - - EXPECT_EQ(output->v_input_data_addr_, v_data_addr); - EXPECT_EQ(output->v_input_size_, v_data_size); - delete output; -} -} // namespace ge diff --git a/tests/ut/ge/graph/manager/graph_manager_unittest.cc b/tests/ut/ge/graph/manager/graph_manager_unittest.cc index 9bae10eb..b40690e2 100644 --- a/tests/ut/ge/graph/manager/graph_manager_unittest.cc +++ b/tests/ut/ge/graph/manager/graph_manager_unittest.cc @@ -15,20 +15,9 @@ */ #include + #include #include -#define protected public -#define private public -#include "graph/manager/graph_manager.h" -#include "graph/load/model_manager/model_manager.h" -#include "graph/load/model_manager/davinci_model.h" -#define const -#include "common/helper/model_cache_helper.h" -#undef const -#include "init/gelib.h" -#undef private -#undef public - #include #include #include @@ -38,13 +27,18 @@ #include #include +#define protected public +#define private public +#include "graph/manager/graph_manager.h" +#include "init/gelib.h" + #include "common/math/math_util.h" #include "common/thread_pool.h" #include "common/dump/dump_manager.h" #include "analyzer/analyzer.h" -#include "graph/common/ge_call_wrapper.h" -#include "graph/common/local_context.h" -#include "graph/common/transop_util.h" +#include "common/ge_call_wrapper.h" +#include "common/local_context.h" +#include "common/transop_util.h" #include "graph/ge_context.h" #include "graph/ge_global_options.h" #include "graph/manager/util/rt_context_util.h" @@ -111,8 +105,8 @@ #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" #include "ir_build/option_utils.h" -#include "graph/common/local_context.h" -#include "graph/common/omg_util.h" +#include "common/local_context.h" +#include "common/omg_util.h" #include "common/formats/utils/formats_trans_utils.h" #include "../passes/graph_builder_utils.h" #include "register/custom_pass_helper.h" @@ -121,7 +115,6 @@ using namespace std; using namespace testing; -using namespace ge; using namespace domi; namespace { @@ -129,6 +122,8 @@ const uint32_t kNotAdded = 0; const uint32_t kStartAdd = 1; const uint32_t kDoneAdded = 2; } + +namespace ge { class UtestGraphManagerTest : public testing::Test { protected: void SetUp() {} @@ -136,6 +131,31 @@ class UtestGraphManagerTest : public testing::Test { void TearDown() {} }; +class StubExecutor : public Executor { + public: + Status LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { + return SUCCESS; + } + + Status UnloadGraph(const GeRootModelPtr &ge_root_model, uint32_t graph_id) { + return SUCCESS; + } + + Status PushGraph(const RunArgs &args) { + return SUCCESS; + } + + Status RunGraph(const GraphNodePtr &graph_node, GraphId graph_id, + const std::vector &inputs, std::vector &outputs) { + return SUCCESS; + } + + Status RunGraphWithStream(const GraphNodePtr &graph_node, GraphId graph_id, rtStream_t stream, + const std::vector &inputs, std::vector &outputs){ + return SUCCESS; + } +}; + void CreateGraph(Graph &graph) { TensorDesc desc(ge::Shape({1, 3, 224, 224})); uint32_t size = desc.GetShape().GetShapeSize(); @@ -288,26 +308,20 @@ TEST_F(UtestGraphManagerTest, test_remove_graph_1) { TEST_F(UtestGraphManagerTest, test_remove_graph_2) { GraphId graph_id = 1; GraphManager graph_manager; + StubExecutor stub_executor; + graph_manager.executor_ = &stub_executor; + GraphNodePtr graph_node = MakeShared(graph_id); Graph graph("test_graph"); CreateGraph(graph); auto compute_graph = GraphUtils::GetComputeGraph(graph); GeRootModelPtr ge_root_model = MakeShared(compute_graph); - auto model_manager = ModelManager::GetInstance(); - auto listener = MakeShared(); - shared_ptr davinci_model1 = MakeShared(1, listener); - davinci_model1->SetId(1); - shared_ptr davinci_model2 = MakeShared(2, listener); - davinci_model1->SetId(2); - model_manager->InsertModel(1, davinci_model1); - model_manager->InsertModel(2, davinci_model2); ge_root_model->SetModelId(1); ge_root_model->SetModelId(2); graph_node->SetGeRootModel(ge_root_model); graph_node->SetLoadFlag(true); graph_manager.AddGraphNode(graph_id, graph_node); - Status status = graph_manager.RemoveGraph(graph_id); - EXPECT_EQ(status, ge::SUCCESS); + EXPECT_EQ(graph_manager.RemoveGraph(graph_id), SUCCESS); } TEST_F(UtestGraphManagerTest, test_pre_run_thread) { @@ -327,7 +341,7 @@ TEST_F(UtestGraphManagerTest, test_pre_run_thread) { GraphNodePtr graph_node = MakeShared(graph_id); graph_manager.AddGraphNode(graph_id, graph_node); - graph_manager.PreRunThread(&graph_manager); + graph_manager.PreRunThread(); // end with failed } @@ -355,48 +369,10 @@ TEST_F(UtestGraphManagerTest, test_pre_run_thread_2) { graph_manager.AddGraphNode(graph_id, graph_node_2); ret = graph_manager.prerun_args_q_.Push({graph_id, input_tensor, session_id, error_context, context, callback}); EXPECT_EQ(ret, true); - graph_manager.PreRunThread(&graph_manager); + graph_manager.PreRunThread(); // end with failed } -TEST_F(UtestGraphManagerTest, test_check_and_release_memory) { - - GraphManager graph_manager; - GeModelPtr ge_model = make_shared(); - int64_t memory_size = 25 * 1024UL * 1024UL * 1024UL; - int64_t weight_size = 25 * 1024UL * 1024UL * 1024UL; - uint64_t session_id = 0; - ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, memory_size); - ge::AttrUtils::SetInt(ge_model, ATTR_MODEL_WEIGHT_SIZE, weight_size); - ge::AttrUtils::SetInt(ge_model, MODEL_ATTR_SESSION_ID, session_id); - - - GraphId graph_id = 1; - GraphNodePtr graph_node = MakeShared(graph_id); - graph_manager.AddGraphNode(graph_id, graph_node); - graph_manager.IncreaseGraphCount(graph_id); - graph_manager.IncreaseGraphCount(graph_id); - - auto model_manager = ModelManager::GetInstance(); - auto listener = MakeShared(); - shared_ptr davinci_model1 = MakeShared(1, listener); - davinci_model1->SetId(1); - shared_ptr davinci_model2 = MakeShared(2, listener); - davinci_model1->SetId(2); - model_manager->InsertModel(1, davinci_model1); - model_manager->InsertModel(2, davinci_model2); - ComputeGraphPtr compute_graph = MakeShared("test_graph"); - bool is_dynamic_shape = false; - (void)AttrUtils::GetBool(compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); - GeRootModelPtr ge_root_model = MakeShared(compute_graph); - ge_root_model->SetModelId(1); - ge_root_model->SetModelId(2); - graph_node->SetGeRootModel(ge_root_model); - graph_node->SetLoadFlag(true); - Status status = graph_manager.CheckAndReleaseMemory(ge_model, graph_node); - EXPECT_EQ(status, ge::SUCCESS); -} - TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_1) { // no need to build GraphId graph_id = 1; @@ -406,7 +382,7 @@ TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_1) { GraphManager::PreRunArgs arg; GraphNodePtr graph_node = MakeShared(graph_id); graph_node->SetBuildFlag(true); - Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); + Status status = graph_manager.CheckIncreBuildAndPreRun(arg, graph_node, ge_root_model); EXPECT_EQ(status, ge::SUCCESS); } @@ -422,7 +398,7 @@ TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_2) { graph_node->SetBuildFlag(true); graph_node->Lock(); graph_manager.var_acc_ctrl_.graph_ids_need_rebuild_.insert(graph_id); - Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); + Status status = graph_manager.CheckIncreBuildAndPreRun(arg, graph_node, ge_root_model); EXPECT_EQ(status, ge::PARAM_INVALID); } @@ -437,7 +413,7 @@ TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_3) { GraphNodePtr graph_node = MakeShared(graph_id); graph_node->SetBuildFlag(false); graph_node->Lock(); - Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); + Status status = graph_manager.CheckIncreBuildAndPreRun(arg, graph_node, ge_root_model); EXPECT_NE(status, ge::SUCCESS); } @@ -471,14 +447,6 @@ TEST_F(UtestGraphManagerTest, test_add_graph_with_copy_fail) { EXPECT_NE(status, ge::SUCCESS); } -TEST_F(UtestGraphManagerTest, ParseInputsDimsForData_success) { - GraphManager graph_manager; - std::vector input_tensors; - ge::Tensor tensor; - input_tensors.emplace_back(tensor); - graph_manager.ParseInputsDimsForData(input_tensors); -} - TEST_F(UtestGraphManagerTest, test_prerunthread_failed_1) { GraphId graph_id = 1; GraphManager graph_manager; @@ -509,7 +477,7 @@ TEST_F(UtestGraphManagerTest, test_prerunthread_failed_1) { graph_node->SetRunFlag(false); // function return. graph_manager.prerun_args_q_.Push(args); - auto t1 = std::thread(GraphManager::PreRunThread, &graph_manager); + auto t1 = std::thread(&GraphManager::PreRunThread, &graph_manager); if (t1.joinable()) { t1.join(); } @@ -549,7 +517,7 @@ TEST_F(UtestGraphManagerTest, test_prerunthread_failed_2) { int ret = setenv("ENABLE_NETWORK_ANALYSIS_DEBUG", "1", 1); EXPECT_EQ(ret, 0); graph_manager.prerun_args_q_.Push(args); - auto t1 = std::thread(GraphManager::PreRunThread, &graph_manager); + auto t1 = std::thread(&GraphManager::PreRunThread, &graph_manager); if (t1.joinable()) { t1.join(); } @@ -593,3 +561,4 @@ TEST_F(UtestGraphManagerTest, ChangeAndDeleteConst_success) { auto all_nodes = graph->GetDirectNode(); EXPECT_EQ(all_nodes.size(), 3); } +} // namespace ge diff --git a/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc b/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc new file mode 100644 index 00000000..c20e786d --- /dev/null +++ b/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#define protected public +#define private public +#include "graph/manager/graph_var_manager.h" +#include "graph/ge_context.h" +#undef protected +#undef private + +namespace ge { +class UtestGraphVarManagerTest : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UtestGraphVarManagerTest, test_get_total_memory_size) { + size_t total_mem_size = 0; + Status ret = VarManager::Instance(0)->GetTotalMemorySize(total_mem_size); + EXPECT_EQ(total_mem_size, 1024UL * 1024UL * 1024UL); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_no_related_option) { + const map options{}; + Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); + EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f))); + EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f))); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_graph_mem_max_size) { + const map options{{"ge.graphMemoryMaxSize", "536870912"}}; + Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); + EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); + EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f))); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_var_mem_max_size) { + const map options{{"ge.variableMemoryMaxSize", "536870912"}}; + Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); + EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f))); + EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); + EXPECT_EQ(ret, SUCCESS); +} +} // namespace ge diff --git a/tests/ut/ge/graph/manager/hcom_util_unittest.cc b/tests/ut/ge/graph/manager/hcom_util_unittest.cc index 9f104f5f..4aeeddb9 100644 --- a/tests/ut/ge/graph/manager/hcom_util_unittest.cc +++ b/tests/ut/ge/graph/manager/hcom_util_unittest.cc @@ -94,4 +94,15 @@ TEST_F(UtestHcomUtil, test_GetHcomCount_succ) { auto ret = hcom_ome_util.GetHcomCount(op_desc, HCCL_DATA_TYPE_FP32, true, count); EXPECT_EQ(ret, 0); } + +TEST_F(UtestHcomUtil, test_GetHcomCount_succ_2) { + ComputeGraphPtr graph = std::make_shared("test"); + NodePtr node = NodeBuilder("node", HCOMSEND).AddInputDesc({1, 1, 224, 224}).Build(graph); + auto op_desc = node->GetOpDesc(); + HcomOmeUtil hcom_util; + int count = 0; + auto ret = hcom_util.GetHcomCount(op_desc, HCCL_DATA_TYPE_FP32, true, count); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(count, 224 * 224); +} } // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc b/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc index 5468ec97..7f26aa8c 100644 --- a/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc +++ b/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc @@ -32,14 +32,14 @@ using namespace ge; namespace { const char *const kVectorCore = "VectorCore"; const char *const kAicoreEngine = "AIcoreEngine"; -string CreateEngineConfigJson() { +void CreateEngineConfigJson(string &dir_path, string &file_path) { GELOGI("Begin to create engine config json file."); string base_path = PluginManager::GetPath(); GELOGI("Base path is %s.", base_path.c_str()); - string dir_path = base_path.substr(0, base_path.rfind('/') + 1) + "plugin/nnengine/ge_config"; + dir_path = base_path.substr(0, base_path.rfind('/') + 1) + "plugin/nnengine/ge_config"; string cmd = "mkdir -p " + dir_path; system(cmd.c_str()); - string file_path = dir_path + "/engine_conf.json"; + file_path = dir_path + "/engine_conf.json"; GELOGI("Begin to write into the config file: %s.", file_path.c_str()); ofstream ofs(file_path, ios::out); EXPECT_EQ(!ofs, false); @@ -56,7 +56,6 @@ string CreateEngineConfigJson() { "}"; ofs.close(); GELOGI("Json config file %s has been written.", file_path.c_str()); - return file_path; } void DeleteFile(const string &file_name) { @@ -69,14 +68,16 @@ void DeleteFile(const string &file_name) { class UtestGraphOptimizeTest : public testing::Test { protected: void SetUp() { - config_file_ = CreateEngineConfigJson(); + CreateEngineConfigJson(config_dir_, config_file_); } void TearDown() { DeleteFile(config_file_); + DeleteFile(config_dir_); } private: + string config_dir_; string config_file_; }; diff --git a/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc b/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc index c8abadb5..1d19a8bd 100644 --- a/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc +++ b/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc @@ -20,9 +20,11 @@ #define protected public #include "graph/partition/dynamic_shape_partition.h" #include "compute_graph.h" +#include "graph/compute_graph_impl.h" #include "inc/framework/common/types.h" #include "utils/graph_utils.h" #include "graph/debug/ge_attr_define.h" +#include "common/omg_util.h" namespace ge { namespace { @@ -37,33 +39,33 @@ GeTensorDescPtr CreateTensorDesc(std::initializer_list shape, Format fo } class NodeBuilder { - public: - NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared(name, type); } - - NodeBuilder &AddInputDesc(std::initializer_list shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, - DataType data_type = DT_FLOAT) { - op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); - return *this; - } - - NodeBuilder &AddOutputDesc(std::initializer_list shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, - DataType data_type = DT_FLOAT) { - op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); - return *this; - } - - NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) { - op_desc_->AddOutputDesc(tensor_desc->Clone()); - return *this; - } - - NodePtr Build(const ComputeGraphPtr &graph) { - NodePtr node = graph->AddNode(op_desc_); - return node; - } - - private: - OpDescPtr op_desc_; + public: + NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared(name, type); } + + NodeBuilder &AddInputDesc(std::initializer_list shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, + DataType data_type = DT_FLOAT) { + op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); + return *this; + } + + NodeBuilder &AddOutputDesc(std::initializer_list shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, + DataType data_type = DT_FLOAT) { + op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); + return *this; + } + + NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) { + op_desc_->AddOutputDesc(tensor_desc->Clone()); + return *this; + } + + NodePtr Build(const ComputeGraphPtr &graph) { + NodePtr node = graph->AddNode(op_desc_); + return node; + } + + private: + OpDescPtr op_desc_; }; } // namespace @@ -92,28 +94,137 @@ TEST_F(UtestDynamicShapePartition, single_op_scene_success) { EXPECT_EQ(partitioner.Partition(), SUCCESS); } +/******************************************************************************* + * | + * Merge1 + * Active / \ Active + * / \. + * / \. + * Merge2 \. + * Active/ \Active \. + * / \ \. + * Add Sub Relu + * | | | + * | | | + * Switch_f2 Switch_t2 | + * \ / | + * \ / | + * Less2 | + * | | + * | | + * Switch_f Switch_t + * | \ / | + * | Active | + * | | | + * | Less1 | + * | / \ | + * | / \ | + * Data Data + ******************************************************************************/ TEST_F(UtestDynamicShapePartition, merge_control_flow_group) { ComputeGraphPtr graph = std::make_shared("default"); AttrUtils::SetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, "session_graph_id"); - NodePtr data1 = NodeBuilder("data1", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); - NodePtr data2 = NodeBuilder("data2", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); - NodePtr merge = NodeBuilder("node2", MERGE).AddInputDesc({1}).AddInputDesc({1}) - .AddOutputDesc({1}).AddOutputDesc({}).Build(graph); - - GraphUtils::AddEdge(data1->GetOutDataAnchor(0), merge->GetInDataAnchor(0)); - GraphUtils::AddEdge(data2->GetOutDataAnchor(0), merge->GetInDataAnchor(1)); - - (void)AttrUtils::SetBool(data1->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); - (void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); - (void)AttrUtils::SetBool(data2->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); - (void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); - (void)AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); - (void)AttrUtils::SetInt(merge->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); - - EXPECT_EQ(graph->sub_graph_.size(), 0); + auto data1 = NodeBuilder("data1", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto data2 = NodeBuilder("data2", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + + auto less1 = NodeBuilder("less1", LESS).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto active1 = NodeBuilder("active1", STREAMACTIVE).Build(graph); + auto switch_t = NodeBuilder("switch_t", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); + auto switch_f = NodeBuilder("switch_f", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); + auto const_01 = NodeBuilder("const_01", CONSTANT).AddOutputDesc({1}).Build(graph); + auto const_11 = NodeBuilder("const_11", CONSTANT).AddOutputDesc({1}).Build(graph); + + + auto less2 = NodeBuilder("less2", LESS).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto active2 = NodeBuilder("active2", STREAMACTIVE).Build(graph); + auto switch_t2 = NodeBuilder("switch_t2", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); + auto switch_f2 = NodeBuilder("switch_f2", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); + auto const_02 = NodeBuilder("const_02", CONSTANT).AddOutputDesc({1}).Build(graph); + auto const_12 = NodeBuilder("const_12", CONSTANT).AddOutputDesc({1}).Build(graph); + + auto add2 = NodeBuilder("add2", ADD).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto sub2 = NodeBuilder("sub2", SUB).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto merge2 = NodeBuilder("merge2", STREAMMERGE).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto active_f2 = NodeBuilder("active_f2", STREAMACTIVE).Build(graph); + auto active_t2 = NodeBuilder("active_t2", STREAMACTIVE).Build(graph); + + auto relu1 = NodeBuilder("relu1", RELU).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto merge1 = NodeBuilder("merge1", STREAMMERGE).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto active_f1 = NodeBuilder("active_f1", STREAMACTIVE).Build(graph); + auto active_t1 = NodeBuilder("active_t1", STREAMACTIVE).Build(graph); + + auto output1 = NodeBuilder("noutput1", NETOUTPUT).AddInputDesc({1}).Build(graph); + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data2->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(0)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(0)); + GraphUtils::AddEdge(const_01->GetOutDataAnchor(0), switch_t->GetInDataAnchor(1)); + GraphUtils::AddEdge(const_11->GetOutDataAnchor(0), switch_f->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutControlAnchor(), active1->GetInControlAnchor()); + GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); + GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_f->GetInControlAnchor()); + + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less2->GetInDataAnchor(0)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), less2->GetInDataAnchor(1)); + GraphUtils::AddEdge(less2->GetOutDataAnchor(0), switch_t2->GetInDataAnchor(0)); + GraphUtils::AddEdge(less2->GetOutDataAnchor(0), switch_f2->GetInDataAnchor(0)); + GraphUtils::AddEdge(const_02->GetOutDataAnchor(0), switch_t2->GetInDataAnchor(1)); + GraphUtils::AddEdge(const_12->GetOutDataAnchor(0), switch_f2->GetInDataAnchor(1)); + GraphUtils::AddEdge(less2->GetOutControlAnchor(), active2->GetInControlAnchor()); + GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t2->GetInControlAnchor()); + GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f2->GetInControlAnchor()); + + + GraphUtils::AddEdge(switch_f2->GetOutControlAnchor(), add2->GetInControlAnchor()); + GraphUtils::AddEdge(less2->GetOutDataAnchor(0), add2->GetInDataAnchor(0)); + GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(0)); + GraphUtils::AddEdge(add2->GetOutControlAnchor(), active_f2->GetInControlAnchor()); + GraphUtils::AddEdge(active_f2->GetOutControlAnchor(), merge2->GetInControlAnchor()); + + GraphUtils::AddEdge(switch_t2->GetOutControlAnchor(), sub2->GetInControlAnchor()); + GraphUtils::AddEdge(less2->GetOutDataAnchor(0), sub2->GetInDataAnchor(0)); + GraphUtils::AddEdge(sub2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1)); + GraphUtils::AddEdge(sub2->GetOutControlAnchor(), active_t2->GetInControlAnchor()); + GraphUtils::AddEdge(active_t2->GetOutControlAnchor(), merge2->GetInControlAnchor()); + + GraphUtils::AddEdge(switch_t->GetOutControlAnchor(), less2->GetInControlAnchor()); + GraphUtils::AddEdge(switch_f->GetOutControlAnchor(), relu1->GetInControlAnchor()); + + + GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge2->GetOutControlAnchor(), active_f1->GetInControlAnchor()); + GraphUtils::AddEdge(active_f1->GetOutControlAnchor(), merge1->GetInControlAnchor()); + + GraphUtils::AddEdge(data2->GetOutDataAnchor(0), relu1->GetInDataAnchor(1)); + GraphUtils::AddEdge(relu1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); + GraphUtils::AddEdge(relu1->GetOutControlAnchor(), active_t1->GetInControlAnchor()); + GraphUtils::AddEdge(active_t1->GetOutControlAnchor(), merge1->GetInControlAnchor()); + + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); + + AttrUtils::SetBool(merge2->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); + EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); + + SetControlFlowGroup(merge2, merge2->GetOpDesc()->GetId()); + SetControlFlowGroup(switch_f2, merge2->GetOpDesc()->GetId()); + SetControlFlowGroup(switch_t2, merge2->GetOpDesc()->GetId()); + SetControlFlowGroup(active2, merge2->GetOpDesc()->GetId()); + SetControlFlowGroup(active_t2, merge2->GetOpDesc()->GetId()); + SetControlFlowGroup(active_f2, merge2->GetOpDesc()->GetId()); + + SetControlFlowGroup(merge1, merge1->GetOpDesc()->GetId()); + SetControlFlowGroup(switch_f, merge1->GetOpDesc()->GetId()); + SetControlFlowGroup(switch_t, merge1->GetOpDesc()->GetId()); + SetControlFlowGroup(active1, merge1->GetOpDesc()->GetId()); + SetControlFlowGroup(active_f1, merge1->GetOpDesc()->GetId()); + SetControlFlowGroup(active_t1, merge1->GetOpDesc()->GetId()); + + EXPECT_EQ(graph->impl_->sub_graph_.size(), 0); DynamicShapePartitioner partitioner(graph); EXPECT_EQ(partitioner.Partition(), SUCCESS); - EXPECT_EQ(graph->sub_graph_.size(), 1); + EXPECT_EQ(graph->impl_->sub_graph_.size(), 3); // input less1 uknown } } // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/graph/passes/addn_pass_unittest.cc b/tests/ut/ge/graph/passes/addn_pass_unittest.cc index 6107a7d8..39029e8c 100644 --- a/tests/ut/ge/graph/passes/addn_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/addn_pass_unittest.cc @@ -72,7 +72,7 @@ TEST(UtestGraphPassesAddnPass, null_pass) { AddNPass *addn_pass = nullptr; NamesToPass names_to_pass; names_to_pass.emplace_back("Test", addn_pass); - EXPECT_EQ(pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR); } TEST(UtestGraphPassesAddnPass, null_graph) { diff --git a/tests/ut/ge/graph/passes/assert_pass_unittest.cc b/tests/ut/ge/graph/passes/assert_pass_unittest.cc index 4aa133d3..9247681c 100644 --- a/tests/ut/ge/graph/passes/assert_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/assert_pass_unittest.cc @@ -55,7 +55,7 @@ class UtestGraphPassesAssertPass : public Test { }; /// D E -/// | \ | \ +/// | \ | \. /// F C G /// : | : /// H A I @@ -134,8 +134,8 @@ TEST_F(UtestGraphPassesAssertPass, assert_pass_test2) { EXPECT_EQ(graph->FindNode("D"), nullptr); } -/// E F -/// | \ | \ +/// E F +/// | \ | \. /// H C -> D G /// \ | : /// A I diff --git a/tests/ut/ge/graph/passes/base_pass_unittest.cc b/tests/ut/ge/graph/passes/base_pass_unittest.cc index 9bba5d77..3b0235f5 100644 --- a/tests/ut/ge/graph/passes/base_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/base_pass_unittest.cc @@ -1,523 +1,903 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -#include "gtest/gtest.h" - -#define protected public -#include "graph/passes/base_pass.h" -#undef protected - -#include "external/graph/ge_error_codes.h" -#include "framework/common/ge_inner_error_codes.h" -#include "framework/common/types.h" -#include "graph/node.h" -#include "graph/utils/graph_utils.h" -#include "graph_builder_utils.h" - -template class std::unordered_set; - -namespace ge { -class UtestTestPass : public BaseNodePass { - public: - UtestTestPass() = default; - UtestTestPass(bool dead_loop) : dead_loop_(dead_loop), run_times_(0) {} - - Status Run(NodePtr &node) override { - ++run_times_; - iter_nodes_.push_back(node); - auto iter = names_to_add_del_.find(node->GetName()); - if (iter != names_to_add_del_.end()) { - for (const auto &node_name : iter->second) { - auto del_node = node->GetOwnerComputeGraph()->FindNode(node_name); - GraphUtils::IsolateNode(del_node, {0}); - AddNodeDeleted(del_node); - } - } - iter = names_to_add_repass_.find(node->GetName()); - if (iter != names_to_add_repass_.end()) { - auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); - for (const auto &node_name : iter->second) { - for (auto &node_re_pass : all_nodes) { - if (node_re_pass->GetName() == node_name) { - AddRePassNode(node_re_pass); - break; - } - } - } - if (!dead_loop_) { - names_to_add_repass_.erase(iter); - } - } - // simulate infershape pass - if(node->GetType() == WHILE){ - bool need_repass = false; - AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass); - if(!OptionExists(kOptimizeAfterSubGraph)){ - return SUCCESS; - } - if(need_repass){ - AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false); - AddImmediateRePassNode(node); - } - else{ - // clear attr on while - node->GetOpDesc()->DelAttr("_need_infer_again"); - } - } - return SUCCESS; - } - void clear() { iter_nodes_.clear(); } - std::vector GetIterNodes() { return iter_nodes_; } - - void AddRePassNodeName(const std::string &iter_node, const std::string &re_pass_node) { - names_to_add_repass_[iter_node].insert(re_pass_node); - } - void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { - names_to_add_del_[iter_node].insert(del_node); - } - unsigned int GetRunTimes() { return run_times_; } - - private: - std::vector iter_nodes_; - std::map> names_to_add_del_; - std::map> names_to_add_repass_; - bool dead_loop_; - unsigned int run_times_; -}; - -class TestDelPass : public BaseNodePass { - public: - Status Run(NodePtr &node) override { return SUCCESS; } -}; - -class UTESTGraphPassesBasePass : public testing::Test { - protected: - UTESTGraphPassesBasePass() { - auto p1 = new UtestTestPass; - names_to_pass_.push_back(std::make_pair("test1", p1)); - } - void SetUp() override { - for (auto &name_to_pass : names_to_pass_) { - dynamic_cast(name_to_pass.second)->clear(); - } - } - ~UTESTGraphPassesBasePass() override { - for (auto &name_to_pass : names_to_pass_) { - delete name_to_pass.second; - } - } - NamesToPass names_to_pass_; -}; -/// reshape1 -/// | -/// add1 -/// / \ -/// | | -/// data1 const1 -ComputeGraphPtr BuildGraph1() { - auto builder = ut::GraphBuilder("g1"); - auto data = builder.AddNode("data1", DATA, 0, 1); - auto a1 = builder.AddNode("add1", ADD, 2, 1); - auto c1 = builder.AddNode("const1", CONSTANT, 0, 1); - auto r1 = builder.AddNode("reshape1", RESHAPE, 1, 1); - - builder.AddDataEdge(data, 0, a1, 0); - builder.AddDataEdge(c1, 0, a1, 1); - builder.AddDataEdge(a1, 0, r1, 0); - - return builder.GetGraph(); -} - -/// sum1 -/// / \ -/// / \ -/// / \ -/// reshape1 addn1 -/// | c | -/// add1 <--- shape1 -/// / \ | -/// | | | -/// data1 const1 const2 -ComputeGraphPtr BuildGraph2() { - auto builder = ut::GraphBuilder("g1"); - auto data1 = builder.AddNode("data1", DATA, 0, 1); - auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); - auto const2 = builder.AddNode("const2", CONSTANT, 0, 1); - auto add1 = builder.AddNode("add1", ADD, 2, 1); - auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1); - auto reshape1 = builder.AddNode("reshape1", RESHAPE, 1, 1); - auto addn1 = builder.AddNode("addn1", ADDN, 1, 1); - auto sum1 = builder.AddNode("sum1", SUM, 2, 1); - - builder.AddDataEdge(data1, 0, add1, 0); - builder.AddDataEdge(const1, 0, add1, 1); - builder.AddDataEdge(const2, 0, shape1, 0); - builder.AddControlEdge(shape1, add1); - builder.AddDataEdge(add1, 0, reshape1, 0); - builder.AddDataEdge(shape1, 0, addn1, 0); - builder.AddDataEdge(reshape1, 0, sum1, 0); - builder.AddDataEdge(addn1, 0, sum1, 1); - - return builder.GetGraph(); -} - -/// rnextiteration -/// | | -/// merge -/// | -/// data1 -ComputeGraphPtr BuildGraph3() { - auto builder = ut::GraphBuilder("g1"); - auto data1 = builder.AddNode("data1", DATA, 0, 1); - auto merge1 = builder.AddNode("merge1", MERGE, 2, 1); - auto next1 = builder.AddNode("next1", NEXTITERATION, 1, 1); - - builder.AddDataEdge(data1, 0, merge1, 0); - builder.AddDataEdge(merge1, 0, next1, 0); - builder.AddDataEdge(next1, 0, merge1, 1); - builder.AddControlEdge(merge1, next1); - builder.AddControlEdge(next1, merge1); - - return builder.GetGraph(); -} - -void CheckIterOrder(UtestTestPass *pass, std::vector> &nodes_layers) { - std::unordered_set layer_nodes; - size_t layer_index = 0; - for (const auto &node : pass->GetIterNodes()) { - layer_nodes.insert(node->GetName()); - EXPECT_LT(layer_index, nodes_layers.size()); - if (layer_nodes == nodes_layers[layer_index]) { - layer_index++; - layer_nodes.clear(); - } - } - EXPECT_EQ(layer_index, nodes_layers.size()); -} - -/// Op1 -/// | -/// Merge -/// / \ -/// Op2 Op3 -TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { - auto builder = ut::GraphBuilder("g1"); - auto merge_node = builder.AddNode("Merge", MERGE, 1, 1); - auto node1 = builder.AddNode("Op1", RELU, 1, 1); - auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1); - auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1); - - GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); - - EXPECT_EQ(node1->GetOutDataNodes().size(), 1); - - TestDelPass del_pass; - auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1}); - EXPECT_EQ(ret, FAILED); - - OpDescPtr op_desc = std::make_shared("merge", MERGE); - NodePtr node = shared_ptr(new (std::nothrow) Node(op_desc, nullptr)); - ret = del_pass.IsolateAndDeleteNode(node, {0, -1}); - EXPECT_EQ(ret, FAILED); -} - -/// Op1 -/// | -/// Merge -/// / \ -/// Op2 Op3 -TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { - auto builder = ut::GraphBuilder("g1"); - auto merge_node = builder.AddNode("Merge", MERGE, 1, 2); - auto node1 = builder.AddNode("Op1", RELU, 1, 1); - auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1); - auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1); - - GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); - - EXPECT_EQ(node1->GetOutDataNodes().size(), 1); - - TestDelPass del_pass; - auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1}); - EXPECT_EQ(ret, SUCCESS); -} - -TEST_F(UTESTGraphPassesBasePass, data_graph) { - auto graph = BuildGraph1(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); - auto *pass = dynamic_cast(names_to_pass_[0].second); - - EXPECT_EQ(pass->GetIterNodes().size(), 4); - std::vector> layers; - layers.push_back({"data1", "const1"}); - layers.push_back({"add1"}); - layers.push_back({"reshape1"}); - CheckIterOrder(pass, layers); -} - -TEST_F(UTESTGraphPassesBasePass, graph_with_control_link) { - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); - auto *pass = dynamic_cast(names_to_pass_[0].second); - - EXPECT_EQ(pass->GetIterNodes().size(), 8); - EXPECT_EQ(pass->GetIterNodes().at(3)->GetName(), "shape1"); - - std::vector> layers; - layers.push_back({"data1", "const1", "const2"}); - layers.push_back({"shape1"}); - layers.push_back({"add1", "addn1", "reshape1"}); - layers.push_back({"sum1"}); - CheckIterOrder(pass, layers); -} - -TEST_F(UTESTGraphPassesBasePass, re_pass_after) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddRePassNodeName("add1", "sum1"); - test_pass.AddRePassNodeName("shape1", "sum1"); - test_pass.AddRePassNodeName("shape1", "add1"); - test_pass.AddRePassNodeName("data1", "add1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 8); -} - -TEST_F(UTESTGraphPassesBasePass, re_pass_before) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddRePassNodeName("add1", "data1"); - - auto graph = BuildGraph1(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 5); - EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1"); - EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1"); - EXPECT_EQ(test_pass.GetIterNodes().at(4)->GetName(), "data1"); -} - -TEST_F(UTESTGraphPassesBasePass, re_pass_before_multi_times) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddRePassNodeName("add1", "data1"); - test_pass.AddRePassNodeName("add1", "const1"); - test_pass.AddRePassNodeName("reshape1", "data1"); - - auto graph = BuildGraph1(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 6); - EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1"); - EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1"); -} - -TEST_F(UTESTGraphPassesBasePass, del_after) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddDelNodeName("add1", "sum1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 7); -} - -TEST_F(UTESTGraphPassesBasePass, del_after_multiple) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddDelNodeName("add1", "sum1"); - test_pass.AddDelNodeName("add1", "reshape1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 6); -} - -TEST_F(UTESTGraphPassesBasePass, del_after_break_link) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddDelNodeName("shape1", "add1"); - test_pass.AddDelNodeName("shape1", "addn1"); - test_pass.AddRePassNodeName("shape1", "shape1"); - test_pass.AddRePassNodeName("shape1", "reshape1"); - test_pass.AddRePassNodeName("shape1", "sum1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 7); -} - -TEST_F(UTESTGraphPassesBasePass, del_self_and_after) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddDelNodeName("shape1", "add1"); - test_pass.AddDelNodeName("shape1", "addn1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 4); -} - -TEST_F(UTESTGraphPassesBasePass, del_before) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddDelNodeName("reshape1", "add1"); - test_pass.AddDelNodeName("sum1", "addn1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 8); -} - -TEST_F(UTESTGraphPassesBasePass, re_pass_and_del) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddRePassNodeName("add1", "sum1"); - test_pass.AddDelNodeName("reshape1", "sum1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 7); -} -/* -TEST_F(UTESTGraphPassesBasePass, dead_loop) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(true); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddRePassNodeName("add1", "sum1"); - test_pass.AddRePassNodeName("sum1", "add1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetRunTimes(), 1007); -} -*/ - -TEST_F(UTESTGraphPassesBasePass, while_loop) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(true); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - auto graph = BuildGraph3(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); -} - -/// data1 const -/// \ / -/// while -/// / \ -/// | | -/// cast1 cast2 -ComputeGraphPtr BuildWhileGraph1() { - // build sub graph - auto builder_sub = ut::GraphBuilder("sub"); - auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1); - auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1); - auto add = builder_sub.AddNode("add", ADD, 2, 1); - - builder_sub.AddDataEdge(data_1, 0, add, 0); - builder_sub.AddDataEdge(data_2, 0, add, 1); - auto sub_graph = builder_sub.GetGraph(); - sub_graph->SetName("while_sub"); - // build root graph - auto builder = ut::GraphBuilder("g1"); - auto data = builder.AddNode("data1", DATA, 0, 1); - auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1); - auto c1 = builder.AddNode("cast1", CAST, 1, 1); - auto c2 = builder.AddNode("cast2", CAST, 1, 1); - // add while op - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1,1,1,1})); - tensor_desc->SetFormat(FORMAT_ND); - tensor_desc->SetDataType(DT_INT32); - - auto op_desc = std::make_shared("while", WHILE); - for (int i = 0; i < 2; ++i) { - op_desc->AddInputDesc(tensor_desc->Clone()); - } - for (int i = 0; i < 2; ++i) { - op_desc->AddOutputDesc(tensor_desc->Clone()); - } - AttrUtils::SetBool(op_desc,"_need_infer_again", true); - op_desc->AddSubgraphName(sub_graph->GetName()); - op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); - auto root_graph = builder.GetGraph(); - auto while_op = root_graph->AddNode(op_desc); - - builder.AddDataEdge(data, 0, while_op, 0); - builder.AddDataEdge(const_op, 0, while_op, 1); - builder.AddDataEdge(while_op, 0, c1, 0); - builder.AddDataEdge(while_op, 1, c2, 0); - sub_graph->SetParentGraph(root_graph); - sub_graph->SetParentNode(while_op); - root_graph->AddSubgraph(sub_graph); - return root_graph; -} - -TEST_F(UTESTGraphPassesBasePass, while_infershape) { -NamesToPass names_to_pass; -auto test_pass = UtestTestPass(); -names_to_pass.push_back(std::make_pair("test", &test_pass)); - -auto graph = BuildWhileGraph1(); -auto ge_pass = GEPass(graph); -auto while_node = graph->FindNode("while"); -EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); -EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); -} - -} // namespace ge +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#define protected public +#include "graph/passes/base_pass.h" +#undef protected + +#include "framework/common/types.h" +#include "graph/node.h" +#include "graph/utils/graph_utils.h" +#include "graph_builder_utils.h" + +template class std::unordered_set; + +namespace ge { +class UtestTestPass : public BaseNodePass { + public: + UtestTestPass() = default; + UtestTestPass(bool dead_loop) : dead_loop_(dead_loop), run_times_(0) {} + + Status Run(NodePtr &node) override { + ++run_times_; + iter_nodes_.push_back(node); + auto iter = names_to_add_del_.find(node->GetName()); + if (iter != names_to_add_del_.end()) { + for (const auto &node_name : iter->second) { + auto del_node = node->GetOwnerComputeGraph()->FindNode(node_name); + GraphUtils::IsolateNode(del_node, {0}); + AddNodeDeleted(del_node); + } + } + iter = names_to_add_repass_.find(node->GetName()); + if (iter != names_to_add_repass_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddRePassNode(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_repass_.erase(iter); + } + } + + iter = names_to_add_repass_immediate_.find(node->GetName()); + if (iter != names_to_add_repass_immediate_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddImmediateRePassNode(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_repass_immediate_.erase(iter); + } + } + + iter = names_to_add_suspend_.find(node->GetName()); + if (iter != names_to_add_suspend_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddNodeSuspend(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_suspend_.erase(iter); + } + } + + iter = names_to_add_resume_.find(node->GetName()); + if (iter != names_to_add_resume_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddNodeResume(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_resume_.erase(iter); + } + } + // simulate infershape pass + if(node->GetType() == WHILE){ + bool need_repass = false; + AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass); + if(!OptionExists(kOptimizeAfterSubGraph)){ + return SUCCESS; + } + if(need_repass){ + AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false); + AddImmediateRePassNode(node); + } + else{ + // clear attr on while + node->GetOpDesc()->DelAttr("_need_infer_again"); + } + } + return SUCCESS; + } + + Status OnSuspendNodesLeaked() override { + // resume all node remain in suspend_nodes when leaked + auto compute_graph = (iter_nodes_.size() > 0) ? iter_nodes_[0]->GetOwnerComputeGraph() : nullptr; + if (compute_graph == nullptr) { + return SUCCESS; + } + + for (const auto &node_name : names_to_add_resume_onleaked_) { + auto node_to_resume = compute_graph->FindNode(node_name); + AddNodeResume(node_to_resume); + } + return SUCCESS; + } + void clear() { iter_nodes_.clear(); } + std::vector GetIterNodes() { return iter_nodes_; } + + void AddRePassNodeName(const std::string &iter_node, const std::string &re_pass_node) { + names_to_add_repass_[iter_node].insert(re_pass_node); + } + void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { + names_to_add_del_[iter_node].insert(del_node); + } + void AddRePassImmediateNodeName(const std::string &iter_node, const std::string &re_pass_node) { + names_to_add_repass_immediate_[iter_node].insert(re_pass_node); + } + + void AddSuspendNodeName(const std::string &iter_node, const std::string &suspend_node) { + names_to_add_suspend_[iter_node].insert(suspend_node); + } + void AddResumeNodeName(const std::string &iter_node, const std::string &resume_node) { + names_to_add_resume_[iter_node].insert(resume_node); + } + void AddResumeNodeNameOnLeaked(const std::string &resume_node) { + names_to_add_resume_onleaked_.insert(resume_node); + } + + unsigned int GetRunTimes() { return run_times_; } + + private: + std::vector iter_nodes_; + std::map> names_to_add_del_; + std::map> names_to_add_repass_; + std::map> names_to_add_repass_immediate_; + std::map> names_to_add_suspend_; + std::map> names_to_add_resume_; + std::unordered_set names_to_add_resume_onleaked_; + + bool dead_loop_; + unsigned int run_times_; +}; + +class TestDelPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override { return SUCCESS; } +}; + +class UTESTGraphPassesBasePass : public testing::Test { + protected: + UTESTGraphPassesBasePass() { + auto p1 = new UtestTestPass; + names_to_pass_.push_back(std::make_pair("test1", p1)); + } + void SetUp() override { + for (auto &name_to_pass : names_to_pass_) { + dynamic_cast(name_to_pass.second)->clear(); + } + } + ~UTESTGraphPassesBasePass() override { + for (auto &name_to_pass : names_to_pass_) { + delete name_to_pass.second; + } + } + NamesToPass names_to_pass_; +}; +/// reshape1 +/// | +/// add1 +/// / \. +/// | | +/// data1 const1 +ComputeGraphPtr BuildGraph1() { + auto builder = ut::GraphBuilder("g1"); + auto data = builder.AddNode("data1", DATA, 0, 1); + auto a1 = builder.AddNode("add1", ADD, 2, 1); + auto c1 = builder.AddNode("const1", CONSTANT, 0, 1); + auto r1 = builder.AddNode("reshape1", RESHAPE, 1, 1); + + builder.AddDataEdge(data, 0, a1, 0); + builder.AddDataEdge(c1, 0, a1, 1); + builder.AddDataEdge(a1, 0, r1, 0); + + return builder.GetGraph(); +} + +/// sum1 +/// / \. +/// / \. +/// / \. +/// reshape1 addn1 +/// | c | +/// add1 <--- shape1 +/// / \ | +/// | | | +/// data1 const1 const2 +ComputeGraphPtr BuildGraph2() { + auto builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", DATA, 0, 1); + auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); + auto const2 = builder.AddNode("const2", CONSTANT, 0, 1); + auto add1 = builder.AddNode("add1", ADD, 2, 1); + auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1); + auto reshape1 = builder.AddNode("reshape1", RESHAPE, 1, 1); + auto addn1 = builder.AddNode("addn1", ADDN, 1, 1); + auto sum1 = builder.AddNode("sum1", SUM, 2, 1); + + builder.AddDataEdge(data1, 0, add1, 0); + builder.AddDataEdge(const1, 0, add1, 1); + builder.AddDataEdge(const2, 0, shape1, 0); + builder.AddControlEdge(shape1, add1); + builder.AddDataEdge(add1, 0, reshape1, 0); + builder.AddDataEdge(shape1, 0, addn1, 0); + builder.AddDataEdge(reshape1, 0, sum1, 0); + builder.AddDataEdge(addn1, 0, sum1, 1); + + return builder.GetGraph(); +} + +/// rnextiteration +/// | | +/// merge +/// | +/// data1 +ComputeGraphPtr BuildGraph3() { + auto builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", DATA, 0, 1); + auto merge1 = builder.AddNode("merge1", MERGE, 2, 1); + auto next1 = builder.AddNode("next1", NEXTITERATION, 1, 1); + + builder.AddDataEdge(data1, 0, merge1, 0); + builder.AddDataEdge(merge1, 0, next1, 0); + builder.AddDataEdge(next1, 0, merge1, 1); + builder.AddControlEdge(merge1, next1); + builder.AddControlEdge(next1, merge1); + + return builder.GetGraph(); +} + +/// cast1--shape1 +/// / +/// data1 +/// \ +/// transdata1--shape2 +ComputeGraphPtr BuildGraph4() { + auto builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", DATA, 0, 1); + auto cast1 = builder.AddNode("cast1", CAST, 1, 1); + auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1); + auto transdata1 = builder.AddNode("transdata1", TRANSDATA, 1, 1); + auto shape2 = builder.AddNode("shape2", SHAPE, 1, 1); + + builder.AddDataEdge(data1, 0, cast1, 0); + builder.AddDataEdge(data1, 0, transdata1, 0); + builder.AddDataEdge(cast1, 0, shape1, 0); + builder.AddDataEdge(transdata1, 0, shape2, 0); + return builder.GetGraph(); +} + +void CheckIterOrder(UtestTestPass *pass, std::vector> &nodes_layers) { + std::unordered_set layer_nodes; + size_t layer_index = 0; + for (const auto &node : pass->GetIterNodes()) { + layer_nodes.insert(node->GetName()); + EXPECT_LT(layer_index, nodes_layers.size()); + if (layer_nodes == nodes_layers[layer_index]) { + layer_index++; + layer_nodes.clear(); + } + } + EXPECT_EQ(layer_index, nodes_layers.size()); +} + +/// Op1 +/// | +/// Merge +/// / \. +/// Op2 Op3 +TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { + auto builder = ut::GraphBuilder("g1"); + auto merge_node = builder.AddNode("Merge", MERGE, 1, 1); + auto node1 = builder.AddNode("Op1", RELU, 1, 1); + auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1); + auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1); + + GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); + + EXPECT_EQ(node1->GetOutDataNodes().size(), 1); + + TestDelPass del_pass; + auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1}); + EXPECT_EQ(ret, FAILED); + + OpDescPtr op_desc = std::make_shared("merge", MERGE); + NodePtr node = shared_ptr(new (std::nothrow) Node(op_desc, nullptr)); + ret = del_pass.IsolateAndDeleteNode(node, {0, -1}); + EXPECT_EQ(ret, FAILED); +} + +/// Op1 +/// | +/// Merge +/// / \. +/// Op2 Op3 +TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { + auto builder = ut::GraphBuilder("g1"); + auto merge_node = builder.AddNode("Merge", MERGE, 1, 2); + auto node1 = builder.AddNode("Op1", RELU, 1, 1); + auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1); + auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1); + + GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); + + EXPECT_EQ(node1->GetOutDataNodes().size(), 1); + + TestDelPass del_pass; + auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1}); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UTESTGraphPassesBasePass, data_graph) { + auto graph = BuildGraph1(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + auto *pass = dynamic_cast(names_to_pass_[0].second); + + EXPECT_EQ(pass->GetIterNodes().size(), 4); + std::vector> layers; + layers.push_back({"data1", "const1"}); + layers.push_back({"add1"}); + layers.push_back({"reshape1"}); + CheckIterOrder(pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, graph_with_control_link) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + auto *pass = dynamic_cast(names_to_pass_[0].second); + + EXPECT_EQ(pass->GetIterNodes().size(), 8); + EXPECT_EQ(pass->GetIterNodes().at(3)->GetName(), "shape1"); + + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1", "reshape1"}); + layers.push_back({"sum1"}); + CheckIterOrder(pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_after) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddRePassNodeName("add1", "sum1"); + test_pass.AddRePassNodeName("shape1", "sum1"); + test_pass.AddRePassNodeName("shape1", "add1"); + test_pass.AddRePassNodeName("data1", "add1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 8); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_before) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddRePassNodeName("add1", "data1"); + + auto graph = BuildGraph1(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 5); + EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1"); + EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1"); + EXPECT_EQ(test_pass.GetIterNodes().at(4)->GetName(), "data1"); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_before_multi_times) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddRePassNodeName("add1", "data1"); + test_pass.AddRePassNodeName("add1", "const1"); + test_pass.AddRePassNodeName("reshape1", "data1"); + + auto graph = BuildGraph1(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 6); + EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1"); + EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1"); +} + +TEST_F(UTESTGraphPassesBasePass, del_after) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddDelNodeName("add1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 7); +} + +TEST_F(UTESTGraphPassesBasePass, del_after_multiple) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddDelNodeName("add1", "sum1"); + test_pass.AddDelNodeName("add1", "reshape1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 6); +} + +TEST_F(UTESTGraphPassesBasePass, del_after_break_link) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddDelNodeName("shape1", "add1"); + test_pass.AddDelNodeName("shape1", "addn1"); + test_pass.AddRePassNodeName("shape1", "shape1"); + test_pass.AddRePassNodeName("shape1", "reshape1"); + test_pass.AddRePassNodeName("shape1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 7); +} + +TEST_F(UTESTGraphPassesBasePass, del_self_and_after) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddDelNodeName("shape1", "add1"); + test_pass.AddDelNodeName("shape1", "addn1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 6); +} + +TEST_F(UTESTGraphPassesBasePass, del_before) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddDelNodeName("reshape1", "add1"); + test_pass.AddDelNodeName("sum1", "addn1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 8); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_and_del) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddRePassNodeName("add1", "sum1"); + test_pass.AddDelNodeName("reshape1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 7); +} +/* +TEST_F(UTESTGraphPassesBasePass, dead_loop) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(true); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddRePassNodeName("add1", "sum1"); + test_pass.AddRePassNodeName("sum1", "add1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetRunTimes(), 1007); +} +*/ + +TEST_F(UTESTGraphPassesBasePass, while_loop) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(true); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + auto graph = BuildGraph3(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); +} + +/// data1 const +/// \ / +/// while +/// / \. +/// | | +/// cast1 cast2 +ComputeGraphPtr BuildWhileGraph1() { + // build sub graph + auto builder_sub = ut::GraphBuilder("sub"); + auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1); + auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1); + auto add = builder_sub.AddNode("add", ADD, 2, 1); + + builder_sub.AddDataEdge(data_1, 0, add, 0); + builder_sub.AddDataEdge(data_2, 0, add, 1); + auto sub_graph = builder_sub.GetGraph(); + sub_graph->SetName("while_sub"); + // build root graph + auto builder = ut::GraphBuilder("g1"); + auto data = builder.AddNode("data1", DATA, 0, 1); + auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1); + auto c1 = builder.AddNode("cast1", CAST, 1, 1); + auto c2 = builder.AddNode("cast2", CAST, 1, 1); + // add while op + auto tensor_desc = std::make_shared(); + tensor_desc->SetShape(GeShape({1,1,1,1})); + tensor_desc->SetFormat(FORMAT_ND); + tensor_desc->SetDataType(DT_INT32); + + auto op_desc = std::make_shared("while", WHILE); + for (int i = 0; i < 2; ++i) { + op_desc->AddInputDesc(tensor_desc->Clone()); + } + for (int i = 0; i < 2; ++i) { + op_desc->AddOutputDesc(tensor_desc->Clone()); + } + AttrUtils::SetBool(op_desc,"_need_infer_again", true); + op_desc->AddSubgraphName(sub_graph->GetName()); + op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); + auto root_graph = builder.GetGraph(); + auto while_op = root_graph->AddNode(op_desc); + + builder.AddDataEdge(data, 0, while_op, 0); + builder.AddDataEdge(const_op, 0, while_op, 1); + builder.AddDataEdge(while_op, 0, c1, 0); + builder.AddDataEdge(while_op, 1, c2, 0); + sub_graph->SetParentGraph(root_graph); + sub_graph->SetParentNode(while_op); + root_graph->AddSubgraph(sub_graph); + return root_graph; +} + +TEST_F(UTESTGraphPassesBasePass, while_infershape) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + auto graph = BuildWhileGraph1(); + auto ge_pass = GEPass(graph); + auto while_node = graph->FindNode("while"); + EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass pre_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 9);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "add1", "sum1"}); + CheckIterOrder(test_pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass cur_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "reshape1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass next_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "sum1"); + // repass node after next_node immediately + test_pass->AddRePassImmediateNodeName("add1", "sum1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 8); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} +/** + * A->B->C + * if node B suspend its pre_node A, and C resume A, it is a useless operation, so iter_order should follow normal order + * when C resuem A, A will pass again. + */ +TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_C_resume_A) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeName("sum1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + layers.push_back({"add1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * A->B->C + * if node B suspend its pre_node A, and B resume A, it is a useless operation, so iter_order should follow normal order + * when B resuem A, A will pass again. + */ +TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_B_resume_A) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeName("reshape1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1", "add1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * A->B->C + * if node B resume C(which is not suspended), it is a useless operation, C will not pass. + */ +TEST_F(UTESTGraphPassesBasePass, B_resume_node_not_suspended) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddResumeNodeName("reshape1", "sum1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 8); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * A->B->C + * if node B suspend its pre_node A, it is a useless operation, so iter_order should follow normal order + * because nobody resume it ,which means A is a leaked node, so return fail + */ +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_nobody_resume_it_return_failed) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + // suspend pre_node immediately + test_pass.AddSuspendNodeName("reshape1", "add1"); + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), INTERNAL_ERROR); +} + +/** + * A->B->C + * if node B suspend its pre_node A, it is a useless operation, + * so iter_order should follow normal order + * resume A on leaked, which means A will pass again + */ +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_resume_it_onleaked) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeNameOnLeaked("add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + layers.push_back({"add1"}); + CheckIterOrder(test_pass, layers); +} + + +/// cast1--shape1 +/// / +/// data1 +/// \ +/// transdata1--shape2 +/** + * suspend cur node + * cast1 suspend itself, shape2 resume cast1 + * iter order follows : data1; cast1,transdata1; shape2; cast1 ; shape1 + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_cur_node_shape2_resume_cast1) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeName("shape2", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1"}); + layers.push_back({"shape2"}); + layers.push_back({"cast1", "shape1"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend cur node + * cast1 suspend itself, then resume cast1 + * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_itself) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeName("cast1", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1","cast1","shape1", "shape2"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend cur node + * cast1 suspend itself, then resume cast1 on leaked + * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_onleaked) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeNameOnLeaked("cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1", "shape2"}); + layers.push_back({"cast1","shape1"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend next node + * data1 suspend cast1, then resume cast1 on leaked + * iter order follows : data1; transdata1, shape2; cast1, shape1. + */ +TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_resume_cast1_onleaked) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("data1", "cast1"); + test_pass->AddResumeNodeNameOnLeaked("cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 5); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"transdata1", "shape2"}); + layers.push_back({"cast1","shape1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * suspend next node + * data1 suspend cast1, nobody resume it + * iter order follows : data1; transdata1, shape2; + * run ret is failed ,because node leaked + */ +TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_nobody_resume) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("data1", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), INTERNAL_ERROR); + EXPECT_EQ(test_pass->GetIterNodes().size(), 3); +} + +/* +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + // repass next_node immediately + test_pass.AddRePassNodeName("reshape1", "sum1"); + // repass node after next_node immediately + test_pass.AddRePassNodeName("add1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(&test_pass, layers); +}*/ +} // namespace ge diff --git a/tests/ut/ge/graph/passes/cond_branch_v1_unittest.cc b/tests/ut/ge/graph/passes/cond_branch_v1_unittest.cc index 0927aec4..a6109d51 100644 --- a/tests/ut/ge/graph/passes/cond_branch_v1_unittest.cc +++ b/tests/ut/ge/graph/passes/cond_branch_v1_unittest.cc @@ -34,11 +34,11 @@ namespace { /// net_output /// | /// merge -/// / \ +/// / \. /// square add -/// F| T/ T\ +/// F| T/ T\. /// switch1 switch2 -/// / \ / \ +/// / \ / \. /// var1 var2 var3 /// ComputeGraphPtr BuildGraph1() { diff --git a/tests/ut/ge/graph/passes/constant_folding_pass_unittest.cc b/tests/ut/ge/graph/passes/constant_folding_pass_unittest.cc index 96788b53..08944cf1 100644 --- a/tests/ut/ge/graph/passes/constant_folding_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/constant_folding_pass_unittest.cc @@ -173,8 +173,8 @@ namespace { /// shapeNo1 /// | /// addnYes1 -/// / \ -/// / \ +/// / \. +/// / \. /// const1 const2 ComputeGraphPtr BuildGraph1() { auto builder = ut::GraphBuilder("test"); @@ -223,8 +223,8 @@ ComputeGraphPtr BuildGraph2() { /// shapeNo1 /// | c /// addnYes1 <----- dataNo1 -/// / \ -/// / \ +/// / \. +/// / \. /// const1 const2 ComputeGraphPtr BuildGraph3() { auto builder = ut::GraphBuilder("test"); @@ -249,8 +249,8 @@ ComputeGraphPtr BuildGraph3() { /// shapeNo1 /// | c /// addnYes1 <--------- -/// / \ \ -/// / \ c \ +/// / \ \. +/// / \ c \. /// const1 const2 <----- dataNo1 ComputeGraphPtr BuildGraph4() { auto builder = ut::GraphBuilder("test"); @@ -276,7 +276,7 @@ ComputeGraphPtr BuildGraph4() { /// shapeNo1 /// | c /// addnYes1 <----- dataNo1 -/// / \ +/// / \. /// / \ c /// const1 const2 <----- dataNo2 ComputeGraphPtr BuildGraph5() { @@ -306,8 +306,8 @@ ComputeGraphPtr BuildGraph5() { /// addYes1 <---- const3 /// | /// addnYes1 <- -/// / \ \ -/// / \ \ +/// / \ \. +/// / \ \. /// const1 const2 const4 ComputeGraphPtr BuildGraph6() { auto builder = ut::GraphBuilder("test"); @@ -332,12 +332,12 @@ ComputeGraphPtr BuildGraph6() { } /// netoutput1 -/// / \ +/// / \. /// shapeNo1 ShpaeNo2 /// \ / /// huberLoss1 -/// / | \ -/// / | \ +/// / | \. +/// / | \. /// const1 const2 const3 ComputeGraphPtr BuildGraph7() { auto builder = ut::GraphBuilder("test"); @@ -365,8 +365,8 @@ ComputeGraphPtr BuildGraph7() { /// shapeNo1 /// | /// addnNo1 -/// / \ -/// / \ +/// / \. +/// / \. /// const1 const2 ComputeGraphPtr BuildGraph8() { auto builder = ut::GraphBuilder("test"); @@ -389,8 +389,8 @@ ComputeGraphPtr BuildGraph8() { /// shapeNo1 /// | /// addnYes1 -/// / \ -/// / \ +/// / \. +/// / \. /// const1 data1 ComputeGraphPtr BuildGraph9() { auto builder = ut::GraphBuilder("test"); @@ -409,12 +409,12 @@ ComputeGraphPtr BuildGraph9() { } /// netoutput1 -/// / \ +/// / \. /// addDim sqrt1 /// \ / /// switch1 -/// / \ -/// / \ +/// / \. +/// / \. /// const1 const2 ComputeGraphPtr BuildGraph10() { auto builder = ut::GraphBuilder("test"); diff --git a/tests/ut/ge/graph/passes/dimension_compute_pass_unittest.cc b/tests/ut/ge/graph/passes/dimension_compute_pass_unittest.cc index d2d736b6..3936b525 100644 --- a/tests/ut/ge/graph/passes/dimension_compute_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/dimension_compute_pass_unittest.cc @@ -63,8 +63,8 @@ namespace { /// shapeNo1 /// | /// addnNo1 -/// / \ -/// / \ +/// / \. +/// / \. /// const1 const2 ComputeGraphPtr BuildGraph8() { auto builder = ut::GraphBuilder("test"); @@ -87,8 +87,8 @@ ComputeGraphPtr BuildGraph8() { /// shapeNo1 /// | /// addnYes1 -/// / \ -/// / \ +/// / \. +/// / \. ///const1 data1 ComputeGraphPtr BuildGraph9() { auto builder = ut::GraphBuilder("test"); diff --git a/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc b/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc index f58d6d9b..c0cce260 100644 --- a/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc +++ b/tests/ut/ge/graph/passes/folding_kernel/fill_kernel_unittest.cc @@ -64,6 +64,7 @@ class UtestGraphPassesFoldingKernelFillKernel : public testing::Test { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -124,6 +125,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillBoolShape2And3) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -230,6 +232,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsHaveNegativeNumber) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -284,6 +287,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsTypeNotSupport) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -310,6 +314,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsOverflow) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -336,6 +341,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsMulDataTypeOverflow) { op_desc_ptr->AddInputDesc(dims_tensor_desc); op_desc_ptr->AddInputDesc(value_tensor_desc); + op_desc_ptr->AddOutputDesc(dims_tensor_desc); std::vector input = {dim_tensor, value_tensor}; std::vector outputs; @@ -343,3 +349,33 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsMulDataTypeOverflow) { EXPECT_EQ(PARAM_INVALID, status); } + +TEST_F(UtestGraphPassesFoldingKernelFillKernel, OutputdescUnknown) { + ge::OpDescPtr op_dims = std::make_shared(); + vector dims_vec = {2}; + vector dims_value_vec = {2, 3}; + GeTensorDesc dims_tensor_desc(GeShape(dims_vec), FORMAT_NCHW, DT_INT32); + GeTensorPtr dim_tensor = std::make_shared(dims_tensor_desc, (uint8_t *) dims_value_vec.data(), + dims_value_vec.size() * sizeof(int32_t)); + OpDescUtils::SetWeights(op_dims, dim_tensor); + + ge::OpDescPtr op_value = std::make_shared(); + vector data_vec = {1}; + GeTensorDesc value_tensor_desc(GeShape(), FORMAT_NCHW, DT_BOOL); + GeTensorPtr value_tensor = + std::make_shared(value_tensor_desc, (uint8_t *) data_vec.data(), data_vec.size() * sizeof(bool)); + OpDescUtils::SetWeights(op_value, value_tensor); + + op_desc_ptr->AddInputDesc(dims_tensor_desc); + op_desc_ptr->AddInputDesc(value_tensor_desc); + + vector out_vec = {-1, -1}; + GeTensorDesc out_tensor_desc(GeShape(out_vec), FORMAT_NCHW, DT_INT32); + op_desc_ptr->AddOutputDesc(out_tensor_desc); + + std::vector input = {dim_tensor, value_tensor}; + std::vector outputs; + Status status = kernel->Compute(op_desc_ptr, input, outputs); + + EXPECT_EQ(NOT_CHANGED, status); +} \ No newline at end of file diff --git a/tests/ut/ge/graph/passes/folding_kernel/gather_v2_kernel_unittest.cc b/tests/ut/ge/graph/passes/folding_kernel/gather_v2_kernel_unittest.cc index 0083146b..ad165d25 100644 --- a/tests/ut/ge/graph/passes/folding_kernel/gather_v2_kernel_unittest.cc +++ b/tests/ut/ge/graph/passes/folding_kernel/gather_v2_kernel_unittest.cc @@ -92,7 +92,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, INT32Axis0VersionA) { GeTensorPtr tensor_out = outputs[0]; int32_t *data_buf = (int32_t *)tensor_out->GetData().data(); vector expect_out = {2, 2}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -139,7 +139,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, INT32Axis0VersionB) { GeTensorPtr tensor_out = outputs[0]; int32_t *data_buf = (int32_t *)tensor_out->GetData().data(); vector expect_out = {3, 3}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -186,7 +186,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, INT64Axis0) { GeTensorPtr tensor_out = outputs[0]; int64_t *data_buf = (int64_t *)tensor_out->GetData().data(); vector expect_out = {3, 3}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -233,7 +233,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, INT32Axis0) { GeTensorPtr tensor_out = outputs[0]; int32_t *data_buf = (int32_t *)tensor_out->GetData().data(); vector expect_out = {11, 12, 13, 14, 15, 16, 17, 18, 19, 11, 12, 13, 14, 15, 16, 17, 18, 19}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -279,7 +279,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, INT32Axis0And1) { GeTensorPtr tensor_out = outputs[0]; int32_t *data_buf = (int32_t *)tensor_out->GetData().data(); vector expect_out = {11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -327,7 +327,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, INT32Axis1) { GeTensorPtr tensor_out = outputs[0]; int32_t *data_buf = (int32_t *)tensor_out->GetData().data(); vector expect_out = {4, 5, 6, 4, 5, 6, 14, 15, 16, 14, 15, 16}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -374,7 +374,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, INT32Axis2) { GeTensorPtr tensor_out = outputs[0]; int32_t *data_buf = (int32_t *)tensor_out->GetData().data(); vector expect_out = {1, 1, 4, 4, 7, 7, 11, 11, 14, 14, 17, 17}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -422,7 +422,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, INT32Axis3) { GeTensorPtr tensor_out = outputs[0]; int32_t *data_buf = (int32_t *)tensor_out->GetData().data(); vector expect_out = {1, 2, 4, 5, 7, 8, 11, 12, 14, 15, 17, 18, 1, 2, 4, 5, 7, 8, 11, 12, 14, 15, 17, 18}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -470,7 +470,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, INT8Axis0) { GeTensorPtr tensor_out = outputs[0]; int8_t *data_buf = (int8_t *)tensor_out->GetData().data(); vector expect_out = {2, 2}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -517,7 +517,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, INT16Axis0) { GeTensorPtr tensor_out = outputs[0]; int16_t *data_buf = (int16_t *)tensor_out->GetData().data(); vector expect_out = {2, 2}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -564,7 +564,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, UINT8Axis0) { GeTensorPtr tensor_out = outputs[0]; uint8_t *data_buf = (uint8_t *)tensor_out->GetData().data(); vector expect_out = {2, 2}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -611,7 +611,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, UINT16Axis0) { GeTensorPtr tensor_out = outputs[0]; uint16_t *data_buf = (uint16_t *)tensor_out->GetData().data(); vector expect_out = {2, 2}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -658,7 +658,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, UINT32Axis0) { GeTensorPtr tensor_out = outputs[0]; uint32_t *data_buf = (uint32_t *)tensor_out->GetData().data(); vector expect_out = {2, 2}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -705,7 +705,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, UINT64Axis0) { GeTensorPtr tensor_out = outputs[0]; uint64_t *data_buf = (uint64_t *)tensor_out->GetData().data(); vector expect_out = {2, 2}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { EXPECT_EQ(*(data_buf + i), expect_out[i]); } } @@ -753,7 +753,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, DoubleAxis0) { GeTensorPtr tensor_out = outputs[0]; double *data_buf = (double *)tensor_out->GetData().data(); vector expect_out = {2, 2}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { double diff = *(data_buf + i) - expect_out[i]; bool is_same = fabs(diff) < 0.0001 ? true : false; EXPECT_EQ(is_same, true); @@ -802,7 +802,7 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, Float16Axis0) { GeTensorPtr tensor_out = outputs[0]; fp16_t *data_buf = (fp16_t *)tensor_out->GetData().data(); vector expect_out = {2, 2}; - for (int i = 0; i < expect_out.size(); i++) { + for (size_t i = 0; i < expect_out.size(); i++) { double diff = (double)*(data_buf + i) - (double)expect_out[i]; bool is_same = fabs(diff) < 0.0001 ? true : false; EXPECT_EQ(is_same, true); diff --git a/tests/ut/ge/graph/passes/folding_kernel/ssd_prior_box_kernel_unittest.cc b/tests/ut/ge/graph/passes/folding_kernel/ssd_prior_box_kernel_unittest.cc index ccc90afb..65e816cc 100644 --- a/tests/ut/ge/graph/passes/folding_kernel/ssd_prior_box_kernel_unittest.cc +++ b/tests/ut/ge/graph/passes/folding_kernel/ssd_prior_box_kernel_unittest.cc @@ -46,7 +46,7 @@ class UtestGraphPassesFoldingKernelSsdPriorboxKernel : public testing::Test { /// convolution data /// | / /// ssdpriorbox -/// \ +/// \. /// reshape class NodeBuilder { public: diff --git a/tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc b/tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc index 8c3469c8..8638be5c 100644 --- a/tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc @@ -120,7 +120,7 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { /// graph with subgraph /// const -/// / \ +/// / \. /// cast1 cast1 /// \ / /// case diff --git a/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc index 9da2565d..cc9a4077 100644 --- a/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc @@ -34,7 +34,6 @@ #include "graph/tuning_utils.h" #include "graph_builder_utils.h" #include "graph/ge_context.h" -#include "graph/ge_local_context.h" #include "inc/pass_manager.h" #undef protected #undef private @@ -62,13 +61,9 @@ static ComputeGraphPtr BuildGraph1() { TEST_F(UtestGlobalStepInsertPass, skip_insert) { auto graph = BuildGraph1(); - std::string build_mode; - std::map options_map; - options_map.insert({ge::RUN_FLAG, "0"}); - ge::GetThreadLocalContext().SetGraphOption(options_map); GlobalStepInsertPass pass; Status status = pass.Run(graph); EXPECT_EQ(status, SUCCESS); NodePtr found_node = graph->FindNode(NODE_NAME_GLOBAL_STEP); - EXPECT_EQ(found_node, nullptr); + EXPECT_NE(found_node, nullptr); } diff --git a/tests/ut/ge/graph/passes/infer_base_pass_unittest.cc b/tests/ut/ge/graph/passes/infer_base_pass_unittest.cc new file mode 100644 index 00000000..24cc5c1b --- /dev/null +++ b/tests/ut/ge/graph/passes/infer_base_pass_unittest.cc @@ -0,0 +1,359 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "graph/passes/infer_base_pass.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph_builder_utils.h" + +using namespace std; +using namespace testing; +namespace ge { +class ChildPassBuilder; +static const char *kInferTimes = "infer_times"; +class InferBasePassStub : public InferBasePass { + public: + friend class ChildPassBuilder; + graphStatus Infer(NodePtr &node) override{ + call_infer_times++; + for (size_t i = 0; i < node->GetOutDataNodesSize(); ++i) { + auto output_td = node->GetOpDesc()->MutableOutputDesc(i); + int times = 0; + AttrUtils::GetInt(output_td, kInferTimes, times); + AttrUtils::SetInt(output_td, kInferTimes, times + 1); + } + return infer_result_; + }; + + int32_t call_infer_times = 0; + int32_t call_update_tensor_desc_times = 0; + int32_t call_update_from_subgraph_times = 0; + int32_t call_update_from_subgraph_multi_dims_times = 0; + std::vector> update_td_pairs; + + private: + bool NeedInfer(const NodePtr &node) const override { + return need_infer_; + }; + std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override { return "test SerialTensorInfo"; }; + graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override { + call_update_tensor_desc_times++; + changed = td_changed_; + int times = 0; + if (AttrUtils::GetInt(src, kInferTimes, times)) { + AttrUtils::SetInt(dst, kInferTimes, times); + } + update_td_pairs.emplace_back(src, dst); + return GRAPH_SUCCESS; + }; + graphStatus UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) override { + call_update_from_subgraph_times++; + return GRAPH_SUCCESS; + }; + graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) override { + call_update_from_subgraph_multi_dims_times++; + return GRAPH_SUCCESS; + }; + bool td_changed_; + bool need_infer_; + graphStatus infer_result_; +}; + +class ChildPassBuilder { + public: + ChildPassBuilder &SetNeedInferFlag(bool flag) { + need_infer_ = flag; + return *this; + } + + ChildPassBuilder &SetInferResult(graphStatus ret) { + infer_result_ = ret; + return *this; + } + + ChildPassBuilder &SetTdChangedFlag(bool changed_flag) { + td_changed_ = changed_flag; + return *this; + } + + InferBasePassStub Build() { + InferBasePassStub ib; + ib.td_changed_ = td_changed_; + ib.need_infer_ = need_infer_; + ib.infer_result_ = infer_result_; + return ib; + } + + private: + bool td_changed_ = false; + bool need_infer_ = true; + graphStatus infer_result_ = GRAPH_SUCCESS; +}; + +class UtestGraphInferBasePassStub : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +/* + * data1 data2 + * \ / + * sub1 + * | + * netoutput + */ +ut::GraphBuilder TestSubgraphBuilder() { + ut::GraphBuilder builder = ut::GraphBuilder("branch_graph"); + std::vector shape1 = {1,1}; + auto data1 = builder.AddNode("data1_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1); + auto data1_desc = data1->GetOpDesc(); + EXPECT_NE(data1_desc, nullptr); + AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); + std::vector shape2 = {2,2}; + auto data2 = builder.AddNode("data2_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2); + auto data2_desc = data2->GetOpDesc(); + EXPECT_NE(data2_desc, nullptr); + AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); + + auto sub1 = builder.AddNode("Sub", "Sub", 2, 1); + std::vector shape7 = {8,8}; + auto netoutput = builder.AddNode("output", NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7); + auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0); + EXPECT_NE(input0_desc, nullptr); + AttrUtils::SetInt(input0_desc, "_parent_node_index", 0); + + builder.AddDataEdge(data1, 0, sub1, 0); + builder.AddDataEdge(data2, 0, sub1, 1); + builder.AddDataEdge(sub1, 0, netoutput, 0); + return builder; +} + +/* + * data1 data2 + * \ / + * case1 + * | + * netoutput + */ +ut::GraphBuilder RootGraphBuilder() { + ut::GraphBuilder builder = ut::GraphBuilder("root_graph"); + auto data1 = builder.AddNode("data1", "Data", 0, 1); + auto data2 = builder.AddNode("data2", "Data", 0, 1); + auto case1 = builder.AddNode("case1", CASE, 2, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + builder.AddDataEdge(data1, 0, case1, 0); + builder.AddDataEdge(data2, 0, case1, 1); + builder.AddDataEdge(case1, 0, netoutput, 0); + + auto parent_graph = builder.GetGraph(); + auto subgraph_builder = TestSubgraphBuilder(); + auto subgraph = subgraph_builder.GetGraph(); + case1->GetOpDesc()->AddSubgraphName(subgraph->GetName()); + case1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph->GetName()); + subgraph->SetParentNode(case1); + subgraph->SetParentGraph(parent_graph); + EXPECT_EQ(parent_graph->AddSubgraph(subgraph->GetName(), subgraph), GRAPH_SUCCESS); + return builder; +} + +/* + * data1 data2 + * \ / + * add1 + * | + * netoutput + */ +ut::GraphBuilder NoSubgraphBuilder() { + ut::GraphBuilder builder = ut::GraphBuilder("no_subgraph"); + auto data1 = builder.AddNode("data1", "Data", 0, 1); + auto data2 = builder.AddNode("data2", "Data", 0, 1); + auto add1 = builder.AddNode("add1", ADD, 2, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + builder.AddDataEdge(data1, 0, add1, 0); + builder.AddDataEdge(data2, 0, add1, 1); + builder.AddDataEdge(add1, 0, netoutput, 0); + return builder; +} + +TEST_F(UtestGraphInferBasePassStub, CallInfer_WhenNeedInferReturnTrue) { + auto builder = NoSubgraphBuilder(); + auto test_graph = builder.GetGraph(); + auto add_node = test_graph->FindNode("add1"); + EXPECT_NE(add_node, nullptr); + ChildPassBuilder pass_builder; + auto stub_base_pass = pass_builder.Build(); + + // NeedInfer return true + EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); + EXPECT_EQ(stub_base_pass.call_infer_times, 1); + int times = -1; + EXPECT_TRUE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times)); + EXPECT_EQ(times, 1); +} + +TEST_F(UtestGraphInferBasePassStub, NotCallInfer_WhenNeedInferReturnFalse) { + auto builder = NoSubgraphBuilder(); + auto test_graph = builder.GetGraph(); + auto add_node = test_graph->FindNode("add1"); + EXPECT_NE(add_node, nullptr); + ChildPassBuilder pass_builder; + auto stub_base_pass = pass_builder.SetNeedInferFlag(false).Build(); + + // NeedInfer return false + EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); + EXPECT_EQ(stub_base_pass.call_infer_times, 0); + int times = -1; + EXPECT_FALSE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times)); +} + +TEST_F(UtestGraphInferBasePassStub, NotAddCurNodeRepass_CallUpdatePeerNode_WhenInferReturnSuccess) { + auto builder = NoSubgraphBuilder(); + auto test_graph = builder.GetGraph(); + auto add_node = test_graph->FindNode("add1"); + auto netoutput = test_graph->FindNode("netoutput"); + EXPECT_NE(add_node, nullptr); + EXPECT_NE(netoutput, nullptr); + ChildPassBuilder pass_builder; + auto stub_base_pass = pass_builder.Build(); + + EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); + EXPECT_EQ(stub_base_pass.call_infer_times, 1); + EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1); + std::vector> expected_updated_tensor_desc_pairs = { + {add_node->GetOpDesc()->MutableOutputDesc(0), netoutput->GetOpDesc()->MutableInputDesc(0)}}; + EXPECT_EQ(stub_base_pass.update_td_pairs, expected_updated_tensor_desc_pairs); + EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set({})); +} + +TEST_F(UtestGraphInferBasePassStub, AddCurNodeRepass_NotCallUpdatePeerNode_WhenInferReturnNeedRepass) { + auto builder = NoSubgraphBuilder(); + auto test_graph = builder.GetGraph(); + auto add_node = test_graph->FindNode("add1"); + EXPECT_NE(add_node, nullptr); + ChildPassBuilder pass_builder; + auto stub_base_pass = pass_builder.SetInferResult(GRAPH_NODE_NEED_REPASS).Build(); + + // do re_pass + EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); + EXPECT_EQ(stub_base_pass.call_infer_times, 1); + EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 0); +// EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set({add_node})); +} + +TEST_F(UtestGraphInferBasePassStub, NotAddPeerNodeRepass_AfterUpdatePeerNode_WhenUnchanged) { + auto builder = NoSubgraphBuilder(); + auto test_graph = builder.GetGraph(); + auto add_node = test_graph->FindNode("add1"); + auto netoutput = test_graph->FindNode("netoutput"); + EXPECT_NE(add_node, nullptr); + EXPECT_NE(netoutput, nullptr); + ChildPassBuilder pass_builder; + auto stub_base_pass = pass_builder.Build(); + + EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); + EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1); + EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set({})); + int times = -1; + EXPECT_TRUE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times)); + EXPECT_EQ(times, 1); + times = -1; + EXPECT_TRUE(AttrUtils::GetInt(netoutput->GetOpDesc()->GetInputDescPtr(0), kInferTimes, times)); + EXPECT_EQ(times, 1); +} + +TEST_F(UtestGraphInferBasePassStub, AddPeerNodeRepass_AfterUpdatePeerNode_WhenChanged) { + auto builder = NoSubgraphBuilder(); + auto test_graph = builder.GetGraph(); + auto add_node = test_graph->FindNode("add1"); + auto netoutput = test_graph->FindNode("netoutput"); + EXPECT_NE(add_node, nullptr); + EXPECT_NE(netoutput, nullptr); + ChildPassBuilder pass_builder; + auto stub_base_pass = pass_builder.SetTdChangedFlag(true).Build(); + + EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS); + EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1); +// EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set({netoutput})); +} + +TEST_F(UtestGraphInferBasePassStub, TestUpdateSubgraphData_WhenBeforeSubgraph) { + auto builder = RootGraphBuilder(); + auto parent_graph = builder.GetGraph(); + auto subgraphs = parent_graph->GetAllSubgraphs(); + EXPECT_EQ(subgraphs.size(), 1); + + auto case_node = parent_graph->FindNode("case1"); + auto data1 = subgraphs[0]->FindNode("data1_1"); + auto data2 = subgraphs[0]->FindNode("data2_1"); + EXPECT_NE(case_node, nullptr); + EXPECT_NE(data1, nullptr); + EXPECT_NE(data2, nullptr); + ChildPassBuilder pass_builder; + auto stub_base_pass = pass_builder.SetInferResult(GRAPH_NODE_NEED_REPASS).Build(); + + EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS); + // when GRAPH_NODE_NEED_REPASS, not update peer node, only update two data, update input and output, 2*2 + EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 4); + std::vector> expected_updated_tensor_desc_pairs = { + {case_node->GetOpDesc()->MutableInputDesc(0), data1->GetOpDesc()->MutableInputDesc(0)}, + {case_node->GetOpDesc()->MutableInputDesc(0), data1->GetOpDesc()->MutableOutputDesc(0)}, + {case_node->GetOpDesc()->MutableInputDesc(1), data2->GetOpDesc()->MutableInputDesc(0)}, + {case_node->GetOpDesc()->MutableInputDesc(1), data2->GetOpDesc()->MutableOutputDesc(0)}, + }; + EXPECT_EQ(stub_base_pass.update_td_pairs, expected_updated_tensor_desc_pairs); +} + +TEST_F(UtestGraphInferBasePassStub, TestUpdateParentNodeOutput_WhenAfterSubgraph) { + auto builder = RootGraphBuilder(); + auto parent_graph = builder.GetGraph(); + auto subgraphs = parent_graph->GetAllSubgraphs(); + EXPECT_EQ(subgraphs.size(), 1); + + auto case_node = parent_graph->FindNode("case1"); + EXPECT_NE(case_node, nullptr); + ChildPassBuilder pass_builder; + auto stub_base_pass = pass_builder.Build(); + stub_base_pass.SetOption(kOptimizeAfterSubGraph, ""); + + EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS); + EXPECT_EQ(stub_base_pass.call_update_from_subgraph_times, 1); + EXPECT_EQ(stub_base_pass.call_update_from_subgraph_multi_dims_times, 0); +} + +TEST_F(UtestGraphInferBasePassStub, TestUpdateParentNodeOutputForMultiDims_WhenAfterSubgraph) { + auto builder = RootGraphBuilder(); + auto parent_graph = builder.GetGraph(); + auto subgraphs = parent_graph->GetAllSubgraphs(); + EXPECT_EQ(subgraphs.size(), 1); + + auto case_node = parent_graph->FindNode("case1"); + auto set_ret = AttrUtils::SetInt(case_node->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); + EXPECT_EQ(set_ret, true); + EXPECT_NE(case_node, nullptr); + ChildPassBuilder pass_builder; + auto stub_base_pass = pass_builder.Build(); + stub_base_pass.SetOption(kOptimizeAfterSubGraph, ""); + + EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS); + EXPECT_EQ(stub_base_pass.call_update_from_subgraph_times, 0); + EXPECT_EQ(stub_base_pass.call_update_from_subgraph_multi_dims_times, 1); +} +} // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc new file mode 100644 index 00000000..014d87dc --- /dev/null +++ b/tests/ut/ge/graph/passes/infer_value_range_pass_unittest.cc @@ -0,0 +1,705 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define protected public +#define private public +#include "graph/passes/infer_value_range_pass.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph_builder_utils.h" + +#include "inc/external/graph/operator_reg.h" +#include "inc/external/graph/operator.h" +#include "inc/external/graph/operator_factory.h" +#include "inc/graph/operator_factory_impl.h" +#include "inc/kernel.h" +#include "inc/kernel_factory.h" + +using namespace std; +using namespace testing; +namespace ge { +class UtestGraphInferValueRangePass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +/* + * data1 const1 + * \ / + * case1 + * | + * relu10 + * | + * netoutput + */ +ut::GraphBuilder ParentGraphBuilder() { + ut::GraphBuilder builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", "Data", 0, 1); + std::vector const_shape = {1}; + auto const1 = builder.AddNode("const1", "Const", 0, 1, FORMAT_NCHW, DT_INT32, const_shape); + auto case1 = builder.AddNode("case1", CASE, 2, 1); + auto relu1 = builder.AddNode("relu10", "Relu", 1, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + int32_t weight[1] = {1}; + GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); + GeTensorPtr tensor = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); + OpDescUtils::SetWeights(const1, {tensor}); + auto case_in0_shape = GeShape({1, 1,-1, 224}); + auto case_in1_shape = GeShape({1,1}); + std::vector> in0_range = {make_pair(1, 1), make_pair(1, 1), + make_pair(1, -1), make_pair(1, 224)}; + std::vector> in1_range = {make_pair(1, 100), make_pair(1, 10)}; + case1->GetOpDesc()->MutableInputDesc(0)->SetShape(case_in0_shape); + case1->GetOpDesc()->MutableInputDesc(0)->SetValueRange(in0_range); + case1->GetOpDesc()->MutableInputDesc(1)->SetShape(case_in1_shape); + case1->GetOpDesc()->MutableInputDesc(1)->SetValueRange(in1_range); + + builder.AddDataEdge(data1, 0, case1, 0); + builder.AddDataEdge(const1, 0, case1, 1); + builder.AddDataEdge(case1, 0, relu1, 0); + builder.AddDataEdge(relu1, 0, netoutput, 0); + return builder; +} + +/* + * data1 data2 + * \ / + * switch + * / \ + * relu1 relu2 + * \ / + * merge + * | + * netoutput + */ +ut::GraphBuilder SwitchSubgraphBuilder(string graph_name, uint32_t num) { + ut::GraphBuilder builder = ut::GraphBuilder(graph_name); + + std::vector shape1 = {2,2}; + string data1_name = "data1_" + std::to_string(num); + auto data1 = builder.AddNode(data1_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1); + auto data1_desc = data1->GetOpDesc(); + EXPECT_NE(data1_desc, nullptr); + AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); + + std::vector shape2 = {3,3}; + string data2_name = "data2_" + std::to_string(num); + auto data2 = builder.AddNode(data2_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2); + auto data2_desc = data2->GetOpDesc(); + EXPECT_NE(data2_desc, nullptr); + AttrUtils::SetInt(data2_desc, "_parent_node_index", 1); + + string switch_name = "switch_" + std::to_string(num); + auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2); + + string relu1_name = "relu1_" + std::to_string(num); + auto relu1 = builder.AddNode(relu1_name, "Relu", 1, 1); + + string relu2_name = "relu2_" + std::to_string(num); + auto relu2 = builder.AddNode(relu2_name, "Relu", 1, 1); + + string merge_name = "merge_" + std::to_string(num); + auto merge = builder.AddNode(merge_name, "Merge", 2, 1); + + std::vector shape7 = {8,8}; + string output_name = "output_" + std::to_string(num); + auto netoutput = builder.AddNode(output_name, NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7); + auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0); + EXPECT_NE(input0_desc, nullptr); + AttrUtils::SetInt(input0_desc, "_parent_node_index", 0); + std::vector> range = {make_pair(1, -1), make_pair(1, -1)}; + input0_desc->SetValueRange(range); + + builder.AddDataEdge(data1, 0, switch1, 0); + builder.AddDataEdge(data2, 0, switch1, 1); + builder.AddDataEdge(switch1, 0, relu1, 0); + builder.AddDataEdge(switch1, 1, relu2, 0); + builder.AddDataEdge(relu1, 0, merge, 0); + builder.AddDataEdge(relu2, 0, merge, 1); + builder.AddDataEdge(merge, 0, netoutput, 0); + + return builder; +} + +void AddCaseSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) { + auto case_node = parent_graph->FindNode("case1"); + EXPECT_NE(case_node, nullptr); + + for (uint32_t i = 0; i < branch_num; ++i) { + string name = "Branch_Graph_" + std::to_string(i); + + auto builder_subgraph = SwitchSubgraphBuilder(name, i); + auto switch_subgraph = builder_subgraph.GetGraph(); + + case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName()); + case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName()); + + switch_subgraph->SetParentNode(case_node); + switch_subgraph->SetParentGraph(parent_graph); + EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS); + } +} + +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UnregisteredNodeType) { + auto graph = std::make_shared("test_graph"); + GeTensorDesc ge_tensor_desc(GeShape({1, 1, 4, 192}), ge::FORMAT_NCHW, DT_FLOAT16); + auto addn_op_desc = std::make_shared("AddN", "AddN"); + addn_op_desc->AddInputDesc(ge_tensor_desc); + addn_op_desc->AddOutputDesc(ge_tensor_desc); + auto addn_op_node = graph->AddNode(addn_op_desc); + + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(addn_op_node), SUCCESS); +} + +auto ShapeValueInfer = [&](Operator &op) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto output_tensor_desc = op_desc->MutableOutputDesc(0); + std::vector> in_shape_range; + op_desc->MutableInputDesc(0)->GetShapeRange(in_shape_range); + if (!in_shape_range.empty()) { + output_tensor_desc->SetValueRange(in_shape_range); + } + return SUCCESS; +}; +REG_OP(Shape) + .OP_END_FACTORY_REG(Shape) +IMPL_INFER_VALUE_RANGE_FUNC(Shape, ShapeValueRangeFunc){ + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto output_tensor_desc = op_desc->MutableOutputDesc(0); + std::vector> in_shape_range; + op_desc->MutableInputDesc(0)->GetShapeRange(in_shape_range); + if (!in_shape_range.empty()) { + output_tensor_desc->SetValueRange(in_shape_range); + } + return GRAPH_SUCCESS; +} + +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseRegistedFunc_NotInfer) { + INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueRangeFunc); + auto graph = std::make_shared("test_graph"); + GeTensorDesc ge_tensor_desc(GeShape({1, 1, 4, 192}), ge::FORMAT_NCHW, DT_INT32); + std::vector> shape_range = {make_pair(1, 1), make_pair(1, 1), + make_pair(4, 4), make_pair(192, 192)}; + ge_tensor_desc.SetShapeRange(shape_range); + GeTensorDesc output_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32); + auto op_desc = std::make_shared("Shape", "Shape"); + op_desc->AddInputDesc(ge_tensor_desc); + op_desc->AddOutputDesc(output_tensor_desc); + auto op_node = graph->AddNode(op_desc); + + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(op_node), SUCCESS); + + auto output_0_desc = op_node->GetOpDesc()->GetOutputDesc(0); + std::vector> value_range; + output_0_desc.GetValueRange(value_range); + EXPECT_EQ(value_range.empty(), true); +} + +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseRegistedFunc_DoInfer) { + // sqrt -> shape -> Output + INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueRangeFunc); + auto graph = std::make_shared("test_graph"); + GeTensorDesc sqrt_tensor_desc(GeShape({-1, -1, 4, 192}), ge::FORMAT_NCHW, DT_INT32); + std::vector> shape_range = {make_pair(1, 100), make_pair(1, 240), + make_pair(4, 4), make_pair(192, 192)}; + sqrt_tensor_desc.SetShapeRange(shape_range); + auto sqrt_op_desc = std::make_shared("Sqrt", "Sqrt"); + sqrt_op_desc->AddInputDesc(sqrt_tensor_desc); + sqrt_op_desc->AddOutputDesc(sqrt_tensor_desc); + auto sqrt_node = graph->AddNode(sqrt_op_desc); + + GeTensorDesc shape_output_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddInputDesc(sqrt_tensor_desc); + shape_op_desc->AddOutputDesc(shape_output_desc); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc Output_in_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT32); + auto Output_op_desc = std::make_shared("Output", "Output"); + Output_op_desc->AddInputDesc(Output_in_tensor_desc); + auto Output_node = graph->AddNode(Output_op_desc); + + ge::GraphUtils::AddEdge(sqrt_node->GetOutDataAnchor(0), shape_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), Output_node->GetInDataAnchor(0)); + EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); + + + InferValueRangePass infer_pass; + auto ret = infer_pass.Run(shape_node); + EXPECT_EQ(ret, SUCCESS); + + auto output_0_desc = shape_node->GetOpDesc()->GetOutputDesc(0); + std::vector> value_range; + output_0_desc.GetValueRange(value_range); + EXPECT_EQ(value_range.size(), 4); + std::vector target_value_range = {1, 100, 1, 240, 4, 4, 192, 192}; + std::vector output_value_range; + for (auto pair : value_range) { + output_value_range.push_back(pair.first); + output_value_range.push_back(pair.second); + } + EXPECT_EQ(target_value_range, output_value_range); + + auto in_0_desc = Output_node->GetOpDesc()->GetInputDesc(0); + value_range.clear(); + in_0_desc.GetValueRange(value_range); + EXPECT_EQ(value_range.size(), 4); + output_value_range.clear(); + for (auto pair : value_range) { + output_value_range.push_back(pair.first); + output_value_range.push_back(pair.second); + } + EXPECT_EQ(target_value_range, output_value_range); + +} + +class AddKernel : public Kernel { + public: + Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector &input, + std::vector &v_output) override { + if (input[0]->GetTensorDesc().GetDataType() == DT_INT64 || input[0]->GetTensorDesc().GetDataType() == DT_UINT64) { + vector data_vec; + auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); + auto x1_data = reinterpret_cast(input[0]->GetData().data()); + auto x2_data = reinterpret_cast(input[1]->GetData().data()); + for (size_t i = 0; i < data_num; i++) { + auto x_index = *(x1_data + i); + auto y_index = *(x2_data + i); + data_vec.push_back(x_index + y_index); + } + GeTensorPtr const_tensor = std::make_shared(input[0]->GetTensorDesc(), (uint8_t *)data_vec.data(), + data_num * sizeof(int64_t)); + v_output.emplace_back(const_tensor); + return SUCCESS; + } else if (input[0]->GetTensorDesc().GetDataType() == DT_INT32 || input[0]->GetTensorDesc().GetDataType() == DT_UINT32) { + vector data_vec; + auto data_num = input[0]->GetTensorDesc().GetShape().GetShapeSize(); + if (input[0]->GetTensorDesc().GetShape().IsScalar()) { + data_num = 1; + } + auto x1_data = reinterpret_cast(input[0]->GetData().data()); + auto x2_data = reinterpret_cast(input[1]->GetData().data()); + for (size_t i = 0; i < data_num; i++) { + auto x_index = *(x1_data + i); + auto y_index = *(x2_data + i); + data_vec.push_back(x_index + y_index); + } + GeTensorPtr const_tensor = std::make_shared(input[0]->GetTensorDesc(), (uint8_t *)data_vec.data(), + data_num * sizeof(int32_t)); + v_output.emplace_back(const_tensor); + return SUCCESS; + } + } +}; +REGISTER_KERNEL(ADD, AddKernel); +INFER_VALUE_RANGE_DEFAULT_REG(Add); +INFER_VALUE_RANGE_DEFAULT_REG(Sqrt); + +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveUnKnownValueRange) { + // shape --- add --- sqrt + // constant / + auto graph = std::make_shared("test_graph"); + + vector dims_vec = {4}; + vector data_vec = {1, 1, 1, 1}; + GeTensorDesc const_tensor_desc(ge::GeShape(dims_vec), ge::FORMAT_NCHW, ge::DT_INT64); + GeTensorPtr const_tensor = + std::make_shared(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int64_t)); + + auto const_op_desc = std::make_shared("Constant", "Constant"); + const_op_desc->AddOutputDesc(const_tensor_desc); + EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); + auto const_node = graph->AddNode(const_op_desc); + + GeTensorDesc shape_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); + std::vector> unknown_value_range = {make_pair(1, -1), make_pair(1, 240), + make_pair(4, 4), make_pair(192, 192)}; + shape_tensor_desc.SetValueRange(unknown_value_range); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddOutputDesc(shape_tensor_desc); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc add_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); + auto add_op_desc = std::make_shared("Add", "Add"); + add_op_desc->AddInputDesc(shape_tensor_desc); + add_op_desc->AddInputDesc(const_tensor_desc); + add_op_desc->AddOutputDesc(add_tensor_desc); + auto add_node = graph->AddNode(add_op_desc); + + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); + + // test unknown value range + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); + auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); + std::vector> out_value_range; + output_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 4); + + std::vector unknown_target_value_range = {1, -1, 1, -1, 1, -1, 1, -1}; + std::vector output_value_range; + for (auto pair : out_value_range) { + output_value_range.push_back(pair.first); + output_value_range.push_back(pair.second); + } + EXPECT_EQ(unknown_target_value_range, output_value_range); +} + +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveZeroInValueRange) { + // shape --- add --- sqrt + auto graph = std::make_shared("test_graph"); + GeTensorDesc shape_tensor_desc(GeShape({2}), ge::FORMAT_NCHW, ge::DT_INT64); + std::vector> unknown_value_range = {make_pair(1, -1), make_pair(0, 240)}; + shape_tensor_desc.SetValueRange(unknown_value_range); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddOutputDesc(shape_tensor_desc); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc add_tensor_desc(GeShape({2}), ge::FORMAT_NCHW, ge::DT_INT64); + auto add_op_desc = std::make_shared("Add", "Add"); + add_op_desc->AddInputDesc(shape_tensor_desc); + add_op_desc->AddInputDesc(shape_tensor_desc); + add_op_desc->AddOutputDesc(add_tensor_desc); + auto add_node = graph->AddNode(add_op_desc); + + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); + + // test unknown value range + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); + auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); + std::vector> out_value_range; + output_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 0); +} + +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsHaveUnKnownValueRange_ScalarOutput) { + // shape --- add --- sqrt + // constant / + auto graph = std::make_shared("test_graph"); + vector data_vec = {1}; + GeTensorDesc const_tensor_desc(ge::GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); + GeTensorPtr const_tensor = + std::make_shared(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int64_t)); + + auto const_op_desc = std::make_shared("Constant", "Constant"); + const_op_desc->AddOutputDesc(const_tensor_desc); + EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); + auto const_node = graph->AddNode(const_op_desc); + + GeTensorDesc shape_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); + std::vector> unknown_value_range = {make_pair(1, -1)}; + shape_tensor_desc.SetValueRange(unknown_value_range); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddOutputDesc(shape_tensor_desc); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc add_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); + auto add_op_desc = std::make_shared("Add", "Add"); + add_op_desc->AddInputDesc(shape_tensor_desc); + add_op_desc->AddInputDesc(const_tensor_desc); + add_op_desc->AddOutputDesc(add_tensor_desc); + auto add_node = graph->AddNode(add_op_desc); + + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); + + // test unknown value range + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); + auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); + std::vector> out_value_range; + output_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 1); + + std::vector unknown_target_value_range = {1, -1}; + std::vector output_value_range; + for (auto pair : out_value_range) { + output_value_range.push_back(pair.first); + output_value_range.push_back(pair.second); + } + EXPECT_EQ(unknown_target_value_range, output_value_range); +} + +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_ScalarOutput) { + // shape --- add --- sqrt + // constant / + auto graph = std::make_shared("test_graph"); + vector data_vec = {2}; + GeTensorDesc const_td(ge::GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); + GeTensorPtr const_tensor = std::make_shared(const_td, (uint8_t *)data_vec.data(), sizeof(int32_t)); + auto const_op_desc = std::make_shared("Constant", "Constant"); + const_op_desc->AddOutputDesc(const_td); + EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); + auto const_node = graph->AddNode(const_op_desc); + + GeTensorDesc shape_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); + std::vector> known_value_range = {make_pair(1, 100)}; + shape_td.SetValueRange(known_value_range); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddOutputDesc(shape_td); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc add_td(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); + auto add_op_desc = std::make_shared("Add", "Add"); + add_op_desc->AddInputDesc(shape_td); + add_op_desc->AddInputDesc(const_td); + add_op_desc->AddOutputDesc(add_td); + auto add_node = graph->AddNode(add_op_desc); + + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); + + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); + + auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); + std::vector> out_value_range; + output_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 1); + + std::vector target_value_range = {3, 102}; + std::vector output_value_range = {out_value_range[0].first, out_value_range[0].second}; + EXPECT_EQ(output_value_range, target_value_range); +} + +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int64) { + // shape --- add --- sqrt + // constant / + auto graph = std::make_shared("test_graph"); + + vector dims_vec = {4}; + vector data_vec = {1, 1, 1, 1}; + GeTensorDesc const_tensor_desc(ge::GeShape(dims_vec), ge::FORMAT_NCHW, ge::DT_INT64); + GeTensorPtr const_tensor = + std::make_shared(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int64_t)); + + auto const_op_desc = std::make_shared("Constant", "Constant"); + const_op_desc->AddOutputDesc(const_tensor_desc); + EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); + auto const_node = graph->AddNode(const_op_desc); + + GeTensorDesc shape_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); + std::vector> unknown_value_range = {make_pair(1, 100), make_pair(1, 240), + make_pair(4, 4), make_pair(192, 192)}; + shape_tensor_desc.SetValueRange(unknown_value_range); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddOutputDesc(shape_tensor_desc); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc add_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT64); + auto add_op_desc = std::make_shared("Add", "Add"); + add_op_desc->AddInputDesc(shape_tensor_desc); + add_op_desc->AddInputDesc(const_tensor_desc); + add_op_desc->AddOutputDesc(add_tensor_desc); + auto add_node = graph->AddNode(add_op_desc); + + auto sqrt_op_desc = std::make_shared("Sqrt", "Sqrt"); + sqrt_op_desc->AddInputDesc(GeTensorDesc()); + auto sqrt_node = graph->AddNode(sqrt_op_desc); + + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); + ge::GraphUtils::AddEdge(add_node->GetOutDataAnchor(0), sqrt_node->GetInDataAnchor(1)); + + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(sqrt_node), SUCCESS); + + // test known value range + EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); + auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); + std::vector> out_value_range; + output_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 4); + + std::vector target_value_range = {2, 101, 2, 241, 5, 5, 193, 193}; + std::vector output_value_range; + for (auto pair : out_value_range) { + output_value_range.push_back(pair.first); + output_value_range.push_back(pair.second); + } + EXPECT_EQ(target_value_range, output_value_range); +} + +TEST_F(UtestGraphInferValueRangePass, CallRun_NoSubgraph_UseCpuKernel_InputsAreKnownValueRange_Int32) { + // shape --- add --- sqrt + // constant / + auto graph = std::make_shared("test_graph"); + vector data_vec = {1, 100, 2, 200}; + GeTensorDesc const_tensor_desc(ge::GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT32); + GeTensorPtr const_tensor = + std::make_shared(const_tensor_desc, (uint8_t *)data_vec.data(), data_vec.size() * sizeof(int32_t)); + auto const_op_desc = std::make_shared("Constant", "Constant"); + const_op_desc->AddOutputDesc(const_tensor_desc); + EXPECT_EQ(OpDescUtils::SetWeights(const_op_desc, const_tensor), GRAPH_SUCCESS); + auto const_node = graph->AddNode(const_op_desc); + + GeTensorDesc shape_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT32); + std::vector> known_value_range = {make_pair(1, 100), make_pair(1, 240), + make_pair(4, 4), make_pair(192, 192)}; + shape_tensor_desc.SetValueRange(known_value_range); + auto shape_op_desc = std::make_shared("Shape", "Shape"); + shape_op_desc->AddOutputDesc(shape_tensor_desc); + auto shape_node = graph->AddNode(shape_op_desc); + + GeTensorDesc add_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, ge::DT_INT32); + auto add_op_desc = std::make_shared("Add", "Add"); + add_op_desc->AddInputDesc(shape_tensor_desc); + add_op_desc->AddInputDesc(const_tensor_desc); + add_op_desc->AddOutputDesc(add_tensor_desc); + auto add_node = graph->AddNode(add_op_desc); + + ge::GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), add_node->GetInDataAnchor(1)); + + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(add_node), SUCCESS); + auto output_0_desc = add_node->GetOpDesc()->GetOutputDesc(0); + std::vector> out_value_range; + output_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 4); + + std::vector target_value_range = {2, 101, 101, 340, 6, 6, 392, 392}; + std::vector output_value_range; + for (auto pair : out_value_range) { + output_value_range.push_back(pair.first); + output_value_range.push_back(pair.second); + } + EXPECT_EQ(target_value_range, output_value_range); +} + +REG_OP(Case) + .OP_END_FACTORY_REG(Case) +IMPL_INFER_VALUE_RANGE_FUNC(Case, ValueRangeFunc){ + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto output_tensor_desc = op_desc->MutableOutputDesc(0); + std::vector> in_value_range; + output_tensor_desc->GetValueRange(in_value_range); + if (in_value_range.empty()) { + std::vector> out_value_range = {make_pair(1, 2), make_pair(1, 3), + make_pair(1, 4), make_pair(1, 5)};; + output_tensor_desc->SetValueRange(out_value_range); + } + return GRAPH_SUCCESS; +} +INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Case, INPUT_HAS_VALUE_RANGE, ValueRangeFunc); + +TEST_F(UtestGraphInferValueRangePass, CallRun_HasCaeSubgraph_WhenBeforeSubgraph) { + auto builder = ParentGraphBuilder(); + auto parent_graph = builder.GetGraph(); + AddCaseSubgraph(parent_graph, 2); + auto subgraphs = parent_graph->GetAllSubgraphs(); + EXPECT_EQ(subgraphs.size(), 2); + + // check before subgraph + auto case_node = parent_graph->FindNode("case1"); + EXPECT_NE(case_node, nullptr); + InferValueRangePass infer_pass; + EXPECT_EQ(infer_pass.Run(case_node), SUCCESS); + + auto case_out_0_desc = case_node->GetOpDesc()->MutableOutputDesc(0); + std::vector> out_value_range; + case_out_0_desc->GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 4); + std::vector target_value_range = {1,2,1,3,1,4,1,5}; + std::vector output_value_range_list; + for (auto pair : out_value_range) { + output_value_range_list.push_back(pair.first); + output_value_range_list.push_back(pair.second); + } + EXPECT_EQ(target_value_range, output_value_range_list); + + auto data_node = subgraphs[0]->FindNode("data1_0"); + auto data_output_0_desc = data_node->GetOpDesc()->GetOutputDesc(0); + std::vector target_value_range_list = {1, 1, 1, 1, 1, -1, 1, 224}; + std::vector> output_value_range; + data_output_0_desc.GetValueRange(output_value_range); + EXPECT_EQ(output_value_range.size(), 4); + std::vector data_value_range_list; + for (auto pair : output_value_range) { + data_value_range_list.push_back(pair.first); + data_value_range_list.push_back(pair.second); + } + EXPECT_EQ(data_value_range_list, target_value_range_list); + + data_node = subgraphs[0]->FindNode("data2_0"); + auto data2_input_0_desc = data_node->GetOpDesc()->GetInputDesc(0); + std::vector target_value_range_list2 = {1, 100, 1, 10}; + out_value_range.clear(); + data2_input_0_desc.GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 2); + data_value_range_list.clear(); + for (auto pair : out_value_range) { + data_value_range_list.push_back(pair.first); + data_value_range_list.push_back(pair.second); + } + EXPECT_EQ(data_value_range_list, target_value_range_list2); +} + +TEST_F(UtestGraphInferValueRangePass, CallRun_HasCaeSubgraph_WhenAfterSubgraph) { + auto builder = ParentGraphBuilder(); + auto parent_graph = builder.GetGraph(); + AddCaseSubgraph(parent_graph, 2); + auto subgraphs = parent_graph->GetAllSubgraphs(); + EXPECT_EQ(subgraphs.size(), 2); + + auto case_node = parent_graph->FindNode("case1"); + EXPECT_NE(case_node, nullptr); + InferValueRangePass infer_pass; + // check after subgraph + infer_pass.options_[kOptimizeAfterSubGraph] = "yes"; + EXPECT_EQ(infer_pass.Run(case_node), SUCCESS); + + std::vector out_target_dims = {1, -1, 1, -1}; + auto case_out = case_node->GetOpDesc()->GetOutputDescPtr(0); + std::vector> out_value_range; + case_out->GetValueRange(out_value_range); + EXPECT_EQ(out_value_range.size(), 2); + + std::vector output_value_range_list; + for (auto pair : out_value_range) { + output_value_range_list.push_back(pair.first); + output_value_range_list.push_back(pair.second); + } + EXPECT_EQ(out_target_dims, output_value_range_list); +} + +TEST_F(UtestGraphInferValueRangePass, CallRun_HasSubgraph_WhenAfterSubgraph_ForMultiDims) { + auto builder = ParentGraphBuilder(); + auto parent_graph = builder.GetGraph(); + AddCaseSubgraph(parent_graph, 2); + auto subgraphs = parent_graph->GetAllSubgraphs(); + EXPECT_EQ(subgraphs.size(), 2); + + auto case_node = parent_graph->FindNode("case1"); + EXPECT_NE(case_node, nullptr); + InferValueRangePass infer_pass; + infer_pass.options_[kOptimizeAfterSubGraph] = "yes"; + + // check after subgraph for multi-batch + auto set_ret = AttrUtils::SetInt(case_node->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); + EXPECT_EQ(set_ret, true); + EXPECT_EQ(infer_pass.Run(case_node), GRAPH_FAILED); +} +} // namespace ge diff --git a/tests/ut/ge/graph/passes/infershape_pass_unittest.cc b/tests/ut/ge/graph/passes/infershape_pass_unittest.cc index 13e66c50..d84aff50 100644 --- a/tests/ut/ge/graph/passes/infershape_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/infershape_pass_unittest.cc @@ -1,161 +1,262 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define protected public -#define private public -#include "graph/passes/infershape_pass.h" - -#include "graph/utils/tensor_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/operator_factory.h" -#include "graph/operator_reg.h" -#include "graph_builder_utils.h" - -using namespace std; -using namespace testing; -namespace ge { -class UtestGraphInfershapePass : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { - OpDescPtr op_desc = std::make_shared(name, type); - op_desc->SetStreamId(0); - static int32_t index = 0; - op_desc->SetId(index++); - - GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); - TensorUtils::SetSize(tensor, 512); - vector input_offset; - for (int i = 0; i < in_num; i++) { - op_desc->AddInputDesc(tensor); - input_offset.emplace_back(1024); - } - op_desc->SetInputOffset(input_offset); - - vector output_offset; - for (int i = 0; i < out_num; i++) { - op_desc->AddOutputDesc(tensor); - output_offset.emplace_back(1024); - } - op_desc->SetOutputOffset(output_offset); - - op_desc->SetWorkspace({}); - op_desc->SetWorkspaceBytes({}); - op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); - - const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; - op_desc->AddInferFunc(stub_func); - op_desc->AddInferFormatFunc(stub_func); - op_desc->AddVerifierFunc(stub_func); - - return graph.AddNode(op_desc); -} - -TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { - GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); - string type = "AddN"; - auto addn_op_desc = std::make_shared("AddN", type); - addn_op_desc->AddInputDesc(ge_tensor_desc); - addn_op_desc->AddOutputDesc(ge_tensor_desc); - auto graph = std::make_shared("test"); - auto addn_node = std::make_shared(addn_op_desc, graph); - addn_node->Init(); - - InferShapePass infershape_pass; - EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); -} - -TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { - auto graph = std::make_shared("test"); - - auto no_op_desc = std::make_shared("No", "NoOp"); - auto no_op_node = graph->AddNode(no_op_desc); - AttrUtils::SetBool(no_op_desc, "_need_infer_again", false); - - InferShapePass infershape_pass; - infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; - EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); -} - -TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { -/******************************************************************************* - * Exit Identify - * \ / \. - * \ / \. - * Switch Add - * / | | - * / | | - * / | | - * LoopCond | | - * \ | | - * \ | | - * \ | | - * Less | | - * \ | NextIteration - * \ | | - * \ | | - * Merge <---------| - * | - * | - * Enter - ******************************************************************************/ - auto graph = std::make_shared("test_infer_shape"); - auto data1 = CreateNode(*graph, "data", DATA, 1, 1); - auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); - auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); - auto less1 = CreateNode(*graph, "less", LESS, 2, 1); - auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); - auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); - auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); - auto add1 = CreateNode(*graph, "add", ADD, 2, 1); - auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); - auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); - auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); - auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); - auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); - - GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); - GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); - GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); - GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); - - GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); - - GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); - GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); - - GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); - GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); - GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); - - GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); - GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); - - GEPass ge_passes(graph); - NamesToPass names_to_passes; - InferShapePass infer_shape_pass; - names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); - - EXPECT_EQ(ge_passes.Run(names_to_passes), SUCCESS); -} -} // namespace ge +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define protected public +#define private public +#include "graph/passes/infershape_pass.h" + +#include "graph/utils/tensor_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/operator_factory.h" +#include "graph/operator_reg.h" +#include "graph_builder_utils.h" + +using namespace std; +using namespace testing; +namespace ge { +class UtestGraphInfershapePass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { + OpDescPtr op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + TensorUtils::SetSize(tensor, 512); + vector input_offset; + for (int i = 0; i < in_num; i++) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(1024); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; i++) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(1024); + } + op_desc->SetOutputOffset(output_offset); + + op_desc->SetWorkspace({}); + op_desc->SetWorkspaceBytes({}); + op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); + + const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; + op_desc->AddInferFunc(stub_func); + op_desc->AddInferFormatFunc(stub_func); + op_desc->AddVerifierFunc(stub_func); + + return graph.AddNode(op_desc); +} + +TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { + GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + string type = "AddN"; + auto addn_op_desc = std::make_shared("AddN", type); + addn_op_desc->AddInputDesc(ge_tensor_desc); + addn_op_desc->AddOutputDesc(ge_tensor_desc); + auto graph = std::make_shared("test"); + auto addn_node = std::make_shared(addn_op_desc, graph); + addn_node->Init(); + + InferShapePass infershape_pass; + EXPECT_EQ(infershape_pass.Run(addn_node), GRAPH_FAILED); +} + +TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { +/******************************************************************************* + * Exit Identify + * \ / \. + * \ / \. + * Switch Add + * / | | + * / | | + * / | | + * LoopCond | | + * \ | | + * \ | | + * \ | | + * Less | | + * \ | NextIteration + * \ | | + * \ | | + * Merge <---------| + * | + * | + * Enter + ******************************************************************************/ + auto graph = std::make_shared("test_infer_shape"); + auto data1 = CreateNode(*graph, "data", DATA, 1, 1); + auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); + auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); + auto less1 = CreateNode(*graph, "less", LESS, 2, 1); + auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); + auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); + auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); + auto add1 = CreateNode(*graph, "add", ADD, 2, 1); + auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); + auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); + auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); + auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); + auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); + GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); + GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); + + GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); + + GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); + + GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); + GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); + GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); + + GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); + GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); + + GEPass ge_passes(graph); + NamesToPass names_to_passes; + InferShapePass infer_shape_pass; + names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); + + EXPECT_EQ(infer_shape_pass.Run(switch1), SUCCESS); + auto suspend_nodes = infer_shape_pass.GetNodesSuspend(); + auto exit_node = graph->FindNode("exit"); + EXPECT_EQ(suspend_nodes.count(exit_node), 1); + infer_shape_pass.OnSuspendNodesLeaked(); + auto resume_nodes = infer_shape_pass.GetNodesResume(); + EXPECT_EQ(resume_nodes.count(exit_node), 1); +} +TEST_F(UtestGraphInfershapePass, update_tensordesc_when_changed) { + GeTensorDesc src_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDesc dst_ge_tensor_desc(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDescPtr src_tensor_desc_ptr = std::make_shared(src_ge_tensor_desc); + GeTensorDescPtr dst_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + bool changed = false; + infershape_pass.UpdateTensorDesc(src_tensor_desc_ptr, dst_tensor_desc_ptr, changed); + EXPECT_EQ(changed, true); + EXPECT_EQ(dst_tensor_desc_ptr->GetShape().GetDims(), std::vector({1, 2, 3, 4})); +} + +TEST_F(UtestGraphInfershapePass, update_tensordesc_when_not_changed) { + GeTensorDesc src_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDescPtr src_tensor_desc_ptr = std::make_shared(src_ge_tensor_desc); + GeTensorDescPtr dst_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + bool changed = false; + infershape_pass.UpdateTensorDesc(src_tensor_desc_ptr, dst_tensor_desc_ptr, changed); + EXPECT_EQ(changed, false); +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_failed) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, GRAPH_FAILED); +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_get_unknown_rank) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), UNKNOWN_RANK); +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_get_unknown_shape) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc ge_tensor_desc2(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), std::vector({-1,2,3,4})); + // todo shape range? +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_failed) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, + dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, GRAPH_FAILED); +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_failed_shape_size_overflow) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc ge_tensor_desc2(GeShape({INT64_MAX, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, + dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, PARAM_INVALID); +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_success) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc ge_tensor_desc2(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, + dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), std::vector({2,2,3,4})); +} +} // namespace ge diff --git a/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc b/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc index b416d958..557359b7 100644 --- a/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc @@ -69,62 +69,100 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string return graph.AddNode(op_desc); } -static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge) { +static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge, vector &loop, vector &cond) { /******************************************************************************* - * Exit Identify - * \ / \. - * \ / \. - * Switch Add - * / | | - * / | | - * / | | - * LoopCond | | - * \ | | - * \ | | - * \ | | - * Less | | - * \ | NextIteration - * \ | | - * \ | | - * Merge <---------| - * | - * | - * Enter + * | + * +--------------------- Merge ----------------------+ + * / | + * / | + * / | + * / | + * Exit Identify | + * \ / \. | + * \ / \. | + * Switch Add Add + * / | | | + * / | | | + * / | | | + * LoopCond | | | + * \ | | | + * \ | | | + * \ | | | + * Less | | | + * \ | NextIteration | + * \ | | | + * \ | | | + * Merge <---------| | + * | | + * | | + * Enter | + * \ | + * \ | + * Switch Switch + * | | + * +-----------------Equal----------------------+ + * | ******************************************************************************/ - auto data1 = CreateNode(*graph, "data", DATA, 1, 1); + auto data1 = CreateNode(*graph, "data1", DATA, 1, 1); + auto data2 = CreateNode(*graph, "data2", DATA, 1, 1); + + auto equal1 = CreateNode(*graph, "equal1", EQUAL, 2, 1); + auto switch1 = CreateNode(*graph, "switch1", SWITCH, 2, 2); + auto switch2 = CreateNode(*graph, "switch2", SWITCH, 2, 2); + auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); - auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); - auto less1 = CreateNode(*graph, "less", LESS, 2, 1); + auto merge1 = CreateNode(*graph, "merge1", MERGE, 2, 2); + auto less1 = CreateNode(*graph, "less1", LESS, 2, 1); auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); - auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); + auto switch3 = CreateNode(*graph, "switch3", SWITCH, 2, 2); auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); - auto add1 = CreateNode(*graph, "add", ADD, 2, 1); + auto add1 = CreateNode(*graph, "add1", ADD, 2, 1); auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); - auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); - auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); + auto value1 = CreateNode(*graph, "const1", CONSTANT, 0, 1); + + auto value2 = CreateNode(*graph, "const2", CONSTANT, 0, 1); + auto add2 = CreateNode(*graph, "add2", ADD, 2, 1); + auto merge2 = CreateNode(*graph, "merge2", MERGE, 2, 2); auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); - GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), equal1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data2->GetOutDataAnchor(0), equal1->GetInDataAnchor(1)); + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0)); + GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); + GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1)); + cond.emplace_back(switch1); + cond.emplace_back(switch2); + + GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); // false GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); - GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); + GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch3->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1)); + loop.emplace_back(merge1); - GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); - GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); // false + GraphUtils::AddEdge(switch3->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); // true + loop.emplace_back(switch3); GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); - GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); - GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); - merge = merge1; + GraphUtils::AddEdge(switch2->GetOutDataAnchor(1), add2->GetInDataAnchor(1)); // true + GraphUtils::AddEdge(value2->GetOutDataAnchor(0), add2->GetInDataAnchor(0)); + + GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), merge2->GetInDataAnchor(0)); + GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1)); + GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); + + cond.emplace_back(merge2); + merge = merge2; } static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { @@ -197,12 +235,24 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { auto graph = std::make_shared("test_graph"); NodePtr merge; - CreateLoopGraph(graph, merge); - - AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); + vector loop; + vector cond; + CreateLoopGraph(graph, merge, loop, cond); MarkForceUnknownForCondPass mark_force_unknown_pass; EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond + + EXPECT_EQ(loop.size(), 2); + for (const auto &node : loop) { + EXPECT_FALSE(node->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)); + } + + EXPECT_EQ(cond.size(), 3); + for (const auto &node : cond) { + int64_t group_index = -1; + EXPECT_TRUE(AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)); + EXPECT_EQ(group_index, merge->GetOpDesc()->GetId()); + } } TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { diff --git a/tests/ut/ge/graph/passes/mark_node_unknown_shape_pass_unittest.cc b/tests/ut/ge/graph/passes/mark_node_unknown_shape_pass_unittest.cc index 5157e510..7d4663b3 100644 --- a/tests/ut/ge/graph/passes/mark_node_unknown_shape_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/mark_node_unknown_shape_pass_unittest.cc @@ -24,7 +24,7 @@ #include "common/ge_inner_error_codes.h" #include "inc/pass_manager.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #undef private namespace ge { @@ -33,7 +33,7 @@ protected: void SetUp() {} void TearDown() {} public: - NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { + NodePtr MakeNode(const ComputeGraphPtr &graph, int in_num, int out_num, string name, string type) { GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); auto op_desc = std::make_shared(name, type); for (auto i = 0; i < in_num; ++i) { diff --git a/tests/ut/ge/graph/passes/merge_pass_unittest.cc b/tests/ut/ge/graph/passes/merge_pass_unittest.cc index 75fdb21b..f8f0afea 100644 --- a/tests/ut/ge/graph/passes/merge_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/merge_pass_unittest.cc @@ -110,8 +110,8 @@ TEST_F(UtestGraphPassesMergePass, multiple_inputs) { } /// Merge -/// | \ -/// | \ +/// | \. +/// | \. /// Op1 Op2 Merge2 /// \ | | /// \ | Op3 @@ -137,10 +137,10 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_da } /// Merge -/// | \ -/// | \ +/// | \. +/// | \. /// Op1 Op2 Merge2 -/// \ | | \ +/// \ | | \. /// \ | Op3 /// \ | : /// NetOutput @@ -165,8 +165,8 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_co TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { /// Merge - /// | \ - /// | \ + /// | \. + /// | \. /// Op1 Op2 Merge2 /// \ | | /// \ | Op3 @@ -210,7 +210,7 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { /// Op1 Op2 Merge2 /// \ | /// \ Op3 - /// \ + /// \. /// Merge3 ret = pass_.Run(merge_node2); @@ -224,7 +224,7 @@ TEST_F(UtestGraphPassesMergePass, single_non_const_input) { /// Op1 /// | /// Merge - /// / \ + /// / \. /// Op2 Op3 auto merge_node = NewNode("Merge", MERGE, 1, 2); auto node1 = NewNode("Op1", RELU, 1, 1); @@ -253,7 +253,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input) { /// Const /// | /// Merge Pass Const - /// / \ ===> / \ + /// / \ ===> / \. /// Op1 Op2 Op1 Op2 auto merge_node = NewNode("Merge", MERGE, 1, 2); auto const_node = NewNode("Const", CONSTANT, 1, 1); @@ -284,7 +284,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes) /// / | ===> / \(control anchor) /// Op1 | \ Op1 Constant /// Op2 Op3 | - /// / \ + /// / \. /// Op2 Op3 auto merge_node = NewNode("Merge", MERGE, 1, 2); auto const_node = NewNode("Const", CONSTANT, 1, 1); @@ -329,7 +329,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes1) /// / | ===> / \(control anchor) /// Op1 | \ Op1 Constant /// Op2 Op3 | - /// / \ + /// / \. /// Op2 Op3 auto merge_node = NewNode("Merge", MERGE, 1, 2); auto const_node = NewNode("Const", CONSTANT, 1, 1); @@ -357,7 +357,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { /// C /// | /// Merge - /// / \ + /// / \. /// Op1 Op2 auto switch_node = NewNode("Switch", SWITCH, 1, 2); auto identity_node = NewNode("Identity", SWITCH, 1, 1); @@ -381,7 +381,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { /// . /// . /// C - /// / \ + /// / \. /// Op1 Op2 auto ret = pass_.Run(merge_node); EXPECT_EQ(ret, SUCCESS); diff --git a/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc b/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc index 1b75a613..9ec254d7 100644 --- a/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc @@ -22,7 +22,7 @@ #include "inc/pass_manager.h" #include "graph/utils/tensor_utils.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/passes/multi_batch_pass.h" #include "graph/preprocess/multi_batch_copy_graph.h" #include "graph/preprocess/insert_op/util_insert_aipp_op.h" @@ -45,7 +45,7 @@ protected: } public: - NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { + NodePtr MakeNode(const ComputeGraphPtr &graph, int in_num, int out_num, string name, string type) { GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); auto op_desc = std::make_shared(name, type); for (auto i = 0; i < in_num; ++i) { diff --git a/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc index d5b1db41..a6c3ff6a 100644 --- a/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc @@ -19,7 +19,8 @@ #include #define private public - +#include "inc/graph/ge_local_context.h" +#include "inc/external/ge/ge_api_types.h" #include "common/ge_inner_error_codes.h" #include "inc/pass_manager.h" #include "utils/graph_utils.h" @@ -66,11 +67,11 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { void BuildDefaultGraph() { /// input - /// \ + /// \. /// sqrt pred /// \ / /// cast - /// / \ + /// / \. /// switch_t switch_f /// | | /// F T @@ -118,13 +119,13 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { void BuildDefaultGraph1() { /// input - /// \ + /// \. /// sqrt pred /// \ / /// Switch /// | | /// ----F T---- - /// \ | / \ + /// \ | / \. /// \ Merge1 Merge2 /// \_________| input_node_ = NewNode("input", RELU, 0, 1); @@ -164,14 +165,14 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { void BuildDefaultGraph2() { /// input input1 - /// \ \ + /// \ \. /// sqrt pred sqrt1 pred1 /// \ / \ / /// Switch Switch1 /// | | _______| /// | | / /// ____F T____ - /// \ | / \ + /// \ | / \. /// \ Merge1 Merge2 /// \__________| input_node_ = NewNode("input", RELU, 0, 2); @@ -225,6 +226,70 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { output_true_node_->GetOpDesc()->SetIsInputConst({false}); } + void BuildDefaultGraph3() { + /// input + /// \ + /// sqrt pred + /// \ / + /// Switch + /// | | + /// F T ------ + /// / \_/_ \ + /// / / \ \ + /// Merge sqrt2 sqrt3 + /// / \ \ + /// sqrt1 \ relu + /// \ \ + /// \ sqrt4 + /// \ / + /// Merge1 + input_node_ = NewNode("input", RELU, 0, 1); + AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + pred_node_ = NewNode("pred", GREATER, 2, 1); + sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); + cast_node_ = NewNode("cast", CAST, 2, 2); + + switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); + switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); + AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); + output_false_node_ = NewNode("false_output", RELU, 1, 2); + AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + output_true_node_ = NewNode("true_output", RELU, 1, 2); + AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); + sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1); + AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + sqrt_node2_ = NewNode("sqrt2", SQRT, 1, 1); + AttrUtils::SetStr(sqrt_node2_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + sqrt_node3_ = NewNode("sqrt3", SQRT, 1, 1); + relu_node_ = NewNode("relu", RELU, 1, 1); + sqrt_node4_ = NewNode("sqrt4", SQRT, 1, 1); + AttrUtils::SetStr(sqrt_node4_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); + merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1); + + GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); + GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); + + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), sqrt_node2_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), sqrt_node3_->GetInDataAnchor(0)); + + GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node3_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node4_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node2_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node4_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(1)); + output_false_node_->GetOpDesc()->SetIsInputConst({false}); + output_true_node_->GetOpDesc()->SetIsInputConst({false}); + } + ComputeGraphPtr graph_; ComputeGraphPtr sub_graph_; GeTensorDescPtr default_tensor_desc_; @@ -235,6 +300,9 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { NodePtr cast_node1_; NodePtr sqrt_node_; NodePtr sqrt_node1_; + NodePtr sqrt_node2_; + NodePtr sqrt_node3_; + NodePtr sqrt_node4_; NodePtr input_node_; NodePtr input_node1_; NodePtr switch_node_t; @@ -278,6 +346,16 @@ TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) { EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor())); } +TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph3) { + std::map options; + options.emplace(OPTION_GRAPH_RUN_MODE, "1"); + GetThreadLocalContext().SetGraphOption(options); + BuildDefaultGraph3(); + auto ret = pass_.Run(graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); + EXPECT_EQ(true, merge_node1_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor())); +} + TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { BuildDefaultGraph1(); NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true); diff --git a/tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc b/tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc index 6711b0d3..d353498c 100644 --- a/tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc @@ -57,6 +57,36 @@ ut::GraphBuilder Graph1Builder() { builder.AddDataEdge(cast1, 0, conv2d, 0); return builder; } + +/// data1 const1 +/// \ / +/// add1 +/// | +/// data2 -> switch1 (empty) +/// | +/// conv2d +ut::GraphBuilder Graph2Builder() { + ut::GraphBuilder builder = ut::GraphBuilder("graph2"); + auto data1 = builder.AddNode("data1", "Data", 0, 1); + auto data2 = builder.AddNode("data2", "Data", 0, 1); + auto const1 = builder.AddNode("const1", "Const", 0, 1); + auto add1 = builder.AddNode("add1", "Add", 2, 1); + auto switch1 = builder.AddNode("switch1", "Switch", 2, 1); + auto conv2d = builder.AddNode("conv2d", "Conv2D", 1, 0); + + add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW)); + add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW)); + add1->GetOpDesc()->AddOutputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW)); + GeTensorDesc empty_tensor(GeShape({1, 0, 8, 8}),FORMAT_NCHW); + switch1->GetOpDesc()->UpdateOutputDesc(0, empty_tensor); + + builder.AddDataEdge(data1, 0, add1, 0); + builder.AddDataEdge(const1, 0, add1, 1); + builder.AddDataEdge(add1, 0, switch1, 0); + builder.AddDataEdge(data2, 0, switch1, 1); + builder.AddDataEdge(switch1, 0, conv2d, 0); + return builder; +} } // namespace @@ -85,4 +115,19 @@ TEST_F(UtestReplaceWithEmptyConstPass, replace_whith_empty_const_success) { auto conv2d = graph->FindNode("conv2d"); EXPECT_EQ(conv2d->GetInDataNodes().at(0)->GetType(),"Const"); } + +TEST_F(UtestReplaceWithEmptyConstPass, replace_whith_empty_switch_skip) { + auto builder = Graph2Builder(); + auto graph = builder.GetGraph(); + graph->SetSessionID(0); + ReplaceWithEmptyConstPass replace_with_empty_const_pass; + + EXPECT_EQ(graph->GetDirectNodesSize(), 6); + // run pass on switch1, graph still has 6 nodes + auto switch1 = graph->FindNode("switch1"); + EXPECT_NE(switch1, nullptr); + Status ret = replace_with_empty_const_pass.Run(switch1); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(graph->GetDirectNodesSize(), 6); +} } // namespace ge diff --git a/tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc b/tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc index 3be11452..f941645e 100644 --- a/tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc @@ -31,9 +31,9 @@ class UtestReshapeRecoveryPass : public testing::Test { namespace { /// netoutput1 -/// | \ -///transdata1 \ -/// | \ +/// | \. +///transdata1 \. +/// | \. /// | transdata2 /// | / /// var1 const1 diff --git a/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc b/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc index 351e96d7..ca0cac86 100644 --- a/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc @@ -35,7 +35,7 @@ namespace { /// transdata1 /// | /// reshape1 -/// | \ +/// | \. /// var1 const1 ut::GraphBuilder Graph1Builder() { ut::GraphBuilder builder = ut::GraphBuilder("g1"); @@ -55,11 +55,11 @@ ut::GraphBuilder Graph1Builder() { } /// netoutput1 -/// | \ -///transdata1 \ -/// | \ +/// | \. +///transdata1 \. +/// | \. /// reshape1 reshape2 -/// | \ / \ +/// | \ / \. /// var1 const1 var2 ut::GraphBuilder Graph2Builder() { ut::GraphBuilder builder = ut::GraphBuilder("g2"); @@ -83,9 +83,9 @@ ut::GraphBuilder Graph2Builder() { } /// netoutput1 -/// | \ -///transdata1 \ -/// | \ +/// | \. +///transdata1 \. +/// | \. /// reshape1 transdata2 /// | \ / /// var1 const1 diff --git a/tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc b/tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc index 6d12a49d..8cdfd0c7 100644 --- a/tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc @@ -34,7 +34,7 @@ class UtestResourcePairControlPass : public testing::Test { namespace { /// netoutput1 -/// | \ +/// | \. /// StackPush StackPop /// | | /// var1 const1 diff --git a/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc b/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc index 00157395..6565295c 100644 --- a/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc @@ -20,7 +20,7 @@ #include #include "framework/omg/omg_inner_types.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/passes/subgraph_const_migration_pass.h" #include "inc/pass_manager.h" #include "register/op_registry.h" @@ -32,7 +32,7 @@ class UtestSubgraphConstMigrationPass : public testing::Test { void TearDown() {} public: - NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { + NodePtr MakeNode(const ComputeGraphPtr &graph, int in_num, int out_num, string name, string type) { GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); auto op_desc = std::make_shared(name, type); for (auto i = 0; i < in_num; ++i) { diff --git a/tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc b/tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc index dcad318c..22734047 100644 --- a/tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc @@ -63,9 +63,9 @@ ComputeGraphPtr BuildGraph1() { /// netoutput1 /// | /// merge1 -/// / \ +/// / \. /// / add1 -/// / F| \ +/// / F| \. /// addn1 swtich2 var3 /// \F T/ | /// switch1 | @@ -101,9 +101,9 @@ ComputeGraphPtr BuildGraph2() { /// add1 /// / \T /// var3 swtich2 -/// T/ \ -/// switch1 \ -/// / \ \ +/// T/ \. +/// switch1 \. +/// / \ \. /// var1 var2 var4 ComputeGraphPtr BuildGraph3() { auto builder = ut::GraphBuilder("g3"); @@ -129,7 +129,7 @@ ComputeGraphPtr BuildGraph3() { /// netoutput1 /// | /// merge1 -/// / \ +/// / \. /// add1 addn1 /// / \T F/ /// var3 swtich2 diff --git a/tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc b/tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc index dbb163e1..d05bd695 100644 --- a/tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc @@ -402,7 +402,7 @@ TEST_F(UtestGraphPassesTransOpBreadthFusionPass, test_multi_anchor_case) { } /// ----> netoutput1 -/// / | \ +/// / | \. /// transdata1 transdata2 transdata3 /// \ / | /// var1-------------- @@ -432,7 +432,7 @@ static ComputeGraphPtr BuildGraph1() { } /// ---------> netoutput1 -/// / | \ +/// / | \. /// transdata1 transdata2(l1) transdata3(l1) /// \ / | /// var1------------------ diff --git a/tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc b/tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc index a9ea41ea..dbac3246 100644 --- a/tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc @@ -456,19 +456,19 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) /// -->transpose1 -->transpose3-->sinh2 /// | \ / /// | -->transpose2 - /// | \ + /// | \. /// / -->cast3-->cast4-->sinh3 /// / /// / -->transpose4-->transpose5-->sinh4 /// / / /// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5 - /// \ \ + /// \ \. /// \ -->sinh6 - /// \ + /// \. /// \ -->transpose6-->transpose7-->sinh9 /// \ / /// -->reshape-->cast6-->cast7-->sinh8 - /// \ + /// \. /// -->sinh7 /// after optimized graph @@ -479,15 +479,15 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) /// / /-->transpose3-->sinh2 /// -->Cast1 /// / \-->sinh7 - /// / \ + /// / \. /// / -->sinh9 /// Node4D /// \ -->sinh4 /// \ / /// -->Cast5-->sinh5 - /// \ \ + /// \ \. /// \ -->sinh6 - /// \ + /// \. /// -->Cast7-->sinh8 ge::ComputeGraphPtr graph = std::make_shared("test"); diff --git a/tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc b/tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc index 1220b35e..9c6d8276 100644 --- a/tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc @@ -180,7 +180,7 @@ ComputeGraphPtr GetGraph7(size_t symmetric_transdata_num, size_t asymmetric_tran /// TransData TransData ... MatMul ... /// \ | / / / /// HcomAllReduce - /// / | \ \ \ + /// / | \ \ \. /// TransData TransData ... RealDiv ... ComputeGraphPtr graph = std::make_shared("test"); NodePtr allreduce = @@ -340,7 +340,7 @@ TEST(UtestTransopNearbyAllreduceFusionPass, test7_all_reduce_with_multiple_trans /// TransData TransData ... MatMul ... /// \ | / / / /// HcomAllReduce - /// / | \ \ \ + /// / | \ \ \. /// TransData TransData ... RealDiv ... size_t symmetric_transdata_num = 20; size_t asymmetric_transdata_num = 20; diff --git a/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc b/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc index f1ea7a27..655867a7 100644 --- a/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc @@ -66,7 +66,7 @@ namespace { /// transdata2 /// | /// assign1 -/// / \ +/// / \. /// transdata1 | /// | | /// var1 const1 diff --git a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc index ebd0ab25..b1c07d81 100644 --- a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc +++ b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc @@ -23,6 +23,7 @@ #include "graph/passes/graph_builder_utils.h" #include "graph/utils/attr_utils.h" #include "graph/debug/ge_attr_define.h" +#include "graph/manager/graph_var_manager.h" #define private public #define protected public @@ -179,6 +180,21 @@ TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) { EXPECT_EQ(intput2_result_shape_range.size(), 0); } +TEST_F(UtestGraphPreproces, test_update_input_fail) { + ge::GraphPrepare graph_prepare; + graph_prepare.compute_graph_ = BuildGraph1(); + + ge::GeTensorDesc tensor1; + tensor1.SetFormat(ge::FORMAT_NCHW); + tensor1.SetShape(ge::GeShape({3, 12, 5, 5})); + tensor1.SetDataType(ge::DT_UNDEFINED); + GeTensor input1(tensor1); + std::vector user_input = {input1}; + std::map graph_option; + auto ret = graph_prepare.UpdateInput(user_input, graph_option); + EXPECT_EQ(ret, ge::FAILED); +} + TEST_F(UtestGraphPreproces, test_check_user_input) { ge::GraphPrepare graph_prepare; graph_prepare.compute_graph_ = BuildGraph1(); @@ -270,4 +286,26 @@ TEST_F(UtestGraphPreproces, test_prepare_dyn_shape) { GraphPrepare graph_prepare; EXPECT_EQ(graph_prepare.PrepareDynShape(graph_node, user_input, compute_graph, 0), SUCCESS); } + +TEST_F(UtestGraphPreproces, test_updar_variable_formats) { + auto builder = ut::GraphBuilder("g1"); + auto var = builder.AddNode("var", VARIABLE, 1, 1); + auto g1 = builder.GetGraph(); + g1->SetSessionID(0); + TransNodeInfo trans_node_info; + VarTransRoad fusion_road; + fusion_road.emplace_back(trans_node_info); + VarManager::Instance(g1->GetSessionID())->SetTransRoad(var->GetName(), fusion_road); + GraphPrepare graph_prepare; + EXPECT_EQ(graph_prepare.UpdateVariableFormats(g1), INTERNAL_ERROR); + + auto builder1 = ut::GraphBuilder("g2"); + auto var1 = builder1.AddNode("var1", VARIABLE, 1, 1); + auto g2 = builder1.GetGraph(); + g2->SetSessionID(0); + VarTransRoad fusion_road1; + VarManager::Instance(g2->GetSessionID())->SetTransRoad(var1->GetName(), fusion_road1); + AttrUtils::SetStr(var1->GetOpDesc(), REF_VAR_SRC_VAR_NAME, "var1"); + EXPECT_EQ(graph_prepare.UpdateVariableFormats(g2), SUCCESS); +} } \ No newline at end of file diff --git a/tests/ut/ge/graph/transop_util_unittest.cc b/tests/ut/ge/graph/transop_util_unittest.cc index 9f645c22..02aa97bf 100644 --- a/tests/ut/ge/graph/transop_util_unittest.cc +++ b/tests/ut/ge/graph/transop_util_unittest.cc @@ -16,7 +16,7 @@ #include -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "common/debug/log.h" #include "common/types.h" diff --git a/tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc b/tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc index 37b4bda7..bf350b6c 100644 --- a/tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc +++ b/tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc @@ -35,8 +35,8 @@ namespace { /// shapeNo1 /// | /// addnYes1 -/// / \ -/// / \ +/// / \. +/// / \. /// const1 const2 ComputeGraphPtr BuildGraph1() { @@ -57,9 +57,9 @@ ComputeGraphPtr BuildGraph1() { /// /// netoutput1 -/// / \ \ -/// add1 assign1 \ -/// / \ / \ \ +/// / \ \. +/// add1 assign1 \. +/// / \ / \ \. /// var1 var2 const1 var3 ComputeGraphPtr BuildGraph2() { diff --git a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc index e14178d8..500dbc2a 100644 --- a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc +++ b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#include #include #include "ir_build/option_utils.h" #include "graph/testcase/ge_graph/graph_builder_utils.h" @@ -21,7 +21,7 @@ #include "graph/utils/graph_utils.h" #include "ge/ge_ir_build.h" #include "graph/ops_stub.h" - +#include "ge/ir_build/attr_options/attr_options.h" #define protected public #define private public @@ -70,6 +70,22 @@ static ComputeGraphPtr BuildComputeGraph() { return builder.GetGraph(); } +static ComputeGraphPtr BuildComputeGraph1() { + auto builder = ut::GraphBuilder("test"); + auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3}); + auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10}); + auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1); + auto node1 = builder.AddNode("addd", "Mul", 2, 1); + auto node2 = builder.AddNode("ffm", "FrameworkOp", 2, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + builder.AddDataEdge(data1, 0, addn1, 0); + builder.AddDataEdge(data2, 0, addn1, 1); + builder.AddDataEdge(addn1, 0,netoutput, 0); + + return builder.GetGraph(); +} + // data not set attr index; // but becasue of op proto, register attr index. so all data index is zero; static Graph BuildIrGraph() { @@ -89,10 +105,12 @@ static Graph BuildIrGraph1() { auto data1 = op::Data("data1").set_attr_index(0); auto data2 = op::Data("data2").set_attr_index(1); auto data3 = op::Data("data3"); - std::vector inputs {data1, data2, data3}; + auto data4 = op::Data("Test"); + std::vector inputs {data1, data2, data3, data4}; std::vector outputs; Graph graph("test_graph"); + graph.AddNodeByOp(Operator("gg", "Mul")); graph.SetInputs(inputs).SetOutputs(outputs); return graph; } @@ -349,7 +367,7 @@ TEST(UtestIrBuild, check_data_op_attr_index_valid) { }; ModelBufferData model; graphStatus ret = aclgrphBuildModel(graph, build_options, model); - EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); + EXPECT_EQ(ret, ge::FAILED); } // set attr index invalid, when not set input shape range @@ -359,7 +377,7 @@ TEST(UtestIrBuild, check_data_attr_index_succ_no_input_range) { const map build_options; ModelBufferData model; graphStatus ret = aclgrphBuildModel(graph, build_options, model); - EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); + EXPECT_EQ(ret, ge::FAILED); } TEST(UtestIrBuild, check_modify_mixlist_param) { @@ -368,7 +386,58 @@ TEST(UtestIrBuild, check_modify_mixlist_param) { {"ge.exec.modify_mixlist", "/modify.json"} }; ModelBufferData model; - + + auto ret = aclgrphBuildModel(graph, build_options, model); + EXPECT_EQ(ret, GRAPH_PARAM_INVALID); +} + +TEST(UtestIrBuild, check_op_precision_mode_param) { + Graph graph = BuildIrGraph1(); + const std::map build_options = { + {"ge.exec.op_precision_mode", "./op_precision_mode.ini"} + }; + ModelBufferData model; + auto ret = aclgrphBuildModel(graph, build_options, model); EXPECT_EQ(ret, GRAPH_PARAM_INVALID); +} + +TEST(UtestIrBuild, check_build_model_and_build_step) { + Graph graph_1 = BuildIrGraph1(); + const std::map build_options_1 = { + {"ge.buildMode", "xxx"} + }; + ModelBufferData model_1; + auto ret_1 = aclgrphBuildModel(graph_1, build_options_1, model_1); + EXPECT_NE(ret_1, GRAPH_SUCCESS); + + Graph graph_2 = BuildIrGraph1(); + const std::map build_options_2 = { + {"ge.buildStep", "xxx"} + }; + ModelBufferData model_2; + auto ret_2 = aclgrphBuildModel(graph_2, build_options_2, model_2); + EXPECT_NE(ret_2, GRAPH_SUCCESS); + + Graph graph_3 = BuildIrGraph1(); + const std::map build_options_3 = { + {"ge.buildMode", "tuning"} + }; + ModelBufferData model_3; + auto ret_3 = aclgrphBuildModel(graph_3, build_options_3, model_3); + EXPECT_NE(ret_3, GRAPH_SUCCESS); +} + +TEST(UtestIrBuild, atc_cfg_optype_param) { + ComputeGraphPtr graph = BuildComputeGraph1(); + FILE *fp = fopen("./keep.txt", "w+"); + if (fp) { + fprintf(fp, "Test\n"); + fprintf(fp, "OpType::Mul\n"); + fprintf(fp, "Optype::Sub\n"); + fclose(fp); + } + auto ret = KeepDtypeFunc(graph, "./keep.txt"); + (void)remove("./keep.txt"); + EXPECT_EQ(ret, GRAPH_PARAM_INVALID); } \ No newline at end of file diff --git a/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc b/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc index d2679439..c053885f 100644 --- a/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc @@ -82,26 +82,53 @@ TEST_F(UtestHybridModelAsyncExecutor, BuildDeviceTensor) { GeTensorDesc ge_tensor_desc; int64_t output_size = 100; std::vector outputs; - executor.BuildDeviceTensor(tensor, ge_tensor_desc, output_size, outputs); + auto ret = executor.BuildDeviceTensor(tensor, ge_tensor_desc, output_size, outputs); auto size = tensor.GetSize(); ASSERT_EQ(size, 100); } -TEST_F(UtestHybridModelAsyncExecutor, Test_execute_internal) { +TEST_F(UtestHybridModelAsyncExecutor, Test_execute) { ComputeGraphPtr graph = std::make_shared("test"); GeRootModelPtr ge_root_model = make_shared(graph); ge_root_model->SetModelName("test_name"); HybridModel hybrid_model(ge_root_model); + hybrid_model.root_graph_item_.reset(new GraphItem); HybridModelExecutor executor(&hybrid_model, 0, nullptr); ASSERT_EQ(executor.Init(), SUCCESS); auto &context = executor.context_; - GraphItem graph_item; - SubgraphExecutor subgraph_executor(&graph_item, &context); HybridModelExecutor::ExecuteArgs args; std::pair> eof_entry; eof_entry.first = nullptr; context.callback_manager->callback_queue_.Push(eof_entry); - ASSERT_EQ(executor.ExecuteGraphInternal(subgraph_executor, args), SUCCESS); + ASSERT_EQ(executor.Execute(args), SUCCESS); +} + +TEST_F(UtestHybridModelAsyncExecutor, test_PrepareInputs) { + ComputeGraphPtr graph = std::make_shared("test"); + GeRootModelPtr ge_root_model = make_shared(graph); + ge_root_model->SetModelName("test_name"); + GeModelPtr ge_sub_model = make_shared(); + HybridModel hybrid_model(ge_root_model); + HybridModelAsyncExecutor executor(&hybrid_model); + GeTensorDescPtr tensor_desc = make_shared(GeShape({-1, 16, 16, 3})); + tensor_desc->SetShapeRange({{1, 256}, {16, 16}, {16, 16}, {3, 3}}); + executor.input_tensor_desc_.insert({0, tensor_desc}); + executor.device_id_ = 0; + executor.input_sizes_.insert({0, -1}); + executor.is_input_dynamic_.push_back(true); + + unique_ptr data_buf(new (std::nothrow)uint8_t[3072]); + InputData input_data; + input_data.blobs.push_back(DataBuffer(data_buf.get(), 3072, false)); + input_data.shapes.push_back({1, 16, 16, 3}); + HybridModelExecutor::ExecuteArgs args; + + auto ret = executor.PrepareInputs(input_data, args); + ASSERT_EQ(ret, SUCCESS); + ASSERT_EQ(args.input_desc[0]->GetShape().ToString(), GeShape({1, 16, 16, 3}).ToString()); + int64_t tensor_size = 0; + TensorUtils::GetSize(*(args.input_desc[0]), tensor_size); + ASSERT_EQ(tensor_size, 3104); } } // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc index 2dc3b639..827705ae 100644 --- a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc @@ -249,6 +249,9 @@ TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { graph_context.callback_manager = std::unique_ptr(new CallbackManager()); ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); + auto root_graph = hybrid_model.root_graph_; + switch_t = root_graph->FindNode("switch_t"); + switch_f = root_graph->FindNode("switch_f"); const auto node_it_t = hybrid_model.node_items_.find(switch_t); const auto node_it_f = hybrid_model.node_items_.find(switch_f); ASSERT_NE(hybrid_model.node_items_.end(), node_it_t); diff --git a/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc b/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc index 07022230..96641c59 100644 --- a/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc +++ b/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc @@ -27,6 +27,7 @@ #include "hybrid/executor/hybrid_model_executor.h" #include "hybrid/executor/worker/execution_engine.h" #include "hybrid/executor/subgraph_executor.h" +#include "hybrid/executor/worker/task_compile_engine.h" #undef private #undef protected @@ -45,7 +46,14 @@ class UtestExecutionEngine : public testing::Test { }; namespace { const int kIntBase = 10; +class CompileNodeExecutor : public NodeExecutor { + public: + Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const override { + return SUCCESS; + } +}; } + static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { auto op_desc = std::make_shared(name, type); op_desc->SetStreamId(0); @@ -83,18 +91,14 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { execution_context.profiling_level = 1; SubgraphContext subgraph_context(nullptr, &execution_context); - NodeState node_state(*node_item, &subgraph_context); - auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); - auto shared_task_context = std::shared_ptr(task_context.release()); - node_state.SetTaskContext(shared_task_context); - - ExecutionEngine execution_engine; - ASSERT_TRUE(node_state.GetTaskContext() != nullptr); + auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); + ASSERT_TRUE(node_state->GetTaskContext() != nullptr); std::function callback; SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); - executor.InitCallback(&node_state, callback); - EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); + executor.InitCallback(node_state.get(), callback); + ExecutionEngine execution_engine; + EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); } TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { @@ -118,21 +122,22 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { execution_context.model = &hybrid_model; SubgraphContext subgraph_context(nullptr, &execution_context); - NodeState node_state(*node_item, &subgraph_context); - auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); + auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); uint32_t task_id = 0; uint32_t stream_id = 1; std::string task_type = "rts"; uint32_t block_dim = 0; - task_context->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); - auto shared_task_context = std::shared_ptr(task_context.release()); - node_state.SetTaskContext(shared_task_context); + node_state->GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim, op_desc->GetType()); - ExecutionEngine execution_engine; - ASSERT_TRUE(node_state.GetTaskContext() != nullptr); + ASSERT_TRUE(node_state->GetTaskContext() != nullptr); std::function callback; SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); - executor.InitCallback(&node_state, callback); - EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); + executor.InitCallback(node_state.get(), callback); + ExecutionEngine execution_engine; + EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); + + CompileNodeExecutor node_executor; + node_item->node_executor = &node_executor; + EXPECT_EQ(TaskCompileEngine::Compile(*node_state, &execution_context), SUCCESS); } diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 7a2a5dfe..782a06d6 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -25,8 +25,8 @@ #include "hybrid/model/hybrid_model_builder.h" #include "hybrid/model/hybrid_model.h" #include "hybrid/node_executor/node_executor.h" -#include "model/ge_model.h" -#include "model/ge_root_model.h" +#include "common/model/ge_model.h" +#include "common/model/ge_root_model.h" #include "hybrid/node_executor/aicore/aicore_op_task.h" #include "framework/common/taskdown_common.h" #include "framework/common/debug/log.h" @@ -34,20 +34,20 @@ #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/executor/hybrid_model_executor.h" #include "hybrid/node_executor/aicore/aicore_task_builder.h" +#include "hybrid/node_executor/aicore/aicore_node_executor.h" #include "graph/load/model_manager/tbe_handle_store.h" #include "graph/manager/graph_mem_allocator.h" #include "hybrid/common/npu_memory_allocator.h" #include "graph/types.h" #include "graph/utils/tensor_utils.h" #include "graph/testcase/ge_graph/graph_builder_utils.h" -#undef private -#undef protected +#include "single_op/task/build_task_utils.h" +#include "graph/op_desc_impl.h" using namespace std; -using namespace testing; -using namespace ge; -using namespace hybrid; +namespace ge { +using namespace hybrid; class UtestGeHybrid : public testing::Test { protected: @@ -58,16 +58,30 @@ class UtestGeHybrid : public testing::Test { } }; -static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { +static ge::OpDescPtr CreateOpDesc(string name = "", string type = "", int in_num = 0, int out_num = 0) { auto op_desc = std::make_shared(name, type); op_desc->SetStreamId(0); - op_desc->SetId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); + TensorUtils::SetSize(tensor, 64); + vector input_offset; + for (int i = 0; i < in_num; ++i) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(index * 64 + i * 64); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; ++i) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(index * 64 + in_num * 64 + i * 64); + } + op_desc->SetOutputOffset(output_offset); op_desc->SetWorkspace({}); - ; op_desc->SetWorkspaceBytes({}); - op_desc->SetInputOffset({}); - op_desc->SetOutputOffset({}); ge::AttrUtils::SetStr(op_desc, ge::TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF_AIVEC"); bool support_dynamic = true; @@ -99,7 +113,7 @@ TEST_F(UtestGeHybrid, aicore_op_task_init_success) { op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); std::string kernel_name("kernel/Add"); AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); - ASSERT_EQ(aicore_task->InitWithTaskDef(*op_desc.get(), task_def), SUCCESS); + ASSERT_EQ(aicore_task->Init(*op_desc.get(), task_def), SUCCESS); rtStream_t stream = nullptr; rtStreamCreate(&stream, 0); ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS); @@ -159,11 +173,9 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) { GraphExecutionContext execution_context; SubgraphContext subgraph_context(nullptr, &execution_context); - NodeState node_state(*node_item, &subgraph_context); - auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); - ASSERT_TRUE(task_context != nullptr); + auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS); - ASSERT_EQ(aicore_task->UpdateTilingInfo(*task_context), SUCCESS); + ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state->GetTaskContext()), SUCCESS); } TEST_F(UtestGeHybrid, index_taskdefs_failed) { @@ -330,6 +342,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor) { ComputeGraphPtr compute_graph = MakeShared("abc"); GeRootModelPtr root_model = MakeShared(compute_graph); HybridModel model(root_model); + model.root_graph_item_.reset(new GraphItem); HybridModel *model_ptr = &model; uint32_t device_id = 0; @@ -412,49 +425,84 @@ TEST_F(UtestGeHybrid, test_parse_parallel_group) { } TEST_F(UtestGeHybrid, unfold_subgraphs_success) { - ComputeGraphPtr merged_graph = nullptr; + ComputeGraphPtr root_graph = std::make_shared("root_graph"); + auto partitioned_call_op_desc = CreateOpDesc("partitioned_call", PARTITIONEDCALL, 3, 1); + auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); + partitioned_call_op_desc->AddSubgraphName("f"); + partitioned_call_op_desc->SetSubgraphInstanceName(0, "sub_graph"); ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); - OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); - NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); + { + OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); + NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); + sub_sub_graph1->SetParentGraph(root_graph); + root_graph->AddSubGraph(sub_sub_graph1); + } ComputeGraphPtr sub_sub_graph2 = std::make_shared("while_body"); - /*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); - NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ - OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); - NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); - sub_sub_graph2->SetGraphUnknownFlag(true); - /*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); - NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); - sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); - sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/ + { + OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); + NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); + sub_sub_graph2->SetGraphUnknownFlag(true); + sub_sub_graph2->SetParentGraph(root_graph); + root_graph->AddSubGraph(sub_sub_graph2); + } + // Will unfold to merged_graph. ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); - OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); - NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); - sub_graph->SetGraphUnknownFlag(true); - sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); - sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); - sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); - sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body"); + { + OpDescPtr sub_graph_data1_op_desc = CreateOpDesc("data1", DATA, 1, 1); + OpDescPtr sub_graph_data2_op_desc = CreateOpDesc("data2", DATA, 1, 1); + OpDescPtr sub_graph_data3_op_desc = CreateOpDesc("data3", DATA, 1, 1); + NodePtr sub_graph_data1_node = sub_graph->AddNode(sub_graph_data1_op_desc); + NodePtr sub_graph_data2_node = sub_graph->AddNode(sub_graph_data2_op_desc); + NodePtr sub_graph_data3_node = sub_graph->AddNode(sub_graph_data3_op_desc); + + AttrUtils::SetInt(sub_graph_data1_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 0); + AttrUtils::SetInt(sub_graph_data2_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 1); + AttrUtils::SetInt(sub_graph_data3_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 2); + + OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE, 2, 2); + NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); + sub_sub_graph1->SetParentNode(sub_graph_while_node); + sub_sub_graph2->SetParentNode(sub_graph_while_node); + sub_graph_while_op_desc->AddSubgraphName("while_cond"); + sub_graph_while_op_desc->SetSubgraphInstanceName(0, "while_cond"); + sub_graph_while_op_desc->AddSubgraphName("while_body"); + sub_graph_while_op_desc->SetSubgraphInstanceName(1, "while_body"); + + OpDescPtr sub_graph_matmul_op_desc = CreateOpDesc("matmul", MATMUL, 2, 1); + NodePtr sub_graph_matmul_node = sub_graph->AddNode(sub_graph_matmul_op_desc); + + OpDescPtr sub_graph_output_op_desc = CreateOpDesc("output", NETOUTPUT, 1, 1); + NodePtr sub_graph_output_node = sub_graph->AddNode(sub_graph_output_op_desc); + + GraphUtils::AddEdge(sub_graph_data1_node->GetOutDataAnchor(0), sub_graph_while_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(sub_graph_data2_node->GetOutDataAnchor(0), sub_graph_while_node->GetInDataAnchor(1)); + GraphUtils::AddEdge(sub_graph_data3_node->GetOutDataAnchor(0), sub_graph_matmul_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(sub_graph_while_node->GetOutDataAnchor(0), sub_graph_matmul_node->GetInDataAnchor(1)); + GraphUtils::AddEdge(sub_graph_matmul_node->GetOutDataAnchor(0), sub_graph_output_node->GetInDataAnchor(0)); + + sub_graph->SetGraphUnknownFlag(true); + sub_graph->SetParentNode(partitioned_call_node); + sub_graph->SetParentGraph(root_graph); + root_graph->AddSubGraph(sub_graph); + } - ComputeGraphPtr root_graph = std::make_shared("root_graph"); - auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); - auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); - partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); - partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); - - root_graph->AddSubGraph(sub_sub_graph1); - root_graph->AddSubGraph(sub_sub_graph2); - sub_sub_graph1->SetParentGraph(root_graph); - sub_sub_graph2->SetParentGraph(root_graph); - sub_sub_graph1->SetParentNode(sub_graph_while_node); - sub_sub_graph2->SetParentNode(sub_graph_while_node); - - root_graph->AddSubGraph(sub_graph); - sub_graph->SetParentNode(partitioned_call_node); - sub_graph->SetParentGraph(root_graph); + OpDescPtr graph_data1_op_desc = CreateOpDesc("data1", DATA, 1, 1); + OpDescPtr graph_data2_op_desc = CreateOpDesc("data2", DATA, 1, 1); + OpDescPtr graph_data3_op_desc = CreateOpDesc("data3", DATA, 1, 1); + NodePtr graph_data1_node = root_graph->AddNode(graph_data1_op_desc); + NodePtr graph_data2_node = root_graph->AddNode(graph_data2_op_desc); + NodePtr graph_data3_node = root_graph->AddNode(graph_data3_op_desc); + AttrUtils::SetInt(graph_data1_op_desc, ATTR_NAME_INDEX, 0); + AttrUtils::SetInt(graph_data2_op_desc, ATTR_NAME_INDEX, 1); + AttrUtils::SetInt(graph_data3_op_desc, ATTR_NAME_INDEX, 2); + GraphUtils::AddEdge(graph_data1_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(graph_data2_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(1)); + GraphUtils::AddEdge(graph_data3_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(2)); + ComputeGraphPtr merged_graph = nullptr; GeRootModelPtr root_model = MakeShared(root_graph); HybridModel hybrid_model(root_model); HybridModelBuilder hybrid_model_builder(hybrid_model); @@ -476,12 +524,14 @@ TEST_F(UtestGeHybrid, TestTaskContext) { node_item->output_start = 0; GraphExecutionContext execution_context; - SubgraphContext subgraph_context(nullptr, &execution_context); + GraphItem graph_item; + SubgraphContext subgraph_context(&graph_item, &execution_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); subgraph_context.all_inputs_.resize(2); subgraph_context.all_outputs_.resize(1); - NodeState node_state(*node_item, &subgraph_context); - auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); + auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); + auto task_context = node_state->GetTaskContext(); ASSERT_TRUE(task_context != nullptr); auto desc = task_context->MutableInputDesc(2); ASSERT_TRUE(desc == nullptr); @@ -521,12 +571,14 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_update_args) { node_item->output_start = 0; GraphExecutionContext execution_context; - SubgraphContext subgraph_context(nullptr, &execution_context); + GraphItem graph_item; + SubgraphContext subgraph_context(&graph_item, &execution_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); subgraph_context.all_inputs_.resize(2); subgraph_context.all_outputs_.resize(1); - NodeState node_state(*node_item, &subgraph_context); - auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); + auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); + auto task_context = node_state->GetTaskContext(); int32_t buffer[1]; aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer)); @@ -672,6 +724,15 @@ TEST_F(UtestGeHybrid, test_key_for_kernel_bin) { EXPECT_EQ(atomic_task->GetKeyForKernelName(op_desc), "Sum_atomic_kernelname"); } +TEST_F(UtestGeHybrid, test_op_type) { + auto aicore_task = std::unique_ptr(new(std::nothrow)hybrid::AiCoreOpTask()); + aicore_task->op_type_ = "Add"; + EXPECT_EQ(aicore_task->GetOpType(), "Add"); + + auto atomic_task = std::unique_ptr(new(std::nothrow)hybrid::AtomicAddrCleanOpTask()); + EXPECT_EQ(atomic_task->GetOpType(), "DynamicAtomicAddrClean"); +} + TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) { NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl", NodeExecutorManager::ExecutorType::HCCL); @@ -736,11 +797,41 @@ TEST_F(UtestGeHybrid, TestParseDependencies) { std::vector deps; deps.push_back("Data"); auto op_desc = netoutput->GetOpDesc(); - op_desc->input_name_idx_["Data"] = 0; + op_desc->impl_->input_name_idx_["Data"] = 0; auto data_desc = data->GetOpDesc(); auto tensor = std::make_shared(); auto tensor_desc = data_desc->MutableInputDesc(0); AttrUtils::SetTensor(tensor_desc, "_value", tensor); std::set dependent_for_shape_inference; ASSERT_EQ(builder.ParseDependencies(*node_item, deps, dependent_for_shape_inference), SUCCESS); -} \ No newline at end of file +} + +TEST_F(UtestGeHybrid, TestTaskExecuteAsync) { + auto graph = make_shared("graph"); + OpDescPtr op_desc = CreateOpDesc("Add", "Add"); + GeShape shape({2, 16}); + GeTensorDesc tensor_desc(shape); + op_desc->AddInputDesc(tensor_desc); + op_desc->AddInputDesc(tensor_desc); + op_desc->AddOutputDesc(tensor_desc); + auto node = graph->AddNode(op_desc); + std::unique_ptr node_item; + NodeItem::Create(node, node_item); + node_item->input_start = 0; + node_item->output_start = 0; + + GraphExecutionContext execution_context; + GraphItem graph_item; + SubgraphContext subgraph_context(&graph_item, &execution_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + subgraph_context.all_inputs_.resize(2); + subgraph_context.all_outputs_.resize(1); + auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); + auto task_context = *node_state->GetTaskContext(); + ASSERT_NE(BuildTaskUtils::GetTaskInfo(task_context), ""); + std::unique_ptr task1(new AiCoreOpTask()); + std::vector> tasks; + AiCoreNodeTask node_task(std::move(tasks)); + ASSERT_EQ(node_task.ExecuteAsync(task_context, nullptr), SUCCESS); +} +} // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/hybrid/known_node_executor_unittest.cc b/tests/ut/ge/hybrid/known_node_executor_unittest.cc index 98e985f7..b6d06f5d 100644 --- a/tests/ut/ge/hybrid/known_node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/known_node_executor_unittest.cc @@ -27,6 +27,7 @@ #undef protected #include "graph/manager/graph_mem_allocator.h" #include "../graph/passes/graph_builder_utils.h" +#include "../inc/graph/utils/graph_utils.h" using namespace std; using namespace testing; @@ -48,6 +49,34 @@ class KnownNodeTaskMock : public KnownNodeTask { }; } +static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { + auto op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + op_desc->SetId(0); + + op_desc->SetWorkspace({}); + ; + op_desc->SetWorkspaceBytes({}); + op_desc->SetInputOffset({}); + op_desc->SetOutputOffset({}); + + ge::AttrUtils::SetStr(op_desc, ge::TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF_AIVEC"); + bool support_dynamic = true; + ge::AttrUtils::GetBool(op_desc, "support_dynamicshape", support_dynamic); + return op_desc; +} + +static ComputeGraphPtr BuildDataDirectConnectGraph() { + const char *kRefIndex = "_parent_node_index"; + ge::ut::GraphBuilder builder("subgraph"); + auto data = builder.AddNode("Data", "Data", 1, 1); + auto netoutput = builder.AddNode("NetOutput", "NetOutput", 1, 1); + (void)AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), kRefIndex, 0); + + builder.AddDataEdge(data, 0, netoutput, 0); + return builder.GetGraph(); +} + TEST_F(UnknownNodeExecutorTest, test_init_davinci_model) { auto davinci_model = std::make_shared(0, nullptr); davinci_model->SetDeviceId(0); @@ -88,4 +117,29 @@ TEST_F(UnknownNodeExecutorTest, TestParseAttrForAllocatingOutputs) { ASSERT_EQ(node_item.ref_outputs[1], const_node); ASSERT_EQ(node_item.reuse_inputs.size(), 1); ASSERT_EQ(node_item.reuse_inputs[0], 0); -} \ No newline at end of file +} + +TEST_F(UnknownNodeExecutorTest, TestSetGlobalStep) { + OpDescPtr op_desc = CreateOpDesc("PartitionedCall", "PartitionedCall"); + auto root_graph = make_shared("root_graph"); + auto node = root_graph->AddNode(op_desc); + node->SetOwnerComputeGraph(root_graph); + auto sub_graph = BuildDataDirectConnectGraph(); + sub_graph->SetParentGraph(root_graph); + sub_graph->SetParentNode(node); + node->GetOpDesc()->AddSubgraphName("subgraph"); + node->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph"); + root_graph->AddSubgraph("subgraph", sub_graph); + + GeRootModelPtr ge_root_model = make_shared(root_graph); + HybridModel hybrid_model(ge_root_model); + auto *step_id = new int64_t[1]; + step_id[0] = 520; + std::unique_ptr tensor_buf; + tensor_buf = tensor_buf->Create((void *)step_id, sizeof(int64_t)); + hybrid_model.global_step_ = std::move(tensor_buf); + KnownNodeExecutor known_node_executor; + std::shared_ptr davinci_model = MakeShared(0, nullptr); + known_node_executor.SetDaviciModel(hybrid_model, node, davinci_model); + EXPECT_EQ(*(static_cast(davinci_model->global_step_addr_)), 520); +} diff --git a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc index 2ab82350..eb6030dc 100644 --- a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc +++ b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc @@ -29,7 +29,7 @@ #include "graph/utils/graph_utils.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_local_context.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" using namespace std; using namespace testing; @@ -214,11 +214,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { ASSERT_EQ(it->second->frame_index_, index); ASSERT_EQ(it->second->parent_frame_, -1); }; - TestFrameGroup(enter1, control_group_index); - TestFrameGroup(active1, control_group_index); - TestFrameGroup(active2, control_group_index); - TestFrameGroup(active3, control_group_index); - TestFrameGroup(output1, -1); + auto root_graph = hybrid_model.root_graph_; + auto enter1_node = root_graph->FindNode("enter"); + auto active1_node = root_graph->FindNode("active1"); + auto active2_node = root_graph->FindNode("active2"); + auto active3_node = root_graph->FindNode("active3"); + auto output1_node = root_graph->FindNode("net_output"); + TestFrameGroup(enter1_node, control_group_index); + TestFrameGroup(active1_node, control_group_index); + TestFrameGroup(active2_node, control_group_index); + TestFrameGroup(active3_node, control_group_index); + TestFrameGroup(output1_node, -1); engine_mapping.clear(); task_executor.clear(); @@ -346,4 +352,41 @@ EXPECT_EQ(hybrid_model_builder.InitVariableTensors(), SUCCESS); EXPECT_EQ(hybrid_model_builder.hybrid_model_.variable_tensors_.size(), 1); HostMemManager::Instance().var_memory_base_map_.clear(); } + +TEST_F(UtestHybridModelBuilder, TestInitHcclExecutorOnDemand) { + NodeExecutorManager::GetInstance().builders_.erase(NodeExecutorManager::ExecutorType::HCCL); + // build aicore task + domi::ModelTaskDef model_task_def; + std::shared_ptr model_task_def_ptr = make_shared(model_task_def); + GeModelPtr ge_model = make_shared(); + ge_model->SetModelTaskDef(model_task_def_ptr); + + // No hccl task + domi::TaskDef *task_def = model_task_def_ptr->add_task(); + task_def->set_type(RT_MODEL_TASK_MEMCPY_ASYNC); + ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), SUCCESS); + + // get executor failed due to no builder + task_def = model_task_def_ptr->add_task(); + task_def->set_type(RT_MODEL_TASK_HCCL); + ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), INTERNAL_ERROR); + + // get executor success + REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HCCL, NodeExecutor); + ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), SUCCESS); + + // repeat get, do not access builder + NodeExecutorManager::GetInstance().builders_.erase(NodeExecutorManager::ExecutorType::HCCL); + ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), SUCCESS); +} + +TEST_F(UtestHybridModelBuilder, copy_graph_success) { +ComputeGraphPtr graph = std::make_shared("test"); +GeRootModelPtr ge_root_model = make_shared(graph); +HybridModel hybrid_model(ge_root_model); +HybridModelBuilder hybrid_model_builder(hybrid_model); + +Status st = hybrid_model_builder.CopyGraph(); +EXPECT_EQ(st, SUCCESS); +} } // namespace ge diff --git a/tests/ut/ge/hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc new file mode 100644 index 00000000..034b3f47 --- /dev/null +++ b/tests/ut/ge/hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc @@ -0,0 +1,389 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#define private public +#define protected public +#include "graph/runtime_inference_context.h" +#include "aicpu/common/aicpu_task_struct.h" +#include "hybrid/executor/subgraph_context.h" +#include "hybrid/node_executor/aicpu/aicpu_node_executor.h" +#undef protected +#undef private +#include "tests/depends/runtime/src/runtime_stub.h" +using namespace std; +using namespace testing; + +namespace { +struct AicpuTaskStruct { + aicpu::AicpuParamHead head; + uint64_t io_addrp[6]; +}__attribute__((packed)); +} // namespace + +namespace ge { +using namespace hybrid; + +class UtestAicpuNodeExecutor : public testing::Test { + protected: + void SetUp() { + RTS_STUB_SETUP(); + } + void TearDown() { + RTS_STUB_TEARDOWN(); + } +}; + +static NodePtr CreateNode(ComputeGraphPtr graph, const string &name, const string &type, int in_num, int out_num) { + OpDescPtr op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); + TensorUtils::SetSize(tensor, 64); + vector input_offset; + for (int i = 0; i < in_num; i++) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(i * 64); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; i++) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(in_num * 64 + i * 64); + } + op_desc->SetOutputOffset(output_offset); + + return graph->AddNode(op_desc); +} + +TEST_F(UtestAicpuNodeExecutor, aicpu_tf_node_task) { + ComputeGraphPtr graph = std::make_shared("test"); + GeModelPtr ge_sub_model = std::make_shared(); + GeRootModelPtr ge_root_model = std::make_shared(graph); + ge_root_model->SetModelName("test_name"); + ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); + HybridModel hybrid_model(ge_root_model); + + NodePtr node = CreateNode(graph, "frameworkop", FRAMEWORK_OP_TYPE, 4, 2); + + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + hybrid_model.node_items_[node] = std::move(new_node); + node_item->input_start = 0; + node_item->output_start = 0; + node_item->is_dynamic = true; + node_item->shape_inference_type = DEPEND_COMPUTE; + + GraphItem graph_item; + graph_item.node_items_.emplace_back(node_item); + graph_item.total_inputs_ = 4; + graph_item.total_outputs_ = 2; + + GraphExecutionContext graph_context; + SubgraphContext subgraph_context(&graph_item, &graph_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + graph_context.callback_manager = std::unique_ptr(new CallbackManager()); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + + for (int i=0; i<4; ++i) { + uint64_t value_0 = 512; + TensorValue in_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetInput(*node_item, 0, in_tensor0); + } + + uint64_t value_0 = 512; + TensorValue out_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetOutput(*node_item, 0, out_tensor0); + + uint64_t value_1 = 512; + TensorValue out_tensor1(&value_1, sizeof(value_1)); + subgraph_context.SetOutput(*node_item, 1, out_tensor1); + + // task + domi::TaskDef task_def; + domi::KernelExDef *kernel_ex_def = task_def.mutable_kernel_ex(); + kernel_ex_def->set_kernel_ext_info_size(12); + + AicpuExtInfo aicpu_ext_info; + aicpu_ext_info.infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_SHAPE_TYPE; + aicpu_ext_info.infoLen = sizeof(int32_t); + int32_t type = node_item->shape_inference_type; + memcpy_s(aicpu_ext_info.infoMsg, sizeof(int32_t), &type, sizeof(int32_t)); + char *ext_mem = (char*)malloc(sizeof(AicpuExtInfo) + sizeof(int32_t)); + memcpy_s(ext_mem, sizeof(AicpuExtInfo) + sizeof(int32_t), &aicpu_ext_info, sizeof(AicpuExtInfo) + sizeof(int32_t)); + std::string ext_info(ext_mem, sizeof(AicpuExtInfo) + sizeof(int32_t)); + + std::string *mutable_ext_info = kernel_ex_def->mutable_kernel_ext_info(); + (*mutable_ext_info) = ext_info; + + hybrid_model.task_defs_[node] = std::vector({task_def, task_def}); + + AicpuTfNodeTask aicpu_tf_node_task(node_item, task_def); + + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), SUCCESS); + ASSERT_EQ(aicpu_tf_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); + + AicpuTaskStruct args; + args.head.length = sizeof(args); + args.head.ioAddrNum = 6; + + domi::TaskDef task_def2; + task_def2.set_type(RT_MODEL_TASK_ALL_KERNEL); + task_def2.mutable_kernel()->set_args(reinterpret_cast(&args), args.head.length); + task_def2.mutable_kernel()->set_args_size(args.head.length); + + hybrid_model.task_defs_[node] = std::vector({task_def2}); + + AicpuNodeTask aicpu_node_task(node_item, task_def); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), FAILED); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); + + + //kernel_ex_def->set_allocated_kernel_ext_info(nullptr); + + free(ext_mem); + +} + +TEST_F(UtestAicpuNodeExecutor, aicpu_blocking_node_task) { + ComputeGraphPtr graph = std::make_shared("test"); + GeRootModelPtr ge_root_model = std::make_shared(graph); + ge_root_model->SetModelName("test_name"); + HybridModel hybrid_model(ge_root_model); + + NodePtr node = CreateNode(graph, "deque", FRAMEWORK_OP_TYPE, 1, 1); + ge::AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_IS_BLOCKING_OP, true); + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + node_item->input_start = 0; + node_item->output_start = 0; + node_item->is_dynamic = true; + node_item->shape_inference_type = DEPEND_SHAPE_RANGE; + + GraphItem graph_item; + graph_item.node_items_.emplace_back(node_item); + graph_item.total_inputs_ = 1; + graph_item.total_outputs_ = 1; + + GraphExecutionContext graph_execution_context; + SubgraphContext subgraph_context(&graph_item, &graph_execution_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + graph_execution_context.callback_manager = std::unique_ptr(new CallbackManager()); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + + uint64_t value_0 = 512; + TensorValue in_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetInput(*node_item, 0, in_tensor0); + + TensorValue out_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetOutput(*node_item, 0, out_tensor0); + + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + domi::TaskDef task_def; + + AicpuTaskStruct args; + args.head.length = sizeof(args); + args.head.ioAddrNum = 2; + + kernel_def.set_args(reinterpret_cast(&args), args.head.length); + kernel_def.set_args_size(args.head.length); + domi::KernelDef *kernel_def_tmp = task_def.mutable_kernel(); + *kernel_def_tmp = kernel_def; + + AicpuNodeTask aicpu_node_task(node_item, task_def); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), SUCCESS); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); + + node_item->shape_inference_type = DEPEND_COMPUTE; + domi::KernelExDef kernel_ex_def; + kernel_ex_def.set_kernel_ext_info(buf, len); + kernel_ex_def.set_kernel_ext_info_size(len); + kernel_ex_def.set_args(reinterpret_cast(&args), args.head.length); + kernel_ex_def.set_args_size(args.head.length); + domi::KernelExDef *kernel_ex_def_tmp = task_def.mutable_kernel_ex(); + *kernel_ex_def_tmp = kernel_ex_def; + hybrid_model.task_defs_[node] = std::vector({task_def, task_def}); + + AicpuTfNodeTask aicpu_tf_node_task(node_item, task_def); + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), SUCCESS); + ASSERT_EQ(aicpu_tf_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); +} + +TEST_F(UtestAicpuNodeExecutor, aicpu_blocking_node_task_fail) { + ComputeGraphPtr graph = std::make_shared("test"); + GeRootModelPtr ge_root_model = std::make_shared(graph); + ge_root_model->SetModelName("test_name"); + HybridModel hybrid_model(ge_root_model); + + NodePtr node = CreateNode(graph, "deque", FRAMEWORK_OP_TYPE, 1, 1); + ge::AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_IS_BLOCKING_OP, true); + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + node_item->input_start = 0; + node_item->output_start = 0; + node_item->is_dynamic = true; + node_item->shape_inference_type = DEPEND_SHAPE_RANGE; + + GraphItem graph_item; + graph_item.node_items_.emplace_back(node_item); + graph_item.total_inputs_ = 1; + graph_item.total_outputs_ = 1; + + GraphExecutionContext graph_execution_context; + SubgraphContext subgraph_context(&graph_item, &graph_execution_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + graph_execution_context.callback_manager = std::unique_ptr(new CallbackManager()); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + + uint64_t value_0 = 512; + TensorValue in_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetInput(*node_item, 0, in_tensor0); + + TensorValue out_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetOutput(*node_item, 0, out_tensor0); + + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + domi::TaskDef task_def; + + AicpuTaskStruct args; + args.head.length = sizeof(args); + args.head.ioAddrNum = 2; + + kernel_def.set_args(reinterpret_cast(&args), args.head.length); + kernel_def.set_args_size(args.head.length); + domi::KernelDef *kernel_def_tmp = task_def.mutable_kernel(); + *kernel_def_tmp = kernel_def; + + AicpuNodeTask aicpu_node_task(node_item, task_def); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT + 1); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), FAILED); + + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), FAILED); + + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtEventReset, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); + + node_item->shape_inference_type = DEPEND_COMPUTE; + domi::KernelExDef kernel_ex_def; + kernel_ex_def.set_kernel_ext_info(buf, len); + kernel_ex_def.set_kernel_ext_info_size(len); + kernel_ex_def.set_args(reinterpret_cast(&args), args.head.length); + kernel_ex_def.set_args_size(args.head.length); + domi::KernelExDef *kernel_ex_def_tmp = task_def.mutable_kernel_ex(); + *kernel_ex_def_tmp = kernel_ex_def; + hybrid_model.task_defs_[node] = std::vector({task_def, task_def}); + + AicpuTfNodeTask aicpu_tf_node_task(node_item, task_def); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT + 1); + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), FAILED); + + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_tf_node_task.LaunchTask(*node_state->GetTaskContext()), FAILED); + + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtEventReset, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_tf_node_task.LaunchTask(*node_state->GetTaskContext()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(aicpu_tf_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(aicpu_tf_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); +} +} // namespace ge + diff --git a/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc index a7a407a4..53b28762 100644 --- a/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc @@ -22,7 +22,7 @@ #define protected public #include "hybrid/executor/subgraph_context.h" #include "hybrid/node_executor/ge_local/ge_local_node_executor.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" #undef protected #undef private @@ -97,11 +97,6 @@ TEST_F(UtestGeLocalNodeExecutor, test_no_op_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - NodeTaskPtr task = nullptr; GeLocalNodeExecutor node_executor; ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); diff --git a/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc index afaf067e..8e6630f6 100644 --- a/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc @@ -94,18 +94,17 @@ TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { tensor.SetData(data); ctx->SetTensor(1, 0, tensor.Clone()); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); vector addr_infos; shared_ptr task = MakeShared(); task->remote_index_ = {1, 0}; - ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); + ASSERT_EQ(task->ExtractTensor(*node_state->GetTaskContext(), addr_infos), PARAM_INVALID); Shape s2({1}); TensorDesc tensor_desc2(s2); Tensor tensor2(tensor_desc2); ctx->SetTensor(1, 0, tensor2.Clone()); - task->ExtractTensor(*unique_task_context, addr_infos); - ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); + task->ExtractTensor(*node_state->GetTaskContext(), addr_infos); + ASSERT_EQ(task->ExtractTensor(*node_state->GetTaskContext(), addr_infos), PARAM_INVALID); RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); } @@ -140,11 +139,6 @@ TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - for (int i=0; i<4; ++i) { uint64_t value_0 = 512; TensorValue in_tensor0(&value_0, sizeof(value_0)); @@ -206,11 +200,6 @@ TEST_F(UtestHcclNodeExecutor, alltoallv_execute) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - for (int i=0; i<5; ++i) { uint64_t value_0 = 512; TensorValue in_tensor0(&value_0, sizeof(value_0)); diff --git a/tests/ut/ge/hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc b/tests/ut/ge/hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc index b113fa9b..bb134175 100644 --- a/tests/ut/ge/hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc @@ -22,7 +22,7 @@ #define protected public #include "hybrid/executor/subgraph_context.h" #include "hybrid/node_executor/host_cpu/host_cpu_node_executor.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" #include "graph/passes/graph_builder_utils.h" #include "aicpu/common/aicpu_task_struct.h" #include "graph/manager/graph_mem_manager.h" diff --git a/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc new file mode 100644 index 00000000..1d5bbb3d --- /dev/null +++ b/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#define private public +#define protected public +#include "hybrid/node_executor/node_executor.h" +#undef protected +#undef private + +using namespace std; +using namespace testing; + +namespace ge { +using namespace hybrid; + +namespace { + bool finalized = false; +} + +class NodeExecutorTest : public testing::Test { + protected: + void SetUp() {} + void TearDown() { } +}; + +class FailureNodeExecutor : public NodeExecutor { + public: + Status Initialize() override { + return INTERNAL_ERROR; + } +}; + +class SuccessNodeExecutor : public NodeExecutor { + public: + Status Initialize() override { + initialized = true; + finalized = false; + return SUCCESS; + } + + Status Finalize() override { + finalized = true; + } + + bool initialized = false; +}; + +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICORE, FailureNodeExecutor); +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_TF, SuccessNodeExecutor); + +TEST_F(NodeExecutorTest, TestGetOrCreateExecutor) { + auto &manager = NodeExecutorManager::GetInstance(); + const NodeExecutor *executor = nullptr; + Status ret = SUCCESS; + // no builder + ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::RESERVED, &executor); + ASSERT_EQ(ret, INTERNAL_ERROR); + // initialize failure + ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICORE, &executor); + ASSERT_EQ(ret, INTERNAL_ERROR); + ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); + ASSERT_EQ(ret, SUCCESS); + ASSERT_TRUE(executor != nullptr); + ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); + ASSERT_EQ(ret, SUCCESS); + ASSERT_TRUE(executor != nullptr); + ASSERT_TRUE(((SuccessNodeExecutor*)executor)->initialized); +} + +TEST_F(NodeExecutorTest, TestInitAndFinalize) { + auto &manager = NodeExecutorManager::GetInstance(); + manager.FinalizeExecutors(); + manager.FinalizeExecutors(); + manager.EnsureInitialized(); + manager.EnsureInitialized(); + const NodeExecutor *executor = nullptr; + auto ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); + ASSERT_EQ(ret, SUCCESS); + ASSERT_TRUE(executor != nullptr); + ASSERT_TRUE(((SuccessNodeExecutor*)executor)->initialized); + manager.FinalizeExecutors(); + ASSERT_FALSE(manager.executors_.empty()); + manager.FinalizeExecutors(); + ASSERT_TRUE(manager.executors_.empty()); + ASSERT_TRUE(finalized); +} +} // namespace ge diff --git a/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc b/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc index 44b2f37f..d21ae1e0 100644 --- a/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc @@ -22,7 +22,7 @@ #define protected public #include "hybrid/executor/subgraph_context.h" #include "hybrid/node_executor/rts/rts_node_executor.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" using namespace std; using namespace testing; @@ -96,11 +96,6 @@ TEST_F(UtestRtsNodeTask, test_stream_switch_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - uint64_t value_0 = 110; uint64_t value_1 = 120; TensorValue in_tensor0(&value_0, sizeof(value_0)); @@ -153,11 +148,6 @@ TEST_F(UtestRtsNodeTask, test_stream_active_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - NodeTaskPtr task = nullptr; RtsNodeExecutor node_executor; ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); @@ -203,11 +193,6 @@ TEST_F(UtestRtsNodeTask, test_stream_merge_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - uint64_t value_0 = 110; TensorValue in_tensor0(&value_0, sizeof(value_0)); subgraph_context.SetInput(*node_item, 0, in_tensor0); @@ -271,11 +256,6 @@ TEST_F(UtestRtsNodeTask, test_memcpy_async_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - uint64_t value_0 = 110; TensorValue in_tensor0(&value_0, sizeof(value_0)); subgraph_context.SetInput(*node_item, 0, in_tensor0); @@ -328,11 +308,6 @@ TEST_F(UtestRtsNodeTask, test_pass_through_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - uint64_t value_0 = 110; TensorValue in_tensor0(&value_0, sizeof(value_0)); subgraph_context.SetInput(*node_item, 0, in_tensor0); @@ -384,11 +359,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_set) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - NodeTaskPtr task = nullptr; RtsNodeExecutor node_executor; ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); @@ -428,11 +398,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_goto) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - NodeTaskPtr task = nullptr; RtsNodeExecutor node_executor; ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); @@ -472,11 +437,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_switch) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - NodeTaskPtr task = nullptr; RtsNodeExecutor node_executor; ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); diff --git a/tests/ut/ge/profiling/ge_profiling_manager_unittest.cc b/tests/ut/ge/profiling/ge_profiling_manager_unittest.cc index 9c615317..35879df8 100644 --- a/tests/ut/ge/profiling/ge_profiling_manager_unittest.cc +++ b/tests/ut/ge/profiling/ge_profiling_manager_unittest.cc @@ -21,10 +21,16 @@ #include #include +#include "graph/load/model_manager/davinci_model.h" + #define protected public #define private public #include "common/profiling/profiling_manager.h" #include "graph/ge_local_context.h" +#include "inc/framework/common/profiling/ge_profiling.h" +#include "graph/manager/graph_manager.h" +#include "graph/ops_stub.h" +#include "inc/framework/omg/omg_inner_types.h" #undef protected #undef private @@ -42,6 +48,23 @@ int32_t ReporterCallback(uint32_t moduleId, uint32_t type, void *data, uint32_t return -1; } +void CreateGraph(Graph &graph) { + TensorDesc desc(ge::Shape({1, 3, 224, 224})); + uint32_t size = desc.GetShape().GetShapeSize(); + desc.SetSize(size); + auto data = op::Data("Data").set_attr_index(0); + data.update_input_desc_data(desc); + data.update_output_desc_out(desc); + + auto flatten = op::Flatten("Flatten").set_input_x(data, data.name_out_out()); + + std::vector inputs{data}; + std::vector outputs{flatten}; + std::vector targets{flatten}; + // Graph graph("test_graph"); + graph.SetInputs(inputs).SetOutputs(outputs).SetTargets(targets); +} + TEST_F(UtestGeProfilinganager, init_success) { setenv("PROFILING_MODE", "true", true); Options options; @@ -115,4 +138,111 @@ TEST_F(UtestGeProfilinganager, get_fp_bp_point_empty) { ProfilingManager::Instance().GetFpBpPoint(fp_point, bp_point); EXPECT_EQ(fp_point, ""); EXPECT_EQ(bp_point, ""); -} \ No newline at end of file +} + +TEST_F(UtestGeProfilinganager, set_step_info_success) { + uint64_t index_id = 0; + auto stream = (rtStream_t)0x1; + Status ret = ProfSetStepInfo(index_id, 0, stream); + EXPECT_EQ(ret, ge::SUCCESS); + ret = ProfSetStepInfo(index_id, 1, stream); + EXPECT_EQ(ret, ge::SUCCESS); +} + +TEST_F(UtestGeProfilinganager, set_step_info_failed) { + uint64_t index_id = 0; + auto stream = (rtStream_t)0x1; + Status ret = ProfSetStepInfo(index_id, 1, stream); + EXPECT_EQ(ret, ge::FAILED); +} + +TEST_F(UtestGeProfilinganager, get_device_from_graph) { + GraphId graph_id = 1; + uint32_t device_id = 0; + GraphManager graph_manager; + GraphNodePtr graph_node = MakeShared(graph_id); + graph_manager.AddGraphNode(graph_id, graph_node); + graph_manager.SetAddGraphCondition(graph_id, 2); + Graph graph("test_graph"); + CreateGraph(graph); + std::map options; + OmgContext context; + Status ret = graph_manager.AddGraph(graph_id, graph, options, context); + EXPECT_EQ(ret, ge::SUCCESS); + ret = ProfGetDeviceFormGraphId(graph_id, device_id); + EXPECT_EQ(ret, ge::SUCCESS); +} + +TEST_F(UtestGeProfilinganager, handle_subscribe_info) { + ProfCommandHandleType prof_type = kProfCommandhandleModelSubscribe; + ProfCommandHandleData prof_data; + prof_data.profSwitch = 0; + prof_data.modelId = 1; + domi::GetContext().train_flag = true; + auto prof_ptr = std::make_shared(prof_data); + Status ret = ProfCommandHandle(prof_type, static_cast(prof_ptr.get()), sizeof(prof_data)); + EXPECT_EQ(ret, ge::SUCCESS); +} + +TEST_F(UtestGeProfilinganager, handle_unsubscribe_info) { + ProfCommandHandleType prof_type = kProfCommandhandleModelUnsubscribe; + ProfCommandHandleData prof_data; + prof_data.profSwitch = 0; + prof_data.modelId = 1; + domi::GetContext().train_flag = true; + auto &profiling_manager = ge::ProfilingManager::Instance(); + profiling_manager.SetSubscribeInfo(0, 1, true); + auto prof_ptr = std::make_shared(prof_data); + Status ret = ProfCommandHandle(prof_type, static_cast(prof_ptr.get()), sizeof(prof_data)); + profiling_manager.CleanSubscribeInfo(); +} + +TEST_F(UtestGeProfilinganager, set_subscribe_info) { + auto &profiling_manager = ge::ProfilingManager::Instance(); + profiling_manager.SetSubscribeInfo(0, 1, true); + const auto &subInfo = profiling_manager.GetSubscribeInfo(); + EXPECT_EQ(subInfo.prof_switch, 0); + EXPECT_EQ(subInfo.graph_id, 1); + EXPECT_EQ(subInfo.is_subscribe, true); +} + +TEST_F(UtestGeProfilinganager, clean_subscribe_info) { + auto &profiling_manager = ge::ProfilingManager::Instance(); + profiling_manager.CleanSubscribeInfo(); + const auto &subInfo = profiling_manager.GetSubscribeInfo(); + EXPECT_EQ(subInfo.prof_switch, 0); + EXPECT_EQ(subInfo.graph_id, 0); + EXPECT_EQ(subInfo.is_subscribe, false); +} + +TEST_F(UtestGeProfilinganager, get_model_id_success) { + auto &profiling_manager = ge::ProfilingManager::Instance(); + profiling_manager.SetGraphIdToModelMap(0, 1); + uint32_t model_id = 0; + Status ret = profiling_manager.GetModelIdFromGraph(0, model_id); + EXPECT_EQ(ret, ge::SUCCESS); +} + +TEST_F(UtestGeProfilinganager, get_model_id_failed) { + auto &profiling_manager = ge::ProfilingManager::Instance(); + profiling_manager.SetGraphIdToModelMap(0, 1); + uint32_t model_id = 0; + Status ret = profiling_manager.GetModelIdFromGraph(10, model_id); + EXPECT_EQ(ret, ge::FAILED); +} + +TEST_F(UtestGeProfilinganager, get_device_id_success) { + auto &profiling_manager = ge::ProfilingManager::Instance(); + profiling_manager.SetGraphIdToDeviceMap(0, 1); + uint32_t device_id = 0; + Status ret = profiling_manager.GetDeviceIdFromGraph(0, device_id); + EXPECT_EQ(ret, ge::SUCCESS); +} + +TEST_F(UtestGeProfilinganager, get_device_id_failed) { + auto &profiling_manager = ge::ProfilingManager::Instance(); + profiling_manager.SetGraphIdToDeviceMap(0, 1); + uint32_t device_id = 0; + Status ret = profiling_manager.GetDeviceIdFromGraph(10, device_id); + EXPECT_EQ(ret, ge::FAILED); +} diff --git a/tests/ut/ge/session/ge_api_unittest.cc b/tests/ut/ge/session/ge_api_unittest.cc index 2cabc4a3..93e6a52c 100644 --- a/tests/ut/ge/session/ge_api_unittest.cc +++ b/tests/ut/ge/session/ge_api_unittest.cc @@ -26,8 +26,6 @@ #include "proto/ge_ir.pb.h" #include "inc/external/ge/ge_api.h" #include "session/session_manager.h" -#undef protected -#undef private using namespace std; @@ -64,11 +62,121 @@ TEST_F(UtestGeApi, build_graph_success) { ASSERT_NE(ret, SUCCESS); } -TEST_F(UtestGeApi, ge_initialize) { +TEST_F(UtestGeApi, ge_initialize_modify_mixlist) { std::map options = { {ge::MODIFY_MIXLIST, "/mixlist.json"} }; auto ret = GEInitialize(options); ASSERT_NE(ret, SUCCESS); } + +TEST_F(UtestGeApi, ge_not_initialized) { + EXPECT_EQ(GEFinalize(), SUCCESS); + + std::map options; + std::map ascend_options; + Session session(options); + + GraphId graph_id = 1; + const auto compute_graph = MakeShared("test_graph"); + Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); + + EXPECT_EQ(session.AddGraph(graph_id, graph), FAILED); + EXPECT_EQ(session.AddGraph(graph_id, graph, ascend_options), FAILED); + + EXPECT_EQ(session.AddGraphWithCopy(graph_id, graph), FAILED); + EXPECT_EQ(session.AddGraphWithCopy(graph_id, graph, ascend_options), FAILED); + + vector inputs; + vector tensors; + EXPECT_EQ(session.BuildGraph(graph_id, inputs), FAILED); + EXPECT_EQ(session.BuildGraph(graph_id, tensors), FAILED); + + vector outputs; + EXPECT_EQ(session.RunGraph(graph_id, inputs, outputs), FAILED); + EXPECT_EQ(session.RunGraphWithStreamAsync(graph_id, nullptr, inputs, outputs), FAILED); + EXPECT_EQ(session.RunGraphAsync(graph_id, inputs, nullptr), FAILED); + + vector var_inputs; + EXPECT_EQ(session.GetVariables(var_inputs, outputs), FAILED); + + vector var_names; + EXPECT_EQ(session.GetVariables(var_names, outputs), FAILED); + + std::string key; + pCallBackFunc ge_callback; + EXPECT_EQ(session.RegisterCallBackFunc(key, ge_callback), FAILED); + + session::pCallBackFunc session_callback; + EXPECT_EQ(session.RegisterCallBackFunc(key.c_str(), session_callback), FAILED); + + EXPECT_FALSE(session.IsGraphNeedRebuild(graph_id)); + + EXPECT_EQ(session.RemoveGraph(graph_id), FAILED); + EXPECT_EQ(GEFinalize(), SUCCESS); +} + +TEST_F(UtestGeApi, ge_session_ascend_string) { + std::map options; + EXPECT_EQ(GEInitialize(options), SUCCESS); + + Session session(options); + + GraphId graph_id = 1; + const auto compute_graph = MakeShared("test_graph"); + EXPECT_EQ(session.AddGraph(graph_id, GraphUtils::CreateGraphFromComputeGraph(compute_graph)), SUCCESS); + + EXPECT_TRUE(session.IsGraphNeedRebuild(graph_id)); + + EXPECT_EQ(session.RemoveGraph(graph_id), SUCCESS); + + EXPECT_EQ(GEFinalize(), SUCCESS); +} + +TEST_F(UtestGeApi, ge_session_test) { + std::map options; + EXPECT_EQ(GEInitialize(options), SUCCESS); + + std::map ascend_options; + Session session(options); + + GraphId graph_id = 1; + const auto compute_graph = MakeShared("test_graph"); + Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); + + EXPECT_EQ(session.AddGraph(graph_id, graph), SUCCESS); + EXPECT_EQ(session.AddGraph(graph_id, graph, ascend_options), SUCCESS); + + EXPECT_EQ(session.AddGraphWithCopy(graph_id, graph), FAILED); + EXPECT_EQ(session.AddGraphWithCopy(graph_id, graph, ascend_options), FAILED); + + vector inputs; + vector tensors; + EXPECT_EQ(session.BuildGraph(graph_id, inputs), FAILED); + EXPECT_EQ(session.BuildGraph(graph_id, tensors), FAILED); + + vector outputs; + EXPECT_EQ(session.RunGraph(graph_id, inputs, outputs), FAILED); + EXPECT_EQ(session.RunGraphWithStreamAsync(graph_id, nullptr, inputs, outputs), FAILED); + EXPECT_EQ(session.RunGraphAsync(graph_id, inputs, nullptr), SUCCESS); // Push to queue. + + vector var_inputs; + EXPECT_EQ(session.GetVariables(var_inputs, outputs), FAILED); + + vector var_names; + EXPECT_EQ(session.GetVariables(var_names, outputs), FAILED); + + std::string key; + pCallBackFunc ge_callback; + EXPECT_EQ(session.RegisterCallBackFunc(key, ge_callback), SUCCESS); + + session::pCallBackFunc session_callback; + EXPECT_EQ(session.RegisterCallBackFunc(key.c_str(), session_callback), SUCCESS); + + EXPECT_TRUE(session.IsGraphNeedRebuild(graph_id)); + + EXPECT_EQ(session.RemoveGraph(graph_id), SUCCESS); + EXPECT_EQ(GEFinalize(), SUCCESS); +} + } // namespace ge diff --git a/tests/ut/ge/session/inner_session_unittest.cc b/tests/ut/ge/session/inner_session_unittest.cc index ecad56d6..80cc2834 100644 --- a/tests/ut/ge/session/inner_session_unittest.cc +++ b/tests/ut/ge/session/inner_session_unittest.cc @@ -19,21 +19,18 @@ #define private public #define protected public #include "session/inner_session.h" -#undef private -#undef protected - using namespace std; namespace ge { -class Utest_Inner_session : public testing::Test { +class UtestInnerSession : public testing::Test { protected: void SetUp() override {} void TearDown() override {} }; -TEST_F(Utest_Inner_session, build_graph_success) { +TEST_F(UtestInnerSession, build_graph_success) { std::map options; uint64_t session_id = 1; InnerSession inner_seesion(session_id, options); @@ -44,9 +41,17 @@ TEST_F(Utest_Inner_session, build_graph_success) { EXPECT_NE(ret, ge::SUCCESS); } -TEST_F(Utest_Inner_session, initialize) { +TEST_F(UtestInnerSession, initialize) { + std::map options = {}; + uint64_t session_id = 1; + InnerSession inner_session(session_id, options); + EXPECT_EQ(inner_session.Initialize(), SUCCESS); + EXPECT_EQ(inner_session.Finalize(), SUCCESS); +} + +TEST_F(UtestInnerSession, check_op_precision_mode) { std::map options = { - {ge::MODIFY_MIXLIST, "/modify.json"} + {ge::OP_PRECISION_MODE, "./op_precision_mode.ini"} }; uint64_t session_id = 1; InnerSession inner_session(session_id, options); diff --git a/tests/ut/ge/single_op/single_op_model_unittest.cc b/tests/ut/ge/single_op/single_op_model_unittest.cc index a2c1cb02..7b7a05d8 100644 --- a/tests/ut/ge/single_op/single_op_model_unittest.cc +++ b/tests/ut/ge/single_op/single_op_model_unittest.cc @@ -17,12 +17,11 @@ #include #include +#define protected public +#define private public #include "graph/load/model_manager/model_utils.h" #include "graph/utils/graph_utils.h" #include "runtime/rt.h" - -#define protected public -#define private public #include "single_op/single_op_model.h" #include "single_op/task/tbe_task_builder.h" #include "single_op/task/rts_kernel_task_builder.h" @@ -30,14 +29,22 @@ #include "framework/common/helper/model_helper.h" #include "single_op/single_op.h" #include "single_op/stream_resource.h" +#include "graph/passes/graph_builder_utils.h" +#include "graph/op_desc_impl.h" #undef private #undef protected -#include "graph/passes/graph_builder_utils.h" using namespace std; using namespace testing; using namespace ge; +namespace { +constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; +const char *const kEngineNameAiCore = "AIcoreEngine"; +const char *const kEngineNameAiCpu = "aicpu_ascend_kernel"; +const char *const kEngineNameAiCpuTf = "aicpu_tf_kernel"; +} // namespace + class UtestSingleOpModel : public testing::Test { protected: void SetUp() {} @@ -208,11 +215,22 @@ TEST_F(UtestSingleOpModel, test_build_dynamic_op) { model.model_helper_.model_ = ge::MakeShared(); // make graph - auto compute_graph = make_shared("graph"); - auto data_op = make_shared("Data", DATA); - auto data_node = compute_graph->AddNode(data_op); + ut::GraphBuilder builder = ut::GraphBuilder("graph"); + auto data = builder.AddNode("Data", "Data", 1, 1); + auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); + auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); + builder.AddDataEdge(data, 0, transdata, 0); + builder.AddDataEdge(transdata, 0, netoutput, 0); + auto compute_graph = builder.GetGraph(); + auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); model.model_helper_.model_->SetGraph(graph); + model.op_list_[0] = transdata; + + auto op_desc = transdata->GetOpDesc(); + const vector depend_names = { "Data" }; + op_desc->SetOpInferDepends(depend_names); + (void)AttrUtils::SetBool(op_desc, kAttrSupportDynamicShape, true); // set task_def auto model_task_def = make_shared(); @@ -227,6 +245,15 @@ TEST_F(UtestSingleOpModel, test_build_dynamic_op) { DynamicSingleOp dynamic_single_op(0, &stream_mu_, nullptr); StreamResource res((uintptr_t)1); model.BuildDynamicOp(res, dynamic_single_op); + + op_desc->impl_->input_name_idx_["Data"] = 0; + model.BuildDynamicOp(res, dynamic_single_op); + + auto tensor = std::make_shared(); + auto data_desc = data->GetOpDesc(); + auto tensor_desc = data_desc->MutableInputDesc(0); + AttrUtils::SetTensor(tensor_desc, "_value", tensor); + model.BuildDynamicOp(res, dynamic_single_op); } TEST_F(UtestSingleOpModel, test_host_mem) { @@ -287,3 +314,55 @@ TEST_F(UtestSingleOpModel, BuildTaskList) { MemcpyAsyncTask mem_task; ASSERT_EQ(mem_task.LaunchKernel(0), SUCCESS); } + +TEST_F(UtestSingleOpModel, build_dynamic_task) { + ComputeGraphPtr graph = make_shared("single_op"); + GeModelPtr ge_model = make_shared(); + ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); + shared_ptr model_task_def = make_shared(); + ge_model->SetModelTaskDef(model_task_def); + + domi::TaskDef *task_def = model_task_def->add_task(); + task_def->set_type(RT_MODEL_TASK_KERNEL_EX); + + domi::TaskDef *task_def2 = model_task_def->add_task(); + task_def2->set_type(RT_MODEL_TASK_KERNEL); + domi::KernelDef *kernel_def = task_def2->mutable_kernel(); + domi::KernelContext *context = kernel_def->mutable_context(); + context->set_kernel_type(6); // ccKernelType::AI_CPU + + domi::TaskDef *task_def3 = model_task_def->add_task(); + task_def3->set_type(RT_MODEL_TASK_ALL_KERNEL); + + domi::TaskDef *task_def4 = model_task_def->add_task(); + task_def4->set_type(RT_MODEL_TASK_KERNEL); + + string model_data_str = "dynamic_model"; + SingleOpModel model("model", model_data_str.c_str(), model_data_str.size()); + std::mutex stream_mu; + rtStream_t stream = nullptr; + rtStreamCreate(&stream, 0); + DynamicSingleOp single_op(0, &stream_mu, stream); + model.model_helper_.model_ = ge_model; + auto op_desc = std::make_shared("add", "Add"); + AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF"); + std::vector kernelBin; + TBEKernelPtr tbe_kernel = std::make_shared("name/Add", std::move(kernelBin)); + op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); + NodePtr node = graph->AddNode(op_desc); + model.op_list_[0] = node; + StreamResource *res = new (std::nothrow) StreamResource(1); + + ASSERT_EQ(model.ParseTasks(), SUCCESS); + model.node_tasks_[node] = { *task_def3, *task_def4 }; + op_desc->SetOpKernelLibName(kEngineNameAiCore); + model.BuildTaskListForDynamicOp(res, single_op); + + model.node_tasks_[node] = { *task_def }; + op_desc->SetOpKernelLibName(kEngineNameAiCpuTf); + ASSERT_EQ(model.BuildTaskListForDynamicOp(res, single_op), SUCCESS); + + model.node_tasks_[node] = { *task_def2 }; + op_desc->SetOpKernelLibName(kEngineNameAiCpu); + model.BuildTaskListForDynamicOp(res, single_op); +} diff --git a/tests/ut/ge/single_op/single_op_task_unittest.cc b/tests/ut/ge/single_op/single_op_task_unittest.cc index a17c9012..52091856 100644 --- a/tests/ut/ge/single_op/single_op_task_unittest.cc +++ b/tests/ut/ge/single_op/single_op_task_unittest.cc @@ -19,6 +19,7 @@ #include "graph/load/model_manager/model_utils.h" #include "graph/utils/graph_utils.h" +#include "hybrid/node_executor/aicpu/aicpu_ext_info.h" #include "runtime/rt.h" #define protected public @@ -30,6 +31,7 @@ #include "external/register/op_tiling_registry.h" #undef private #undef protected +#include "tests/depends/runtime/src/runtime_stub.h" using namespace std; using namespace testing; @@ -38,9 +40,13 @@ using namespace optiling; class UtestSingleOpTask : public testing::Test { protected: - void SetUp() {} + void SetUp() { + RTS_STUB_SETUP(); + } - void TearDown() {} + void TearDown() { + RTS_STUB_TEARDOWN(); + } }; TEST_F(UtestSingleOpTask, test_build_kernel_task) { @@ -54,6 +60,7 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) { auto graph = make_shared("graph"); auto op_desc = make_shared("Add", "Add"); + AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF"); std::vector kernelBin; TBEKernelPtr tbe_kernel = std::make_shared("name/Add", std::move(kernelBin)); op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); @@ -91,10 +98,11 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) { TbeOpTask task_tmp; TbeOpTask *task = &task_tmp; ASSERT_EQ(model.BuildKernelTask(task_def, &task), SUCCESS); + ge::DataBuffer data_buffer; vector input_desc; - vector input_buffers; + vector input_buffers = { data_buffer }; vector output_desc; - vector output_buffers; + vector output_buffers = { data_buffer }; task->node_ = node; OpTilingFunc op_tiling_func = [](const TeOpParas &, const OpCompileInfo &, OpRunInfo &) -> bool {return true;}; OpTilingRegistryInterf("Add", op_tiling_func); @@ -106,12 +114,253 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) { task->max_tiling_size_ = 64; task->tiling_data_ = "tiling_data"; task->arg_size_ = 64; - uint8_t task_args{0}; - task->args_.reset(&task_args); + task->args_.reset(new (std::nothrow) uint8_t[sizeof(void *) * 3]); ASSERT_EQ(task->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_), SUCCESS); - char handle_tmp = '0'; - char *handle = &handle_tmp; + char *handle = "00"; task->SetHandle(handle); ASSERT_EQ(task->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_), SUCCESS); -} \ No newline at end of file +} + +TEST_F(UtestSingleOpTask, test_update_ioaddr) { + auto graph = make_shared("graph"); + auto op_desc = make_shared("Add", "Add"); + + GeTensorDesc desc; + op_desc->AddInputDesc(desc); + op_desc->AddInputDesc(desc); + op_desc->AddOutputDesc(desc); + vector is_input_const = { true, false }; + op_desc->SetIsInputConst(is_input_const); + auto node = graph->AddNode(op_desc); + + TbeOpTask task; + task.op_desc_ = op_desc; + task.node_ = node; + ASSERT_EQ(task.SetArgIndex(), SUCCESS); + task.arg_size_ = sizeof(void *) * 4; + task.args_.reset(new (std::nothrow) uint8_t[task.arg_size_]); + task.arg_index_ = {0}; + task.input_num_ = 2; + task.output_num_ = 1; + + vector args; + vector inputs; + vector outputs; + ASSERT_EQ(task.UpdateIoAddr(inputs, outputs), ACL_ERROR_GE_PARAM_INVALID); + + ge::DataBuffer data_buffer; + inputs = { data_buffer }; + outputs = { data_buffer }; + ASSERT_EQ(task.UpdateIoAddr(inputs, outputs), SUCCESS); + + task.tiling_buffer_ = (void *)0x0001; + task.workspaces_ = { (void *)0x0002 }; + ASSERT_EQ(task.UpdateTilingArgs(nullptr), SUCCESS); + task.tiling_buffer_ = nullptr; +} + +TEST_F(UtestSingleOpTask, test_atomic_exec) { + auto graph = make_shared("graph"); + auto op_desc = make_shared("Add", "Add"); + GeTensorDesc desc; + op_desc->AddInputDesc(desc); + op_desc->AddOutputDesc(desc); + auto node = graph->AddNode(op_desc); + AtomicAddrCleanOpTask task; + task.op_desc_ = op_desc; + task.node_ = node; + + vector inputs; + vector outputs; + std::vector atomic_output_indices; + ge::AttrUtils::SetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); + ASSERT_EQ(task.InitAtomicAddrCleanIndices(), INTERNAL_ERROR); + atomic_output_indices = { 0 }; + ge::AttrUtils::SetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); + ASSERT_EQ(task.InitAtomicAddrCleanIndices(), INTERNAL_ERROR); + task.arg_size_ = sizeof(void *) * 2; + task.args_.reset(new (std::nothrow) uint8_t[task.arg_size_]); + ASSERT_EQ(task.InitAtomicAddrCleanIndices(), SUCCESS); + ASSERT_EQ(task.UpdateIoAddr(inputs, outputs), ACL_ERROR_GE_PARAM_INVALID); + + ge::DataBuffer data_buffer; + outputs = { data_buffer }; + ASSERT_EQ(task.UpdateIoAddr(inputs, outputs), SUCCESS); + + task.tiling_buffer_ = (void *)0x0001; + ASSERT_EQ(task.UpdateTilingArgs(nullptr), SUCCESS); + task.tiling_buffer_ = nullptr; + + optiling::utils::OpRunInfo run_info(0, true, 0); + task.CalcTilingInfo(run_info); +} + +TEST_F(UtestSingleOpTask, test_aicpu_task_update_io_addr) { + AiCpuCCTask task; + task.num_inputs_ = 2; + task.num_outputs_ = 1; + task.input_is_const_ = {true, false}; + int total_addr = 3; + uint32_t* addrs[total_addr] = {nullptr, nullptr, nullptr}; + task.io_addr_ = reinterpret_cast(addrs); + task.io_addr_num_ = total_addr; + + { + vector inputs(1, DataBuffer()); + vector outputs(1, DataBuffer()); + auto ret = task.UpdateIoAddr(inputs, outputs); + ASSERT_EQ(ret, SUCCESS); + ASSERT_EQ(addrs[0], nullptr); + ASSERT_EQ(addrs[1], nullptr); + ASSERT_EQ(addrs[2], nullptr); + } + + { + uint32_t data_buf[2]; + vector inputs{DataBuffer(&data_buf[0], 4, false)}; + vector outputs{DataBuffer(&data_buf[1], 4, false)}; + auto ret = task.UpdateIoAddr(inputs, outputs); + ASSERT_EQ(ret, SUCCESS); + ASSERT_EQ(addrs[0], nullptr); + ASSERT_EQ(addrs[1], &data_buf[0]); + ASSERT_EQ(addrs[2], &data_buf[1]); + } + + { + uint32_t data_buf[2]; + vector inputs{DataBuffer(nullptr, 4, false)}; + vector outputs{DataBuffer(&data_buf[1], 4, false)}; + auto ret = task.UpdateIoAddr(inputs, outputs); + ASSERT_EQ(ret, PARAM_INVALID); + } + + { + uint32_t data_buf[2]; + vector inputs{DataBuffer(&data_buf[0], 4, false)}; + vector outputs{DataBuffer(nullptr, 4, false)}; + auto ret = task.UpdateIoAddr(inputs, outputs); + ASSERT_EQ(ret, PARAM_INVALID); + } +} + +TEST_F(UtestSingleOpTask, test_blocking_aicpu_op_01) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + auto op_desc = make_shared("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + AiCpuCCTask aicpu_task; + aicpu_task.SetOpDesc(op_desc); + rtStream_t stream; + ASSERT_EQ(rtStreamCreate(&stream, 0), RT_ERROR_NONE); + + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), SUCCESS); +} + +TEST_F(UtestSingleOpTask, test_blocking_aicpu_op_02) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + auto op_desc = make_shared("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + AiCpuTask aicpu_task; + aicpu_task.SetOpDesc(op_desc); + rtStream_t stream; + ASSERT_EQ(rtStreamCreate(&stream, 0), RT_ERROR_NONE); + + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), SUCCESS); +} + +TEST_F(UtestSingleOpTask, test_blocking_aicpu_op_fail) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + auto op_desc = make_shared("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + AiCpuTask aicpu_task; + aicpu_task.SetOpDesc(op_desc); + rtStream_t stream; + ASSERT_EQ(rtStreamCreate(&stream, 0), RT_ERROR_NONE); + + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), SUCCESS); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT + 1); + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), FAILED); + + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + RTS_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), FAILED); + + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + RTS_STUB_RETURN_VALUE(rtEventReset, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(aicpu_task.LaunchKernel(stream), SUCCESS); +} diff --git a/tests/ut/ge/single_op/single_op_unittest.cc b/tests/ut/ge/single_op/single_op_unittest.cc index db3de7ec..181805ff 100644 --- a/tests/ut/ge/single_op/single_op_unittest.cc +++ b/tests/ut/ge/single_op/single_op_unittest.cc @@ -23,6 +23,7 @@ #define private public #include "single_op/single_op.h" #include "single_op/single_op_manager.h" +#include "single_op/task/build_task_utils.h" #undef private #undef protected @@ -103,7 +104,7 @@ TEST_F(UtestSingleOp, test_dynamic_singleop_execute_async1) { EXPECT_EQ(desc_ptr->AddInputDesc("x", GeTensorDesc(GeShape({2}), FORMAT_NCHW)), GRAPH_SUCCESS); dynamic_single_op.op_task_->op_desc_ = desc_ptr; // UpdateRunInfo failed - EXPECT_EQ(dynamic_single_op.ExecuteAsync(input_desc, input_buffers, output_desc, output_buffers), PARAM_INVALID); + EXPECT_EQ(dynamic_single_op.ExecuteAsync(input_desc, input_buffers, output_desc, output_buffers), ACL_ERROR_GE_PARAM_INVALID); } @@ -126,9 +127,19 @@ TEST_F(UtestSingleOp, test_singleop_execute_async1) { SingleOpModelParam model_params; single_op.running_param_.reset(new (std::nothrow)SingleOpModelParam(model_params)); single_op.args_.resize(1); + + auto *tbe_task = new (std::nothrow) TbeOpTask(); + ge::OpDescPtr op_desc = std::make_shared("Mul", MATMUL); + EXPECT_EQ(op_desc->AddInputDesc("x", GeTensorDesc(GeShape({2}), FORMAT_NCHW)), GRAPH_SUCCESS); + EXPECT_EQ(op_desc->AddOutputDesc("x", GeTensorDesc(GeShape({2}), FORMAT_NCHW)), GRAPH_SUCCESS); + EXPECT_NE(BuildTaskUtils::GetTaskInfo(op_desc), ""); + ge::ComputeGraphPtr graph = std::make_shared("default"); + ge::NodePtr node = graph->AddNode(op_desc); + tbe_task->node_ = node; + tbe_task->op_desc_ = op_desc; + single_op.tasks_.push_back(tbe_task); EXPECT_EQ(single_op.hybrid_model_executor_, nullptr); EXPECT_EQ(single_op.running_param_->mem_base, nullptr); - EXPECT_EQ(single_op.tasks_.size(), 0); EXPECT_EQ(single_op.ExecuteAsync(input_buffers, output_buffers), SUCCESS); } diff --git a/tests/ut/ge/single_op/stream_resource_unittest.cc b/tests/ut/ge/single_op/stream_resource_unittest.cc index e07fc39d..e4ab469e 100644 --- a/tests/ut/ge/single_op/stream_resource_unittest.cc +++ b/tests/ut/ge/single_op/stream_resource_unittest.cc @@ -66,6 +66,9 @@ TEST_F(UtestStreamResource, test_build_op) { res.op_map_[0].reset(single_op); res.dynamic_op_map_[1].reset(dynamic_single_op); + ThreadPool *thread_pool = nullptr; + EXPECT_EQ(res.GetThreadPool(&thread_pool), SUCCESS); + EXPECT_EQ(res.GetOperator(0), nullptr); EXPECT_EQ(res.GetDynamicOperator(1), nullptr); EXPECT_EQ(res.BuildOperator(model_data, &single_op, 0), SUCCESS); diff --git a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h index df57c82e..5733d68f 100644 --- a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h +++ b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h @@ -62,6 +62,7 @@ enum FWKTaskExtInfoType { FWK_ADPT_EXT_SESSION_INFO, FWK_ADPT_EXT_BITMAP, FWK_ADPT_EXT_TOPIC_TYPE, + FWK_ADPT_EXT_ASYNCWAIT, FWK_ADPT_EXT_INVALID }; @@ -80,6 +81,12 @@ enum FWKExtUpdateAddrType { FWK_ADPT_UPDATE_INPUT_OUTPUT }; +enum FWKExtWaitType { + FWK_ADPT_WAIT_TYPE_NULL = 0, + FWK_ADPT_WAIT_TYPE_EVENT, + FWK_ADPT_WAIT_TYPE_INVALID +}; + #pragma pack(push, 1) // API Parameter Structure struct StrFWKKernel { @@ -133,6 +140,15 @@ struct ResultSummary { uint64_t raw_data_size; // size of raw data }; #pragma pack(pop) + +#pragma pack(push, 1) +struct AsyncWait { + uint8_t waitType; // wait type, FWK_ADPT_WAIT_TYPE_EVENT: event wait + uint32_t waitId; // wait id, GE refresh + uint32_t timeOut; // reserved + uint64_t reserved; +}; +#pragma pack(pop) } // end namespace FWKAdapter } // namespace aicpu diff --git a/third_party/fwkacllib/inc/external/runtime/rt_error_codes.h b/third_party/fwkacllib/inc/external/runtime/rt_error_codes.h index 67146dbe..9f216a56 100644 --- a/third_party/fwkacllib/inc/external/runtime/rt_error_codes.h +++ b/third_party/fwkacllib/inc/external/runtime/rt_error_codes.h @@ -38,6 +38,7 @@ static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callba static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type +static const int32_t ACL_ERROR_RT_WAIT_TIMEOUT = 107019; // wait timeout static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error @@ -50,6 +51,7 @@ static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no eve static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource +static const int32_t ACL_ERROR_RT_NO_CDQ_RESOURCE = 207011; // no cdq resource static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error @@ -85,9 +87,14 @@ static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error +static const int32_t ACL_ERROR_RT_VECTOR_CORE_TIMEOUT = 507034; // vector core timeout +static const int32_t ACL_ERROR_RT_VECTOR_CORE_EXCEPTION = 507035; // vector core exception +static const int32_t ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION = 507036; // vector core trap exception +static const int32_t ACL_ERROR_RT_CDQ_BATCH_ABNORMAL = 507037; // cdq alloc batch abnormal static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error +static const int32_t ACL_ERROR_RT_SOCKET_CLOSE = 507901; // hdc disconnect #ifdef __cplusplus } diff --git a/third_party/fwkacllib/inc/ops/array_ops.h b/third_party/fwkacllib/inc/ops/array_ops.h index fd35b546..450c893e 100644 --- a/third_party/fwkacllib/inc/ops/array_ops.h +++ b/third_party/fwkacllib/inc/ops/array_ops.h @@ -35,7 +35,7 @@ namespace ge { * @li values:A `Tensor`. Must have the same type as `sorted_x`. \n *@par Attributes: -*@li out_type:An optional `DType` from: `int32, int64`. +*out_type:An optional `DType` from: `int32, int64`. Defaults to `int32`. \n *@par Outputs: @@ -504,7 +504,7 @@ REG_OP(Constant) *x: A tensor. \n *@par Outputs: -*y: A tensor. \n +*y: A copy of input tensor. \n *@par Third-party framework compatibility *Compatible with the TensorFlow operator Snapshot. @@ -684,7 +684,9 @@ REG_OP(ExpandDims) *@par Inputs: *@li x: Original tensor. -*@li axis: List of ints. \n + +*@par Attributes: +*@li axes: List of ints indicating the dimensions to be inserted. \n *@par Outputs: *y: Reshape tensor with same data as input. \n @@ -755,10 +757,10 @@ REG_OP(Squeeze) *@brief Returns an integer representing the rank of input tensor. The rank of a tensor is the number of indices required to uniquely select each element of the tensor, that is, the dimension size of the tensor. \n *@par Inputs: -*x: A tensor. \n +*x: A Tensor of type float32, float16, int8, int16, uint16, uint8, int32, int64, uint32, uint64, bool, double. \n *@par Outputs: -*y: A tensor. The rank of input tensor. \n +*y: A tensor. The rank of input tensor. Type is int32. \n *@par Third-party framework compatibility *Compatible with the TensorFlow operator Rank. @@ -848,7 +850,6 @@ REG_OP(PlaceHolder) *x: A tensor. \n *@par Attributes: -*@li dtype: data type of tensor. *@li shape: tensor shape. \n *@par Outputs: @@ -867,13 +868,13 @@ REG_OP(PlaceholderWithDefault) *@brief Reads and returns the value of the input variable tensor. \n *@par Inputs: -*x: A tensor. \n +*x: A tensor must have numeric type. \n *@par Attributes: *dtype: An optional int32 or int64. The output data type. Defaults to int32. \n *@par Outputs: -*y: A tensor. \n +*y: A tensor must have numeric type. \n *@par Third-party framework compatibility *Compatible with the TensorFlow operator ReadVariableOp. @@ -1134,10 +1135,10 @@ This is an M-length vector. This is an R-length vector *@par Attributes: -*@li normalize: boolean (if true, edit distances are normalized by length of truth). \n +*normalize: boolean (if true, edit distances are normalized by length of truth). \n *@par Outputs: -*@li output: A dense float tensor with rank R - 1. \n +*output: A dense float tensor with rank R - 1. \n *@par Third-party framework compatibility * Compatible with TensorFlow EditDistance operator. @@ -1154,18 +1155,17 @@ REG_OP(EditDistance) .OP_END_FACTORY_REG(EditDistance) /** -* @brief sort_v2. +* @brief sort the input tensor without returning the value of index. * @par Inputs: -* @li x: An ND tensor of type float16. +* x: An ND tensor of type float16. * @par Attributes: - * @li axis: An optional int. The dimension to sort along. This value defaults to -1. * @li descending: An optional bool. Controls the sorting order (ascending or descending). This value defaults to False. * @par Outputs: -* @li y: An ND tensor of type float16. +* y: An ND tensor of type float16. * @attention Constraints: * @li Axis should select the last dim. @@ -1206,7 +1206,7 @@ REG_OP(Expand) *@Returns a tensor containing the indices of all non-zero elements of input. \n *@par Inputs: -*@li x: A Tensor. Must be one of the following types: float16, float32, int32, int64. +*x: A Tensor. Must be one of the following types: float16, float32, int32, int64. *@par Attributes: * transpose: the output tensor will be transposed if true. \n @@ -1230,15 +1230,15 @@ REG_OP(NonZero) * @par Inputs: * One inputs, including: -* @li x: A Tensor. Must be one of the following types: +* x: A Tensor. Must be one of the following types: * float16, float32, int32, int8 ,uint8. \n * @par Attributes: -* @li shape: A required listInt to specify the shape that the input tensor expanded to. \n +* shape: A required listInt to specify the shape that the input tensor expanded to. \n * @par Outputs: -* @li y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n +* y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n * @par Third-party framework compatibility * Compatible with the ONNX operator Expand. @@ -1249,6 +1249,38 @@ REG_OP(ExpandD) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) .REQUIRED_ATTR(shape, ListInt) .OP_END_FACTORY_REG(ExpandD) + +/** +*@brief Finds unique elements in a 1D tensor. \n + +*@par Inputs: +*x: 1D tensor. Must be one of the following types: +* float16, float32, double, int64, int32, int16, uint16, int8 ,uint8. \n + +*@par Attributes: +*@li return_inverse: Whether to also return the indices for where elements in the original +* input ended up in the returned unique list. +*@li return_inverse: Whether to also return the counts for each unique element. + +*@par Outputs: +*@li y1: The output list of unique scalar elements. Has the same type as "x". +*@li y2: Representing the indices for where elements in the original input map to in the output. +*@li y3: Representing the number of occurrences for each unique value or tensor. \n + +* @par Third-party framework compatibility +* Compatible with the troch operator _unique2. +*/ + +REG_OP(UniqueWithCountsAndSorting) + .INPUT(x, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE })) + .OUTPUT(y1, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE })) + .OUTPUT(y2, TensorType({ DT_INT32, DT_INT64 })) + .OUTPUT(y3, TensorType({ DT_INT32, DT_INT64 })) + .ATTR(return_inverse, Bool, false) + .ATTR(return_counts, Bool, false) + .OP_END_FACTORY_REG(UniqueWithCountsAndSorting) } // namespace ge #endif // OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/control_flow_ops.h b/third_party/fwkacllib/inc/ops/control_flow_ops.h index e5bd3534..cd993599 100644 --- a/third_party/fwkacllib/inc/ops/control_flow_ops.h +++ b/third_party/fwkacllib/inc/ops/control_flow_ops.h @@ -96,7 +96,7 @@ REG_OP(RefMerge) * Otherwise, the data is forwarded to "output_false" . \n *@par Inputs: - *@li data: The tensor to be forwarded. \n + *@li data: The tensor to be forwarded. * Must be one of the following types: float16, float32, float64, * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. *@li pred: A boolean scalar. The output port that will receive data . \n diff --git a/third_party/fwkacllib/inc/ops/ctc_ops.h b/third_party/fwkacllib/inc/ops/ctc_ops.h index e907b828..bbc610ff 100644 --- a/third_party/fwkacllib/inc/ops/ctc_ops.h +++ b/third_party/fwkacllib/inc/ops/ctc_ops.h @@ -74,7 +74,7 @@ REG_OP(CTCLoss) *@li sequence_length: A vector containing sequence lengths, size `(batch_size)`. \n *@par Attributes: -*@li merge_repeated: If True, merge repeated classes in output. \n +* merge_repeated: If True, merge repeated classes in output. \n *@par Outputs: *@li decoded_indices: Indices matrix, size `(total_decoded_outputs x 2)`, @@ -108,6 +108,8 @@ REG_OP(CTCGreedyDecoder) *@par Attributes: *@li merge_repeated: If True, merge repeated classes in output. \n +*@li beam_width:A scalar >= 0 (beam search beam width). +*@li top_paths:A scalar >= 0, <= beam_width (controls output size). *@par Outputs: *@li decoded_indices: A list (length: top_paths) of indices matrices. Matrix j, @@ -162,7 +164,7 @@ REG_OP(CTCBeamSearchDecoder) * Compatible with Pytorch CTCLoss operator. *@par Restrictions: -*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*The length of Label should in [4, 1000]. */ REG_OP(CTCLossV2) .INPUT(log_probs, TensorType({DT_FLOAT, DT_DOUBLE})) @@ -203,7 +205,7 @@ REG_OP(CTCLossV2) * Compatible with Pytorch CTCLoss operator. *@par Restrictions: -*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*The limit of Label’s length is 1K. */ REG_OP(CTCLossV2Grad) .INPUT(grad_out, TensorType({DT_FLOAT, DT_DOUBLE})) diff --git a/third_party/fwkacllib/inc/ops/data_flow_ops.h b/third_party/fwkacllib/inc/ops/data_flow_ops.h index 6021f4e3..32454d27 100644 --- a/third_party/fwkacllib/inc/ops/data_flow_ops.h +++ b/third_party/fwkacllib/inc/ops/data_flow_ops.h @@ -1201,6 +1201,8 @@ REG_OP(TensorArraySize) *@brief A queue implementation that dequeues elements in a random order. \n *@par Attributes: +*@li component_types:A list of fully-defined Tensortype objects with +the same length as shapes, or None. *@li shapes: (Optional.) A list of fully-defined TensorShape objects with the same length as dtypes, or None. *@li capacity: An integer. The upper bound on the number of elements that may @@ -1281,6 +1283,7 @@ The length of this attr must be either 0 or the same as the length of elements are not constrained, and only one element may be dequeued at a time. *@li container: An optional string. Defaults to "". If non-empty, this queue is placed in the given container. Otherwise, a default container is used. +*@li capacity:An integer. The upper bound on the number of elements that may be stored in this queue. *@li shared_name: An optional string. Defaults to "". If non-empty, this queue will be shared under the given name across multiple sessions. \n @@ -1435,7 +1438,7 @@ REG_OP(OrderedMapClear) *@par Inputs: *Including: -* @li resource: A Tensor of type DT_RESOURCE. +* resource: A Tensor of type DT_RESOURCE. *@par Outputs: *handle: A Tensor of type DT_STRING ref. \n @@ -1526,7 +1529,7 @@ REG_OP(OrderedMapPeek) *@par Inputs: *Including: -* @li indices: A Tensor of type DT_INT32. \n +* indices: A Tensor of type DT_INT32. \n *@par Attributes: *@li capacity: An optional int that is >= 0. Defaults to "0". @@ -2332,6 +2335,40 @@ REG_OP(CacheAllIndexToLocal) .OP_END_FACTORY_REG(CacheAllIndexToLocal) /** +*@brief LRUCacheV2, aicore LRUCache. +*@par Inputs: +*index_list: exchange index list +*data: host data +*cache: gm cache +*tag: cache's tag +*is_last_call: if is last call write all cache to data +*@par Outputs: +*data: output data +*cache: gm cache +*tag: cache's tag +*index_offset_list: index_offset_list +*not_in_cache_index_list: output not in cache's index_list +*not_in_cache_number: scalar +*@par Attributes: +*pre_route_count: types of all outputs +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(LRUCacheV2) + .INPUT(index_list, TensorType::BasicType()) + .INPUT(data, TensorType::BasicType()) + .INPUT(cache, TensorType::BasicType()) + .INPUT(tag, TensorType::BasicType()) + .INPUT(is_last_call, TensorType::BasicType()) + .OUTPUT(data, TensorType::BasicType()) + .OUTPUT(cache, TensorType::BasicType()) + .OUTPUT(tag, TensorType::BasicType()) + .OUTPUT(index_offset_list, TensorType::BasicType()) + .OUTPUT(not_in_cache_index_list, TensorType::BasicType()) + .OUTPUT(not_in_cache_number, TensorType::BasicType()) + .REQUIRED_ATTR(pre_route_count, Int) + .OP_END_FACTORY_REG(LRUCacheV2) + +/** *@brief DynamicGetNext, dynamic get next data *@par Inputs: *x: the iterator, all types are available diff --git a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h index f61e2939..b4299026 100644 --- a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h @@ -624,9 +624,9 @@ REG_OP(Log1p) *@attention Constraints: *@li x2: The input data does not support 0 -*@li When NUM exceeds 2048 , the accuracy of operator cannot guarantee the +*@li When NUM exceeds 2048 , the accuracy of operator cannot guarantee the *requirement of double thousandths in the mini form -*@li Due to different architectures, the calculation results of this operator +*@li Due to different architectures, the calculation results of this operator *on NPU and CPU may be inconsistent *@li If shape is expressed as (D1,D2... ,Dn), then D1*D2... *DN<=1000000,n<=8 @@ -2066,9 +2066,9 @@ REG_OP(FloorDiv) *@attention Constraints: *@li x2: The input data does not support 0 -*@li When NUM exceeds 2048 , the accuracy of operator cannot guarantee the +*@li When NUM exceeds 2048 , the accuracy of operator cannot guarantee the *requirement of double thousandths in the mini form -*@li Due to different architectures, the calculation results of this operator +*@li Due to different architectures, the calculation results of this operator *on NPU and CPU may be inconsistent *@li If shape is expressed as (D1,D2... ,Dn), then D1*D2... *DN<=1000000,n<=8 @@ -2200,9 +2200,9 @@ REG_OP(Tan) *@attention Constraints: *@li x2: The input data does not support 0 -*@li When NUM exceeds 2048 , the accuracy of operator cannot guarantee the +*@li When NUM exceeds 2048 , the accuracy of operator cannot guarantee the *requirement of double thousandths in the mini form -*@li Due to different architectures, the calculation results of this operator +*@li Due to different architectures, the calculation results of this operator *on NPU and CPU may be inconsistent *@li If shape is expressed as (D1,D2... ,Dn), then D1*D2... *DN<=1000000,n<=8 @@ -2467,11 +2467,11 @@ REG_OP(Eltwise) *@par Inputs: *One inputs, including: - * @li input_x: A tensor. Must be one of the following types: + * input_x: A tensor. Must be one of the following types: * float16, float32. \n *@par Outputs: - *y: A Tensor with the same type and shape of input_x's. \n + *output_y: A Tensor with the same type and shape of input_x's. \n *@par Third-party framework compatibility *Compatible with the Pytorch operator Erfinv. \n @@ -3154,13 +3154,13 @@ REG_OP(FusedMulAddNL2loss) *@brief Tests whether the input exceeds a threshold. \n *@par Inputs: -*@li x: A Tensor with any format. Must be one of the following types: float16, float32. \n +* x: A Tensor with any format. Must be one of the following types: float16, float32. \n *@par Attributes: -*@li threshold: A required float32. Defaults to "0.0". "x" is compared with "threshold", outputs "1" for inputs above threshold; "0" otherwise. \n +* threshold: A required float32. Defaults to "0.0". "x" is compared with "threshold", outputs "1" for inputs above threshold; "0" otherwise. \n *@par Outputs: -*@li y: A Tensor with any format. Has the same type as the input. Must be one of the following types: float16, float32. +* y: A Tensor with any format. Has the same type as the input. Must be one of the following types: float16, float32. *@par Third-party framework compatibility * Compatible with the Caffe operator Threshold. */ @@ -3175,7 +3175,7 @@ REG_OP(FusedMulAddNL2loss) *@brief Returns the index number corresponding to the maximum value entered. \n *@par Inputs: -*@li x: A tensor. Must be one of the following types: float16, float32. \n +*x: A tensor. Must be one of the following types: float16, float32. \n *@par Attributes: *@li axis: An optional int. Specify the axis to be cut at the input tensor. If this parameter is not provided, find the topk for each batch. Defaults to 10000 @@ -3203,12 +3203,11 @@ REG_OP(ArgMaxWithK) *@brief Multiply tensor with scale. \n *@par Inputs: -*Five inputs, including: -* @li x1: A Tensor. Must be one of the following types:int32,int16, float16, float32. -* @li x2: A scale. Must be float. \n +*One input, including: +*x: A Tensor. Must be one of the following types:int32,int16, float16, float32. *@par Outputs: -*@li y: A Tensor. Has the same type and shape as "x1". \n +*y: A Tensor. Has the same type and shape as "x1". \n *@par Third-party framework compatibility: * Compatible with the Pytorch operator muls. @@ -3223,12 +3222,11 @@ REG_OP(Muls) *@brief Fill tensor with scale. \n *@par Inputs: -*Five inputs, including: -* @li x1: A Tensor. Must be one of the following types:int32,int16, float16, float32. -* @li x2: A scale. Must be float. \n +*One input, including: +*x1: A Tensor. Must be one of the following types:int32,int16, float16, float32. *@par Outputs: -*@li y: A Tensor. Has the same type and shape as "x1". \n +*y: A Tensor. Has the same type and shape as "x1". \n *@par Third-party framework compatibility: * Compatible with the Pytorch operator fills. @@ -3378,7 +3376,7 @@ REG_OP(TensorMove) *@par Inputs: *One inputs, including: -* @li x: A Tensor. Must be one of the following types: float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64. \n +*x: A Tensor. Must be one of the following types: float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64. \n *@par Outputs: *output_x: A Tensor. Has the same type as "x". \n @@ -3397,7 +3395,7 @@ REG_OP(TensorRedirect) * multiply the result by the scalar value and add it to tensor x1 * @par Inputs: -* Three inputs, including: +* Four inputs, including: * @li input_data: A mutable input Tensor. Must be one of the following types: * float16, float32. * @li x1: A mutable input Tensor of the same type as x1. @@ -3406,7 +3404,7 @@ REG_OP(TensorRedirect) * float16, float32, int32. \n * @par Outputs: -* @li y: A mutable Tensor. Has the same type as "x1". \n +* y: A mutable Tensor. Has the same type as "x1". \n * @par Third-party framework compatibility * Compatible with the Pytorch operator Addcdiv. @@ -3420,12 +3418,12 @@ REG_OP(Addcdiv) .OP_END_FACTORY_REG(Addcdiv) /** -* @brief Performs the element-wise multiplication of tensor x2 by tensor x3, -* multiply the result by the scalar value and add it to tensor input_data +* @brief Performs the element-wise multiplication of tensor x2 by tensor x3, +* multiply the result by the scalar value and add it to tensor input_data * @par Inputs: -* Three inputs, including: +* Four inputs, including: * @li input_data: A mutable input Tensor. Must be one of the following types: * float16, float32, int8, int32, uint8. * @li x1: A mutable input Tensor of the same type as x1. @@ -3433,7 +3431,7 @@ REG_OP(Addcdiv) * @li value: A tensor which includes only one element of the same type as x1. \n * @par Outputs: -* @li y: A mutable output Tensor. Has the same type as "x1". \n +* y: A mutable output Tensor. Has the same type as "x1". \n * @par Third-party framework compatibility * Compatible with the Pytorch operator Addcmul. @@ -3455,7 +3453,7 @@ REG_OP(Addcmul) * @li alpha: A scalar tensor of type float16, float32. \n * @par Outputs: -* @li y: An ND tensor tensor with the same shape and type as "x1". \n +* y: An ND tensor tensor with the same shape and type as "x1". \n * @par Third-party framework compatibility * Compatible with the Pytorch operator Axpy. @@ -3468,25 +3466,6 @@ REG_OP(AxpyV2) .OP_END_FACTORY_REG(AxpyV2) /** -* @brief Computes the result of x1 - x2. - -* @par Inputs: -* @li x1: An ND tensor of type float16, float, int32. -* @li x2: An ND tensor of type float16, float, int32. \n - -* @par Outputs: -* @li y: An ND tensor tensor with the same type as "x1". \n - -* @par Third-party framework compatibility -* Compatible with the Pytorch operator Sub. -*/ -REG_OP(PtSub) - .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) - .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) - .OP_END_FACTORY_REG(PtSub) - -/** * @brief Add the partial values of two tensors in format NC1HWC0. * @par Inputs: @@ -3502,7 +3481,7 @@ REG_OP(PtSub) * the difference between C1 and offset in "x1" and "x2". \n * @par Outputs: -* @li y: A Tensor of the same type as "x1", and the same shape as "x1", +* y: A Tensor of the same type as "x1", and the same shape as "x1", * except for the C1 value. Record the result after adding. \n */ REG_OP(StrideAdd) @@ -3523,7 +3502,7 @@ REG_OP(StrideAdd) * @li input_y: A Tensor. the second tensor. \n * @par Outputs: -* @li output_z: A Tensor. Bool type, compare result of the two inputs. \n +*output_z: A Tensor. Bool type, compare result of the two inputs. \n * @par Third-party framework compatibility * Compatible with the Pytorch equal operator. \n @@ -3535,21 +3514,21 @@ REG_OP(TensorEqual) .OP_END_FACTORY_REG(TensorEqual) /** - * @brief Element-wise min of each of the input tensors (with Numpy-style broadcasting support). - * All inputs and outputs must have the same data type. This operator supports multidirectional + * @brief Element-wise min of each of the input tensors (with Numpy-style broadcasting support). + * All inputs and outputs must have the same data type. This operator supports multidirectional * (i.e., Numpy-style) broadcasting - * - * @par inputs + * + * @par Inputs: * one input including: - * @li x: dynamic input A Tensor. Must be one of the following types: float32, float16, double, int32, int64 - * - * @par output + * x: dynamic input A Tensor. Must be one of the following types: float32, float16, double, int32, int64 + * + * @par Outputs: * one output including: - * @li y:A Tensor of the same type as x - * + * y:A Tensor of the same type as x + * */ REG_OP(MaxN) - .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_FLOAT64, DT_INT32, DT_INT64})) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_FLOAT64, DT_INT32, DT_INT64})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_FLOAT64, DT_INT32, DT_INT64})) .OP_END_FACTORY_REG(MaxN) @@ -3634,16 +3613,16 @@ REG_OP(DataCompare) *which Hardmax will be performed.The output tensor has the same shape and contains the Hardmax values of the *corresponding input. * -*@par inputs +*@par Inputs: *one input including: -*@li x: input A Tensor.Must be one of the following types:float32,float16 +*x: input A Tensor.Must be one of the following types:float32,float16 * *@par Attributes: -*@li axis:A required int attribute that decides which dimension will be used to cal the hard_max +*axis:A required int attribute that decides which dimension will be used to cal the hard_max * -*@par output: +*@par Outputs: *one output including: -*@li y:A Tensor of the same type as x +*y:A Tensor of the same type as x * */ REG_OP(HardMax) @@ -3661,7 +3640,7 @@ REG_OP(HardMax) * @li input_y: A Tensor. the second tensor must be 1d. \n * @par Outputs: -* @li output: A Tensor. Result of the two inputs, must be 1d. \n +* output: A Tensor. Result of the two inputs, must be 1d. \n * @par Third-party framework compatibility * Compatible with the Pytorch dot operator. \n @@ -3671,7 +3650,7 @@ REG_OP(Dot) .INPUT(input_y, TensorType({DT_FLOAT, DT_FLOAT16, DT_UINT8, DT_INT8, DT_INT32})) .OUTPUT(output, TensorType({DT_FLOAT, DT_FLOAT16, DT_UINT8, DT_INT8, DT_INT32})) .OP_END_FACTORY_REG(Dot) - + /** *@brief Returns a new tensor with boolean elements representing \n *if each element of input is “close” to the corresponding element of other \n @@ -3719,7 +3698,7 @@ REG_OP(IsClose) * *@attention Constraints: *@li indices: only support int32,and shape same to "updates" -*@li The value range of "dimension" is [-dims, dims - 1]. "dims" is the dimension length of "x". +*@li The value range of "dimension" is [-dims, dims - 1]. "dims" is the dimension length of "x". *@li y:A Tensor, the type and shape is same to "var" \n *@par Third-party framework compatibility @@ -3754,7 +3733,7 @@ REG_OP(ArgMaxGrad) *@attention Constraints: *@li indices: only support int32,and shape same to "updates" -*@li The value range of "dimension" is [-dims, dims - 1]. "dims" is the dimension length of "x". +*@li The value range of "dimension" is [-dims, dims - 1]. "dims" is the dimension length of "x". *@li y:A Tensor, the type and shape is same to "var" \n *@par Third-party framework compatibility @@ -3805,15 +3784,15 @@ REG_OP(AddMatMatElements) *@par Inputs: *Two inputs, including: -* @li input_x1: A tensor. Must be the following types: -* float32. \n +* @li input_x1: A tensor. Must be the following types: float32. +* @li input_x2: A tensor. Must of the following types: float32. \n -*@par Inputs: -*@li input_x2: A tensor. Must of the following types: -* float32. \n +* @par Attributes: +* @li dim:The type is Int and the default value is 1. +* @li eps:The type is Float and the default value is 1e-8. \n *@par Outputs: -*@li output_y: A Tensor with the same type of input_x's. \n +* output_y: A Tensor with the same type of input_x's. \n *@par Third-party framework compatibility *Compatible with the Pytorch operator CosineSimilarity. \n @@ -3826,6 +3805,45 @@ REG_OP(CosineSimilarity) .ATTR(eps, Float, 1e-8) .OP_END_FACTORY_REG(CosineSimilarity) +/** +*@brief count adam result. \n + +*@par Inputs: +*eleven inputs, including: +* @li var: A Tensor. Support float16/float32.\n +* @li m: A Tensor. Datatype and shape are same as exp_avg.\n +* @li v: A Tensor. Datatype and shape are same as exp_avg.\n +* @li lr: A Tensor. Datatype is same as exp_avg. Shape (1, ).\n +* @li beta1: A Tensor. Datatype is same as exp_avg. Shape (1, ).\n +* @li beta2: A Tensor. Datatype is same as exp_avg. Shape (1, ).\n +* @li epsilon: A Tensor. Datatype is same as exp_avg. Shape (1, ).\n +* @li grad: A Tensor. Datatype and shape are same as exp_avg.\n +* @li max_grad_norm: A Tensor. Datatype is same as exp_avg. Shape (1, ).\n +* @li global_grad_norm: A Tensor. Datatype is same as exp_avg. Shape (1, ).\n +* @li weight_decay: A Tensor. Datatype is same as exp_avg. Shape (1, ).\n + +*@par Outputs: +*three inputs, including: +* @li var: A Tensor. Datatype and shape are same as exp_avg.\n +* @li m: A Tensor. Datatype and shape are same as exp_avg.\n +* @li v: A Tensor. Datatype and shape are same as exp_avg.\n +*/ +REG_OP(ApplyAdamV2) + .INPUT(var, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .INPUT(m, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .INPUT(v, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .INPUT(lr, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .INPUT(beta1, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .INPUT(beta2, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .INPUT(epsilon, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .INPUT(grad, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .INPUT(max_grad_norm, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .INPUT(global_grad_norm, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .INPUT(weight_decay, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .OUTPUT(var, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .OUTPUT(m, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .OUTPUT(v, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .OP_END_FACTORY_REG(ApplyAdamV2) } // namespace ge #endif // OPS_BUILT_IN_OP_PROTO_INC_ELEWISE_CALCULATION_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/functional_ops.h b/third_party/fwkacllib/inc/ops/functional_ops.h index b09ac058..7cfe39c4 100644 --- a/third_party/fwkacllib/inc/ops/functional_ops.h +++ b/third_party/fwkacllib/inc/ops/functional_ops.h @@ -163,9 +163,6 @@ REG_OP(Case) * if it is not a scalar, non-empty means True and empty means False. *@li body: A subgraph takes 'input' and returns a another list of tensors . \n - *@par Attributes: - *parallel_iterations: An optional int, default as 10 . \n - *@par Outputs: *output: The output tensors returned by "body". Has the same type as "input" . \n diff --git a/third_party/fwkacllib/inc/ops/image_ops.h b/third_party/fwkacllib/inc/ops/image_ops.h index 6909345a..2327e76e 100644 --- a/third_party/fwkacllib/inc/ops/image_ops.h +++ b/third_party/fwkacllib/inc/ops/image_ops.h @@ -28,7 +28,7 @@ namespace ge { *@brief Decode the frame(s) of a GIF-encoded image to a uint8 tensor . \n *@par Inputs: -*@li contents:A Tensor of type string. 0-D. The GIF-encoded image. \n +*contents:A Tensor of type string. 0-D. The GIF-encoded image. \n *@par Outputs: *image:A Tensor of type uint8. \n @@ -128,8 +128,8 @@ crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling to a common output size specified by crop_size . \n *@par Inputs: -*Input images must be a 4-D tensor. Inputs include: -*@li images:A Tensor. Must be one of the following types:uint8, uint16, int8, +*Input x must be a 4-D tensor. Inputs include: +*@li x:A Tensor. Must be one of the following types:uint8, uint16, int8, int16, int32, int64, float16, float, double. A 4-D tensor of shape [batch, image_height, image_width, depth]. The format must be NHWC. *@li boxes: A Tensor of type float. A 2-D tensor of shape [num_boxes, 4]. @@ -266,8 +266,9 @@ depth] containing the original image size. Both image_height and image_width need to be positive . \n *@par Attributes: -method: A string specifying the interpolation method. Only 'bilinear' is -supported for now . \n +*@li method: A string specifying the interpolation method. Only 'bilinear' is +supported for now . +*@li T: output of type \n *@par Outputs: *y:A 4-D tensor of shape [batch, image_height, image_width, depth]. The format @@ -585,9 +586,11 @@ REG_OP(ResizeNearestNeighborV2GradD) channels], The image tensor that was resized . \n *@par Attributes: -*align_corners: An optional bool. Defaults to False. If true, the centers of +*@li align_corners: An optional bool. Defaults to False. If true, the centers of the 4 corner pixels of the input and grad tensors are aligned. Defaults to -false . \n +false . +*@li half_pixel_centers: indicates if the offset coordinates are normalized. Defaults +to false . \n *@par Outputs: *y: A Tensor. Has the same type as original_image . \n @@ -617,9 +620,10 @@ REG_OP(ResizeBilinearV2Grad) size for the images . \n *@par Attributes: -*align_corners: If true, the centers of the 4 corner pixels of the input and +* @li align_corners: If true, the centers of the 4 corner pixels of the input and output tensors are aligned, preserving the values at the corner pixels. -Defaults to false . \n +Defaults to false . +* @li half_pixel_centers: An optional bool. Defaults to False . \n *@par Outputs: *y: 4-D with shape [batch, new_height, new_width, channels] . \n @@ -684,6 +688,9 @@ be non-negative. In the case of 0, the cropped area does not need to overlap any of the bounding boxes supplied . *@li aspect_ratio_range: The cropped area of the image must have an aspect ratio = width / height within this range. +*@li area_range: An optional list of `floats`. Defaults to `[0.05, 1]`. The +cropped area of the image must contain a fraction of the supplied image +within this range. *@li max_attempts: Number of attempts at generating a cropped region of the image of the specified constraints. After max_attempts failures, return the entire image. @@ -740,6 +747,9 @@ generator is seeded by the given seed. Otherwise, it is seeded by a random seed. *@li seed2: A second seed to avoid seed collision. *@li aspect_ratio_range: The cropped area of the image must have an aspect ratio = width / height within this range. +*@li area_range: An optional list of `floats`. Defaults to `[0.05, 1]`. The +cropped area of the image must contain a fraction of the supplied image +within this range. *@li max_attempts: Number of attempts at generating a cropped region of the image of the specified constraints. After max_attempts failures, return the entire image. @@ -787,9 +797,10 @@ REG_OP(SampleDistortedBoundingBoxExt2) The new size for the images . \n *@par Attributes: -*align_corners: If true, the centers of the 4 corner pixels of the input and +*@li align_corners: If true, the centers of the 4 corner pixels of the input and output tensors are aligned, preserving the values at the corner pixels. Defaults to false . \n +*@li half_pixel_centers: An optional bool. Defaults to False . \n *@par Outputs: *y: 4-D with shape [batch, new_height, new_width, channels] . \n @@ -999,10 +1010,6 @@ deciding whether boxes overlap too. *@li score_threshold: A 0-D float tensor representing the threshold for deciding when to remove boxes based on score . \n -*@par Attributes: -*pad_to_max_output_size: If true, the output selected_indices is padded -to be of length max_output_size. Defaults to false . \n - *@par Outputs: *selected_indices: A 1-D integer tensor of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size . \n @@ -1094,8 +1101,8 @@ REG_OP(EncodePng) *contents: 0-D. PNG-decoded image . *@par Attributes: -*channels: graph channels \n -*dtype: type of image +*@li channels: graph channels \n +*@li dtype: type of image *@par Outputs: *image: is a 3-D uint8 or uint16 Tensor of shape [height, width, channels] @@ -1116,10 +1123,10 @@ REG_OP(DecodePng) *@brief Bmp-decode an image. \n *@par Inputs: -*@li contents: A Tensor of type string. 0-D. The BMP-encoded image. \n +*contents: A Tensor of type string. 0-D. The BMP-encoded image. \n *@par Attributes: -*@li channels: Decode the desired number of color channels of the image. \n +*channels: Decode the desired number of color channels of the image. \n *@par Outputs: *image: A Tensor dtype of uint8. @@ -1253,6 +1260,7 @@ REG_OP(KeepRatioResizeBilinear) No default value. *@li align_corners: An optional bool. If "true", the centers of the corner pixels of the input and output tensors are aligned. Defaults to "false" . \n +*@li half_pixel_centers: An optional bool. Defaults to False . \n *@par Outputs: *y: A Tensor with the same type and format as input "images" . \n @@ -1381,6 +1389,7 @@ REG_OP(NonMaxSuppressionV5) *@li scale: A `Tensor` of type `float32`. *@li translation: A `Tensor` of type `float32` . \n +*@par Attributes: *@li kernel_type: type is string, default lanczos3 *@li antialias: type is bool, default true \n @@ -1411,6 +1420,7 @@ REG_OP(ScaleAndTranslate) *@li scale: A `Tensor` of type `float32`. *@li translation: A `Tensor` of type `float32` . \n +*@par Attributes: *@li kernel_type: type is string, default lanczos3 *@li antialias: type is bool, default true @@ -1460,9 +1470,10 @@ if they fall beyond [0, 1]. If false, do not do clipping and output the box coordinates as it is. If not specified, defaults to true . \n *@par Outputs: -*nmsed_boxes:type is float -*nmsed_scores:type is float -*nmsed_classes:type is float \n +*@li nmsed_boxes:type is float +*@li nmsed_scores:type is float +*@li nmsed_classes:type is float +*@li valid_detections:type is INT32 \n *@par Third-party framework compatibility * Compatible with tensorflow CombinedNonMaxSuppression operator. @@ -1508,6 +1519,9 @@ REG_OP(IMGWarp) *@par Outputs: *map_img: A Tensor after resize. \n + +*@par Restrictions: +*Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use. */ REG_OP(Remap) .INPUT(img, TensorType({DT_UINT8, DT_FLOAT16, DT_FLOAT32})) @@ -1524,7 +1538,7 @@ and 4 mean input[(h_top, w_left), (h_top, w_right), (h_bottom, w_left), (h_bott *@li warp_index: the resize offset A 4-D float tensor of shape `[n, 2, h, w]`, 2 means (x, y) for resize point. *@par Outputs: -*remap_img: A Tensor after ResizeBilinear, A 4-D tensor of shape `[n, c, h, w]`. \n +*warp_img: A Tensor after ResizeBilinear, A 4-D tensor of shape `[n, c, h, w]`. \n */ REG_OP(IMGWarpResize) .INPUT(img, TensorType({DT_FLOAT32})) @@ -1559,6 +1573,39 @@ REG_OP(SpatialTransformerD) .OP_END_FACTORY_REG(SpatialTransformerD) /** +*@brief Function spatial transformer . \n + +*@par Inputs: +*@li x: A Tensor dtype of float16, float32, double, uint8, int8, uint16, int16, int32, uint32, uint64, int64. +*@li theta: A Tensor dtype of float16, float32, double, uint8, int8, uint16, int16, int32, uint32, uint64, int64, + auxiliary coefficients . \n + +*@par Attributes: +*@li output_size: A tuple output size. +*@li default_theta: A tuple default theta +*@li use_default_theta: List use default theta + +*@par Outputs: +*y: A Tensor dtype of float16, float32, double, uint8, int8, uint16, int16, int32, uint32, uint64, int64, + should be same shape and type as x. + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(SpatialTransformer) + .INPUT(x, TensorType({DT_FLOAT,DT_FLOAT16,DT_DOUBLE,DT_UINT8,DT_INT8,DT_UINT16, + DT_INT16,DT_INT32,DT_UINT32,DT_UINT64,DT_INT64})) + .OPTIONAL_INPUT(theta, TensorType({DT_FLOAT,DT_FLOAT16,DT_DOUBLE,DT_UINT8,DT_INT8, + DT_UINT16,DT_INT16,DT_INT32,DT_UINT32,DT_UINT64,DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT,DT_FLOAT16,DT_DOUBLE,DT_UINT8,DT_INT8,DT_UINT16, + DT_INT16,DT_INT32,DT_UINT32,DT_UINT64,DT_INT64})) + .ATTR(output_size, ListInt, {-1, -1}) + .ATTR(default_theta, ListFloat, {}) + .ATTR(align_corners, Bool, false) + .ATTR(use_default_theta, ListInt, {}) + .OP_END_FACTORY_REG(SpatialTransformer) + +/** * @brief Resize the input tensor. \n currently, only support resize image tensor using nearest neighbor and linear interpolation. @@ -1623,7 +1670,7 @@ REG_OP(Resize) *@brief Function parse image from string to int. \n *@par Inputs: -*@li contents: A Tensor of type string. 0-D. The JPEG-encoded image. \n +* contents: A Tensor of type string. 0-D. The JPEG-encoded image. \n *@par Attributes: *@li channels: An optional int. Defaults to 0. Number of color channels for the decoded image. @@ -1668,7 +1715,7 @@ REG_OP(DenseImageWarp) *@par Inputs: *One inputs, including: -* @li x: A tensor. Must be one of the following types: +* x: A tensor. Must be one of the following types: * float16, float32. \n *@par Attributes: @@ -1713,7 +1760,7 @@ REG_OP(ResizeD) *@par Inputs: *One inputs, including: -* @li grads: A tensor. Must be one of the following types: +* grads: A tensor. Must be one of the following types: * float16, float32. \n *@par Attributes: @@ -1762,8 +1809,8 @@ REG_OP(ResizeGradD) *@li flow: 4-D Tensor with shape `[batch, height, width, 2]`. \n *@par Outputs: -*grad_image: Returns 4-D with the same shape and dtype as `image`. -*grad_flow: Returns 4-D with the same shape and dtype as `flow`. \n +*@li grad_image: Returns 4-D with the same shape and dtype as `image`. +*@li grad_flow: Returns 4-D with the same shape and dtype as `flow`. \n */ REG_OP(DenseImageWarpGrad) .INPUT(grad, TensorType({DT_FLOAT, DT_FLOAT16})) @@ -1817,12 +1864,12 @@ REG_OP(GridSampler2D) *@li assist: Assist matrix, a 4-D tensor of type float16. *@par Attributes: -*@li align_corners: An optional bool. If "true", the centers of the corner +*align_corners: An optional bool. If "true", the centers of the corner pixels of the input and output tensors are aligned. Defaults to "false" . *@par Outputs: -*diff: Returns 4-D Tensor with the same shape and dtype as `grid`. -*position: Returns 4-D Tensor with the same shape as `grid`. +*@li diff: Returns 4-D Tensor with the same shape and dtype as `grid`. +*@li position: Returns 4-D Tensor with the same shape as `grid`. */ REG_OP(GridUnnormal) .INPUT(grid, TensorType({DT_FLOAT16, DT_FLOAT})) @@ -1840,10 +1887,13 @@ REG_OP(GridUnnormal) *@li position: 4-D Tensor with shape `[batch, output_height, output_width, 2]`. *@par Attributes: -*@li padding_mode: An optional string specifying the pad method. Only 'zeros' is supported for now . +*padding_mode: An optional string specifying the pad method. Only 'zeros' is supported for now . *@par Outputs: *y: Returns 4-D Tensor with the same dtype as `x`. + +*@par Restrictions: +*Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use. */ REG_OP(ImageUnfold) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) @@ -1936,5 +1986,204 @@ REG_OP(GridSampler3DGrad) .ATTR(align_corners, Bool, false) .OP_END_FACTORY_REG(GridSampler3DGrad) +/** +*@brief Upsample the 3-D data with the nearest neighbor ​interpolation algorithm. \n + +*@par Inputs: +*One inputs, including: +*x: A 5-D input tensor [N, C, D, H, W]. Must be one of the following types: +* float16, float32, float64. \n + +*@par Attributes: +*@li output_size: An optional listInt. Defaults to none. + contain 3 elements: output_depth, output_height, output_width. The number of elements of 'output_size' + should be the same as the rank of input 'x'. Only one of 'scales' and 'output_size' can be specified. \n +*@li scales: An optional listFloat. Defaults to none. + The scale array along each dimension, contain 3 elements: scale_depth, scale_height, scale_width. + The number of elements of 'scales' should be the same as the rank of input 'x'. One of 'scales' and + 'output_size' MUST be specified and it is an error if both are specified. \n + +*@par Outputs: +*y: A 5-D tensor. Has the same type as input x, shape depends on x and output_size/scales. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n +*/ + +REG_OP(UpsampleNearest3d) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .ATTR(output_size, ListInt, {}) + .ATTR(scales, ListFloat, {}) + .OP_END_FACTORY_REG(UpsampleNearest3d) + +/** +*@brief Upsample the 3-D data with the trilinear ​interpolation algorithm. \n + +*@par Inputs: +*One inputs, including: +*x: A 5-D input tensor [N, C, D, H, W]. Must be one of the following types: +* float16, float32, float64. \n + +*@par Attributes: +*@li output_size: An optional listInt. Defaults to none. + contain 3 elements: output_depth, output_height, output_width. The number of elements of 'output_size' should + be the same as the rank of input 'x'. Only one of 'scales' and 'output_size' can be specified. \n +*@li scales: An optional listFloat. Defaults to none. + The scale array along each dimension, contain 3 elements: scale_depth, scale_height, scale_width. + The number of elements of 'scales' should be the same as the rank of input 'x'. + One of 'scales' and 'output_size' MUST be specified and it is an error if both are specified. \n +*@li align_corners: An optional bool. Defaults to false. + If true, the input and output tensors are aligned by the center points of their corner pixels, preserving the + values at the corner pixels. If false, the input and output tensors are aligned by the corner points of their + corner pixels, and the interpolation use edge value padding for out of boundary values. \n + +*@par Outputs: +*y: A 5-D tensor. Has the same type as input x, shape depends on x and output_size/scales. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n +*/ + +REG_OP(UpsampleTrilinear3d) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .ATTR(output_size, ListInt, {}) + .ATTR(scales, ListFloat, {}) + .ATTR(align_corners, Bool, false) + .OP_END_FACTORY_REG(UpsampleTrilinear3d) + +/** +*@brief Upsample the 3-D gradient data with the nearest neighbor ​interpolation algorithm. \n + +*@par Inputs: +*One inputs, including: +*grad_output: A 5-D input tensor [N, C, D, H, W]. Must be one of the following types: +* float16, float32, float64. \n + +*@par Attributes: +*@li input_size: An required listInt. + contain 5 elements: [min_batch, channels, depth, height, width]. Must: + input_size[0] == grad_output_tensor_size[0] + input_size[1] == grad_output_tensor_size[1]. \n +*@li output_size: An optional listInt. Defaults to none. + contain 3 elements: depth, height, width. The number of elements of 'output_size' should + be the same as the rank of input 'grad_output'. Only one of 'scales' and 'output_size' can be specified. Must: + grad_output_tensor_size[2] == floor(input_size[2] * scales[0]) == output_size[0] + grad_output_tensor_size[3] == floor(input_size[3] * scales[1]) == output_size[1] + grad_output_tensor_size[4] == floor(input_size[4] * scales[2]) == output_size[2]. \n +*@li scales: An optional listFloat. Defaults to none. + The scale array along each dimension, contain 3 elements: scale_depth, scale_height, scale_width. + The number of elements of 'scales' should be the same as the rank of input 'grad_output'. + One of 'scales' and 'output_size' MUST be specified and it is an error if both are specified. \n + +*@par Outputs: +*y: A 5-D tensor. Has the same type as input grad_output, shape depends on Attributes:input_size. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ + +REG_OP(UpsampleNearest3dGrad) + .INPUT(grad_output, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(input_size, ListInt) + .ATTR(output_size, ListInt, {}) + .ATTR(scales, ListFloat, {}) + .OP_END_FACTORY_REG(UpsampleNearest3dGrad) + +/** +*@brief Upsample the 3-D gradient data trilinear ​interpolation algorithm. \n + +*@par Inputs: +*One inputs, including: +*grad_output: A 5-D input tensor [N, C, D, H, W]. Must be one of the following types: +* float16, float32, float64. \n + +*@par Attributes: +*@li input_size: An required listInt. + contain 5 elements: [min_batch, channels, depth, height, width]. Must: + input_size[0] == grad_output_tensor_size[0] + input_size[1] == grad_output_tensor_size[1]. \n +*@li output_size: An optional listInt. Defaults to none. + contain 3 elements: depth, height, width. The number of elements of 'output_size' should + be the same as the rank of input 'grad_output'. Only one of 'scales' and 'output_size' can be specified. Must: + grad_output_tensor_size[2] == floor(input_size[2] * scales[0]) == output_size[0] + grad_output_tensor_size[3] == floor(input_size[3] * scales[1]) == output_size[1] + grad_output_tensor_size[4] == floor(input_size[4] * scales[2]) == output_size[2]. \n +*@li scales: An optional listFloat. Defaults to none. + The scale array along each dimension, contain 3 elements: scale_depth, scale_height, scale_width. + The number of elements of 'scales' should be the same as the rank of input 'grad_output'. + One of 'scales' and 'output_size' MUST be specified and it is an error if both are specified. \n + +*@par Outputs: +*y: A Tensor with shape depends on intput_size and output_size/scales. Must be one of the following + types: float16, float32, float64. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ + +REG_OP(UpsampleTrilinear3dGrad) + .INPUT(grad_output, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(input_size, ListInt) + .ATTR(output_size, ListInt, {}) + .ATTR(scales, ListFloat, {}) + .ATTR(align_corners, Bool, false) + .OP_END_FACTORY_REG(UpsampleTrilinear3dGrad) + + +/** +*@brief Upsample the 1-D data with the nearest neighbor ​interpolation algorithm. \n + +*@par Inputs: +*x: A 1-D input tensor [N, C, W]. Must be one of the following types: +* float16, float32, float64. \n + +*@par Attributes: +*@li output_size: An required listInt contains output_width. +*@li scales: An optional listFloat contains scale_width. Defaults to be zero. \n + +*@par Outputs: +*y: A 3-D tensor. Has the same type as input x, shape depends on x and output_size/scales. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n +*/ + +REG_OP(UpsampleNearest1d) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(output_size, ListInt) + .ATTR(scales, ListFloat, {}) + .OP_END_FACTORY_REG(UpsampleNearest1d) + +/** +*@brief Upsample the 1-D gradient data with the nearest neighbor ​interpolation algorithm. \n + +*@par Inputs: +*grad_output: A 3-D input tensor [N, C, W]. Must be one of the following types: +* float16, float32, float64. \n + +*@par Attributes: +*@li output_size: An required listInt contains output_width. +*@li scales: An optional listFloat contains scale_width. Defaults to be zero. +*@li input_size: An required listInt contains output_width. \n + +*@par Outputs: +*y: A 3-D tensor. Has the same type as input grad_output, shape depends on Attributes:input_size. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n +*/ + +REG_OP(UpsampleNearest1dGrad) + .INPUT(grad_output, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(input_size, ListInt) + .REQUIRED_ATTR(output_size, ListInt) + .ATTR(scales, ListFloat, {}) + .OP_END_FACTORY_REG(UpsampleNearest1dGrad) } // namespace ge #endif // OPS_BUILT_IN_OP_PROTO_INC_IMAGE_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/linalg_ops.h b/third_party/fwkacllib/inc/ops/linalg_ops.h index 69c77bf6..f6cc8694 100644 --- a/third_party/fwkacllib/inc/ops/linalg_ops.h +++ b/third_party/fwkacllib/inc/ops/linalg_ops.h @@ -347,6 +347,9 @@ REG_OP(SelfAdjointEig) .OP_END_FACTORY_REG(SelfAdjointEig) /** +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. + *@brief Computes the sign and the log of the absolute value of the determinant of one or more square matrices . \n @@ -382,9 +385,10 @@ REG_OP(Slogdet) *x:Tensor of shape [..., M, N]. Let P be the minimum of M and N . \n *@par Attributes: -*compute_uv:If True then left and right singular vectors will be computed and +*@li compute_uv:If True then left and right singular vectors will be computed and returned in u and v, respectively. Otherwise, only the singular values will -be computed, which can be significantly faster . \n +be computed, which can be significantly faster . +*@li full_matrices:the param effect u,v. \n *@par Outputs: *@li sigma:Singular values. Shape is [..., P]. The values are sorted in @@ -427,6 +431,9 @@ denotes the lower triangular factor `L` with unit diagonal. *@li p: upper triangular part denotes the upper triangular factor `U`.Permutation of the rows encoded as a list of indices in `0..M-1`. Shape is `[..., M]` . \n +*@par Attributes: +*output_idx_type: An optional DType from: int32, int64. + *@par Third-party framework compatibility * Compatible with TensorFlow Lu operator. */ @@ -467,6 +474,12 @@ left-hand side . \n *@par Outputs: y: Tensor of shape `[..., M, K]` containing the solutions \n +*@par Attributes: +*partial_pivoting: Whether to perform partial pivoting. `True` by default. +Partial pivoting makes the procedure more stable, but slower. Partial +pivoting is unnecessary in some cases, including diagonally dominant and +symmetric positive definite matrices + *@par Third-party framework compatibility * Compatible with TensorFlow TridiagonalSolve operator. */ diff --git a/third_party/fwkacllib/inc/ops/list_ops.h b/third_party/fwkacllib/inc/ops/list_ops.h index a1b622e9..0aa94e73 100644 --- a/third_party/fwkacllib/inc/ops/list_ops.h +++ b/third_party/fwkacllib/inc/ops/list_ops.h @@ -35,10 +35,10 @@ namespace ge { *@li max_num_elements: The maximum number of elements. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: -*@li handle: An empty tensor list . \n +*handle: An empty tensor list . \n *@par Third-party framework compatibility. *Compatible with tensorflow EmptyTensorList operator. @@ -59,10 +59,10 @@ and the other elements of the given list in `input_handle`. \n *@li tensor: The tensor to put on the list. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: -*@li output_handle:A list with the elements of old list followed by tensor. \n +*output_handle:A list with the elements of old list followed by tensor. \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListPushBack operator. @@ -86,7 +86,7 @@ list with all but that element. \n *@li element_shape: A shape compatible with that of elements in the list. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: *@li output_handle:A list with the elements of the old list followed by tensor. @@ -110,10 +110,10 @@ REG_OP(TensorListPopBack) *@brief The number of tensors in the input tensor list. \n *@par Inputs: -*@li input_handle: The input list. \n +*input_handle: The input list. \n *@par Outputs: -*@li length:The number of tensors in the list. \n +*length:The number of tensors in the list. \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListLength operator. @@ -127,13 +127,13 @@ REG_OP(TensorListLength) *@brief The shape of elements in the input tensor list. \n *@par Inputs: -*@li input_handle: The input list. \n +*input_handle: The input list. \n *@par Attributes: -*@li shape_type: The type of shape in the list. \n +*shape_type: The type of shape in the list. \n *@par Outputs: -*@li element_shape:A shape compatible with that of elements in the list. \n +*element_shape:A shape compatible with that of elements in the list. \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListElementShape operator. @@ -156,7 +156,7 @@ REG_OP(TensorListElementShape) *@li shape_type: The type of shape in the list. \n *@par Outputs: -*@li handle: An output tensor list . \n +*handle: An output tensor list . \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListReserve operator. @@ -178,10 +178,10 @@ REG_OP(TensorListReserve) *@li element_shape: A shape compatible with that of elements in the list. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: -*@li item: An output tensor value of index position . \n +*item: An output tensor value of index position . \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListGetItem operator. @@ -206,10 +206,10 @@ REG_OP(TensorListGetItem) *@li item: The element to be assigned to that position. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: -*@li output_handle: An output tensor list . \n +*output_handle: An output tensor list . \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListSetItem operator. @@ -233,10 +233,10 @@ REG_OP(TensorListSetItem) *@li tensor: The tensor push into tensor list. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: -*@li output_handles: The output tensor lists. \n +*output_handles: The output tensor lists. \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListPushBackBatch operator. @@ -263,7 +263,7 @@ REG_OP(TensorListPushBackBatch) *@li num_elements: The number of elements in the list. \n *@par Outputs: -*@li tensor: The tensor of list. \n +*tensor: The tensor of list. \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListStack operator. @@ -293,7 +293,7 @@ the leading dim of input_handle.element_shape or the element_shape input arg is not already set. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: *@li tensor: The concated result. @@ -324,10 +324,10 @@ REG_OP(TensorListConcatV2) *@li lengths: Vector of sizes of the 0th dimension of tensors in the list. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: -*@li output_handle: The list. \n +*output_handle: The list. \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListSplit operator. @@ -351,10 +351,10 @@ REG_OP(TensorListSplit) *@li element_shape: The shape of elements in the list. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: -*@li output_handle: An output tensor list . \n +*output_handle: An output tensor list . \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListFromTensor operator. @@ -377,7 +377,7 @@ REG_OP(TensorListFromTensor) *@li size: size of the output list. \n *@par Outputs: -*@li output_handle: The output tensor list. \n +*output_handle: The output tensor list. \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListResize operator. @@ -397,10 +397,10 @@ REG_OP(TensorListResize) *@li element_shape: The shape of elements in the list. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: -*@li values: The tensor. \n +*values: The tensor. \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListGather operator. @@ -429,10 +429,10 @@ the largest index in indices. If -1, the list is just large enough to include the largest index in indices. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: -*@li output_handle: The TensorList. \n +*output_handle: The TensorList. \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListScatterV2 operator. @@ -458,10 +458,10 @@ REG_OP(TensorListScatterV2) *@li indices: The indices used to index into the list. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: -*@li output_handle: The TensorList. \n +*output_handle: The TensorList. \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListScatterIntoExistingList operator. @@ -485,10 +485,10 @@ REG_OP(TensorListScatterIntoExistingList) *@li input_b: The input tensor list B. \n *@par Attributes: -*@li element_dtype: The type of elements in the list. \n +*element_dtype: The type of elements in the list. \n *@par Outputs: -*@li output: The output list. \n +*output: The output list. \n *@par Third-party framework compatibility. *Compatible with tensorflow TensorListConcatLists operator. diff --git a/third_party/fwkacllib/inc/ops/lookup_ops.h b/third_party/fwkacllib/inc/ops/lookup_ops.h index 5d928e5a..b1fc254f 100644 --- a/third_party/fwkacllib/inc/ops/lookup_ops.h +++ b/third_party/fwkacllib/inc/ops/lookup_ops.h @@ -77,8 +77,8 @@ REG_OP(LookupTableInsert) *handle: A Tensor of type resource. Handle to the table . \n *@par Attributes: -*@li Tkeys: A DType. -*@li Tvalues: A DType . \n +*@li Tkeys: A DType of keys. +*@li Tvalues: A DType of values. *@par Outputs: *@li keys: A Tensor of type Tkeys. diff --git a/third_party/fwkacllib/inc/ops/math_ops.h b/third_party/fwkacllib/inc/ops/math_ops.h index 319bcf70..6eb418d8 100644 --- a/third_party/fwkacllib/inc/ops/math_ops.h +++ b/third_party/fwkacllib/inc/ops/math_ops.h @@ -227,10 +227,10 @@ REG_OP(Bucketize) *@par Inputs: *One inputs, including: -* @li input_x: A tensor. Must be one of the following types: float16, float32, int8, uint8, int32. \n +*input_x: A tensor. Must be one of the following types: float16, float32, int8, uint8, int32. \n *@par Outputs: -*y: A tensor with the same type and shape of input_x \n +*output_y: A tensor with the same type and shape of input_x \n *@par Third-party framework compatibility *Compatible with the Pytorch operator Trunc. \n @@ -298,7 +298,7 @@ REG_OP(SparseSegmentMean) *@par Inputs: *The input grad must have be type float or double. Inputs include: -*@li grad: A Tensor. Must be one of the following types: float, double. +*@li x: A Tensor. Must be one of the following types: float, double. gradient propagated to the SparseSegmentMean op. *@li indices: A Tensor. Must be one of the following types: int32, int64. indices passed to the corresponding SparseSegmentMean op. @@ -365,6 +365,7 @@ REG_OP(InitData) component of an element of this dataset. *@li output_shapes: A nested structure of TensorShape objects corresponding to each component of an element of this dataset. +*@li output_num:output of nums. *@li channel_name: A string. Default "" . \n *@par Outputs: @@ -538,11 +539,11 @@ REG_OP(NextAfter) *@par Inputs: *One inputs, including: -* @li input_x: A tensor. Must be one of the following types: +* input_x: A tensor. Must be one of the following types: * float16, float32. \n *@par Attributes: -*@li p: An optional float.Defaults to 2. \n +*p: An optional float.Defaults to 2. \n *@par Outputs: *y: A Tensor with the same type and shape of input_x's. \n @@ -560,10 +561,10 @@ REG_OP(Pdist) *@brief Compute element-wise finiteness, return a boolean tensor. *@par Inputs: - *x:A Tensor. + *x:A Tensor of type float16, float32, double. *@par Outputs: - *y:A Tensor. Has the same shape as x. + *y:A Tensor. Returns which elements of x are finite *@par Third-party framework compatibility. *Compatible with tensorflow IsFinite operator. @@ -577,10 +578,10 @@ REG_OP(IsFinite) *@brief Compute element-wise infiniteness, return a boolean tensor. *@par Inputs: - *x:A Tensor. + *x:A Tensor of type float16, float32, double. *@par Outputs: - *y:A Tensor. Has the same shape as x. + *y:A Tensor. Has the same shape as x. Returns which elements of x are isinf. *@par Third-party framework compatibility. *Compatible with tensorflow IsInf operator. @@ -594,7 +595,11 @@ REG_OP(IsInf) *@brief Computes the complex absolute value of a tensor. *@par Inputs: - *x:A Tensor. + *x: x of complex numbers, this operation returns a tensor of type + float or double that is the absolute value of each element in x . + +* @par Attributes: +* Tout: representing the output of type. *@par Outputs: *y:A tensor of type `float` or `double` that is the absolute value of each element in `x`. @@ -612,10 +617,10 @@ REG_OP(ComplexAbs) *@brief Returns which elements of x are NaN. *@par Inputs: - *x:A Tensor. + *x:A Tensor of type float16, float32, double. *@par Outputs: - *y:A Tensor. Has the same shape as x. + *y:A Tensor. Has the same shape as x. Returns which elements of x are isnan *@par Third-party framework compatibility. *Compatible with tensorflow IsNan operator. @@ -629,7 +634,10 @@ REG_OP(IsNan) *@brief Returns the real part of a complex number. *@par Inputs: - *input:A Tensor. + *input:A Tensor. Must have numeric type. + + *@par Attributes: + *Tout: Type of outputs. \n *@par Outputs: *output:A Tensor. Has the same shape as input. @@ -670,7 +678,8 @@ REG_OP(Conj) *@li weight: A Tensor dtype of float32 . \n *@par Attributes: -*reduction: An optional attribute. Defaults to "mean" . \n +*@li reduction: An optional attribute. Defaults to "mean" . +*@li ignore_index:An optional attribute.Defaults to -100 . \n *@par Outputs: *@li y: A Tensor dtype of float32. @@ -700,7 +709,8 @@ REG_OP(NLLLoss) *@li total_weight:A Tensor dtype of float32 . \n *@par Attributes: -*reduction: An optional attribute. Defaults to "mean" . \n +*@li reduction: An optional attribute. Defaults to "mean" . +*@li ignore_index:An optional attribute.Defaults to -100 . \n *@par Outputs: *x_grad: A Tensor. Must be the following type: float32 . \n @@ -720,24 +730,24 @@ REG_OP(NLLLossGrad) .OP_END_FACTORY_REG(NLLLossGrad) /** -*@brief The ifmr . \n +*@brief IFMR(Input Feature Map Reconstruction). \n *@par Inputs: -*@li data:A Tensor of feature map -*@li data_min:A Tensor of min value of feature map. -*@li data_max:A Tensor of max value of feature map. -*@li cumsum:A Tensor of cumsum bin of data . \n +*@li data: A Tensor of feature map. +*@li data_min: A Tensor of min value of feature map. +*@li data_max: A Tensor of max value of feature map. +*@li cumsum: A Tensor of cumsum bin of data . \n *@par Attributes: -*min_percentile: min init percentile. -*max_percentile: max init percentile. -*search_range: search range. -*search_step: step size of searching. -*with_offset: whether using offset . \n +*@li min_percentile: min init percentile. +*@li max_percentile: max init percentile. +*@li search_range: search range. +*@li search_step: step size of searching. +*@li with_offset: whether using offset . \n *@par Outputs: -*scale: optimal scale. -*offset: optimal offset . \n +*@li scale: optimal scale. +*@li offset: optimal offset . \n *@par Third-party framework compatibility *Compatible with mindspore @@ -758,16 +768,16 @@ REG_OP(IFMR) .OP_END_FACTORY_REG(IFMR) /** -*@brief weights adaptive range quantization. \n +*@brief Weights Adaptive Range Quantization. \n *@par Inputs: -*@li w:A Tensor of weights. \n -*@li w_min:A Tensor of weights reduce_min. \n -*@li w_max:A Tensor of weights reduce_max. \n +*@li w: A Tensor of weights. \n +*@li w_min: A Tensor of weights reduce_min. \n +*@li w_max: A Tensor of weights reduce_max. \n *@par Attributes: -*num_bits: the bits num used for quantize. -*offset_flag: whether using offset. \n +*@li num_bits: the bits num used for quantize. +*@li offset_flag: whether using offset. \n *@par Outputs: *y: fake quantized weights. \n @@ -789,22 +799,22 @@ REG_OP(WtsARQ) .OP_END_FACTORY_REG(WtsARQ) /** -*@brief The acts_ulq. \n +*@brief Activations Universal Linear Quantization. \n *@par Inputs: -*@li x:A Tensor of feature map -*@li clamp _min:A Tensor of min clamp value of feature map. -*@li clamp _max:A Tensor of max clamp value of feature map. +*@li x: A Tensor of feature map. +*@li clamp _min: A Tensor of min clamp value of feature map. +*@li clamp _max: A Tensor of max clamp value of feature map. *@par Attributes: -*fixed_min: fix min to zero. -*num_bits: quant bits. \n +*@li fixed_min: fix min to zero. +*@li num_bits: quant bits. \n *@par Outputs: -*y: output fake quant feature map. -*clamp_min_mask: where x > clamp_min -*clamp_min_mask: where x < clamp_max -*x_clamped_loss: clamp loss. \n +*@li y: output fake quant feature map. +*@li clamp_min_mask: where x > clamp_min. +*@li clamp_min_mask: where x < clamp_max. +*@li x_clamped_loss: clamp loss. \n *@par Third-party framework compatibility *Compatible with mindspore @@ -826,12 +836,12 @@ REG_OP(ActsULQ) .OP_END_FACTORY_REG(ActsULQ) /** -*@brief The acts_ulq_input_grad. \n +*@brief The gradient of Activations Universal Linear Quantization. \n *@par Inputs: -*@li y_grad: A Tensor of gradient -*@li clamp_min_mask: A Tensor of boolean mask indicating whether an additional one is needed' -*@li clamp_max_mask: A Tensor of boolean mask indicating whether an additional one is needed' +*@li y_grad: A Tensor of gradient. +*@li clamp_min_mask: A Tensor of boolean mask indicating whether an additional one is needed'. +*@li clamp_max_mask: A Tensor of boolean mask indicating whether an additional one is needed'. *@par Outputs: *x_grapd: The gradient of inpust. \n @@ -851,10 +861,10 @@ REG_OP(ActsULQInputGrad) .OP_END_FACTORY_REG(ActsULQInputGrad) /** -*@brief The act_ulq_clamp_max_grad. \n +*@brief The gradient of Activations Universal Linear Quantization clamp max. \n *@par Inputs: -*@li y_grad: A Tensor of gradient +*@li y_grad: A Tensor of gradient. *@li clamp_max_mask: A Tensor of boolean mask indicating whether an additional one is needed. *@li x_clamped_loss: A Tensor of gradient. \n @@ -876,10 +886,10 @@ REG_OP(ActULQClampMaxGrad) .OP_END_FACTORY_REG(ActULQClampMaxGrad) /** -*@brief The act_ulq_clamp_min_grad. \n +*@brief The gradient of Activations Universal Linear Quantization clamp min. \n *@par Inputs: -*@li y_grad: A Tensor of gradient +*@li y_grad: A Tensor of gradient. *@li clamp_min_mask: A Tensor of boolean mask indicating whether an additional one is needed. *@li x_clamped_loss: A Tensor of gradient. \n @@ -904,7 +914,7 @@ REG_OP(ActULQClampMinGrad) * @brief Computes Lp norm. * @par Inputs: -* @li x: An ND tensor of type float16, float32. \n +* x: An ND tensor of type float16, float32. \n * * @par Attributes: * @li p: Int, "inf" or "-inf", default value is 2. @@ -913,7 +923,7 @@ REG_OP(ActULQClampMinGrad) * @li epsilon: Float, default is 1e-12. \n * @par Outputs: -* @li y: An ND tensor of type float16, float32. The shape of y is depending +* y: An ND tensor of type float16, float32. The shape of y is depending * on axes and keepdim. \n * @par Third-party framework compatibility @@ -932,11 +942,13 @@ REG_OP(LpNorm) * @brief get complex. * @par Inputs: -* @li real: An ND tensor of type float32. double -* @li imag: An ND tensor of type float32. double \n +* @li real: An ND tensor of type float32 double, representing the real part of a complex number. +* @li imag: An ND tensor of type float32 double, representing the imaginary part of a complex number. \n * +* @par Attributes: +* Tout: representing the output of type. * @par Outputs: -* @li out: An ND tensor of type complex64, complex128 \n +* out: An ND tensor of type complex64, complex128 \n */ REG_OP(Complex) .INPUT(real, TensorType({DT_FLOAT, DT_DOUBLE})) @@ -949,10 +961,13 @@ REG_OP(Complex) * @brief deal complex. * @par Inputs: -* @li input: An ND tensor of type complex64, complex128 \n -* +* input: An ND tensor of type complex64, complex128 \n + +* @par Attributes: +* Tout: representing the output of type. + * @par Outputs: -* @li output: An ND tensor of type float32. double \n +* output: An ND tensor of type float32. double \n */ REG_OP(Imag) .INPUT(input, TensorType({DT_COMPLEX64, DT_COMPLEX128})) @@ -988,7 +1003,7 @@ REG_OP(Angle) * float16, float32. \n *@par Attributes: -* @li reduction: Specifies the reduction to apply to the output: +* reduction: Specifies the reduction to apply to the output: * 'none' | 'mean' | 'sum'. Default: 'mean'. \n *@par Outputs: diff --git a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h index b317be37..81c6a29e 100644 --- a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h @@ -61,21 +61,28 @@ REG_OP(MatMul) *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n *@par Inputs: -*Two inputs, including: -* @li x1: A matrix Tensor. 2D. Must be one of the following types: float16, -* float32, int32. Has format [ND, NHWC, FRACTAL_NZ]. -* @li x2: A matrix Tensor. 2D. Must be one of the following types: float16, -* float32, int32. Has format [ND, NHWC, FRACTAL_NZ]. -* @li bias: A 1D Tensor. Must be one of the following types: float16, -* float32, int32. Has format [ND, NHWC] . \n +*Four inputs, including: +* @li x1: A matrix Tensor. 2D. Must be one of the following types: float32, + float16, int32, int8. Has format [ND, NHWC, FRACTAL_NZ]. +* @li x2: A matrix Tensor. 2D. Must be one of the following types: float32, + float16, int32, int8. Has format [ND, NHWC, FRACTAL_NZ]. +* @li bias: A 1D Tensor. Must be one of the following types: float32, + float16, int32. Has format [ND, NHWC]. +* @li offset_w: A Optional 1D Tensor for quantized inference. Type is int8. + Reserved. \n *@par Attributes: -*@li transpose_x1: A bool. If True, changes the shape of "x1" from [M, K] to [K, M]. -*@li transpose_x2: A bool. If True, changes the shape of "x2" from [M, K] to [K, M] . \n +* @li transpose_x1: A bool. If True, changes the shape of "x1" from [K, M] to + [M, K]. +* @li transpose_x2: A bool. If True, changes the shape of "x2" from [N, K] to +[K, N]. +* @li offset_x: An optional integer for quantized MatMulV2. +* The negative offset added to the input x1 for int8 type. Ensure offset_x + within the effective range of int8 [-128, 127]. Defaults to "0". \n *@par Outputs: -*y: The result matrix Tensor. 2D. Must be one of the following types: float16, -* float32, int32. Has format [ND, NHWC, FRACTAL_NZ] . \n +*y: The result matrix Tensor. 2D. Must be one of the following types: float32, + float16, int32. Has format [ND, NHWC, FRACTAL_NZ]. \n *@par Third-party framework compatibility * Compatible with the TensorFlow operator BatchMatmul. @@ -95,19 +102,27 @@ REG_OP(MatMulV2) *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n *@par Inputs: -*Two inputs, including: +*Five inputs, including: * @li x1: A matrix Tensor. 2D. Must be one of the following types: int8. * @li x2: A matrix Tensor. 2D. Must be one of the following types: int8. * @li compress_index: A compress index matrix of type int8. -* @li bias: A 1D Tensor. Must be one of the following types: int32, float16. +* @li bias: An optional Tensor. 1D. Must be one of the following types: int32, + float16. +* @li offset_w: An optional matrix Tensor. 2D. Must be one of the following + types: int8. \n *@par Attributes: -*@li transpose_x1: A bool. If True, changes the shape of "x1" from [M, K] to [K, M]. -*@li transpose_x2: A bool. If True, changes the shape of "x2" from [M, K] to [K, M] . \n +*@li transpose_x1: A bool. If True, changes the shape of "x1" from [K, M] to + [M, K]. +*@li transpose_x2: A bool. If True, changes the shape of "x2" from [N, K] to + [K, N]. +*@li offset_x: An optional integer for quantized MatMulV2Compress. +*The negative offset added to the input x1 for int8 type. Ensure offset_x + within the effective range of int8 [-128, 127]. Defaults to "0". \n *@par Outputs: -*y: The result matrix Tensor. 2D. Must be one of the following types: float16, -* int32. \n +*y: The result matrix Tensor. 2D. Must be one of the following types: int32, +* float16. \n */ REG_OP(MatMulV2Compress) @@ -488,13 +503,13 @@ REG_OP(ScatterElements) *@par Inputs: * Three inputs, including: -*@li var: An ND Tensor . \n +*@li var: An ND Tensor . *Must be one of the following types: float16, float32, int32, int8, uint8 *@li indices: An ND Tensor of type int32 or int64 -*@li updates: An Tensor. format:NCHW, NHWC . \n +*@li updates: An Tensor. format:NCHW, NHWC . *Must be one of the following types: float16, float32, int32, int8, uint8 @@ -517,6 +532,61 @@ REG_OP(ScatterAdd) .OP_END_FACTORY_REG(ScatterAdd) /** +*@brief Use a scalar to modify the tensor. \n + +*@par Inputs: +*inputs, including: +*@li index: An ND Tensor . \n + +*Must be one of the following types: float16, float32, int32, int8, uint8 + +*@par Attributes: +* dim : the axis along which to index . +* value : the source element(s) to scatter . \n + +*@par Outputs: +*y: A Tensor. Has the same type and format as input "index" . \n + +*@par Third-party framework compatibility +* Compatible with the Pytorch operator ScatterScalar. +*/ +REG_OP(ScatterScalar) + .INPUT(index, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .REQUIRED_ATTR(dim, Int) + .REQUIRED_ATTR(value, Float) + .OP_END_FACTORY_REG(ScatterScalar) + +/** +*@brief Use a tensor to modify the tensor . \n + +*@par Inputs: +* Two inputs, including: +*@li index: An ND Tensor . \n + +*Must be one of the following types: float16, float32, int32, int8, uint8 + +*@li src: An ND Tensor . \n + +*Must be one of the following types: float16, float32, int32, int8, uint8 + +*@par Attributes: +* dim : the axis along which to index . \n + +*@par Outputs: +*y: A Tensor. Has the same type and format as input "index" . \n + +*@par Third-party framework compatibility +* Compatible with the Pytorch operator ScatterTensor. +*/ +REG_OP(ScatterTensor) + .INPUT(index, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .INPUT(src, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .REQUIRED_ATTR(dim, Int) + .OP_END_FACTORY_REG(ScatterTensor) + +/** *@brief Divides a variable reference by sparse updates . \n *@par Inputs: @@ -530,7 +600,7 @@ REG_OP(ScatterAdd) *Must be one of the following types: float16, float, int32, int8, uint8 *@par Attributes: -*@li use_locking: An optional bool. Defaults to "False". If "True", +*use_locking: An optional bool. Defaults to "False". If "True", * the operation will be protected by a lock . \n *@par Outputs: @@ -752,10 +822,12 @@ REG_OP(DiagPart) *@par Attributes: *@li num_output: Reserved. -*@li transpose: A bool, specifying weight whether to transpose, either "true" or "false". Defaults to "false". +*@li transpose: A bool, specifying weight whether to transpose input w, either "true" or "false". Defaults to "false". *@li axis: Optional. A int, 1 or 2, specifying which dimension the input "K" starts from. Defaults to 1. * The product of the subsequent dimensions starting form first dimension or the second dimension is "K". -*@li offset_x: Reserved . \n +*@li offset_x: An optional integer for quantized FullyConnection. +*The negative offset added to the input image for int8 type. Ensure offset_x within the +*effective range of int8 [-128, 127]. Defaults to "0". \n *@par Outputs: *y: The result tensor of type float16, int32, float32 . \n @@ -779,27 +851,34 @@ REG_OP(FullyConnection) .OP_END_FACTORY_REG(FullyConnection) /** -*@brief Also known as a "fully-connected-compress" layer, computes an inner product with a set of learned weights, and (optionally) adds biases . \n +*@brief Also known as a "fully-connected-compress" layer, computes an inner +product with a set of learned weights, and (optionally) adds biases . \n *@par Inputs: -* Four inputs, including: +* Five inputs, including: *@li x: A Tensor of type uint8, int8. -*@li w: A weight matrix of type int8, int8. -*@li w: A compress index matrix of type int8, int8. -*@li b: A Tensor of type float16, int32, int32. -*@li offset_w: A Tensor of type int8.i +*@li w: A weight matrix of type int8. +*@li compress_index: A compress index matrix of type int8. +*@li b: A Tensor of type int32. +*@li offset_w: A Tensor of type int8. *@par Attributes: -*@li num_output: Reserved. -*@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false". -*@li axis: Reserved. -*@li offset_x: Reserved . \n +*@li num_output: A int, specifying the number of outputs. +*@li transpose: A bool, specifying whether to transpose input w, either "true" + or "false". Defaults to "false". +*@li axis: Optional. A int, 1 or 2, specifying which dimension the input "K" +starts from. Defaults to "1". +* The product of the subsequent dimensions starting form first dimension or the +second dimension is "K". +*@li offset_x: An optional integer for quantized FullyConnectionCompress. +*The negative offset added to the input image for int8 type. Ensure offset_x +within the effective range of int8 [-128, 127]. Defaults to "0". \n *@par Outputs: -*y: The result tensor of type int32 . \n +*y: The result tensor of type int32. \n *@par Third-party framework compatibility -* Compatible with the Caffe operator InnerProduct . \n +* Compatible with the Caffe operator InnerProduct. \n *@par Quantization supported or not * Yes @@ -925,13 +1004,13 @@ REG_OP(ScatterMin) *@par Inputs: * Three inputs, including: -*@li var: An ND Tensor . \n +*@li var: An ND Tensor . *Must be one of the following types: float16, float, int32, int8, uint8 *@li indices: An NCHW, NHWC, or ND Tensor . \n *Must be one of the following types: int32 or int64 -*@li updates: An NCHW, NHWC, or ND Tensor . \n +*@li updates: An NCHW, NHWC, or ND Tensor . *Must be one of the following types: float16, float, int32, int8, uint8 @@ -958,13 +1037,13 @@ REG_OP(ScatterMax) *@par Inputs: * Three inputs, including: -*@li var: An ND Tensor . \n +*@li var: An ND Tensor . *Must be one of the following types: float16, float, int32, int8, uint8 *@li indices: An ND Tensor . \n *Must be one of the following types: int32 or int64 -*@li updates: An ND Tensor . \n +*@li updates: An ND Tensor . *Must be one of the following types: float16, float, int32, int8, uint8 @@ -1113,14 +1192,46 @@ REG_OP(IndexAdd) .OP_END_FACTORY_REG(IndexAdd) /** +* @brief According to the index number of indexes, replace the value +*corresponding to X1 with the value in x2. + +* @par Inputs: +* Three inputs, including: +* @li x1: A Tensor. Must be one of the following types: +* float16, float32, int32, int8, uint8. +* @li x2: A Tensor of the same type as "x1". +* @li indices: A Tensor of the indices, type should be int32. + +* @par Attributes: +* @li accumulate: Does it support self accumulation.Defaults to 0. + +* @par Outputs: +* @li y: A Tensor. Same as input "x1". + +* @par Third-party framework compatibility +* Compatible with the Pytorch operator index_put. + +* @par Restrictions: +* Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(IndexPut) + .INPUT(x1, TensorType({DT_INT64, DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16})) + .INPUT(x2, TensorType({DT_INT64, DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16})) + .INPUT(indices, TensorType({DT_INT64, DT_INT32})) + .OUTPUT(y, TensorType({DT_INT64, DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16})) + .ATTR(accumulate, Int, 0) + .OP_END_FACTORY_REG(IndexPut) + +/** *@brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n *@par Inputs: -* Two inputs, including: -*@li x: A Tensor. Must be one of the following types: -* float16, float32, double, int32, uint8, int16, int8, complex64, int64, -* qint8, quint8, qint32, uint16, complex128, uint32, uint64. -*@li diagonal:(int, optional) – the diagonal to consider。\n +*x: A Tensor. Must be one of the following types: +*float16, float32, double, int32, uint8, int16, int8, complex64, int64, +*qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n + +*@par Attributes: +*diagonal: An optional attribute indicates the diagonal to consider. \n *@par Outputs: *y: A Tensor. Has the same type as "x" . \n @@ -1138,11 +1249,12 @@ REG_OP(Triu) *@brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n *@par Inputs: -* Two inputs, including: -*@li x: A Tensor. Must be one of the following types: -* float16, float32, double, int32, uint8, int16, int8, complex64, int64, -* qint8, quint8, qint32, uint16, complex128, uint32, uint64. -*@li diagonal:(int, optional) – the diagonal to consider。\n +*x: A Tensor. Must be one of the following types: +*float16, float32, double, int32, uint8, int16, int8, complex64, int64, +*qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n + +*@par Attributes: +*diagonal: An optional attribute indicates the diagonal to consider. \n *@par Outputs: *y: A Tensor. Has the same type as "x" . \n @@ -1213,6 +1325,30 @@ REG_OP(Eye) .ATTR(dtype, Int, 0) .OP_END_FACTORY_REG(Eye) +/** +*@brief: Fill diagonal of at least 2 dimension tensors with value . \n + +*@par Inputs: +*x: A Tensor. Must be one of the following types: +* float32, int32, int64 . \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x" . \n + +*@par Attributes: +*fill_value:The value to fill in +*wrap: An optional bool. Defaults to "False". If "True", Use recursive fill. \n + +*@par Third-party framework compatibility +* Compatible with the Pytorch operator FillDiagonal. +*/ +REG_OP(FillDiagonal) + .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64})) + .REQUIRED_ATTR(fill_value, Float) + .ATTR(wrap, Bool, false) + .OP_END_FACTORY_REG(FillDiagonal) + } // namespace ge #endif // OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h index 98473c65..a55cebe2 100644 --- a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h @@ -195,7 +195,7 @@ REG_OP(DepthwiseConv2DBackpropInput) .INPUT(input_size, TensorType({DT_INT32, DT_INT64})) .INPUT(filter, TensorType({DT_FLOAT16})) .INPUT(out_backprop, TensorType({DT_FLOAT16})) - .OUTPUT(input_grad, TensorType({DT_FLOAT16})) + .OUTPUT(input_grad, TensorType({DT_FLOAT16, DT_FLOAT32})) .REQUIRED_ATTR(strides, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1}) .REQUIRED_ATTR(pads, ListInt) @@ -255,7 +255,7 @@ REG_OP(DepthwiseConv2DBackpropInput) REG_OP(DepthwiseConv2DBackpropInputD) .INPUT(filter, TensorType({DT_FLOAT16})) .INPUT(out_backprop, TensorType({DT_FLOAT16})) - .OUTPUT(input_grad, TensorType({DT_FLOAT16})) + .OUTPUT(input_grad, TensorType({DT_FLOAT16, DT_FLOAT32})) .REQUIRED_ATTR(input_size, ListInt) .REQUIRED_ATTR(strides, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1}) @@ -367,19 +367,19 @@ REG_OP(BiasAddGrad) * Gradients with respect to the output of the convolution. *\n *\n - * The following are the supported data types and data formats: -*@verbatim - | Tensor | out_bckprop | filter | y - ------------|-------------|---------|-------- - | Data Type | float16 | float16 | float16 - | |-------------|---------|-------- - | | float32 | float32 | float32 - | |-------------|---------|-------- - | | float64 | float64 | float64 - ------------|-------------|---------|-------- - | Format | NCHW | NCHW | NCHW - | | NHWC | HWCN | NHWC -@endverbatim + * The following are the supported data types and data formats:\n + *\n + | Tensor | out_bckprop | filter | y\n + ------------|-------------|---------|--------\n + | Data Type | float16 | float16 | float16\n + | |-------------|---------|--------\n + | | float32 | float32 | float32\n + | |-------------|---------|--------\n + | | float64 | float64 | float64\n + ------------|-------------|---------|--------\n + | Format | NCHW | NCHW | NCHW\n + | | NHWC | HWCN | NHWC\n + *\n * For float32 and float64 type, the actual calculation on the chip is based on * float16. *\n @@ -398,36 +398,37 @@ REG_OP(BiasAddGrad) * "NHWC". Specify the data format of the input and output data. *\n *\n - * The following value range restrictions must be met: -*@verbatim - | Name | Field | Scope - -------------------|----------|-------------- - | input_size | H | [1, 4096] - | | W | [1, 4096] - -------------------|----------|-------------- - | Filter | H | [1, 255] - | | W | [1, 255] - -------------------|----------|-------------- - | out_backprop | H*strideH| [1, 4096] - | | W*strideW| [1, 4096] - -------------------|----------|-------------- - | y(fmap) | H | [1, 4096] - | | W | [1, 4096] - -------------------|----------|-------------- - | Stride | H | [1, 63] - | | W | [1, 63] - -------------------|----------|-------------- - | Padding | Top | [0, 255] - | | Bottom | [0, 255] - | | Left | [0, 255] - | | Right | [0, 255] - -------------------|----------|-------------- - | Dilation | H | [1, 255] - | | W | [1, 255] + * The following value range restrictions must be met:\n + *\n + | Name | Field | Scope\n + -------------------|----------|--------------\n + | input_size | H | [1, 200000]\n + | | W | [1, 4096]\n + -------------------|----------|--------------\n + | Filter | H | [1, 255]\n + | | W | [1, 255]\n + -------------------|----------|--------------\n + | out_backprop | H*strideH| [1, 200000]\n + | | W*strideW| [1, 4096]\n + -------------------|----------|--------------\n + | y(fmap) | H | [1, 200000]\n + | | W | [1, 4096]\n + -------------------|----------|--------------\n + | Stride | H | [1, 63]\n + | | W | [1, 63]\n + -------------------|----------|--------------\n + | Padding | Top | [0, 255]\n + | | Bottom | [0, 255]\n + | | Left | [0, 255]\n + | | Right | [0, 255]\n + -------------------|----------|--------------\n + | Dilation | H | [1, 255]\n + | | W | [1, 255]\n + *\n -@endverbatim * In Ascend910, fmap or out_backprop's H and W not support 1 when * fmap_h + pad_top + pad_bottom != (filter_height - 1) * dilation_h + 1 + * and filter_width > fmap_width * If filter_h = 1 and filter_w = 1, out_backprop_w * stride_h * stride_w < 4096 *\n * @@ -496,7 +497,7 @@ REG_OP(Conv2DBackpropInput) REG_OP(Conv2DBackpropInputD) .INPUT(filter, TensorType({DT_FLOAT16, DT_INT8})) .INPUT(out_backprop, TensorType({DT_FLOAT16, DT_INT8})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32, DT_FLOAT32})) .REQUIRED_ATTR(input_size, ListInt) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) @@ -508,7 +509,7 @@ REG_OP(Conv2DBackpropInputD) /** *@brief Computes the Deconvolution with respect to the input. *@par Inputs: - * Three inputs: + * Two required inputs: * @li x: A Tensor of type float16 or int8. 4D with shape * [batch, out_channels, out_height, out_width]. Gradients with respect * to the output of the convolution. @@ -520,16 +521,16 @@ REG_OP(Conv2DBackpropInputD) * Type is int8. Reserved.\n *\n *\n - * The following are the supported data types and data formats: -*@verbatim - | Tensor | x | filter | bias | y - ------------|---------|---------|---------|-------- - | Data Type | float16 | float16 | float16 | float16 - | |---------|---------|---------|-------- - | | int8 | int8 | int32 | int32 - ------------|---------|---------|---------|-------- - | Format | NCHW | NCHW | ND | NCHW -@endverbatim + * The following are the supported data types and data formats:\n + *\n + | Tensor | x | filter | bias | y\n + ------------|---------|---------|---------|--------\n + | Data Type | float16 | float16 | float16 | float16\n + | |---------|---------|---------|--------\n + | | int8 | int8 | int32 | int32\n + ------------|---------|---------|---------|--------\n + | Format | NCHW | NCHW | ND | NCHW\n + *\n * For int8, a dequant or requant operator must be followed. *\n * @@ -550,35 +551,35 @@ REG_OP(Conv2DBackpropInputD) * within the effective range of int8 [-128, 127]. Defaults to "0". *\n *\n - * The following value range restrictions must be met: -*@verbatim - | Name | Field | Scope - -------------------|----------|-------------- - | x (out_backprop) | H*strideH| [1, 4096] - | | W*strideW| [1, 4096] - -------------------|----------|-------------- - | Filter | H | [1, 255] - | | W | [1, 255] - -------------------|----------|-------------- - | y (fmap) | H | [1, 4096] - | | W | [1, 4096] - -------------------|----------|-------------- - | Stride | H | [1, 63] - | | W | [1, 63] - -------------------|----------|-------------- - | Padding | Top | [0, 255] - | | Bottom | [0, 255] - | | Left | [0, 255] - | | Right | [0, 255] - -------------------|----------|-------------- - | Dilation | H | [1, 255] - | | W | [1, 255] - -------------------|----------|-------------- - | Offset_x | | [-128, 127] - -@endverbatim + * The following value range restrictions must be met:\n + *\n + | Name | Field | Scope\n + -------------------|----------|--------------\n + | x (out_backprop) | H*strideH| [1, 200000]\n + | | W*strideW| [1, 4096]\n + -------------------|----------|--------------\n + | Filter | H | [1, 255]\n + | | W | [1, 255]\n + -------------------|----------|--------------\n + | y (fmap) | H | [1, 200000]\n + | | W | [1, 4096]\n + -------------------|----------|--------------\n + | Stride | H | [1, 63]\n + | | W | [1, 63]\n + -------------------|----------|--------------\n + | Padding | Top | [0, 255]\n + | | Bottom | [0, 255]\n + | | Left | [0, 255]\n + | | Right | [0, 255]\n + -------------------|----------|--------------\n + | Dilation | H | [1, 255]\n + | | W | [1, 255]\n + -------------------|----------|--------------\n + | Offset_x | | [-128, 127]\n + *\n * In Ascend910, fmap or out_backprop's H and W not support 1 when * fmap_h + pad_top + pad_bottom != (filter_height - 1) * dilation_h + 1 + * and filter_width > fmap_width * If filter_h = 1 and filter_w = 1, out_backprop_w * stride_h * stride_w < 4096 *\n * @@ -628,19 +629,19 @@ REG_OP(Deconvolution) * convolution. *\n *\n - * The following are the supported data types and data formats: -*@verbatim - | Tensor | x | out_backprop | y - ------------|---------|--------------|--------- - | Data Type | float16 | float16 | float16 - | |---------|--------------|--------- - | | float32 | float32 | float32 - | |---------|--------------|--------- - | | float64 | float64 | float64 - |-----------|---------|--------------|--------- - | Format | NCHW | NCHW | NCHW - | | NHWC | NHWC | HWCN -@endverbatim + * The following are the supported data types and data formats:\n + *\n + | Tensor | x | out_backprop | y\n + ------------|---------|--------------|---------\n + | Data Type | float16 | float16 | float16\n + | |---------|--------------|---------\n + | | float32 | float32 | float32\n + | |---------|--------------|---------\n + | | float64 | float64 | float64\n + |-----------|---------|--------------|---------\n + | Format | NCHW | NCHW | NCHW\n + | | NHWC | NHWC | HWCN\n + *\n * For float32 and float64 type of x and outbackprop, the actual calculation on the chip * is based on float16. *\n @@ -658,39 +659,34 @@ REG_OP(Deconvolution) * @li data_format: An optional string from: "NHWC", "NCHW". Defaults to * "NHWC". Specify the data format of the input and output data. *\n -*\n -* The following value range restrictions must be met: -*@verbatim - | Name | Field | Scope - -------------------|----------|-------------- - | x(fmap) | H | [1, 4096] - | | W | [1, 4096] - -------------------|----------|-------------- - | Filter Size | H | [1, 255] - | | W | [1, 255] - -------------------|----------|-------------- - | out_backprop | H | [1, 4096] - | | W | [1, 4096] - -------------------|----------|-------------- - | y | H | [1, 4096] - | | W | [1, 4096] - -------------------|----------|-------------- - | Stride | H | [1, 63] - | | W | [1, 63] - -------------------|----------|-------------- - | Padding | Top | [0, 255] - | | Bottom | [0, 255] - | | Left | [0, 255] - | | Right | [0, 255] - -------------------|----------|-------------- - | Dilation | H | [1, 255] - | | W | [1, 255] - -@endverbatim - * In Ascend910, out_backprop's H and W not support 1 when - * fmap_h + pad_top + pad_bottom != (filter_height - 1) * dilation_h + 1 *\n - * + * The following value range restrictions must be met:\n + *\n + | Name | Field | Scope\n + -------------------|----------|--------------\n + | x(fmap) | H | [1, 200000]\n + | | W | [1, 4096]\n + -------------------|----------|--------------\n + | Filter Size | H | [1, 255]\n + | | W | [1, 255]\n + -------------------|----------|--------------\n + | out_backprop | H | [1, 200000]\n + | | W | [1, 4096]\n + -------------------|----------|--------------\n + | y | H | [1, 200000]\n + | | W | [1, 4096]\n + -------------------|----------|--------------\n + | Stride | H | [1, 63]\n + | | W | [1, 63]\n + -------------------|----------|--------------\n + | Padding | Top | [0, 255]\n + | | Bottom | [0, 255]\n + | | Left | [0, 255]\n + | | Right | [0, 255]\n + -------------------|----------|--------------\n + | Dilation | H | [1, 255]\n + | | W | [1, 255]\n + *\n *@par Outputs: * y: A Tensor. Has the same type as x, has the same format as filter_size. *\n @@ -780,16 +776,16 @@ REG_OP(Conv2DBackpropFilterD) *\n *\n * The following are the supported data types and data formats: -*@verbatim - | Tensor | x | filter | bias | y - ------------|---------|---------|---------|-------- - | Data Type | float16 | float16 | float16 | float16 - | | float32 | float32 | float32 | float32 - | | int8 | int8 | int32 | int32 - ------------|---------|---------|---------|-------- - | Format | NCHW | NCHW | ND | NCHW - | | NHWC | HWCN | | NHWC -@endverbatim +*\n +*\n +| Tensor | x | filter | bias | y |\n +| :-------: | :-----: | :-----: | :-----: | :-----: |\n +| Data Type | float16 | float16 | float16 | float16 |\n +| | float32 | float32 | float32 | float32 |\n +| | int8 | int8 | int32 | int32 |\n +| Format | NCHW | NCHW | ND | NCHW |\n +| | NHWC | HWCN | | NHWC |\n +*\n * For float32 type, the actual calculation on the chip is based on * float16. *\n @@ -813,35 +809,30 @@ REG_OP(Conv2DBackpropFilterD) *\n *\n * The following value range restrictions must be met: -*@verbatim - | Name | Field | Scope - -------------------|----------|-------------- - | Input Image Size | H | [1, 100000] - | | W | [1, 4096] - -------------------|----------|-------------- - | Filter Size | H | [1, 255] - | | W | [1, 255] - -------------------|----------|-------------- - | Stride | H | [1, 63] - | | W | [1, 63] - -------------------|----------|-------------- - | Padding | Top | [0, 255] - | | Bottom | [0, 255] - | | Left | [0, 255] - | | Right | [0, 255] - -------------------|----------|-------------- - | Dilation | H | [1, 255] - | | W | [1, 255] - -------------------|----------|-------------- - | Offset_x | | [-128, 127] - -@endverbatim +*\n +*\n +| Name | Field | Scope |\n +| :--------------: | :------: | :---------: |\n +| Input Image Size | H | [1, 100000] |\n +| | W | [1, 4096] |\n +| Filter Size | H | [1, 255] |\n +| | W | [1, 255] |\n +| Stride | H | [1, 63] |\n +| | W | [1, 63] |\n +| Padding | Top | [0, 255] |\n +| | Bottom | [0, 255] |\n +| | Left | [0, 255] |\n +| | Right | [0, 255] |\n +| Dilation | H | [1, 255] |\n +| | W | [1, 255] |\n +| Offset_x | - | [-128, 127] |\n +*\n * The W dimension of the input image supports cases exceeding 4096, but it may * cause compilation errors. *\n * *@par Outputs: -*@li y: A 4D Tensor of output feature map. Has the same type as "x". With the +* y: A 4D Tensor of output feature map. Has the same type as "x". With the * format "NHWC", the data is stored in the order of: [batch, out_height, * out_width, out_channels]. *\n @@ -956,16 +947,15 @@ REG_OP(Conv2DCompress) *\n *\n * The following are the supported data types and data formats: -*@verbatim - | Tensor | x | filter | offsets | bias | y - ------------|---------|---------|---------|----------|-------- - | Data Type | float16 | float16 | float16 | float16 | float16 - | |---------|---------|---------|----------|-------- - | | float32 | float32 | float32 | float32 | float32 - ------------|---------|---------|---------|----------|-------- - | Format | NCHW | NCHW | NCHW | ND | NCHW - | | NHWC | HWCN | NHWC | | NHWC -@endverbatim +*\n +*\n +| Tensor | x | filter | offsets | bias | y |\n +| :-------: | :-----: | :-----: | :-----: | :-----: | :-----: |\n +| Data Type | float16 | float16 | float16 | float16 | float16 |\n +| | float32 | float32 | float32 | float32 | float32 |\n +| Format | NCHW | NCHW | NCHW | ND | NCHW |\n +| | NHWC | HWCN | NCHW | | NHWC |\n +*\n * For float32 type, the actual convolution calculation part on the chip is * based on float16. *\n @@ -992,19 +982,18 @@ REG_OP(Conv2DCompress) *\n *\n * The following value range restrictions must be met: -*@verbatim - | Name | Field | Scope - --------------------|--------|---------------------------- - | Input Image Size | H | [1, 100000 / filter_height] - | | W | [1, 4096 / filter_width] - --------------------|--------|---------------------------- - | Filter Size | H | [1, 63] - | | W | [1, 63] -@endverbatim +*\n +*\n +| Name | Field | Scope |\n +| :--------------: | :------: | :-------------------------: |\n +| Input Image Size | H | [1, 100000 / filter_height] |\n +| | W | [1, 4096 / filter_width] |\n +| Filter Size | H | [1, 63] |\n +| | W | [1, 63] |\n *\n * *@par Outputs: -*@li y: A 4D Tensor of output feature map. Has the same type as "x". With the +* y: A 4D Tensor of output feature map. Has the same type as "x". With the * format "NHWC", the data is stored in the order of: [batch, out_height, * out_width, out_channels]. *\n @@ -1042,41 +1031,38 @@ REG_OP(DeformableConv2D) /** *@brief Computes a 3D convolution given 5D "x" and "filter" tensors. - *@par Inputs: + +*@par Inputs: * @li x: A 5D tensor. Must be one of the following types: float16, * (Currently does not support int8). The format of x is NCDHW or NDHWC. * @li filter: A 5D tensor of the same type as "x". * (Currently does not support int8). - * The format is NCDHW, NDHWC or DHWCN . \n - -*@par Optional input: - * @li bias: An optional 1D tensor of the same type as "x". - * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved . \n + * The format is NCDHW, NDHWC or DHWCN. + * @li bias: Optional. An 1D tensor of the same type as "x". + * @li offset_w: Optional. An 1D tensor for quantized deconvolution. Reserved. \n -*@par Required Attributes: - * @li strides: A list of 5 integers. Specifies the stride of the sliding window +*@par Attributes: + * @li strides: Required. A list of 5 integers. Specifies the stride of the sliding window * for each dimension of "x". * The N and C dimensions must be 1. Has the same format as "x". - * @li pads: A list of 6 integers. + * @li pads: Required. A list of 6 integers. * Supports only padding along the D, H and W dimensions in sequence of head, - * tail, top, bottom, left and right . \n - -*@par Attributes: - * @li groups: Number of blocked connections from input channels to output + * tail, top, bottom, left and right. + * @li dilations: Optional. A list of 5 integers. Specifies the dilation factor for each + * dimension of "x". + * @li groups: Optional. Number of blocked connections from input channels to output * channels. - * @li data_format: An optional string from: "NDHWC", "NCDHW". + * @li data_format: Optional. An string from: "NDHWC", "NCDHW". * Defaults to "NDHWC". Specify the data format of the input and output data. - * @li dilations: A list of 5 integers. Specifies the dilation factor for each - * dimension of "x". * The N, C and D dimensions must be 1. Has the same format as "x". - * @li offset_x: An optional int. Input offset, used for quantized inference. - * Defaults to 0. Reserved . \n + * @li offset_x: Optional. An int. Input offset, used for quantized inference. + * Defaults to 0. Reserved. \n *@par Outputs: - *y: A Tensor. Has the same type and data format as "x". \n + * y: A Tensor. Has the same type and data format as "x". \n *@attention Constraints: - *The image size after padding is greater than the filter size . \n + * The image size after padding is greater than the filter size. \n *@par Third-party framework compatibility * @li Compatible with the TensorFlow operator conv3d. @@ -1085,9 +1071,9 @@ REG_OP(DeformableConv2D) REG_OP(Conv3D) .INPUT(x, TensorType({DT_FLOAT16})) .INPUT(filter, TensorType({DT_FLOAT16})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT32})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) - .OUTPUT(y, TensorType({DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32})) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) @@ -1099,8 +1085,8 @@ REG_OP(Conv3D) /** *@brief Computes the gradients of convolution 3d with respect to the input. + *@par Inputs: - * Three inputs: * @li input_size: A Tensor of type int32, int64. An integer vector representing * the shape of input, where input is a 5-D tensor * [batch, depth, height, width, channels] or @@ -1110,28 +1096,25 @@ REG_OP(Conv3D) * @li out_backprop: A Tensor. Must have the same type as filter. * 5-D with shape [batch, depth, out_height, out_width, out_channels] * or [batch, out_channels, depth, out_height, out_width]. Gradients with - * respect to the output of the convolution . \n + * respect to the output of the convolution. \n -*@par Required Attributes: - * @li strides: A list of 5 integers. Specifies the stride of the sliding window +*@par Attributes: + * @li strides: Required. A list of 5 integers. Specifies the stride of the sliding window * for each dimension of "out_backprop". * The N and C dimensions must be 1. Has the same format as "out_backprop". - * @li pads: A list of 6 integers. + * @li pads: Required. A list of 6 integers. * Supports only padding along the D, H and W dimensions in sequence of head, - * tail, top, bottom, left and right . \n - -*@par Attributes: - * Three attributes: - * @li groups: Number of blocked connections from input channels to output - * channels. - * @li data_format: An optional string from: "NDHWC", "NCDHW". - * Defaults to "NDHWC". Specify the data format of the input and output data. - * @li dilations: A tuple/list of 5 integers, The dilation factor for each + * tail, top, bottom, left and right. + * @li dilations: Optional. A tuple/list of 5 integers, The dilation factor for each * dimension of the input. * The N, C and D dimensions must be 1. Has the same format as "out_backprop". + * @li groups: Optional. Number of blocked connections from input channels to output + * channels. + * @li data_format: Optional. An string from: "NDHWC", "NCDHW". + * Defaults to "NDHWC". Specify the data format of the input and output data. \n *@par Outputs: - * y: A Tensor. Has the same type as filter,and has same format as "input_size" + * y: A Tensor. Has the same type as filter,and has same format as "input_size". \n *@par Third-party framework compatibility * Compatible with Tensorflow's conv3d_backprop_input @@ -1150,45 +1133,44 @@ REG_OP(Conv3DBackpropInput) /** *@brief Computes the gradients of convolution 3d with respect to the input. + *@par Inputs: - * Two inputs: * @li filter: A Tensor whose type is float16. The format of filter is NCDHW, * NDHWC or DHWCN. * @li out_backprop: A Tensor. Must have the same type as filter. The format is - * NDHWC or NCDHW. \n + * NDHWC or NCDHW. \n -*@par Required Attributes: - * @li strides: A list of 5 integers. Specifies the stride of the sliding window +*@par Attributes: + * @li input_size: Required. A tuple/list of type int32, int64. An integer vector + * representing the shape of input, where input is a 5-D tensor + * [batch, depth, height, width, channels] or + * [batch, channels, depth, height, width]. + * @li strides: Required. A list of 5 integers. Specifies the stride of the sliding window * for each dimension of "out_backprop". * The N and C dimensions must be 1. Has the same format as "out_backprop". - * @li pads: A list of 6 integers. Supports only padding along the D, H and W + * @li pads: Required. A list of 6 integers. Supports only padding along the D, H and W * dimensions in sequence of head, tail, top, bottom, left and right. - * @li input_size: A tuple/list of type int32, int64. An integer vector - * representing the shape of input, where input is a 5-D tensor - * [batch, depth, height, width, channels] or - * [batch, channels, depth, height, width] . \n - -*@par Attributes: - * Three attributes: - * @li groups: Number of blocked connections from input channels to output - * channels. - * @li data_format: An optional string from: "NDHWC", "NCDHW". - * Defaults to "NDHWC". Specify the data format of the input and output data. - * @li dilations: A tuple/list of 5 integers, The dilation factor for each + * @li dilations: Optional. A tuple/list of 5 integers, The dilation factor for each * dimension of input. * The N, C and D dimensions must be 1. Has the same format as "out_backprop". + * @li groups: Optional. Number of blocked connections from input channels to output + * channels. + * @li data_format: Optional. An string from: "NDHWC", "NCDHW". + * Defaults to "NDHWC". Specify the data format of the input and output data. \n + *@par Outputs: - * y: A Tensor. Has the same type and data format as "out_backprop". + * y: A Tensor. Has the same type and data format as "out_backprop". \n + *@par Third-party framework compatibility - * Compatible with Tensorflow's conv3d_backprop_input + * Compatible with Tensorflow's conv3d_backprop_input. \n *@par Restrictions: -* Warning: THIS FUNCTION IS DEPRECATED. Please use Conv3DBackpropInput instead. + * Warning: THIS FUNCTION IS DEPRECATED. Please use Conv3DBackpropInput instead. */ REG_OP(Conv3DBackpropInputD) .INPUT(filter, TensorType({DT_FLOAT16})) .INPUT(out_backprop, TensorType({DT_FLOAT16})) - .OUTPUT(y, TensorType({DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32})) .REQUIRED_ATTR(input_size, ListInt) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) @@ -1242,8 +1224,8 @@ REG_OP(LSTM) /** *@brief Computes the gradients of convolution3D with respect to the filter + *@par Inputs: - * Three inputs: * @li x: A Tensor. Must be one of the following types: float16, float32. * Currently does not support double. * 5-D with shape [batch, in_depth, in_height, in_width, in_channels] @@ -1258,26 +1240,23 @@ REG_OP(LSTM) * or [batch, out_channels, out_depth, out_height, out_width]. * Gradients with respect to the output of the convolution. \n -*@par Required Attributes: - * @li strides: A tuple/list of 5 integers. Specifies the stride of the sliding +*@par Attributes: + * @li strides: Required. A tuple/list of 5 integers. Specifies the stride of the sliding * window for each dimension of "x". The N and C dimensions must be 1. * Has the same format as "x". - * @li pads: A tuple/list of 6 integers, [front, back, top, bottom, left, right] - * pads on feature map . \n - -*@par Attributes: - * Three attributes: - * @li dilations: A tuple/list of 5 integers, The dilation factor for each + * @li pads: Required. A tuple/list of 6 integers, [front, back, top, bottom, left, right] + * pads on feature map. + * @li dilations: Optional. A tuple/list of 5 integers, The dilation factor for each * dimension of input. * The N, C and D dimensions must be 1. Has the same format as "x". - * @li groups: Number of blocked connections from input channels to output + * @li groups: Optional. Number of blocked connections from input channels to output * channels. - * @li data_format: An optional string from: "NDHWC", "NCDHW". - * Defaults to "NDHWC". Specify the data format of the input and output data. + * @li data_format: Optional. An string from: "NDHWC", "NCDHW". + * Defaults to "NDHWC". Specify the data format of the input and output data. \n *@par Outputs: - * y: A Tensor that has the same type as "x" - * and the format is NDHWC, NCDHW or DHWCN. + * y: A Tensor that has the same type as "x" and the format is NDHWC, NCDHW or DHWCN. \n + *@par Third-party framework compatibility * Compatible with Tensorflow's conv3d_backprop_filter */ @@ -1295,8 +1274,8 @@ REG_OP(Conv3DBackpropFilter) /** *@brief Computes the gradients of convolution with respect to the filter. + *@par Inputs: - * Two inputs: * @li x: A Tensor of type float16. * 5-D with shape [batch, in_depth, in_height, in_width, in_channels] * or [batch, in_channels, in_depth, in_height, in_width]. @@ -1305,37 +1284,34 @@ REG_OP(Conv3DBackpropFilter) * or [batch, out_channels, out_depth, out_height, out_width]. * Gradients with respect to the output of the convolution. \n -*@par Required Attributes: - * @li filter_size: A tuple/list of type integers. An integer vector +*@par Attributes: + * @li filter_size: Required. A tuple/list of type integers. An integer vector * representing the tensor shape of filter, where filter is a 5-D tensor * [filter_depth, filter_height, filter_width, in_channels, out_channels], * [out_channels, filter_depth, filter_height, filter_width, in_channels] * or [out_channels, in_channels, filter_depth, filter_height, filter_width]. - * @li strides: A tuple/list of 5 integers. Specifies the stride of the sliding + * @li strides: Required. A tuple/list of 5 integers. Specifies the stride of the sliding * window for each dimension of "x". * The N and C dimensions must be 1. Has the same format as "x". - * @li pads: A tuple/list of 6 integers, [front, back, top, bottom, left, right] - * pads on feature map. \n - -*@par Attributes: - * Three attributes: - * @li dilations: A tuple/list of 5 integers, The dilation factor for each + * @li pads: Required. A tuple/list of 6 integers, [front, back, top, bottom, left, right] + * pads on feature map. + * @li dilations: Optional. A tuple/list of 5 integers, The dilation factor for each * dimension of input. * The N, C and D dimensions must be 1. Has the same format as "x". - * @li groups: Number of blocked connections from input channels to output + * @li groups: Optional. Number of blocked connections from input channels to output * channels. - * @li data_format: An optional string from: "NDHWC", "NCDHW". - * Defaults to "NDHWC". Specify the data format of the input and output data. + * @li data_format: Optional. An optional string from: "NDHWC", "NCDHW". + * Defaults to "NDHWC". Specify the data format of the input and output data. \n *@par Outputs: - * y: A Tensor of type float32 and the format is NDHWC, NCDHW or DHWCN. + * y: A Tensor of type float32 and the format is NDHWC, NCDHW or DHWCN. \n + *@par Third-party framework compatibility - * Compatible with Tensorflow's conv3d_backprop_filter + * Compatible with Tensorflow's conv3d_backprop_filter. \n + *@par Restrictions: -* Warning: THIS FUNCTION IS DEPRECATED. Please use Conv3DBackpropFilter instead. + * Warning: THIS FUNCTION IS DEPRECATED. Please use Conv3DBackpropFilter instead. */ - - REG_OP(Conv3DBackpropFilterD) .INPUT(x, TensorType({DT_FLOAT16})) .INPUT(out_backprop, TensorType({DT_FLOAT16})) @@ -1350,37 +1326,32 @@ REG_OP(Conv3DBackpropFilterD) /** *@brief Computes the transpose of convolution 3d with respect to the input. + *@par Inputs: - * Three inputs: * @li input_size: A Tensor of type int32. An integer vector representing the * shape of input. * @li x: A Tensor of type float16, currently does not support int8. The format * is NDHWC or NCDHW. * @li filter: A Tensor of type float16, currently does not support int8. * The format is NDHWC, NCDHW or DHWCN. + * @li bias: Optional. An optional 1D tensor of the same type as "x". Reserved. + * @li offset_w: Optional. An optional 1D tensor for quantized deconvolution. Reserved. \n -*@par Optional input: - * Two optional inputs - * @li bias: An optional 1D tensor of the same type as "x". Reserved. - * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved . \n - -*@par Required Attributes: - * @li strides: A tuple/list of 5 integers. Specifies the stride of the sliding +*@par Attributes: + * @li strides: Required. A tuple/list of 5 integers. Specifies the stride of the sliding * window for each dimension of "x". * The N and C dimensions must be 1. Has the same format as "x". - * @li pads: A tuple/list of 6 integers - -*@par Attributes: - * Five attributes: - * @li groups: Number of blocked connections from input channels to output - * channels. - * @li dilations: A tuple/list of 5 integers, + * @li pads: Required. A tuple/list of 6 integers. + * @li dilations: Optional. A tuple/list of 5 integers, * The dilation factor for each dimension of input. * The N, C and D dimensions must be 1. Has the same format as "x". - * @li data_format: An optional string from: "NDHWC", "NCDHW". + * @li groups: Optional. Number of blocked connections from input channels to output + * channels. + * @li data_format: Optional. An string from: "NDHWC", "NCDHW". * Defaults to "NDHWC". Specify the data format of the input and output data. - * @li output_padding: The size will be added in the output shape. - * @li offset_x: Input offset_x value. Reserved. + * @li output_padding: Optional. The size will be added in the output shape. + * @li offset_x: Optional. Input offset_x value. Reserved. \n + *@par Outputs: * y: A Tensor. Has the same type and format as "x". */ @@ -1388,9 +1359,9 @@ REG_OP(Conv3DTranspose) .INPUT(input_size, TensorType({DT_INT32, DT_INT64})) .INPUT(x, TensorType({DT_FLOAT16})) .INPUT(filter, TensorType({DT_FLOAT16})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT32})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) - .OUTPUT(y, TensorType({DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32})) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) @@ -1402,46 +1373,44 @@ REG_OP(Conv3DTranspose) /** *@brief Computes the transpose of convolution 3d with respect to the input. + *@par Inputs: * @li x: A Tensor of type float16, currently does not support int8. * The format is NDHWC or NCDHW. * @li filter: A Tensor of type float16, currently does not support int8. * The format is NDHWC, NCDHW or DHWCN. + * @li bias: Optional. An 1D tensor of the same type as "x". Reserved. + * @li offset_w: Optional. An 1D tensor for quantized deconvolution. Reserved. \n -*@par Optional inputs: - * @li bias: An optional 1D tensor of the same type as "x". Reserved. - * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved . \n - -*@par Required Attributes: - * @li input_size: A tuple/list of type int32. - * An integer vector representing the shape of input - * @li strides: A tuple/list of 5 integers. +*@par Attributes: + * @li input_size: Required. A tuple/list of type int32. + * An integer vector representing the shape of input. + * @li strides: Required. A tuple/list of 5 integers. * Specifies the stride of the sliding window for each dimension of "x". * The N and C dimensions must be 1. Has the same format as "x". - * @li pads: A tuple/list of 6 integers . \n - -*@par Attributes: - * Five attributes: - * @li dilations: A tuple/list of 5 integers, The dilation factor for each + * @li pads: Required. A tuple/list of 6 integers. + * @li dilations: Optional. A tuple/list of 5 integers, The dilation factor for each * dimension of input. * The N, C and D dimensions must be 1. Has the same format as "x". - * @li groups: Number of blocked connections from input channels to output + * @li groups: Optional. Number of blocked connections from input channels to output * channels. - * @li data_format: An optional string from: "NDHWC", "NCDHW". + * @li data_format: Optional. An optional string from: "NDHWC", "NCDHW". * Defaults to "NDHWC". Specify the data format of the input and output data. - * @li output_padding: The size will be added in the output shape. - * @li offset_x: Input offset_x value. Reserved. + * @li output_padding: Optional. The size will be added in the output shape. + * @li offset_x: Optional. Input offset_x value. Reserved. \n + *@par Outputs: - * y: A Tensor. Has the same type and format as "x". + * y: A Tensor. Has the same type and format as "x". \n + *@par Restrictions: -* Warning: THIS FUNCTION IS DEPRECATED. Please use Conv3DTranspose instead. + * Warning: THIS FUNCTION IS DEPRECATED. Please use Conv3DTranspose instead. */ REG_OP(Conv3DTransposeD) .INPUT(x, TensorType({DT_FLOAT16})) .INPUT(filter, TensorType({DT_FLOAT16})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT32})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) - .OUTPUT(y, TensorType({DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32})) .REQUIRED_ATTR(input_size, ListInt) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) @@ -1469,17 +1438,17 @@ REG_OP(Conv3DTransposeD) * @li offset_w: An optional 1D tensor for quantized inference. Reserved. *\n *\n - * The following are the supported data types and data formats: -*@verbatim - | Tensor | x | filter | bias | y - ------------|---------|---------|---------|-------- - | Data Type | float16 | float16 | float16 | float16 - | |---------|---------|---------|-------- - | | int8 | int8 | int32 | int32 - ------------|---------|---------|---------|-------- - | Format | NCHW | NCHW | ND | NCHW - | | NHWC | HWCN | | NHWC -@endverbatim + * The following are the supported data types and data formats:\n + *\n + | Tensor | x | filter | bias | y\n + ------------|---------|---------|---------|--------\n + | Data Type | float16 | float16 | float16 | float16\n + | |---------|---------|---------|--------\n + | | int8 | int8 | int32 | int32\n + ------------|---------|---------|---------|--------\n + | Format | NCHW | NCHW | ND | NCHW\n + | | NHWC | HWCN | | NHWC\n + *\n * For int8, a dequant or requant operator must be followed. *\n * @@ -1504,38 +1473,38 @@ REG_OP(Conv3DTransposeD) * within the effective range of int8 [-128, 127]. Defaults to "0". *\n *\n - * The following value range restrictions must be met: -*@verbatim - | Name | Field | Scope - -------------------|----------|-------------- - | input_size | H | [1, 4096] - | | W | [1, 4096] - -------------------|----------|-------------- - | x (out_backprop) | H*strideH| [1, 4096] - | | W*strideW| [1, 4096] - -------------------|----------|-------------- - | filter | H | [1, 255] - | | W | [1, 255] - -------------------|----------|-------------- - | y (fmap) | H | [1, 4096] - | | W | [1, 4096] - -------------------|----------|-------------- - | Stride | H | [1, 63] - | | W | [1, 63] - -------------------|----------|-------------- - | Padding | Top | [0, 255] - | | Bottom | [0, 255] - | | Left | [0, 255] - | | Right | [0, 255] - -------------------|----------|-------------- - | Dilation | H | [1, 255] - | | W | [1, 255] - -------------------|----------|-------------- - | Offset_x | | [-128, 127] - -@endverbatim + * The following value range restrictions must be met:\n + *\n + | Name | Field | Scope\n + -------------------|----------|--------------\n + | input_size | H | [1, 200000]\n + | | W | [1, 4096]\n + -------------------|----------|--------------\n + | x (out_backprop) | H*strideH| [1, 200000]\n + | | W*strideW| [1, 4096]\n + -------------------|----------|--------------\n + | filter | H | [1, 255]\n + | | W | [1, 255]\n + -------------------|----------|--------------\n + | y (fmap) | H | [1, 200000]\n + | | W | [1, 4096]\n + -------------------|----------|--------------\n + | Stride | H | [1, 63]\n + | | W | [1, 63]\n + -------------------|----------|--------------\n + | Padding | Top | [0, 255]\n + | | Bottom | [0, 255]\n + | | Left | [0, 255]\n + | | Right | [0, 255]\n + -------------------|----------|--------------\n + | Dilation | H | [1, 255]\n + | | W | [1, 255]\n + -------------------|----------|--------------\n + | Offset_x | | [-128, 127]\n + *\n * In Ascend910, fmap or out_backprop's H and W not support 1 when * fmap_h + pad_top + pad_bottom != (filter_height - 1) * dilation_h + 1 + * and filter_width > fmap_width * If filter_h = 1 and filter_w = 1, out_backprop_w * stride_h * stride_w < 4096 *\n * @@ -1557,9 +1526,9 @@ REG_OP(Conv2DTranspose) .INPUT(input_size, TensorType({DT_INT32, DT_INT64})) .INPUT(x, TensorType({DT_FLOAT16, DT_INT8})) .INPUT(filter, TensorType({DT_FLOAT16, DT_INT8})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_INT32})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_INT32, DT_FLOAT32})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32, DT_FLOAT32})) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1}) @@ -1604,9 +1573,9 @@ REG_OP(Conv2DTranspose) REG_OP(Conv2DTransposeD) .INPUT(x, TensorType({DT_FLOAT16, DT_INT8})) .INPUT(filter, TensorType({DT_FLOAT16, DT_INT8})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_INT32})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_INT32, DT_FLOAT32})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32, DT_FLOAT32})) .REQUIRED_ATTR(input_size, ListInt) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) @@ -1623,14 +1592,12 @@ REG_OP(Conv2DTransposeD) * Two inputs: * @li x: A Tensor of type float16,float32 * @li offsets: A Tensor of type float16,float32.Deformation offset parameter. -*@par Required Attributes: +*@par Attributes: * @li strides: A tuple/list of 4 integers.The stride of the sliding window for * height and width for H/W dimension. * @li pads: A tuple/list of 4 integers.Padding added to H/W dimension * of the input. * @li ksize: A tuple/list of 2 integers.kernel size. -*@par Attributes: - * Four attributes: * @li dilations: A tuple/list of 4 integers, The dilation factor for each dimension * of input. Defaults to [1, 1, 1, 1] * @li data_format: An optional string from: "NCHW", "NHWC". Defaults to "NCHW". Specify the data format of the input x. @@ -1659,22 +1626,20 @@ REG_OP(DeformableOffsets) * @li grad: A Tensor of type float16,float32. gradients with respect to DeformableOffsets output * @li x: A Tensor of type float16,float32. * @li offsets: A Tensor of type float16,float32.Deformation offset parameter. -*@par Required Attributes: +*@par Attributes: * @li strides: A tuple/list of 4 integers.The stride of the sliding window for * height and width for H/W dimension. * @li pads: A tuple/list of 4 integers.Padding added to H/W dimension * of the input. * @li ksize: A tuple/list of 2 integers.kernel size. -*@par Attributes: - * Three attributes: * @li dilations: A tuple/list of 4 integers, The dilation factor for each dimension * of input. Defaults to [1, 1, 1, 1] * @li data_format: An optional string from: "NCHW", "NHWC". Defaults to "NCHW". Specify the data format of the input x. * @li deformable_groups: Specify the c-axis grouping number of input x. * @li modulated: Specify version of DeformableConv2D, true means v2, false means v1. *@par Outputs: - * grad_x: A Tensor of type float16, float32. Gradients with respect to input_x - * grad_offsets: A Tensor of type float16, float32. Gradients with respect to input_offsets + * @li grad_x: A Tensor of type float16, float32. Gradients with respect to input_x + * @li grad_offsets: A Tensor of type float16, float32. Gradients with respect to input_offsets */ REG_OP(DeformableOffsetsGrad) .INPUT(grad, TensorType({DT_FLOAT16, DT_FLOAT})) @@ -1695,11 +1660,9 @@ REG_OP(DeformableOffsetsGrad) *@brief Computes the deformed dilation output with the expected input *@par Inputs: * One inputs: - * @li x: A Tensor of type int8, float16, float32 -*@par Required Attributes: - * @li dilations: A tuple/list of integers. + * x: A Tensor of type int8, float16, float32 *@par Attributes: - * Two attributes: + * @li dilations: A tuple/list of integers. * @li padding_value: default value filling in blank * @li pads: A tuple/list of integers. *@par Outputs: diff --git a/third_party/fwkacllib/inc/ops/nn_detect_ops.h b/third_party/fwkacllib/inc/ops/nn_detect_ops.h index 5fa40ad6..bd14df77 100644 --- a/third_party/fwkacllib/inc/ops/nn_detect_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_detect_ops.h @@ -153,6 +153,42 @@ REG_OP(Iou) .OP_END_FACTORY_REG(Iou) /** +*@brief First calculate the minimum closure area of the two boxes, IoU, +* the proportion of the closed area that does not belong to the two boxes in the closure area, +* and finally subtract this proportion from IoU to get GIoU . \n + +*@par Inputs: +* Two inputs, including: +*@li bboxes: Bounding boxes, a 2D Tensor of type float16 or float32 with +* shape (N, 4). "N" indicates the number of bounding boxes, and the value +* "4" refers to [x1, y1, x2, y2] or [x, y, w, h]. +*@li gtboxes: Ground-truth boxes, a 2D Tensor of type float16 or float32 +* with shape (M, 4). "M" indicates the number of ground truth boxes, and +* the value "4" refers to [x1, y1, x2, y2] or [x, y, w, h] . \n + +*@par Attributes: +*@li trans: An optional bool, true for 'xywh', false for 'xyxy'. +*@li is_cross: An optional bool, control whether the output shape is [M, N] or [1, N] +*@li mode: Computation mode, a character string with the value range of [iou, iof] . \n + +*@par Outputs: +* overlap: A 2D Tensor of type float16 or float32 with shape [M, N] or [1, N], +* specifying the IoU or IoF ratio . \n + +*@attention Constraints: +* Only computation of float16 data is supported. To avoid overflow, the input +* length and width are scaled by 0.2 internally. +*/ +REG_OP(GIoU) + .INPUT(bboxes, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(gtboxes, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(overlap, TensorType({DT_FLOAT16, DT_FLOAT})) + .ATTR(trans, Bool, false) + .ATTR(is_cross, Bool, true) + .ATTR(mode, String, "iou") + .OP_END_FACTORY_REG(GIoU) + +/** *@brief Performs the backpropagation of ROIAlign for training scenarios . \n *@par Inputs: @@ -417,7 +453,7 @@ REG_OP(PSROIPooling) *@brief Returns detection result . \n *@par Inputs: -* Four inputs, including: +* Five inputs, including: *@li rois: An NCHW tensor of type floa16 or float32, output from operator proposal_d at the preceding layer, used as the input of operator FSRDetectionOutput. *@li bbox_delta: An NCHWC0 tensor of type floa16 or float32, specifying the prediction offset, used to update the coordinates [x1, y1, x2, y2] of each ROI. *@li score: An NCHWC0 tensor of type floa16 or float32, specifying the probability of each class. Class 0 is the background class. @@ -459,7 +495,7 @@ REG_OP(FSRDetectionOutput) *@brief Returns detection result . \n *@par Inputs: -* Four inputs, including: +* Three inputs, including: *@li bbox_delta: An ND tensor of type floa16 or float32, specifying the box loc predictions, used as the input of operator SSDDetectionOutput. *@li score: An ND tensor of type floa16 or float32, specifying the box confidences data, used as the input of operator SSDDetectionOutput. *@li anchors: An ND tensor of type floa16 or float32, output from operator PriorBoxD, used as the input of operator SSDDetectionOutput. @@ -474,7 +510,6 @@ REG_OP(FSRDetectionOutput) *@li code_type: An optional int32, specify the code type. Defaults to 1(only supports 2). The corner is 1, center_size is 2, corner_size is 3 *@li keep_top_k: An optional int32, specify the topk value after nms. Defaults to -1 *@li confidence_threshold: An optional float32, specify the topk filter threshold. Only consider detections with confidence greater than the threshold -*@li kernel_name: An optional string, specifying the operator name. Defaults to "ssd_detection_output". *@par Outputs: *@li out_boxnum: A tensor of type int32, specifying the number of output boxes. *@li y: A tensor of type float16 or float32 with shape [batch,keep_top_k, 8], describing the information of each output box. @@ -989,26 +1024,26 @@ REG_OP(SPP) * feature map . \n *@attention Constraints: -*@li For the feature map input: -(1) If pooled_h = pooled_w = 2, the feature map size must not exceed 50. -(2) If pooled_h = pooled_w = 3, the feature map size must not exceed 60. -(3) If pooled_h = pooled_w = 4, the feature map size must not exceed 70. -(4) If pooled_h = pooled_w = 5, the feature map size must not exceed 70. -(5) If pooled_h = pooled_w = 6, the feature map size must not exceed 80. -(6) If pooled_h = pooled_w = 7, the feature map size must not exceed 80. -(7) If pooled_h = pooled_w = 8, the feature map size must not exceed 80. -(8) If pooled_h = pooled_w = 9, the feature map size must not exceed 70. -(9) If pooled_h = pooled_w = 10, the feature map size must not exceed 70. -(10) If pooled_h = pooled_w = 11, the feature map size must not exceed 70. -(11) If pooled_h = pooled_w = 12, the feature map size must not exceed 70. -(12) If pooled_h = pooled_w = 13, the feature map size must not exceed 70. -(13) If pooled_h = pooled_w = 14, the feature map size must not exceed 70. -(14) If pooled_h = pooled_w = 15, the feature map size must not exceed 70. -(15) If pooled_h = pooled_w = 16, the feature map size must not exceed 70. -(16) If pooled_h = pooled_w = 17, the feature map size must not exceed 50. -(17) If pooled_h = pooled_w = 18, the feature map size must not exceed 40. -(18) If pooled_h = pooled_w = 19, the feature map size must not exceed 40. -(19) If pooled_h = pooled_w = 20, the feature map size must not exceed 40. +* For the feature map input: +*@li If pooled_h = pooled_w = 2, the feature map size must not exceed 50. +*@li If pooled_h = pooled_w = 3, the feature map size must not exceed 60. +*@li If pooled_h = pooled_w = 4, the feature map size must not exceed 70. +*@li If pooled_h = pooled_w = 5, the feature map size must not exceed 70. +*@li If pooled_h = pooled_w = 6, the feature map size must not exceed 80. +*@li If pooled_h = pooled_w = 7, the feature map size must not exceed 80. +*@li If pooled_h = pooled_w = 8, the feature map size must not exceed 80. +*@li If pooled_h = pooled_w = 9, the feature map size must not exceed 70. +*@li If pooled_h = pooled_w = 10, the feature map size must not exceed 70. +*@li If pooled_h = pooled_w = 11, the feature map size must not exceed 70. +*@li If pooled_h = pooled_w = 12, the feature map size must not exceed 70. +*@li If pooled_h = pooled_w = 13, the feature map size must not exceed 70. +*@li If pooled_h = pooled_w = 14, the feature map size must not exceed 70. +*@li If pooled_h = pooled_w = 15, the feature map size must not exceed 70. +*@li If pooled_h = pooled_w = 16, the feature map size must not exceed 70. +*@li If pooled_h = pooled_w = 17, the feature map size must not exceed 50. +*@li If pooled_h = pooled_w = 18, the feature map size must not exceed 40. +*@li If pooled_h = pooled_w = 19, the feature map size must not exceed 40. +*@li If pooled_h = pooled_w = 20, the feature map size must not exceed 40. *@par Third-party framework compatibility * It is a custom operator. It has no corresponding operator in Caffe. */ @@ -1222,9 +1257,7 @@ REG_OP(RpnProposalsD) * @li box_filter: bool, mark of box_filter. Defaults to "true" * @li core_max_num: int, max number of core. Defaults to "8" *@par Outputs: -* @li sorted_rois: A Tensor. Must be float16. N-D with shape [N, 4]. -* @li sorted_scores: A Tensor. Must be float16. N-D with shape [N, 1]. -* @li sorted_classes: A Tensor. Must be float16. N-D with shape [N, 1]. +*sorted_box: A Tensor. Must be float16. N-D with shape [N, 1]. */ REG_OP(RpnProposalPostProcessing) .INPUT(sorted_proposal, TensorType({DT_FLOAT16})) @@ -1382,7 +1415,7 @@ REG_OP(BatchMultiClassNonMaxSuppression) * @li shape_hw: A 1D Tensor of type int32 . \n * @par Attributes: -* @li reversed_box: An optional bool, specifying the last two dims is "4,num" or +* reversed_box: An optional bool, specifying the last two dims is "4,num" or * "num,4", "true" for "4,num", "false" for "num,4". Defaults to "false" . \n * @par Outputs: @@ -1429,9 +1462,9 @@ REG_OP(NormalizeBBox) * @li anchors: A Tensor. Must be int32. * *@par Attributes: -* @li scales: optional, listfloat, . +* @li scales: optional, listfloat. * @li decode_clip: optional, float, threahold of decode process. -* @li reversed_boxes: optional, bool,. +* @li reversed_boxes: optional, bool. * *@par Outputs: * y: A Tensor. Must have the same type as box_predictions. @@ -1446,16 +1479,16 @@ REG_OP(DecodeBboxV2) .OP_END_FACTORY_REG(DecodeBboxV2) /** -*@brief Computes sort function. +*@brief sort the input tensor and return the value of index. * *@par Inputs: *Inputs include: -* x: A Tensor. Dtype support: flaot16, flaot, int16, int8, +* x: A Tensor. Dtype support: float16, float, int16, int8, uint8, int32, int64. -* + *@par Attributes: -* @li axis: optional, int. -* @li descending: optional,bool. +* @li axis: An optional attribute indicates the sorting axis. +* @li descending: An optional attribute indicates desending sort or not. * *@par Outputs: * @li y1: A Tensor. Must have the same type as x. @@ -1568,16 +1601,18 @@ deciding when to remove boxes based on score . \n the last dim representing (batch_id,class_id,index_id) . \n *@par Attributes: -*center_point_box:Integer indicate the format of the box data. +*@li center_point_box:Integer indicate the format of the box data. The default is 0. 0 - the box data is supplied as [y1, x1, y2, x2] where (y1, x1) and (y2, x2) are the coordinates of any diagonal pair of box corners and the coordinates can be provided as normalized (i.e., lying in the interval [0, 1]) or absolute.Mostly used for TF models. 1 - the box data is supplied as [x_center, y_center, width, height]. Mostly used for Pytorch models. \n +*@li max_boxes_size: An optional attribute integer representing the real maximum +*number of boxes to be selected by non max suppression . \n *@par Outputs: -*@li selected_indices: A 2-D integer tensor of shape [M] representing the +*selected_indices: A 2-D integer tensor of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size. \n *@attention Constraints: @@ -1603,7 +1638,7 @@ REG_OP(NonMaxSuppressionV7) *@brief Obtains the ROI feature matrix from the feature map list. It is a customized fused operator for mmdetection. \n *@par Inputs: -* Three inputs, including: +* Two inputs, including: *@li features: A 5HD Tensor list of type float32 or float16. *@li rois: ROI position. A 2D Tensor of float32 or float16 with shape (N, 5). "N" indicates the number of ROIs, * the value "5" indicates the indexes of images where the ROIs are located, "x0", "y0", "x1", and "y1". @@ -1760,7 +1795,7 @@ REG_OP(AnchorResponseFlags) * "N" indicates the number of ROIs. \n *@par Attributes: -*@li performance_mode: select performance mode, "high_precision" or "high_performance". +*performance_mode: select performance mode, "high_precision" or "high_performance". * select "high_precision" when input type is float32, the output tensor precision * will be smaller than 0.0001, select "high_performance" when input type is float32, * the ops will be best performance, but precision will be only smaller than 0.005. @@ -1795,12 +1830,12 @@ REG_OP(YoloBoxesEncode) *@li num_gts: A Tensor. Support int32. real k. shape (1, ) *@par Attributes: -*@li output_dim: float. IOU threshold for positive bboxes. -*@li group_size: float. minimum iou for a bbox to be considered as a positive bbox -*@li spatial_scale: bool. whether to assign all bboxes with the same highest overlap with some gt to that gt. +*@li pos_iou_thr: float. IOU threshold for positive bboxes. +*@li min_pos_iou: float. minimum iou for a bbox to be considered as a positive bbox +*@li gt_max_assign_all: bool. whether to assign all bboxes with the same highest overlap with some gt to that gt. *@par Outputs: -*@li assigned_gt_inds_pos: A Tensor. Support float16/float32. shape (n, ). +* assigned_gt_inds_pos: A Tensor. Support float16/float32. shape (n, ). */ REG_OP(GridAssignPositive) .INPUT(assigned_gt_inds, TensorType({ DT_FLOAT, DT_FLOAT16 })) @@ -1816,6 +1851,40 @@ REG_OP(GridAssignPositive) .REQUIRED_ATTR(min_pos_iou, Float) .REQUIRED_ATTR(gt_max_assign_all, Bool) .OP_END_FACTORY_REG(GridAssignPositive) + +/** +*@brief GIoUGrad . \n + +*@par Inputs: +*@li dy : data of grad increment, a 1D Tensor of type float16 or float32 with +* shape (N,). +*@li bboxes: Bounding boxes, a 2D Tensor of type float16 or float32 with +* shape (4, N). "N" indicates the number of bounding boxes, and the value +* "4" refers to [x1, y1, x2, y2] or [x, y, w, h]. +*@li gtboxes: Ground-truth boxes, a 2D Tensor of type float16 or float32 +* with shape (4, M). "M" indicates the number of ground truth boxes, and +* the value "4" refers to [x1, y1, x2, y2] or [x, y, w, h] . \n + +*@par Attributes: +*@li trans: An optional attr, true for 'xywh', false for 'xyxy', only support true now. +*@li is_cross: An optional attr, if false M equals N, only support false now. +*@li mode: An optional attr, a character string with the value range of ['iou', 'iof'], +* only support 'iou' now. \n + +*@par Outputs: +*@li dbboxes: A 2D Tensor of type float16 or float32 with shape [4, N]. +*@li dgtboxes: A 2D Tensor of type float16 or float32 with shape [4, M]. +*/ +REG_OP(GIoUGrad) + .INPUT(dy, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(bboxes, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(gtboxes, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(dbboxes, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(dgtboxes, TensorType({DT_FLOAT16, DT_FLOAT})) + .ATTR(trans, Bool, false) + .ATTR(is_cross, Bool, true) + .ATTR(mode, String, "iou") + .OP_END_FACTORY_REG(GIoUGrad) } // namespace ge #endif // OPS_BUILT_IN_OP_PROTO_INC_NN_DETECT_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/nn_norm_ops.h b/third_party/fwkacllib/inc/ops/nn_norm_ops.h index b44c0780..9ce7abfd 100644 --- a/third_party/fwkacllib/inc/ops/nn_norm_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_norm_ops.h @@ -54,15 +54,16 @@ REG_OP(LogSoftmaxGrad) *@par Inputs: *Two inputs, including: * @li features: A Tensor. Must be one of the following types: half, float32, double. -* A "batch_size * num_classes" matrix. +*A "batch_size * num_classes" matrix. * @li labels: A Tensor. Must be one of the following types: 'int32', 'int64'. -* batch_size vector with values in [0, num_classes). -* This is the label for the given minibatch entry. +*batch_size vector with values in [0, num_classes). +*This is the label for the given minibatch entry. \n *@par Outputs: -*loss: A Tensor for per example loss (a "batch_size" vector). Has the same type as "features". -*backprop: A Tensor for the backpropagated gradients (a batch_size * num_classes matrix). Has the same type as "features" . \n +*@li loss: A Tensor for per example loss (a "batch_size" vector). Has the same type as "features". +*@li backprop: A Tensor for the backpropagated gradients (a batch_size * num_classes matrix). +Has the same type as "features" . \n *@par Third-party framework compatibility *Compatible with the TensorFlow operator SparseSoftmaxCrossEntropyWithLogits. @@ -84,8 +85,8 @@ REG_OP(SparseSoftmaxCrossEntropyWithLogits) * @li labels: A Tensor of the same type as "features". A "batch_size * num_classes" matrix . \n *@par Outputs: -*loss: A Tensor for per example loss (a "batch_size" vector). Has the same type as "features". -*backprop: A Tensor for the backpropagated gradients (a batch_size * num_classes matrix). Has the same type as "features" . \n +* @li loss: A Tensor for per example loss (a "batch_size" vector). Has the same type as "features". +* @li backprop: A Tensor for the backpropagated gradients (a batch_size * num_classes matrix). Has the same type as "features" . \n *@par Third-party framework compatibility *Compatible with the TensorFlow operator SoftmaxCrossEntropyWithLogits. @@ -127,12 +128,13 @@ REG_OP(SoftmaxGrad) *@brief Computes the sigmoid cross entropy loss of "predict" and "target" . \n *@par Inputs: -* Two inputs, including: +* Three inputs, including: *@li predict: A multi-dimensional Tensor of type float16 or float32, specifying the predictive value. -*@li target: A multi-dimensional Tensor of type float16 or float32, specifying the target value . \n +*@li target: A multi-dimensional Tensor of type float16 or float32, specifying the target value . +*@li dout:A multi-dimensional Tensor of float16 or float32,specifying the gradient transferred from the upper layer. \n *@par Outputs: -*loss: Sigmoid cross entropy between the predictive value and target value. Has the same dimensions as "predict" . \n +*gradient: Sigmoid cross entropy between the predictive value and target value. Has the same dimensions as "predict" . \n *@par Third-party framework compatibility * Compatible with the scenario where "reduction" is set to "none"of PyTorch operator SigmoidCrossEntropyWithLogitsGrad. @@ -148,13 +150,12 @@ REG_OP(SigmoidCrossEntropyWithLogitsGrad) *@brief Performs the backpropagation of SigmoidCrossEntropyWithLogits for training scenarios . \n *@par Inputs: -* Three inputs, including: +* Two inputs, including: *@li predict: A multi-dimensional Tensor of type float16 or float32, specifying the predictive value. -*@li target: A multi-dimensional Tensor of type float16 or float32, specifying the target value. -*@li dout: A multi-dimensional Tensor of float16 or float32, specifying the gradient transferred from the upper layer . \n +*@li target: A multi-dimensional Tensor of type float16 or float32, specifying the target value. \n *@par Outputs: -*gradient: Return gradient. Has the same dimensions and type as "predict" . \n +*loss: Return loss. Has the same dimensions and type as "predict" . \n *@par Third-party framework compatibility * Compatible with the scenario where "reduction" is set to "none"of PyTorch operator SigmoidCrossEntropyWithLogits. @@ -572,7 +573,7 @@ REG_OP(LayerNorm) *@par Inputs: *One input, including: -* @li x: A Tensor. Must be one of the following types: float16, float32 . \n +* x: A Tensor. Must be one of the following types: float16, float32 . \n *@par Attributes: * @li p: Specify L_p norm, the type is float. @@ -581,7 +582,7 @@ REG_OP(LayerNorm) *@par Outputs: *One outputs, including: -* @li y: shape and dtype of output, should be same shape and type as input. +* y: shape and dtype of output, should be same shape and type as input. */ REG_OP(Renorm) .INPUT(x, TensorType::BasicType()) @@ -811,7 +812,7 @@ REG_OP(LayerNormBetaGammaBackpropV2) * shape of "keep_prob" should be (1,) or [1,]. * Has the same type as "x" . \n -*@par Output: +*@par Outputs: *y: A mutable Tensor. Has the same type as "x". */ REG_OP(DropOutDoMask) @@ -839,7 +840,7 @@ REG_OP(DropOutDoMask) * shape of "keep_prob" should be (1,) or [1,]. * Has the same type as "x" . \n -*@par Output: +*@par Outputs: *y: A mutable Tensor. Has the same type as "x". *@par Restrictions: *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. @@ -1010,7 +1011,7 @@ REG_OP(LRNGrad) *@li grads: A Tensor. Has the same type as acts. *@par Attributes: - *@li blank_label: An optional attribute. Defaults to 0. + *blank_label: An optional attribute. Defaults to 0. *@par Third-party framework compatibility * Compatible with TensorFlow RNNTLoss operator. @@ -1198,13 +1199,11 @@ REG_OP(INInferV2D) * @li epsilon: An attribute of type Float. \n * @par Outputs: -*Three outputs, including: +* Three outputs, including: * @li y: A Tensor. Has the same type as "x". \n * @li mean: A Tensor. Has the same type as "x". \n * @li variance: A Tensor. Has the same type as "x". \n -* @par Third-party framework compatibility -* Can be used by onnx InstanceNormalization */ REG_OP(InstanceNorm) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) @@ -1218,24 +1217,22 @@ REG_OP(InstanceNorm) .OP_END_FACTORY_REG(InstanceNorm) /** -*@brief InstanceNormGrad operator interface implementation. +* @brief InstanceNormGrad operator interface implementation. -*@par Inputs: -*Five inputs, including: +* @par Inputs: +* Five inputs, including: * @li dy: A Tensor. Must be one of the following types: float16, float32. * @li x: A Tensor. Must be one of the following types: float16, float32. * @li variance: A Tensor. Must be one of the following types: float16, float32. * @li mean: A Tensor. Must be one of the following types: float16, float32. * @li gamma: A Tensor. Must be one of the following types: float16, float32 . \n -*@par Outputs: -*Three outputs, including: +* @par Outputs: +* Three outputs, including: * @li pd_x: A Tensor. Must be one of the following types: float16, float32. * @li pd_gamma: A Tensor. Must be one of the following types: float16, float32. * @li pd_beta: A Tensor. Must be one of the following types: float16, float32. -*@par Restrictions: -*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. */ REG_OP(InstanceNormGrad) .INPUT(dy, TensorType({DT_FLOAT, DT_FLOAT16})) @@ -1249,58 +1246,6 @@ REG_OP(InstanceNormGrad) .OP_END_FACTORY_REG(InstanceNormGrad) /** -*@brief InstanceNormXBackprop operator interface implementation. - -*@par Inputs: -*Five inputs, including: -* @li dy: A Tensor. Must be one of the following types: float16, float32. -* @li x: A Tensor. Must be one of the following types: float16, float32. -* @li variance: A Tensor. Must be one of the following types: float16, float32. -* @li mean: A Tensor. Must be one of the following types: float16, float32. -* @li gamma: A Tensor. Must be one of the following types: float16, float32 . \n - -*@par Outputs: -*Two outputs, including: -* @li pd_x: A Tensor. Must be one of the following types: float16, float32. -* @li res_for_gamma: A Tensor. Must be one of the following types: float32. - -*@par Restrictions: -*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. -*/ -REG_OP(InstanceNormXBackprop) - .INPUT(dy, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(variance, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(mean, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(gamma, TensorType({DT_FLOAT, DT_FLOAT16})) - .OUTPUT(pd_x, TensorType({DT_FLOAT, DT_FLOAT16})) - .OUTPUT(res_for_gamma, TensorType({DT_FLOAT})) - .OP_END_FACTORY_REG(InstanceNormXBackprop) - -/** -*@brief InstanceNormBetaGammaBackprop operator interface implementation. - -*@par Inputs: -*Two inputs, including: -* @li dy: A Tensor. Must be one of the following types: float16, float32. -* @li res_for_gamma: A Tensor. Must be one of the following types: float32.\n - -*@par Outputs: -*Two outputs, including: -* @li pd_gamma: A Tensor. Must be one of the following types: float16, float32. -* @li pd_beta: A Tensor. Must be one of the following types: float16, float32. - -*@par Restrictions: -*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. -*/ -REG_OP(InstanceNormBetaGammaBackprop) - .INPUT(dy, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(res_for_gamma, TensorType({DT_FLOAT})) - .OUTPUT(pd_gamma, TensorType({DT_FLOAT, DT_FLOAT16})) - .OUTPUT(pd_beta, TensorType({DT_FLOAT, DT_FLOAT16})) - .OP_END_FACTORY_REG(InstanceNormBetaGammaBackprop) - -/** * @brief Computes Kl_div_loss_grad or Kl_div_loss_backward. \n * @par Inputs: @@ -1340,10 +1285,10 @@ REG_OP(KlDivLossGrad) * @li label: A Tensor. Has the same type as "grads". Required. \n * @par Attributes: -* @li reduction: An optional attribute of type String. Defaults to "mean". \n +* reduction: An optional attribute of type String. Defaults to "mean". \n * @par Outputs: -* @li y: A Tensor. Has the same type as "x". \n +* y: A Tensor. Has the same type as "x". \n * @par Third-party framework compatibility * Compatible with the Pytorch operator L1LossGrad. @@ -1368,7 +1313,7 @@ REG_OP(L1LossGrad) * @li reduction: An optional string.Defaults to "mean". \n * @par Outputs: -* @li y: An ND tensor tensor with the same shape and type as "predict". \n +* y: An ND tensor tensor with the same shape and type as "predict". \n * @par Third-party framework compatibility * Compatible with the Pytorch operator LpLoss. @@ -1390,10 +1335,10 @@ REG_OP(LpLoss) * @li dout: An ND tensor of type float16, float32. \n * @par Attributes: -* @li reduction: An optional string.Defaults to "mean". \n +* reduction: An optional string.Defaults to "mean". \n * @par Outputs: -* @li y: An ND tensor tensor with the same shape and type as "predict". \n +* y: An ND tensor tensor with the same shape and type as "predict". \n * @par Third-party framework compatibility * Compatible with the Pytorch operator MseLossGrad. @@ -1414,10 +1359,10 @@ REG_OP(MseLossGrad) * @li label: An ND Tensor of dtype float16 or float32.\n * * @par Attributes: -* @li reduction:An optional str from sum, none, mean, Defaults to "mean".\n +* reduction:An optional str from sum, none, mean, Defaults to "mean".\n * * @par Outputs: -* @li y: when reduction=sum/mean, y is scale. when reduction=none, y has +* y: when reduction=sum/mean, y is scale. when reduction=none, y has * same type and shape as "predict".\n */ REG_OP(MseLoss) @@ -1445,7 +1390,7 @@ REG_OP(MseLoss) * Must be one of the following: "none", "mean", "sum". \n * @par Outputs: -* @li gradient: A Tensor. Has the same type as "predict". \n +* gradient: A Tensor. Has the same type as "predict". \n * @par Third-party framework compatibility * Compatible with the Pytorch operator SmoothL1LossBackward. @@ -1480,7 +1425,7 @@ REG_OP(SmoothL1LossGradV2) * the output,'sum': the output will be summed. Default: 'mean'. \n * @par Outputs: -* @li loss: Indicates the loss between the predictive value and target value. +* loss: Indicates the loss between the predictive value and target value. * Has the same dimensions as "predict". \n * @par Third-party framework compatibility @@ -1498,12 +1443,12 @@ REG_OP(SmoothL1LossV2) * @brief Computes Centralization. result = x - mean(x, axes) * @par Inputs: -* @li x: An ND tensor of type float16, float32. +* x: An ND tensor of type float16, float32. * @par Attributes: -* @li axes: The dimensions to reduce. Must be one of the following types: int, list, tuple, NoneType. +* axes: The dimensions to reduce. Must be one of the following types: int, list, tuple, NoneType. * Must be in the range [-rank(x), rank(x)). * @par Outputs: -* @li y: A Tensor. Has the same type as "x". \n +* y: A Tensor. Has the same type as "x". \n * @par Third-party framework compatibility * custom operator \n @@ -1521,7 +1466,7 @@ REG_OP(Centralization) *@par Inputs: *One inputs, including: -* @li x: A tensor . Must be one of the following types: +* x: A tensor . Must be one of the following types: * float16, float32, int32, uint32, int8, uint8. \n *@par Attributes: @@ -1546,14 +1491,14 @@ REG_OP(Roll) logistic loss between input_x and input_y (containing 1 or -1). \n *@par Inputs: - *One inputs, including: + *Tow inputs, including: * @li input_x: A tensor. Must be one of the following types: * float16, float32. \n * @li input_y: A tensor. Must be one of the following types: * float16, float32. \n *@par Attributes: - *@li lambd: An optional string.Defaults to "mean". \n + *reduction: An optional string.Defaults to "mean". \n *@par Outputs: *output_z: while reduction == "none", A Tensor with the same type and shape of input_x's. \n @@ -1580,10 +1525,10 @@ REG_OP(SoftMarginLoss) * @li pos_weight: An optional ND tensor of type float16, float32. \n * @par Attributes: -* @li reduction: An optional string.Defaults to "mean". \n +* reduction: An optional string.Defaults to "mean". \n * @par Outputs: -* @li gradient: An ND tensor tensor with the same shape and type as "predict". \n +* gradient: An ND tensor tensor with the same shape and type as "predict". \n * @par Third-party framework compatibility * Compatible with the Pytorch operator SigmoidCrossEntropyWithLogitsGrad. @@ -1603,24 +1548,14 @@ REG_OP(SigmoidCrossEntropyWithLogitsGradV2) * @par Inputs: * Two inputs, including: - * @li input_x: A tensor. Must be one of the following types: - * float16, float32. \n - * - * @par Inputs: - * @li target: A tensor. Must be one of the following types: - * float16, float32. \n + * @li input_x: A tensor. Must be one of the following types: float16, float32. + * @li target: A tensor. Must be one of the following types: float16, float32. \n * @par Attributes: * four Attributes, including: - * @li log_input: An optional bool. Defaults to "True" \n - * - * @par Attributes: - * @li full: An optional bool. Defaults to "False" \n - * - * @par Attributes: - * @li eps: An optional float. Defaults to "1e-8" \n - * - * @par Attributes: + * @li log_input: An optional bool. Defaults to "True" + * @li full: An optional bool. Defaults to "False" + * @li eps: An optional float. Defaults to "1e-8" * @li reduction: An optional string. Defaults to "mean" \n * @par Outputs: @@ -1641,14 +1576,14 @@ REG_OP(PoissonNllLoss) /** *@brief rnn_gen_mask * @par Inputs: - * @li seq_length: A ND Tensor of type int32. Recoed the current length of each batch.\n + * seq_length: A ND Tensor of type int32. Recoed the current length of each batch.\n * * @par Attributes: * @li num_step: A required int.\n * @li hidden_size: A required int. \n * * - * @par Output: + * @par Ouputs: * y: A mutable Tensor of type float16, with the shape of [num_step, batch_size, hidden_size]. \n * */ @@ -1666,18 +1601,16 @@ REG_OP(RnnGenMask) * @par Inputs: * Two inputs, including: * @li x: A tensor. Must be one of the following types: -* float16, float32. \n -* -* @par Inputs: +* float16, float32. * @li target: A tensor. Must be the following types: * int32. \n * @par Attributes: -* @li reduction: An optional string. Defaults to "mean" \n +* reduction: An optional string. Defaults to "mean" \n * @par Outputs: -* y: A Tensor has same element type as input x. \n -* is_target: A Tensor has same element type as input target. \n +* @li y: A Tensor has same element type as input x. \n +* @li is_target: A Tensor has same element type as input target. \n * @par Third-party framework compatibility * Compatible with the Pytorch operator MultiLabelMarginLoss. \n diff --git a/third_party/fwkacllib/inc/ops/nn_ops.h b/third_party/fwkacllib/inc/ops/nn_ops.h index 49fd02fa..5b1a4dd0 100644 --- a/third_party/fwkacllib/inc/ops/nn_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_ops.h @@ -106,16 +106,16 @@ REG_OP(FusedBatchNormV2) .OP_END_FACTORY_REG(FusedBatchNormV2) /** - * @brief: Large amount of data sort.First operator of TopK. + * @brief Large amount of data sort.First operator of TopK. * @par Inputs: * two input, including: * @li input_data: A Tensor. Data to be sorted. Support float16 * @li input_index: A Tensor. Range(0, 2048). Datatype and format is same as input_data. * @par Attributes: - * @li k_num: Int.Number to be sorted. + * k_num: Int.Number to be sorted. * @par Outputs: - * 1 output, including: - * @li output_proposal: A Tensor. Datatype and format is same as input_data. Proposal sorted for each channel. + * One output, including: + * output_proposal: A Tensor. Datatype and format is same as input_data. Proposal sorted for each channel. */ REG_OP(SegmentSort) .INPUT(input_data, TensorType({DT_FLOAT16})) @@ -127,13 +127,13 @@ REG_OP(SegmentSort) /** * @brief: Large amount of data sort.Second operator of TopK. * @par Inputs: - * two input, including: - * @li input_proposal: A Tensor. Proposal sorted for each channel. Support float16 + * One input, including: + * input_proposal: A Tensor. Proposal sorted for each channel. Support float16 * @par Attributes: - * @li k_num: Int.Number to be sorted. + * k_num: Int.Number to be sorted. * @par Outputs: - * 1 output, including: - * @li output_proposal: A Tensor. Datatype and format is same as input_data. Proposal sorted for each channel. + * One output, including: + * output_proposal: A Tensor. Datatype and format is same as input_data. Proposal sorted for each channel. */ REG_OP(MultiMerge) .INPUT(input_proposal, TensorType({DT_FLOAT16})) @@ -142,14 +142,14 @@ REG_OP(MultiMerge) .OP_END_FACTORY_REG(MultiMerge) /** - * @brief: Large amount of data sort.Third operator of TopK. + * @brief Large amount of data sort.Third operator of TopK. * @par Inputs: - * two input, including: - * @li input_proposal: A Tensor. Proposal sorted for each channel. Support float16 + * One input, including: + * input_proposal: A Tensor. Proposal sorted for each channel. Support float16 * @par Attributes: - * @li k_num: Int.Number to be sorted. + * k_num: Int.Number to be sorted. * @par Outputs: - * 2 output, including: + * Two output, including: * @li output_data: A Tensor. Datatype and format is same as input_data. Data sorted. * @li output_index: A Tensor. int32. Data index. */ diff --git a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h index 80a21333..72363d18 100644 --- a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h @@ -29,7 +29,7 @@ namespace ge { /** *@brief Performs pooling on the input. *@par Inputs: -*@li x: An NCHW tensor of type float16, float32, int8. +* x: An NCHW tensor of type float16, float32, int8. *@par Attributes: *@li mode: An optional int32, specifying the pooling algorithm, either "0" (max pooling) or "1" (avg pooling). Defaults to "0". *@li global_pooling: An optional bool. Defaults to "false". @@ -50,6 +50,7 @@ namespace ge { *dilation[2]: An optional int32, specifying the left dilation. Defaults to "1". *dilation[3]: An optional int32, specifying the right dilation. Defaults to "1". *@li ceil_mode: An optional int32, either "0" (ceil mode) or "1" (floor mode). Defaults to "0". +*@li data_format: An optional string, Specify the data format of the input and output data. With the default format "NCHW". *@par Outputs: *y: An NCHW tensor of type float16, float32, int32. *@attention Constraints: @@ -204,7 +205,7 @@ REG_OP(AvgPool3D) *y: The average pooled output tensor . \n *@attention Constraints: -*@li "ksize" is in the range [1, 255]. "strides" is in the range [1, 63] +*"ksize" is in the range [1, 255]. "strides" is in the range [1, 63] *@par Third-party framework compatibility * Compatible with the TensorFlow operator AvgPool3D. @@ -281,10 +282,10 @@ REG_OP(AvgPool3DGrad) * @li data_format: A string, format of input data . \n * @par Outputs: -* @output: The average pooled output tensor . \n +* output: The average pooled output tensor . \n * @attention Constraints: -* @li "ksize" is in the range [1, 255]. "strides" is in the range [1, 63] +* "ksize" is in the range [1, 255]. "strides" is in the range [1, 63] * @par Third-party framework compatibility * Compatible with the TensorFlow operator AvgPool3DGradD. @@ -431,6 +432,47 @@ REG_OP(MaxPool3D) .OP_END_FACTORY_REG(MaxPool3D) /** +* @brief Performs max pooling3d on both max values and indices. +* +* @par Inputs: +* One input: +* x: An 6D tensor. Supported type: float16. Format as NDC1HWC0. +* @par Attributes: +* @li ksize: A required list of int32 values, +* specifying the size of the window for each dimension of the input tensor. +* No default value. +* @li strides: A required list of int32 values, +* specifying the stride of the sliding window for each dimension of +* the input tensor. No default value. +* @li pads: A required 3*2-dimension-list of int32 values. +* specifying the pad of three dimension of input, implement with 0. +* @li dilation: dilation of kernel. default value is {1,1,1,1,1}. +* @li ceil_mode: default value is false. +* @li data_format: the format of torch input, default value is "NCDHW". +* @li argmax_type: the function of this field is to determine the type of +* output argmax, "bitmask" is the default value, the argmax will return +* a img2col bitmask. "index_int32" and "index_int64" represent the torch +* output indices. +* @par Outputs: +* y: An 6D tensor. the maxpool3d output(max value), format as NDoC1HoWoC0. +* @par Outputs: +* argmax: A 5D uint16 tensor. the indice output. +* format as NC1HWC0, actually it represent N, Do, C1*ksize, Ho*Wo//16, 16. +*/ +REG_OP(MaxPool3DWithArgmax) + .INPUT(x, TensorType::RealNumberType()) + .OUTPUT(y, TensorType::RealNumberType()) + .OUTPUT(argmax, TensorType::IndexNumberType()) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dilation, ListInt, {1, 1, 1, 1, 1}) + .ATTR(ceil_mode, Bool, false) + .ATTR(data_format, String, "NCDHW") + .ATTR(argmax_type, String, "bitmask") + .OP_END_FACTORY_REG(MaxPool3DWithArgmax) + +/** *@brief Applies a 2D adaptive max pooling over an input signal conposed of several input planes. \n * The output is of size H x W, for any input size. @@ -522,8 +564,7 @@ REG_OP(MaxPool3DGradGrad) * y: A mutable tensor. Has the same shape and type as "x1" . \n * @attention Constraints: -* @li Computing gradients of global pooling is not supported, which means -* "ksize < x1". +* @li ksize is limited by buffer with full tiling. * @li "ksize" is in the range [1, 255]. "strides" is in the range [1, 63] * @par Third-party framework compatibility @@ -568,7 +609,7 @@ REG_OP(MaxPoolGrad) * @li Other dimensions of ksize and strides is 1 . \n * @par Outputs: -* @li y: Has the same type and format as input "x1" . \n +* y: Has the same type and format as input "x1" . \n * @par Third-party framework compatibility * @li Compatible with the TensorFlow operator MaxPoolGradGrad. @@ -588,7 +629,7 @@ REG_OP(MaxPoolGradGrad) *@brief Performs max_pool_ext2 on the input . \n *@par Inputs: -* Two inputs: +* Three inputs: *@li x: An NC1HWC0 Tensor of type float16. *@li strides: A required type of int32 values, specifying the stride of the sliding window for each dimension of the input tensor. No default value. *@li ksize: A required type of int32 values, specifying the size of the window for each dimension of the input tensor. No default value. @@ -635,7 +676,8 @@ REG_OP(MaxPoolV2) *@li strides: A required list of int8, int16, int32, or int64 values, * specifying the stride of the sliding window for each dimension of * the input tensor. No default value. -*@li padding: A required string. No default value . \n +*@li padding: A required string. No default value . +*@li Targmax:An optional int with default value 7 . \n *@par Outputs: *@li y: A Tensor. Has the same type and format as input "x". @@ -645,7 +687,7 @@ REG_OP(MaxPoolV2) * ksize[1] * ksize[2] <= 255. *@li "stride is a list that has length 4: strides[0] = 1 or strides[3] = 1, * strides[1] <= 63, strides[0] >= 1, strides[2] <= 63, strides[2] >= 1. -*@li "padding" is either "SAME" or "VALID" . \n +*@li "padding" is either "SAME" or "VALID" . *@par Third-party framework compatibility * Compatible with the TensorFlow operator MaxPoolWithArgmax. @@ -710,14 +752,15 @@ REG_OP(MaxPoolGradWithArgmax) *@brief Performs transform mask to argmax . \n *@par Inputs: -* Two input: -*x: An NC1HWC0 Tensor of type float16. -*mask: An NC1HWC0 Tensor of type uint16 . \n +* Two inputs: +*@li x: An NC1HWC0 Tensor of type float16. +*@li mask: An NC1HWC0 Tensor of type uint16 . \n *@par Attributes: *@li ksize: A required list of int8, int16, int32, or int64 values, specifying the size of the window for each dimension of the input tensor. No default value. *@li strides: A required list of int8, int16, int32, or int64 values, specifying the stride of the sliding window for each dimension of the input tensor. No default value. -*@li padding: A required string. No default value . \n +*@li padding: A required string. No default value . +*@li originshape:A required list of int8, int16, int32, or int64 values, No default value. \n *@par Outputs: *argmax: An NC1HWC0 Tensor of type int32 . \n @@ -754,7 +797,7 @@ REG_OP(Mask2Argmax) * @li strides: A required list, specifying the stride of the sliding window. * @li padding: A required string, window sliding mode. Either SAME or VALID. * @par Outputs: -* @li y:Result tensor. Supported type: float, double, int32, +* y:Result tensor. Supported type: float, double, int32, * uint8, int16, int8, int64, uint16, half, uint32, uint64 * @attention Constraints: @@ -767,7 +810,7 @@ REG_OP(Mask2Argmax) * (shape_max_pool[2] * shape_max_pool[3] + 31) // 16, 16), else failed . \n * @par Third-party framework compatibility -* @li Compatible with the TensorFlow operator MaxPoolGradGradWithArgmax. +* Compatible with the TensorFlow operator MaxPoolGradGradWithArgmax. */ REG_OP(MaxPoolGradGradWithArgmax) .INPUT(x, TensorType::RealNumberType()) @@ -931,11 +974,11 @@ REG_OP(AvgPoolV2GradD) .OP_END_FACTORY_REG(AvgPoolV2GradD) /** -*@brief :upsample the layer +*@brief upsample the layer, similar to the nearest-neighbor difference scaling algorithm. *@par Inputs: * one input, including: -*@li x: A tensor of type float16 or float32. +* x: A tensor of type float16 or float32. *@par Attributes: *@li scale: A optional float32, scale factor of x. Defaults to "1.0". *@li stride_h: An optional int32, broadcast the axis of h. Defaults to "2". @@ -1419,7 +1462,7 @@ REG_OP(MaxPoolV3) * the floor function will be used. Default False \n * @par Outputs: -* y: A mutable tensor. Has the same shape and type as "x1" . \n +* out_grad: A mutable tensor. Has the same shape and type as "x1" . \n * @attention Constraints: * @li Computing gradients of global pooling is not supported, which means @@ -1447,8 +1490,8 @@ REG_OP(MaxPoolV3Grad) *@brief Performs Dilation2D on the input . \n *@par Inputs: -*x: A tensor of shape is 4d, format is support NHWC. -*filter: A tensor of shape is 3d, the type is same with x, and the c dimension is same with x. \n +*@li x: A tensor of shape is 4d, format is support NHWC. +*@li filter: A tensor of shape is 3d, the type is same with x, and the c dimension is same with x. \n *@par Attributes: *@li strides: A required list of 4 ints, specifying the stride of the sliding window. The strides of the N and C dimensions are 1. @@ -1480,9 +1523,9 @@ REG_OP(Dilation2D) *@brief Performs Dilation2DBackpropFilter on the input. \n *@par Inputs: -*x: A tensor of shape is 4d, format is support NHWC. -*filter: A tensor of shape is 3d, the type is same with x, and the c dimension is same with x. -*out_backprop: Has the same type and format as input x and the c dimension is same with x. \n +*@li x: A tensor of shape is 4d, format is support NHWC. +*@li filter: A tensor of shape is 3d, the type is same with x, and the c dimension is same with x. +*@li out_backprop: Has the same type and format as input x and the c dimension is same with x. \n *@par Attributes *@li strides: A required list of 4 ints, specifying the stride of the sliding window. The strides of the N and C dimension are 1. @@ -1519,9 +1562,9 @@ REG_OP(Dilation2DBackpropFilter) *@brief Performs Dilation2DBackpropInput on the input. \n *@par Inputs: -*x: A tensor of shape is 4d, format is support NHWC. -*filter: A tensor of shape is 3d, the type is same with x, and the c dimension is same with x. -*out_backprop: Has the same type and format as input x and the c dimension is same with x. \n +*@li x: A tensor of shape is 4d, format is support NHWC. +*@li filter: A tensor of shape is 3d, the type is same with x, and the c dimension is same with x. +*@li out_backprop: Has the same type and format as input x and the c dimension is same with x. \n *@par Attributes *@li strides: A required list of 4 ints, specifying the stride of the sliding window. The strides of the N and C dimension are 1. diff --git a/third_party/fwkacllib/inc/ops/nn_training_ops.h b/third_party/fwkacllib/inc/ops/nn_training_ops.h index 75e91aee..9dd502cd 100644 --- a/third_party/fwkacllib/inc/ops/nn_training_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_training_ops.h @@ -289,7 +289,8 @@ REG_OP(SparseApplyAdagradV2D) * Should be from a Variable(). *@li lr: A scalar. Has the same type as "var". *@li grad: A tensor for the gradient. Has the same type as "var". -* +*@li momentum: Momentum. Must be a scalar. + *@par Attributes: *@li use_nesterov: An optional bool. Defaults to "False". * If "True", the tensor passed to compute grad will be @@ -701,7 +702,7 @@ REG_OP(ApplyPowerSignD) /** *@brief Updates "var" as FOBOS algorithm with fixed learning rate. * prox_v = var - alpha * delta -* var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +* var = sign(prox_v)/(1+alpha * l2) * max{|prox_v|-alpha * l1,0} * *@attention Constraints: * the input tensors must have the same shape. @@ -2128,10 +2129,12 @@ REG_OP(FusedMulApplyMomentumExtern) * otherwise the behavior is undefined, but may exhibit less contention. * *@par Outputs: -* var: A mutable tensor. Has the same type as input "var". +* @li var: A mutable tensor. Has the same type as input "var". +* @li accum: A mutable tensor. Has the same type as input "accum". * *@attention Constraints: -* The input tensors must have the same shape. +* @li var: A mutable tensor. Has the same type as input "var". +* @li accum: A mutable tensor. Has the same type as input "accum". * *@par Third-party framework compatibility * Compatible with the TensorFlow operator ResourceApplyKerasMomentum. diff --git a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h index ca1c24eb..01ff77cb 100644 --- a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h +++ b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h @@ -28,8 +28,8 @@ namespace ge { *@brief Computes the for the gelu of "x" . \n *@par Inputs: -*Two inputs, including: -* @li x: A Tensor. Must be one of the following types: float16, float32 +*One input, including: +*x: A Tensor. Must be one of the following types: float16, float32 *@par Outputs: *y: A Tensor. Has the same type as "x". @@ -66,8 +66,8 @@ REG_OP(GeluGrad) *@brief Computes the for the fast_gelu of "x" . \n *@par Inputs: -*Two inputs, including: -* @li x: A Tensor. Must be one of the following types: float16, float32 +*One input, including: +*x: A Tensor. Must be one of the following types: float16, float32 *@par Outputs: *y: A Tensor. Has the same type as "x". @@ -83,7 +83,7 @@ REG_OP(FastGelu) *@brief Computes the gradient for the fast_gelu of "x" . \n *@par Inputs: -*Three inputs, including: +*Two inputs, including: * @li dy: A Tensor. Must be one of the following types: float16, float32 * @li x: A Tensor of the same type as "dy" . \n @@ -169,7 +169,7 @@ REG_OP(Relu) * x: A Tensor of type RealNumberType . \n * @par Outputs: -* y: A Tensor of type RealNumberType . \n +* y: A Tensor with the same type as x . \n * @par Third-party framework compatibility * Compatible with the TensorFlow operator Relu6. @@ -209,8 +209,12 @@ REG_OP(Relu6D) * backprops = gradients * (features > 0) * (features < 6) . \n * @par Inputs: -* @li features: A Tensor of type RealNumberType. -* @li gradients: A Tensor of type RealNumberType . \n +* @li gradients: A Tensor of type RealNumberType. The backpropagated + gradients to the corresponding Relu6 operation. +* @li features: A Tensor with the same type as gradients.he features passed + as input to the corresponding Relu6 operation, or its output; + using either one produces the same result. \n + * @par Outputs: * backprops: A Tensor of type RealNumberType . \n @@ -228,7 +232,7 @@ REG_OP(Relu6Grad) *Applies the element-wise function: * Computes the backward for the elu: if x>0, 1; otherwise elu() + alpha . *@par Inputs: -*One inputs, including: +*Two inputs, including: * @li grads: A tensor. Must be one of the following types: * float16, float32. * @li activations: A tensor. Must be one of the following types: @@ -238,7 +242,7 @@ REG_OP(Relu6Grad) *y: A Tensor with the same type and shape of grads's. * *@par Attributes: -*@li alpha: scalar parameter, default value = 1.0 +*alpha: scalar parameter, default value = 1.0 */ REG_OP(EluGradV2) .INPUT(grads, TensorType({DT_FLOAT, DT_FLOAT16})) @@ -539,13 +543,9 @@ REG_OP(Elu) *x: A float16, float32, for the input data type . \n *@par Attributes: -*alpha1: A float32. Defines at which negative value the ELU saturates. Defaults to "1.0" . \n - -*@par Attributes: -*alpha2: A float32. Defines at which negative value the ELU saturates. Defaults to "1.0" . \n - -*@par Attributes: -*alpha3: A float32. Defines at which positive value the ELU saturates. Defaults to "1.0" . \n +*@li alpha1: A float32. Defines at which negative value the ELU saturates. Defaults to "1.0" . +*@li alpha2: A float32. Defines at which negative value the ELU saturates. Defaults to "1.0" . +*@li alpha3: A float32. Defines at which positive value the ELU saturates. Defaults to "1.0" . \n *@par Outputs: *y: A float16, float32, for the normalized result . \n @@ -706,8 +706,8 @@ REG_OP(Mish) * @li x: A Tensor. Must be one of the following types: float16, float32 * @li tanhx: A Tensor. shape, datatype and format is same as x * @par Outputs: - * 1 output, including: - * @li x_grad: A Tensor. shape, datatype and format is same as x + * One output, including: + * x_grad: A Tensor. shape, datatype and format is same as x */ REG_OP(MishGrad) @@ -721,20 +721,20 @@ REG_OP(MishGrad) * @brief pytorch hardtanh_backward operator. * * @par Inputs: - * 2 inputs, including: + * Two inputs, including: * @li result, minimum tensor of the linear region range, * datatype: float16/float32, format:ND/5HD. * @li grad, maximum tensor of the linear region range, * datatype:float16/float32, format:ND/5HD. \n * @par Attributes: - * 2 attributes, including: + * Two attributes, including: * @li min_val, minimum value of the linear region range, datatype:float. * @li max_val, maximum value of the linear region range, datatype:float. \n * @par Outputs: - * 1 output, including: - * @li y, hardtanh_backward output tensor, datatype and format is same as + * One output, including: + * y, hardtanh_backward output tensor, datatype and format is same as * input result. \n * @attention Constraints: @@ -756,7 +756,7 @@ REG_OP(HardtanhGrad) * @par Inputs: * One inputs, including: -* @li x: A mutable Tensor. Must be one of the following types: +* x: A mutable Tensor. Must be one of the following types: * float16, float32. \n * @par Attributes: @@ -765,7 +765,7 @@ REG_OP(HardtanhGrad) * @li threshold: An optional float. Defaults to "20.0" \n * @par Outputs: -* @li y: A mutable Tensor. Has the same type as "x" \n +* y: A mutable Tensor. Has the same type as "x" \n * @par Third-party framework compatibility * Compatible with the Pytorch operator Softplus. @@ -792,7 +792,7 @@ REG_OP(SoftplusV2) * @li threshold: An optional float. Defaults to "20.0" \n * @par Outputs: -* @li output_backprops: A mutable Tensor. Has the same type as "input_gradients" \n +* output_backprops: A mutable Tensor. Has the same type as "input_gradients" \n * @par Third-party framework compatibility * Compatible with the Pytorch operator SoftplusGrad. @@ -809,13 +809,16 @@ REG_OP(SoftplusV2Grad) * @brief ThresholdedRelu takes one input data (Tensor) and produces one output data (Tensor) * where the rectified linear function, y = x for x > alpha, y = 0 otherwise, is applied to the tensor elementwise. * - * @par inputs + * @par Inputs: * one input including: - * @li x: input A Tensor. Must be one of the following types: float32, float16 + * x: input A Tensor. Must be one of the following types: float32, float16 * - * @par output + * @par Attributes: + * alpha: An optional float. Defaults to 1.0. \n + + * @par Outputs: * one output including: - * @li y:A Tensor of the same type as x + * y:A Tensor of the same type as x * */ REG_OP(ThresholdedRelu) @@ -829,14 +832,14 @@ REG_OP(ThresholdedRelu) * @par Inputs: * One inputs, including: -* @li input_x: A tensor. Must be one of the following types: +* input_x: A tensor. Must be one of the following types: * float16, float32. \n * @par Attributes: -* @li lambd: An optional float. Defaults to 0.5. \n +* lambd: An optional float. Defaults to 0.5. \n * @par Outputs: -* y: A Tensor with the same dtype and shape of input_x's. \n +* output_y: A Tensor with the same dtype and shape of input_x's. \n * @par Third-party framework compatibility * Compatible with the Pytorch operator Hardshrink. \n @@ -863,7 +866,7 @@ REG_OP(HardShrink) *backprops: A Tensor with the same type and shape of features's. \n * *@par Attributes: -*@li lambd: An optional float.Defaults to 0.5. \n +*lambd: An optional float.Defaults to 0.5. \n * *@par Third-party framework compatibility *Compatible with the Pytorch operator Hardshrink_backward. \n @@ -880,7 +883,7 @@ REG_OP(HardShrink) * @par Inputs: * One inputs, including: -* @li input_x: A tensor. Must be one of the following types: +* input_x: A tensor. Must be one of the following types: * float16, float32, int32. \n * @par Attributes: @@ -905,11 +908,11 @@ REG_OP(HardSigmoid) * @par Inputs: * One inputs, including: -* @li input_x: A tensor. Must be one of the following types: +* input_x: A tensor. Must be one of the following types: * float16, float32. \n * @par Attributes: -* @li lambd: An optional float. Defaults to 0.5. \n +* lambd: An optional float. Defaults to 0.5. \n * @par Outputs: * y: A Tensor with the same dtype and shape of input_x's. \n @@ -933,7 +936,7 @@ REG_OP(SoftShrink) * @li input_x: A tensor of the same dtype as "input_grad". \n * @par Attributes: -* @li lambd: An optional float. Defaults to 0.5. \n +* lambd: An optional float. Defaults to 0.5. \n * @par Outputs: * y: A Tensor of the same dtype and shape as "input_graxd". \n @@ -976,12 +979,12 @@ REG_OP(LogSigmoidGrad) *@par Inputs: *One inputs, including: -* @li x: A tensor. Must be one of the following types: +* x: A tensor. Must be one of the following types: * float16, float32. \n *@par Outputs: *One outputs, including: -* @li y: A tensor with the same type and shape of x's. \n +* y: A tensor with the same type and shape of x's. \n *@par Third-party framework compatibility *Compatible with the Pytorch operator LogSigmoid. \n @@ -1003,7 +1006,7 @@ REG_OP(LogSigmoid) *@par Outputs: *One outputs, including: -* @li y: A tensor with the same type and shape of x's. \n +* y: A tensor with the same type and shape of x's. \n * @par Attributes: * @li alpha: An optional float. Defaults to 0.16666666. \n diff --git a/third_party/fwkacllib/inc/ops/pad_ops.h b/third_party/fwkacllib/inc/ops/pad_ops.h index 6854c866..9d0e7a62 100644 --- a/third_party/fwkacllib/inc/ops/pad_ops.h +++ b/third_party/fwkacllib/inc/ops/pad_ops.h @@ -33,8 +33,8 @@ namespace ge { *@li value: A 0D scalar. Specifies the value to fill the returned tensor. * Must be one of the following types: -* float16, float32, double, int32, uint8, int16, int8, complex64, int64, -* qint8, quint8, qint32, uint16, complex128, uint32, uint64. +* float16, float32, double, int32, uint8, int16, int8, complex64, int64, bool, +* qint8, quint8, qint32, qint16, quint16, uint16, complex128, uint32, uint64, . * *@par Outputs: * y: A tensor. Has the same type as "value". @@ -46,8 +46,14 @@ namespace ge { */ REG_OP(Fill) .INPUT(dims, TensorType::IndexNumberType()) - .INPUT(value, TensorType::BasicType()) - .OUTPUT(y, TensorType::BasicType()) + .INPUT(value, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, + DT_INT8, DT_COMPLEX64, DT_INT64, DT_BOOL, DT_QINT8, + DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16, DT_UINT16, + DT_COMPLEX128, DT_FLOAT16, DT_UINT32, DT_UINT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, + DT_INT8, DT_COMPLEX64, DT_INT64, DT_BOOL, DT_QINT8, + DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16, DT_UINT16, + DT_COMPLEX128, DT_FLOAT16, DT_UINT32, DT_UINT64})) .OP_END_FACTORY_REG(Fill) /** @@ -213,11 +219,11 @@ REG_OP(PadV2) *@brief Pads a tensor . \n *@par Inputs: -*x: A Tensor. Must be one of the following types: float16, float32, int32 . \n -*constant_values: A Tensor. Must have the same type as input. +*@li x: A Tensor. Must be one of the following types: float16, float32, int32 . \n +*@li constant_values: A Tensor. Must have the same type as input. *@par Attributes: -*paddings: An optional "vector>". Defaults to "{}". +*paddings: A required Attribute. * For each dimension D of input, paddings[D, 0] indicates how many * values to add before the contents of tensor in that dimension, * and paddings[D, 1] indicates how many values to add after the @@ -461,7 +467,7 @@ REG_OP(FillV2) * @li dims: An required listInt to specify the shape that the value to fill. * @par Outputs: -* @li y: A Tensor. Has the shape specify by attr shape, and full of the value specify by attr value. +* y: A Tensor. Has the shape specify by attr shape, and full of the value specify by attr value. * @par Third-party framework compatibility * Compatible with the ONNX operator ConstantOfShape. diff --git a/third_party/fwkacllib/inc/ops/parsing_ops.h b/third_party/fwkacllib/inc/ops/parsing_ops.h index b625180a..e578997c 100644 --- a/third_party/fwkacllib/inc/ops/parsing_ops.h +++ b/third_party/fwkacllib/inc/ops/parsing_ops.h @@ -54,27 +54,26 @@ REG_OP(StringToNumber) /** *@brief Convert serialized tensorflow.TensorProto prototype to Tensor. *@brief Parse an Example prototype. -*@par Input: -*serialized: A Tensor of type string. -*dense_defaults: DYNAMIC INPUT Tensor type as string, float, int64. \n +*@par Inputs: +*@li serialized: A Tensor of type string. +*@li dense_defaults: DYNAMIC INPUT Tensor type as string, float, int64. \n *@par Attributes: -*num_sparse: type int num of inputs sparse_indices , sparse_values, sparse_shapes -*out_type: output type -*sparse_keys: ListString -*sparse_types: types of sparse_values -*dense_keys: ListString -*dense_shapes: output of dense_defaults shape -*dense_types: output of dense_defaults type \n +*@li num_sparse: type int num of inputs sparse_indices , sparse_values, sparse_shapes +*@li sparse_keys: ListString +*@li sparse_types: types of sparse_values +*@li dense_keys: ListString +*@li Tdense: output of dense_defaults type +*@li dense_shapes: output of dense_defaults shape \n *@par Outputs: -*sparse_indices: A Tensor of type string. -*sparse_values: Has the same type as sparse_types. -*sparse_shapes: A Tensor of type int64 -*dense_values: Has the same type as dense_defaults. +*@li sparse_indices: A Tensor of type string. +*@li sparse_values: Has the same type as sparse_types. +*@li sparse_shapes: A Tensor of type int64 +*@li dense_values: Has the same type as dense_defaults. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. -**/ +*/ REG_OP(ParseSingleExample) .INPUT(serialized, TensorType({DT_STRING})) .DYNAMIC_INPUT(dense_defaults, TensorType({DT_STRING,DT_FLOAT,DT_INT64})) @@ -92,16 +91,16 @@ REG_OP(ParseSingleExample) /** *@brief Decodes raw file into tensor . \n -*@par Input: +*@par Inputs: *bytes: A Tensor of type string. *@par Attributes: -*little_endian: bool ture -*out_type: output type +*@li little_endian: bool ture +*@li out_type: output type *@par Outputs: *Output: A Tensor -**/ +*/ REG_OP(DecodeRaw) .INPUT(bytes, TensorType({DT_STRING})) .OUTPUT(output, TensorType({DT_BOOL,DT_FLOAT16,DT_DOUBLE,DT_FLOAT, @@ -147,18 +146,20 @@ REG_OP(ParseTensor) *@par Inputs: *Inputs include: -*records: Each string is a record/row in the csv and all records should have the +*@li records: Each string is a record/row in the csv and all records should have the *same format. \n -*record_defaults: One tensor per column of the input record, with either a +*@li record_defaults: One tensor per column of the input record, with either a *scalar default value for that column or an empty vector if the column is *required. \n *@par Attributes: -*OUT_TYPE: The numeric type to interpret each string in string_tensor as . \n -*field_delim: char delimiter to separate fields in a record. \n -*use_quote_delim: If false, treats double quotation marks as regular characters +*@li OUT_TYPE: The numeric type to interpret each string in string_tensor as . \n +*@li field_delim: char delimiter to separate fields in a record. \n +*@li use_quote_delim: If false, treats double quotation marks as regular characters *inside of the string fields (ignoring RFC 4180, Section 2, Bullet 5). \n -*na_value: Additional string to recognize as NA/NaN. \n +*@li na_value: Additional string to recognize as NA/NaN. \n +*@li select_cols: Optional sorted list of column indices to select. If specified, +only this subset of columns will be parsed and returned. *@par Outputs: *output: A Tensor. Has the same type as x . \n @@ -186,25 +187,25 @@ REG_OP(DecodeCSV) /** *@brief Convert serialized tensorflow.TensorProto prototype to Tensor. *@brief Parse an Example prototype. -*@par Input: -*serialized: A Tensor of type string. \n -*name:A Tensor of type string. \n -*sparse_keys: Dynamic input tensor of string. \n -*dense_keys: Dynamic input tensor of string \n -*dense_defaults: Dynamic input tensor type as string, float, int64. \n +*@par Inputs: +*@li serialized: A Tensor of type string. \n +*@li name:A Tensor of type string. \n +*@li sparse_keys: Dynamic input tensor of string. \n +*@li dense_keys: Dynamic input tensor of string \n +*@li dense_defaults: Dynamic input tensor type as string, float, int64. \n *@par Attributes: -*Nsparse: Number of sparse_keys, sparse_indices and sparse_shapes \n -*Ndense: Number of dense_keys \n -*sparse_types: types of sparse_values \n -*Tdense: Type of dense_defaults dense_defaults and dense_values \n -*dense_shapes: output of dense_defaults shape \n +*@li Nsparse: Number of sparse_keys, sparse_indices and sparse_shapes \n +*@li Ndense: Number of dense_keys \n +*@li sparse_types: types of sparse_values \n +*@li Tdense: Type of dense_defaults dense_defaults and dense_values \n +*@li dense_shapes: output of dense_defaults shape \n *@par Outputs: -*sparse_indices: A Tensor of type string. \n -*sparse_values: Has the same type as sparse_types. \n -*sparse_shapes: A Tensor of type int64 \n -*dense_values: Has the same type as dense_defaults. \n +*@li sparse_indices: A Tensor of type string. \n +*@li sparse_values: Has the same type as sparse_types. \n +*@li sparse_shapes: A Tensor of type int64 \n +*@li dense_values: Has the same type as dense_defaults. \n *@par Third-party framework compatibility \n *@li compatible with tensorflow StringToNumber operator. \n */ @@ -228,37 +229,37 @@ REG_OP(ParseExample) /** *@brief Transforms a scalar brain.SequenceExample proto (as strings) into typed *tensors. -*@par Input: -*serialized: A Tensor of type string. \n -*feature_list_dense_missing_assumed_empty:A Tensor of type string. \n -*context_sparse_keys: Dynamic input tensor of string. \n -*context_dense_keys: Dynamic input tensor of string \n -*feature_list_sparse_keys: Dynamic input tensor of string \n -*feature_list_dense_keys: Dynamic input tensor of string \n -*context_dense_defaults: Dynamic input tensor of string, float, int64 \n -*debug_name: A Tensor of type string. \n +*@par Inputs: +*@li serialized: A Tensor of type string. \n +*@li feature_list_dense_missing_assumed_empty:A Tensor of type string. \n +*@li context_sparse_keys: Dynamic input tensor of string. \n +*@li context_dense_keys: Dynamic input tensor of string \n +*@li feature_list_sparse_keys: Dynamic input tensor of string \n +*@li feature_list_dense_keys: Dynamic input tensor of string \n +*@li context_dense_defaults: Dynamic input tensor of string, float, int64 \n +*@li debug_name: A Tensor of type string. \n *@par Attributes: -*Ncontext_sparse: Number of context_sparse_keys, context_sparse_indices and context_sparse_shapes \n -*Ncontext_dense: Number of context_dense_keys \n -*Nfeature_list_sparse: Number of feature_list_sparse_keys \n -*Nfeature_list_dense: Number of feature_list_dense_keys \n -*context_sparse_types: Types of context_sparse_values \n -*Tcontext_dense: Number of dense_keys \n -*feature_list_dense_types: Types of feature_list_dense_values \n -*context_dense_shapes: Shape of context_dense \n -*feature_list_sparse_types: Type of feature_list_sparse_values \n -*feature_list_dense_shapes: Shape of feature_list_dense \n +*@li Ncontext_sparse: Number of context_sparse_keys, context_sparse_indices and context_sparse_shapes \n +*@li Ncontext_dense: Number of context_dense_keys \n +*@li Nfeature_list_sparse: Number of feature_list_sparse_keys \n +*@li Nfeature_list_dense: Number of feature_list_dense_keys \n +*@li context_sparse_types: Types of context_sparse_values \n +*@li Tcontext_dense: Number of dense_keys \n +*@li feature_list_dense_types: Types of feature_list_dense_values \n +*@li context_dense_shapes: Shape of context_dense \n +*@li feature_list_sparse_types: Type of feature_list_sparse_values \n +*@li feature_list_dense_shapes: Shape of feature_list_dense \n *@par Outputs: -*context_sparse_indices: Dynamic output tensor of type int64. \n -*context_sparse_values: Dynamic output tensor of type string, float, int64. \n -*context_sparse_shapes: Dynamic output tensor of type int64 \n -*context_dense_values: Dynamic output tensor of type string, float, int64. \n -*feature_list_sparse_indices: Dynamic output tensor of type int64. \n -*feature_list_sparse_values: Dynamic output tensor of type string, float, int64. \n -*feature_list_sparse_shapes: Dynamic output tensor of type int64 \n -*feature_list_dense_values: Dynamic output tensor of type string, float, int64. \n +*@li context_sparse_indices: Dynamic output tensor of type int64. \n +*@li context_sparse_values: Dynamic output tensor of type string, float, int64. \n +*@li context_sparse_shapes: Dynamic output tensor of type int64 \n +*@li context_dense_values: Dynamic output tensor of type string, float, int64. \n +*@li feature_list_sparse_indices: Dynamic output tensor of type int64. \n +*@li feature_list_sparse_values: Dynamic output tensor of type string, float, int64. \n +*@li feature_list_sparse_shapes: Dynamic output tensor of type int64 \n +*@li feature_list_dense_values: Dynamic output tensor of type string, float, int64. \n *@par Third-party framework compatibility \n *@li compatible with tensorflow StringToNumber operator. \n */ diff --git a/third_party/fwkacllib/inc/ops/quantize_ops.h b/third_party/fwkacllib/inc/ops/quantize_ops.h index 69d5e67e..0636833c 100644 --- a/third_party/fwkacllib/inc/ops/quantize_ops.h +++ b/third_party/fwkacllib/inc/ops/quantize_ops.h @@ -63,10 +63,11 @@ REG_OP(Dequantize) /** *@brief Quantizes the input . \n *@par Inputs: -*x: shape and dtype of input_x. \n -*scales: shape and dtype of input_scales. \n -*zero_points: shape and dtype of input_zero_points \n +*@li x: shape and dtype of input_x. \n +*@li scales: shape and dtype of input_scales. \n +*@li zero_points: shape and dtype of input_zero_points \n *@par Attributes: +*@li dtype: required, type. *@li axis: the processed dim. \n *@par Outputs: *y: shape and dtype of output_y, should be same shape as input, dtype is same as the quantified type . \n @@ -91,7 +92,8 @@ REG_OP(Quantize) *@li offset: A required float16, specifying the offset. *@li sqrt_mode: A optional bool, specifying whether to perform square root on "scale", either "True" or "False". Defaults to "False". *@li round_mode: An optional string, specifying the float16 to int8 cast type. -* The value range is [Round, Floor, Ceiling, Truncate]. Defaults to "Round" . \n +* The value range is [Round, Floor, Ceil, Truncate]. Defaults to "Round" . +*@li dst_type: A optional int32, specifying the output data type. Defaults to "DT_INT8" . \n *@par Outputs: *y: The quantized output tensor of type int8 and with format NC1HWC0 . \n diff --git a/third_party/fwkacllib/inc/ops/ragged_array_ops.h b/third_party/fwkacllib/inc/ops/ragged_array_ops.h index 20484623..5af2dd74 100644 --- a/third_party/fwkacllib/inc/ops/ragged_array_ops.h +++ b/third_party/fwkacllib/inc/ops/ragged_array_ops.h @@ -37,13 +37,18 @@ namespace ge { *deprecated name. *@li indices: Indices in the outermost dimension of `params` of the values that should be *gathered. + +*@par Attributes: +*@li PARAMS_RAGGED_RANK:The ragged rank of the params_nested_splits. +*@li Tsplits:A type of output_nested_splits. *@li OUTPUT_RAGGED_RANK: The ragged rank of the output RaggedTensor. `output_nested_splits` will contain *this number of `row_splits` tensors. This value should equal *`indices.shape.ndims + params.ragged_rank - 1` . \n *@par Outputs: -*y:A Returns The `nested_row_splits` tensors that define the row-partitioning for the -*returned RaggedTensor.The `flat_values` for the returned RaggedTensor . \n +*@li output_nested_splits:A Returns The `nested_row_splits` tensors that define the row-partitioning for the +*returned RaggedTensor.The `flat_values` for the returned RaggedTensor . +*@li output_dense_values:The `flat_values` for the returned RaggedTensor. \n *@par Third-party framework compatibility * Compatible with tensorflow RaggedGather operator. diff --git a/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h b/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h index 020e3da4..ceaa64e4 100644 --- a/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h +++ b/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h @@ -61,7 +61,6 @@ REG_OP(RaggedTensorToSparse) *@brief Create a dense tensor from a ragged tensor, possibly altering its shape . \n *@par Inputs: -*Six inputs, including: *@li shape:A `Tensor`. Must be one of the following types: `int64`, `int32`. *@li values:A 1D tensor representing the values of the ragged tensor. *@li default_value:A `Tensor`. Must have the same type as `values`. @@ -78,7 +77,7 @@ The types of the row partition tensors. At present, these can be: is preceeded by "FIRST_DIM_SIZE" . \n *@par Outputs: -*@li result: A `Tensor`. Has the same type as `values`. +*result: A `Tensor`. Has the same type as `values`. */ REG_OP(RaggedTensorToTensor) .INPUT(shape, TensorType({DT_INT32, DT_INT64})) diff --git a/third_party/fwkacllib/inc/ops/ragged_math_ops.h b/third_party/fwkacllib/inc/ops/ragged_math_ops.h index 258b0ca1..4376437f 100644 --- a/third_party/fwkacllib/inc/ops/ragged_math_ops.h +++ b/third_party/fwkacllib/inc/ops/ragged_math_ops.h @@ -35,7 +35,11 @@ namespace ge { *@li deltas: The deltas of each range . \n *@par Outputs: -*y:A Returns The `row_splits` for the returned `RaggedTensor`.The `flat_values` for the returned `RaggedTensor` . \n +*@li rt_dense_values:The `flat_values` for the returned `RaggedTensor`. +*@li rt_nested_splits:The `row_splits` for the returned `RaggedTensor`. \n + +*@par Attributes: +*Tsplits:A type of rt_nested_splits. *@attention Constraints: *The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors. diff --git a/third_party/fwkacllib/inc/ops/random_ops.h b/third_party/fwkacllib/inc/ops/random_ops.h index b65a68f1..66f9b65f 100644 --- a/third_party/fwkacllib/inc/ops/random_ops.h +++ b/third_party/fwkacllib/inc/ops/random_ops.h @@ -148,6 +148,32 @@ REG_OP(RandomGamma) .OP_END_FACTORY_REG(RandomGamma) /** +*@brief Returns the random permutation of integers from 0 to n-1. \n + +*@par Attributes: +*@li n: An required int. +*@li dtype: An optional str. Defaults to int64 . +*@li layout: An optional int. Defaults to 0 . \n + +*@par Outputs: +*out: A required Tensor. Must be one of the following types: + float16, float32, float32, int8, uint8, int16, int32, int64. \n + +*@attention Constraints: +*The implementation for Randperm on Ascend uses AICPU, with bad performance. + +*@par Third-party framework compatibility +*@li compatible with Pytorch Randperm operator. +*/ +REG_OP(Randperm) + .OUTPUT(out, TensorType({DT_INT64, DT_INT32, DT_INT16, + DT_UINT8, DT_INT8, DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) + .REQUIRED_ATTR(n, Int) + .ATTR(layout, Int, 0) + .ATTR(dtype, Type, DT_INT64) + .OP_END_FACTORY_REG(Randperm) + +/** *@brief Outputs random values from the Poisson distribution(s) described by rate . \n *@par Inputs: @@ -157,11 +183,12 @@ REG_OP(RandomGamma) *@par Attributes: *@li dtype: An optional type from: half, float32, float64, int32, int64. Defaults to int64. -*@li seed: An optional int. Defaults to 0. -*@li seed2: An optional int. Defaults to 0 . \n +*@li seed: An optional int. Defaults to 0. If either seed or seed2 are set to be non-zero, +the random number generator is seeded by the given seed. Otherwise, it is seeded by a random seed. +*@li seed2: An optional int. Defaults to 0 . A second seed to avoid seed collision. \n *@par Outputs: -*y: A Tensor of type dtype . \n +*y: A Tensor of type dtype float16, float, double, int32, int64. \n *@attention Constraints: *The implementation for RandomPoisson on Ascend uses AICPU, with bad performance. @@ -188,11 +215,13 @@ REG_OP(RandomPoisson) *x: A Tensor. The tensor to be shuffled . \n *@par Attributes: -*@li seed: An optional int. Defaults to 0. -*@li seed2: An optional int. Defaults to 0 . \n +*@li seed: An optional int. Defaults to 0. If either seed or seed2 are set to be non-zero, +the random number generator is seeded by the given seed. Otherwise, it is seeded by a random seed. +*@li seed2: An optional int. Defaults to 0 . A second seed to avoid seed collision. \n *@par Outputs: -*y: A Tensor. Has the same type as x . \n +*y: A Tensor. Has the same type as x . A Tensor of type float16, float, +*double, int32, int64, int16, uint16, int8, uint8, int32,int64. \n *@attention Constraints: *The implementation for RandomShuffle on Ascend uses AICPU, with bad performance. @@ -220,11 +249,12 @@ REG_OP(RandomShuffle) *@par Attributes: *@li dtype: A type from: half, float16, float32, float64. The type of the output. -*@li seed: An optional int. Defaults to 0. -*@li seed2: An optional int. Defaults to 0 . \n +*@li seed: An optional int. Defaults to 0. If either seed or seed2 are set to be non-zero, +the random number generator is seeded by the given seed. Otherwise, it is seeded by a random seed. +*@li seed2: An optional int. Defaults to 0 . A second seed to avoid seed collision. \n *@par Outputs: -*y: A Tensor of type dtype . \n +*y: A Tensor of type float32, float16, double. \n *@attention Constraints: *The implementation for RandomStandardNormal on Ascend uses AICPU, with bad performance. @@ -241,6 +271,28 @@ REG_OP(RandomStandardNormal) .OP_END_FACTORY_REG(RandomStandardNormal) /** +*@brief Output random value from separate normal distribution. \n + +*@par Inputs: +*Inputs include: +*mean: The mean is a tensor with the mean of each output element’s normal distribution . +*std: The std is a tensor with the standard deviation of each output element’s normal distribution. \n +*@par Outputs: +*y: A Tensor of type dtype . \n + +*@attention Constraints: +*The implementation for Normal on Ascend uses AICPU, with bad performance. + +*@par Third-party framework compatibility +*@li compatible with Pytorch Normal operator. +*/ +REG_OP(Normal) + .INPUT(mean, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(std, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OP_END_FACTORY_REG(Normal) + +/** *@brief Outputs random integers from a uniform distribution . \n *@par Inputs: @@ -250,8 +302,9 @@ REG_OP(RandomStandardNormal) * @li max: A Tensor. Must have the same type as minval. 0-D . \n *@par Attributes: -*@li seed: An optional int. Defaults to 0. -*@li seed2: An optional int. Defaults to 0 . \n +*@li seed: An optional int. Defaults to 0. If either seed or seed2 are set to be non-zero, +the random number generator is seeded by the given seed. Otherwise, it is seeded by a random seed. +*@li seed2: An optional int. Defaults to 0 . A second seed to avoid seed collision. \n *@par Outputs: *y: A Tensor. Has the same type as min . \n @@ -280,8 +333,9 @@ REG_OP(RandomUniformInt) *@par Attributes: *@li dtype: A type from: half, float16, float32, float64. The type of the output. -*@li seed: An optional int. Defaults to 0. -*@li seed2: An optional int. Defaults to 0 . \n +*@li seed: An optional int. Defaults to 0. If either seed or seed2 are set to be non-zero, +the random number generator is seeded by the given seed. Otherwise, it is seeded by a random seed. +*@li seed2: An optional int. Defaults to 0 . A second seed to avoid seed collision. \n *@par Outputs: *y: A Tensor of type dtype . \n @@ -308,11 +362,14 @@ REG_OP(RandomUniform) *shape: A Tensor. Must be one of the following types: int32, int64 . \n *@par Attributes: -*@li seed: An optional int. Defaults to 0. -*@li seed2: An optional int. Defaults to 0 . \n +*@li seed: An optional int. Defaults to 0.If either `seed` or `seed2` +are set to be non-zero, the random number generator is seeded by the given +seed. Otherwise, it is seeded by a random seed. +*@li seed2: An optional int. Defaults to 0 . A second seed to avoid seed collision. \n *@par Outputs: -*size: A Tensor of types: float16, float32, double . \n +*y: A Tensor of types: float16, float32, double . A tensor of the specified shape +filled with random truncated normal values. \n *@attention Constraints: *The implementation for TruncatedNormal on Ascend uses AICPU, with bad performance. @@ -505,15 +562,15 @@ REG_OP(RandomChoiceWithMask) *@par Inputs: *Inputs including: -* @li x: A required Tensor. Must be one of the following types: - float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64 . \n +* x: A required Tensor. Must be one of the following types: + float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64 . \n *@par Attributes: -*@li group: A required int32, specifying the number of groups to split the channel dimension into. Defaults to "1" . \n +* group: A required int32, specifying the number of groups to split the channel dimension into. Defaults to "1" . \n *@par Outputs: -*y: A required Tensor. Has same type and shape as "x". Must be one of the following types: - float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64 . \n +* y: A required Tensor. Has same type and shape as "x". Must be one of the following types: + float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64 . \n *@attention Constraints: *@li "group" must be greater than 0 and must evenly divide the channel dimension size. @@ -584,6 +641,50 @@ REG_OP(DropoutV2) .OUTPUT(seed, TensorType({ DT_FLOAT })) .REQUIRED_ATTR(p, Float) .OP_END_FACTORY_REG(DropoutV2) + +/** +* @brief The Bernoulli distribution with probability . \n + +* @par Inputs: +* @li x: A ND Tensor. Must be one of the following data types: + int8, uint8, int16, int32, int64, bool, float32, float64 . +* @li p: A ND Tensor. The probability of an element to be zeroed. + Must be one of the following data types: float32, float64. \n + +* @par Attributes: +* seed: An Integer, the seed of the random generator. Default value -1 + to use current timestamp, otherwise it should be a positive integer. + +* @par Outputs: +* y: A tensor with the same shape and type as "x". +*/ + +REG_OP(Bernoulli) + .INPUT(x, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_BOOL, DT_FLOAT, DT_DOUBLE})) + .INPUT(p, TensorType({ DT_FLOAT, DT_DOUBLE })) + .OUTPUT(y, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_BOOL, DT_FLOAT, DT_DOUBLE})) + .ATTR(seed, Int, -1) + .OP_END_FACTORY_REG(Bernoulli) + +/** + * @brief: Fill the input tensor with values drawn from the uniform distribution U(from, to). \n + + * @par Inputs: + * x: A Tensor. Must be one of the following types: float16, float, double. \n + + * @par Attributes: + * @li from: The lower bound of the uniform. Defaults: 0.0 + * @li to: The upper bound of the uniform. Defaults: 1.0 \n + + * @par Outputs: + * y: A Tensor has the same type as x. \n + */ +REG_OP(Uniform) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .ATTR(from, Float, 0.0) + .ATTR(to, Float, 1.0) + .OP_END_FACTORY_REG(Uniform) } // namespace ge #endif // OPS_BUILT_IN_OP_PROTO_INC_RANDOM_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/reduce_ops.h b/third_party/fwkacllib/inc/ops/reduce_ops.h index 97c7b8e1..1578ba59 100644 --- a/third_party/fwkacllib/inc/ops/reduce_ops.h +++ b/third_party/fwkacllib/inc/ops/reduce_ops.h @@ -576,7 +576,7 @@ REG_OP(ReduceAll) *@li axis: A mutable Tensor. The dimensions to reduce . \n *@par Attributes: -*@li keep_dims: A bool. If true, retains reduced dimensions with length 1. Defaults to "False" . \n +*keep_dims: A bool. If true, retains reduced dimensions with length 1. Defaults to "False" . \n *@par Outputs: *y: A Tensor. Has the same type and format as input "x" . \n @@ -967,9 +967,9 @@ REG_OP(EuclideanNormD) Defaults to "0.00001" . \n *@par Outputs: -*y: A Tensor of type float16 or float32 for the normalized "x". -*batch_mean: A Tensor of type float32 for the result mean. -*batch_ variance: A Tensor of type float32 for the result variance . \n +*@li y: A Tensor of type float16 or float32 for the normalized "x". +*@li batch_mean: A Tensor of type float32 for the result mean. +*@li batch_ variance: A Tensor of type float32 for the result variance . \n *@attention Constraints: *For Ascend 310, the result accuracy fails to reach 0.001 due to the square root instruction. @@ -987,7 +987,7 @@ REG_OP(INInferV2) .OP_END_FACTORY_REG(INInferV2) /** -*@brief Performs reduced instance normalization . \n +*@brief Performs reduce instance normalization. \n *@par Inputs: *x: A Tensor of type float16 or float32. \n @@ -1008,32 +1008,31 @@ REG_OP(INTrainingReduceV2) /** -*@brief Performs update instance normalization . \n +*@brief Performs update instance normalization. \n *@par Inputs: -* Seven inputs, including: (NC1HWC0supported) +* Seven inputs, including: *@li x: A Tensor of type float16 or float32. *@li sum: A Tensor of type float32 for the output of operator INTrainingReduceV2. *@li square_sum: A Tensor of type float32 for the output of operator INTrainingReduceV2. *@li gamma: A Tensor of type float32, for the scaling gamma. *@li beta: A Tensor of type float32, for the scaling beta. *@li mean: A Tensor of type float32, for the updated mean. -*@li variance: A Tensor of type float32, for the updated variance . \n +*@li variance: A Tensor of type float32, for the updated variance. \n *@par Attributes: *@li momentum: A required float32, specifying the momentum to update mean and var. -*@li epsilon: A required float32, specifying the small value added to variance to avoid dividing by zero . \n +*@li epsilon: A required float32, specifying the small value added to variance to avoid dividing by zero. \n *@par Outputs: * Three outputs *@li y: A Tensor of type float16 or float32, for normalized "x". *@li batch_mean: A Tensor of type float32, for the updated mean. -*@li batch_variance: A Tensor of type float32, for the updated variance . \n +*@li batch_variance: A Tensor of type float32, for the updated variance. \n *@attention Constraints: -*@li This operator is a InstanceNorm fusion operator for updating the moving averages for training. +* This operator is a InstanceNorm fusion operator for updating the moving averages for training. * This operator is used in conjunction with INTrainingReduceV2. -*@li For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. */ REG_OP(INTrainingUpdateV2) .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) @@ -1052,6 +1051,80 @@ REG_OP(INTrainingUpdateV2) /** +*@brief Performs the backpropagation of InstanceNorm. \n + +*@par Inputs: +* Seven inputs, including: +*@li dy: A Tensor of type float16 or float32. +*@li x: A Tensor of type float16 or float32. +*@li variance: A Tensor of type float32, for the variance of "x". +*@li mean: A Tensor of type float32, for the mean of "x". +*@li res_gamma: A Tensor of type float32. +*@li res_beta: A Tensor of type float32. +*@li gamma: A Tensor of type float32. \n + +*@par Outputs: +*pd_x: A Tensor of type float16 or float32, for the offset of "x". \n + +*@attention Constraints: +* The preceding layer of this operator must be INTrainingUpdateGrad. \n +*/ +REG_OP(INTrainingReduceGrad) + .INPUT(dy, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(variance, TensorType({DT_FLOAT})) + .INPUT(mean, TensorType({DT_FLOAT})) + .INPUT(res_gamma, TensorType({DT_FLOAT})) + .INPUT(res_beta, TensorType({DT_FLOAT})) + .INPUT(gamma, TensorType({DT_FLOAT})) + .OUTPUT(pd_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(INTrainingReduceGrad) + +/** +*@brief Performs the backpropagation of InstanceNorm. \n + +*@par Inputs: +* Four inputs, including: +*@li dy: A Tensor of type float16 or float32, for the gradient. +*@li x: A Tensor of type float16 or float32. +*@li variance: A Tensor of type float32, for the variance of "x". +*@li mean: A Tensor of type float32, for the mean of "x". \n + +*@par Outputs: +*@li res_gamma: A Tensor of type float32. +*@li res_beta: A Tensor of type float32. \n + +*/ +REG_OP(INTrainingUpdateGrad) + .INPUT(dy, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(variance, TensorType({DT_FLOAT})) + .INPUT(mean, TensorType({DT_FLOAT})) + .OUTPUT(res_gamma, TensorType({DT_FLOAT})) + .OUTPUT(res_beta, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(INTrainingUpdateGrad) + +/** +*@brief Performs the backpropagation of InstanceNorm. \n + +*@par Inputs: +* Two inputs, including: +*@li res_gamma: A Tensor of type float32. +*@li res_beta: A Tensor of type float32. \n + +*@par Outputs: +*@li pd_gamma: A Tensor of type float32. +*@li pd_beta: A Tensor of type float32. \n + +*/ +REG_OP(INTrainingUpdateGradGammaBeta) + .INPUT(res_gamma, TensorType({DT_FLOAT})) + .INPUT(res_beta, TensorType({DT_FLOAT})) + .OUTPUT(pd_gamma, TensorType({DT_FLOAT})) + .OUTPUT(pd_beta, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(INTrainingUpdateGradGammaBeta) + +/** *@brief Performs reduced group normalization . \n *@par Inputs: @@ -1063,7 +1136,7 @@ REG_OP(INTrainingUpdateV2) *@par Attributes: -*@li num_groups: Int, specifying the num of groups. required, same to GNTrainingUpdate . \n +*num_groups: Int, specifying the num of groups. required, same to GNTrainingUpdate . \n *@attention Constraints: * This operator is a GroupNorm fusion operator for updating the moving averages for training. @@ -1081,7 +1154,7 @@ REG_OP(GNTrainingReduce) *@brief Performs update group normalization . \n *@par Inputs: -* Eight inputs, including: (NCHW NHWC supported) +* Seven inputs, including: (NCHW NHWC supported) *@li x: A Tensor of type float16 or float32. *@li sum: A 5D Tensor of type float32, shape is [N, G, 1, 1, 1] for NCHW, [N, 1, 1, G, 1] for NHWC @@ -1145,8 +1218,8 @@ include: *@li keep_dims:A bool, An optional bool. Defaults to False. If True, retain reduced dimensions with length 1.. *@li separator:string. -*@par output: -*@li output::A Tensor of type string.. +*@par Outputs: +*output:A Tensor of type string. */ REG_OP(ReduceJoin) .INPUT(input, TensorType({DT_STRING})) @@ -1160,7 +1233,7 @@ REG_OP(ReduceJoin) * @brief Calculates the standard deviation and average value of Tensors. * @par Inputs: -* @li x: A Tensor. Must be one of the following types: +* x: A Tensor. Must be one of the following types: * float16, float32. \n * @par Attributes: diff --git a/third_party/fwkacllib/inc/ops/resource_variable_ops.h b/third_party/fwkacllib/inc/ops/resource_variable_ops.h index 74ac83f8..156f2f34 100644 --- a/third_party/fwkacllib/inc/ops/resource_variable_ops.h +++ b/third_party/fwkacllib/inc/ops/resource_variable_ops.h @@ -33,10 +33,12 @@ namespace ge { *y:A Tensor of type resource. \n *@par Attributes: -* @li container: optional, string. -* @li shared_name: optional, string. -* @li dtype: required, type. -* @li shape: optional, ListInt. \n +* @li container: optional, string. the container this +variable is placed in. +* @li shared_name: optional, string.the name by which + this variable is referred to. +* @li dtype: required, type. the output of type. +* @li shape: optional, ListInt. the output of shape. \n *@see VarHandleOp. */ @@ -53,11 +55,11 @@ REG_OP(VarHandleOp) *@brief Assigns a new value to a variable. \n *@par Inputs: -*resource:Handle to the resource in which to store the variable. -*value:The value to set the new tensor to use. \n +*@li resource:Handle to the resource in which to store the variable. +*@li value:The value to set the new tensor to use. \n *@par Attributes: -* @li dtype: required, type. \n +* dtype: required, type. \n *@see AssignVariableOp. */ @@ -73,11 +75,11 @@ REG_OP(AssignVariableOp) *@brief Adds a value to the current value of a variable. \n *@par Inputs: -*resource:Handle to the resource in which to store the variable. -*value:The value by which the variable will be incremented. \n +*@li resource:Handle to the resource in which to store the variable. +*@li value:The value by which the variable will be incremented. \n *@par Attributes: -* @li dtype: required, type. \n +* dtype: required, type. \n *@see AssignAddVariableOp. */ @@ -93,11 +95,11 @@ REG_OP(AssignAddVariableOp) *@brief Subtracts a value to the current value of a variable. \n *@par Inputs: -*resource:Handle to the resource in which to store the variable. -*value:The value by which the variable will be incremented. \n +*@li resource:Handle to the resource in which to store the variable. +*@li value:The value by which the variable will be incremented. \n *@par Attributes: -* @li dtype: required, type. \n +* dtype: required, type. \n *@see AssignSubVariableOp. */ diff --git a/third_party/fwkacllib/inc/ops/rnn.h b/third_party/fwkacllib/inc/ops/rnn.h index 80546860..20828a89 100644 --- a/third_party/fwkacllib/inc/ops/rnn.h +++ b/third_party/fwkacllib/inc/ops/rnn.h @@ -127,9 +127,7 @@ REG_OP(DynamicLSTM) *@li cell_clip:An float identifying the cell clip in the op. Default to -1. *@li num_proj:An integer identifying the num projection in the op. Default to 0. *@li time_major:An bool identifying the time major in the op. Default to false. -*@li activation:An string identifying the type of activation function in the op. Default to "tanh". Only tanh is currently supported. *@li forget_bias:An float identifying the forget bias in the op. Default to 0. -*@li is_training:An bool identifying is training in the op. Default to true. *@par Outputs: *eight outputs: \n @@ -491,7 +489,6 @@ REG_OP(DynamicLSTMV2) *ten inputs: \n *@li w:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. *@li init_c:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. -*@li h:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. *@li c:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. *@li dy:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. *@li dh:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. @@ -504,10 +501,11 @@ REG_OP(DynamicLSTMV2) *@par Outputs: -*eight outputs: \n +*four outputs: \n *@li dx:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. *@li dh_prev:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. *@li dc_prev:A 4D Tensor. Must be one of the following types: float16, float32. The format must be FRACTAL_NZ. +*@li dgate:A 4D Tensor. Must be one of the following types: float16. The format must be FRACTAL_NZ. */ REG_OP(LSTMInputGrad) .INPUT(w, TensorType({DT_FLOAT16, DT_FLOAT})) @@ -571,13 +569,13 @@ REG_OP(DynamicLSTMGradCell) .INPUT(f, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(o, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(tanhct, TensorType({DT_FLOAT16, DT_FLOAT})) - .INPUT(mask, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(t_state, TensorType({DT_INT32, DT_INT32})) + .INPUT(mask, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(dgate, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(dct_1, TensorType({DT_FLOAT16, DT_FLOAT})) - .ATTR(forget_bias, Float, 1) - .ATTR(activation, String, "") - .ATTR(direction, String, "Forward") + .ATTR(forget_bias, Float, 1.0) + .ATTR(activation, String, "tanh") + .ATTR(direction, String, "UNIDIRECTIONAL") .ATTR(gate_order, String, "ijfo") .OP_END_FACTORY_REG(DynamicLSTMGradCell) @@ -1070,7 +1068,7 @@ REG_OP(GRUV2HiddenGradCell) * If "False", "grad_weight" will not be scale by word_frequency. \n * @par Outputs: -* @li grad_weight: A mutable output Tensor of new word grad has the same type as "grads". \n +* y: A mutable output Tensor of new word grad has the same type as "grads". \n * @par Third-party framework compatibility * Compatible with the Pytorch operator EmbeddingDenseGrad. @@ -1222,7 +1220,7 @@ REG_OP(CommonGRU) * is equivalent to the size of indices. This matches the CSR format.. \n * @par Outputs: -* @li grad_weight: A mutable output Tensor of new word grad has the same type as "grads". \n +* y: A mutable output Tensor of new word grad has the same type as "grads". \n * @par Third-party framework compatibility * Compatible with the Pytorch operator EmbeddingBag. diff --git a/third_party/fwkacllib/inc/ops/rpn_ops.h b/third_party/fwkacllib/inc/ops/rpn_ops.h index 089af326..850b3e5a 100644 --- a/third_party/fwkacllib/inc/ops/rpn_ops.h +++ b/third_party/fwkacllib/inc/ops/rpn_ops.h @@ -28,12 +28,12 @@ namespace ge { * iou_threshold with higher scoring box according to their * intersection-over-union (IoU) . \n -*@par Input: -* @li box_scores: 2-D tensor with shape of [N, 8], including proposal boxes and +* @par Inputs: +* box_scores: 2-D tensor with shape of [N, 8], including proposal boxes and * corresponding confidence scores . \n * @par Attributes: -* @li iou_threshold: An optional float. The threshold for deciding whether boxes +* iou_threshold: An optional float. The threshold for deciding whether boxes * overlap too much with respect to IOU . \n * @par Outputs: diff --git a/third_party/fwkacllib/inc/ops/sdca_ops.h b/third_party/fwkacllib/inc/ops/sdca_ops.h index 34c6a268..601b360b 100644 --- a/third_party/fwkacllib/inc/ops/sdca_ops.h +++ b/third_party/fwkacllib/inc/ops/sdca_ops.h @@ -45,7 +45,13 @@ namespace ge { *corresponding weights in sparse_weights. This field maybe omitted for the dense approach.It's a dynamic input. *@li sparse_weights: a list of vectors where each value is the weight associated with a sparse feature group. *@li dense_weights: a list of vectors where the values are the weights associated with a dense feature group.It's a dynamic input. -*@li example_state_data: a list of vectors containing the example state data. +*@li example_state_data: a list of vectors containing the example state data. \n + +*@par Attributes: +*@li adaptive: the type is bool default false. +*@li num_sparse_features:The num of sparse. +*@li num_sparse_features_with_values: The num of sparse_feature_values +*@li num_dense_features:The num of dense. *@li loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, squared and hinge losses. *@li l1: Symmetric l1 regularization strength. *@li l2: Symmetric l2 regularization strength. @@ -53,10 +59,10 @@ namespace ge { *@li num_inner_iterations: Number of iterations per mini-batch . \n *@par Outputs: -*y: A Returns a list of vectors containing the updated example state +*@li out_example_state_data: A Returns a list of vectors containing the updated example state *data.a list of vectors where each value is the delta -*weights associated with a sparse feature group.a list of vectors where the values are the delta -*weights associated with a dense feature group . \n +*@li out_delta_sparse_weights:weights associated with a sparse feature group.a list of vectors where the values are the delta +*@li out_delta_dense_weights:weights associated with a dense feature group . \n *@par Third-party framework compatibility * Compatible with tensorflow SdcaOptimizerV2 operator. diff --git a/third_party/fwkacllib/inc/ops/selection_ops.h b/third_party/fwkacllib/inc/ops/selection_ops.h index 1c26e033..43f72ef3 100644 --- a/third_party/fwkacllib/inc/ops/selection_ops.h +++ b/third_party/fwkacllib/inc/ops/selection_ops.h @@ -258,7 +258,7 @@ REG_OP(GatherV2D) REG_OP(GatherElements) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64})) - .INPUT(index, TensorType({DT_INT64})) + .INPUT(index, TensorType({DT_INT32, DT_INT64})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64})) .ATTR(dim, Int, 0) .OP_END_FACTORY_REG(GatherElements) @@ -508,7 +508,7 @@ REG_OP(UnsortedSegmentSum) *@par Inputs: *One inputs, including: -* @li assist: A tensor. Must be one of the following types: +* assist: A tensor. Must be one of the following types: * float16, float32. \n * @par Attributes: @@ -970,10 +970,11 @@ REG_OP(TopKV2) * for matrices) . \n * @par Attributes: -* @li sorted: An optional bool. Defaults to true. +* @li sorted: Defaults to true. * If true, the resulting "k" elements will be sorted by the values in descending * order. -* @li T: Indicator of indices type . \n +* @li largest:If true the resulting `k` elements will be sorted by the values in descending order. +* @li dim:0-D. Number of top elements to look for along the last dimension (along each row for matrices). \n * @par Outputs: * @li values: A Tensor, specifying the sorted data. Has the same type as @@ -982,7 +983,7 @@ REG_OP(TopKV2) * @see TopK() * @par Third-party framework compatibility -* @li Compatible with the TensorFlow operator TopKV2. +* Compatible with the TensorFlow operator TopKV2. */ REG_OP(TopK) .INPUT(x, TensorType::RealNumberType()) @@ -1085,7 +1086,6 @@ REG_OP(InTopKD) * @brief Says whether the targets are in the top "k" predictions . \n * @par Inputs: -* Two inputs, including: * @li x1: A 2D Tensor of type float32. A "batch_size * classes" tensor. * @li x2: A 1D Tensor of type IndexNumberType. A batch_size tensor of class ids. * @li k: A 1D Tensor of the same type as "x2". @@ -1618,12 +1618,12 @@ REG_OP(UnsortedSegmentMinD) * y: A Tensor of type RealNumberType . \n * @attention Constraints: -* @li segment_ids must be non-negative tensor. +* segment_ids must be non-negative tensor. * @see UnsortedSegmentSum(), UnsortedSegmentProd(), * @par Third-party framework compatibility -* @li Compatible with the TensorFlow operator UnsortedSegmentMax. +* Compatible with the TensorFlow operator UnsortedSegmentMax. */ REG_OP(UnsortedSegmentMax) .INPUT(x, TensorType::RealNumberType()) @@ -1875,15 +1875,15 @@ REG_OP(Crop) *@par Inputs: *One inputs, including: -* @li x: A tensor . Must be one of the following types: +* x: A tensor . Must be one of the following types: * float16, float32, int32, uint32, int8, uint8. \n *@par Attributes: -* @li axis: Axis along which to cummin. \n +* axis: Axis along which to cummin. \n *@par Outputs: -* y: A Tensor with the same type and shape of x's. \n -* indices: A Tensor with the int32 type and the same shape of x's. \n +* @li y: A Tensor with the same type and shape of x's. +* @li indices: A Tensor with the int32 type and the same shape of x's. \n *@par Third-party framework compatibility *Compatible with the Pytorch operator Cummin. \n @@ -1968,17 +1968,14 @@ REG_OP(WriteSelect) .OP_END_FACTORY_REG(WriteSelect) /** -*@brief Read data by stride . \n +*@brief Read data by stride. *@par Inputs: -*One input: -*x: A Tensor. Must be one of the following types: float16, int8 . \n +*x: A Tensor. Must be one of the following types: float16, int8. \n *@par Attributes: -*@li axis: A required int32, specifying the index of axis to read by stride . \n - -*@par Attributes: -*@li stride: A required int32, specifying the value of reading stride . \n +*@li axis: A required int32, specifying the index of axis to read by stride. \n +*@li stride: A required int32, specifying the value of reading stride. \n *@par Outputs: *y: A Tensor of the same type as "x". @@ -1991,16 +1988,14 @@ REG_OP(StridedRead) .OP_END_FACTORY_REG(StridedRead) /** -*@brief: Write data by stride . \n +*@brief Write data by stride. *@par Inputs: -*x: A Tensor. Must be one of the following types: float16, int8 . \n - -*@par Attributes: -*@li axis: A required int32, specifying the index of axis to write by stride . \n +*x: A Tensor. Must be one of the following types: float16, int8. \n *@par Attributes: -*@li stride: A required int32, specifying the value of writing stride . \n +*@li axis: A required int32, specifying the index of axis to write by stride. \n +*@li stride: A required int32, specifying the value of writing stride. \n *@par Outputs: *y: A Tensor. Has the same type as "x". @@ -2076,10 +2071,10 @@ REG_OP(CumulativeLogsumexpD) * @li updates: A Tensor of the same type as "var". \n * @par Attributes: -* @li axis: An required int to specify the axis to perform indices add. \n +* axis: An required int to specify the axis to perform indices add. \n * @par Outputs: -* @li var: A Tensor. Same as input "var". +* var: A Tensor. Same as input "var". * @par Third-party framework compatibility * Compatible with the Pytorch operator index_add_. @@ -2104,7 +2099,7 @@ REG_OP(InplaceIndexAdd) * @li value: A Tensor of dtype float16 or float32 or int64 or int32 or int8. * @par Outputs: -* @li y: A tensor. Must be one of the following dtypes: +* y: A tensor. Must be one of the following dtypes: * float16, float32, int64, int32, int8. */ REG_OP(MaskedFill) @@ -2123,7 +2118,7 @@ REG_OP(MaskedFill) * @li mask: A Tensor of dtype is bool. \n * @par Outputs: -* @li y: A tensor with the same type as x. \n +* y: A tensor with the same type as x. \n * @par Third-party framework compatibility * Compatible with the Numpy operator select. @@ -2134,13 +2129,50 @@ REG_OP(MaskedSelectV2) .INPUT(mask, TensorType({DT_BOOL})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) .OP_END_FACTORY_REG(MaskedSelectV2) + +/** +* @brief Choose the value of X with value according to mask. + +* @par Inputs: +* two inputs, including: +* @li x: A Tensor of dtype is float16 or float32 or float64 or int64 or int32 or int16 or int8 or uint8. +* @li mask: A Tensor of dtype is bool. \n + +* @par Outputs: +* @li y: A tensor with the same type as x. \n + +*/ +REG_OP(MaskedSelect) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, DT_INT16, DT_INT32, DT_INT64})) + .INPUT(mask, TensorType({DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, DT_INT16, DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(MaskedSelect) + +/** +* @brief update the value of X with value according to mask. + +* @par Inputs: +* three inputs, including: +* @li x: A Tensor of dtype is float16 or float32 or float64 or int64 or int32 or int16 or int8 or uint8. +* @li mask: A Tensor of dtype is bool. +* @li updates: A tensor with the same type as x. \n + +* @par Outputs: +* @li y: A tensor with the same type as x. \n +*/ +REG_OP(MaskedScatter) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, DT_INT16, DT_INT32, DT_INT64})) + .INPUT(mask, TensorType({DT_BOOL})) + .INPUT(updates, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, DT_INT16, DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, DT_INT16, DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(MaskedScatter) /** * @brief Slice a tensor at its last dim, e.x. a[..., begin:end:stride]. \n * @par Inputs: * One inputs, including: -* @li x: A Tensor. Must be one of the following types: float16, float32, int16, int32. +* x: A Tensor. Must be one of the following types: float16, float32, int16, int32. * @par Attributes: * @li start: An attribute of type Int, start index of last dim. \n @@ -2148,7 +2180,7 @@ REG_OP(MaskedSelectV2) * @li stride: An attribute of type Int, stride of slice. \n * @par Outputs: -* @li y: A Tensor. Has the same type as "x". \n +* y: A Tensor. Has the same type as "x". \n * @par Third-party framework compatibility * No compatibility @@ -2162,39 +2194,36 @@ REG_OP(SliceLastDim) .OP_END_FACTORY_REG(SliceLastDim) /** -* @brief Extracts a strided slice of a tensor. Roughly speaking, this op \n -* extracts a slice of size (end-begin)/stride from the given input tensor. \n -* Starting at the location specified by begin the slice continues by \n +* @brief Extracts a strided slice of a tensor. Roughly speaking, this op +* extracts a slice of size (end-begin)/stride from the given input tensor. +* Starting at the location specified by begin the slice continues by * adding stride to the index until all dimensions are not less than end. \n * * @par Inputs: -* Four inputs, including: -* @li x: A Tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8, \n -* complex64, int64, qint8, quint8, qint32, qint16, quint16, uint16, \n -* complex128, float16, uint32, uint64, complex64, complex128. \n +* Five inputs, including: +* @li x: A Tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8, +* complex64, int64, qint8, quint8, qint32, qint16, quint16, uint16, +* complex128, float16, uint32, uint64, complex64, complex128. * @li begin: A Tensor of type int32 or int64, for the index of the first value to select. -* * @li end: A Tensor of type int32 or int64, for the index of the last value to select. -* * @li axes: A Tensor of type int32 or int64, indicate axis to be select. -* -* @li strides: A Tensor of type int32 or int64, for the increment. +* @li strides: A Tensor of type int32 or int64, for the increment. \n * * @par Attributes: -* @li begin_mask: A Tensor of type int32. \n -* A bitmask where a bit "i" being "1" means to ignore the begin \n +* @li begin_mask: A Tensor of type int32. +* A bitmask where a bit "i" being "1" means to ignore the begin * value and instead use the largest interval possible. -* @li end_mask: A Tensor of type int32. \n +* @li end_mask: A Tensor of type int32. * Analogous to "begin_mask". -* @li ellipsis_mask: A Tensor of type int32. \n -* A bitmask where bit "i" being "1" means the "i"th position \n +* @li ellipsis_mask: A Tensor of type int32. +* A bitmask where bit "i" being "1" means the "i"th position * is actually an ellipsis. -* @li new_axis_mask: A Tensor of type int32. \n -* A bitmask where bit "i" being "1" means the "i"th \n +* @li new_axis_mask: A Tensor of type int32. +* A bitmask where bit "i" being "1" means the "i"th * specification creates a new shape 1 dimension. -* @li shrink_axis_mask: A Tensor of type int32. \n -* A bitmask where bit "i" implies that the "i"th \n -* specification should shrink the dimensionality. +* @li shrink_axis_mask: A Tensor of type int32. +* A bitmask where bit "i" implies that the "i"th +* specification should shrink the dimensionality. \n * * @par Outputs: * y: A Tensor. Has the same type as "x". @@ -2231,7 +2260,7 @@ REG_OP(StridedSliceV2) * float16, float32, int32. \n * @par Attributes: -* @li dim: A required int. Used to select the dimension of this tensor. \n +* dim: A required int. Used to select the dimension of this tensor. \n *@par Outputs: *y: A Tensor with the same type and shape of input_x's. \n @@ -2307,6 +2336,34 @@ REG_OP(MaskedFillRange) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32})) .REQUIRED_ATTR(axis, Int) .OP_END_FACTORY_REG(MaskedFillRange) + +/** +* @brief After a set of sorted data and a new set of data are re-sorted, get the first k data. \n +* +* @par Inputs: +* Six inputs, including: +* @li topk_pq_distance: A sorted Tensor, Will be updated after calculation. Must be one of the following types: float32, float16. +* @li topk_pq_index: A Tensor of type int32, index corresponding to topk_pq_distance. +* @li topk_pq_ivf: A Tensor of type int32 , the bucket number corresponding to topk_pq_distance. +* @li pq_distance: A Tensor of type float32 or float16, the new data set will be reordered with topk_pq_distance and updated to topk_pq_distance. +* @li pq_index: A Tensor of type int32, index corresponding to pq_distance. +* @li pq_ivf: A scalar of type int32 , the bucket number corresponding to pq_distance. \n +* +* @par Attributes: +* @li order: A string, indicates the sorting method of topk_pq_distance. \n +* +* @par Restrictions: +* Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(InplaceTopKDistance) + .INPUT(topk_pq_distance, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(topk_pq_index, TensorType({DT_INT32})) + .INPUT(topk_pq_ivf, TensorType({DT_INT32})) + .INPUT(pq_distance, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(pq_index, TensorType({DT_INT32})) + .INPUT(pq_ivf, TensorType({DT_INT32})) + .ATTR(order, String, "asc") + .OP_END_FACTORY_REG(InplaceTopKDistance) } // namespace ge #endif // OPS_BUILT_IN_OP_PROTO_INC_SELECTION_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/sparse_ops.h b/third_party/fwkacllib/inc/ops/sparse_ops.h index a1fc9ee6..8eb7b521 100644 --- a/third_party/fwkacllib/inc/ops/sparse_ops.h +++ b/third_party/fwkacllib/inc/ops/sparse_ops.h @@ -281,9 +281,9 @@ REG_OP(SparseSliceGrad) * @li size: A 1D Tensor of type int64. The size of the slice . \n *@par Outputs: -*y_indices: A Tensor of type int64. -*y_values: A Tensor. Has the same type as "values". -*y_values: A Tensor of type int64 . \n +*@li y_indices: A Tensor of type int64. +*@li y_values: A Tensor. Has the same type as "values". +*@li y_shape: A Tensor of type int64 . \n *@par Third-party framework compatibility * Compatible with the TensorFlow operator SparseSlice. @@ -313,8 +313,8 @@ REG_OP(SparseSlice) * @li sum_indices: A 2D Tensor of type int64. The indices of the sum SparseTensor, with size [nnz(sum), ndims] . \n *@par Outputs: -*x1_val_grad: A Tensor. Has the same type as "backprop_val_grad". -*x2_val_grad: A Tensor. Has the same type as "backprop_val_grad" . \n +*@li x1_val_grad: A Tensor. Has the same type as "backprop_val_grad". +*@li x2_val_grad: A Tensor. Has the same type as "backprop_val_grad" . \n *@par Third-party framework compatibility * Compatible with the TensorFlow operator SparseAddGrad. @@ -363,7 +363,7 @@ REG_OP(SparseFillEmptyRowsGrad) *@par Inputs: * @li x1_indices: A 2D Tensor of type int32 or int64. -* @li The indices of the matrix "SparseTensor", with size [nnz, 2]. +*The indices of the matrix "SparseTensor", with size [nnz, 2]. * @li x1_values: A 1D Tensor. The values of the SparseTensor, with size [nnz]. * @li x1_shape: A 1D Tensor of type int64. The shape of the SparseTensor, with size [2]. * @li x2: A dense matrix Tensor of the same type as "x1_values". 2D . \n @@ -373,9 +373,9 @@ REG_OP(SparseFillEmptyRowsGrad) *@par Attributes: *@li adjoint_a: An optional bool. Defaults to "False".Use the adjoint of A in the matrix multiply. -*@li If A is complex, this is transpose(conj(A)). Otherwise it is transpose(A). +*If A is complex, this is transpose(conj(A)). Otherwise it is transpose(A). *@li adjoint_b: An optional bool. Defaults to "False".Use the adjoint of B in the matrix multiply. -*@li If B is complex, this is transpose(conj(B)). Otherwise it is transpose(B) . \n +*If B is complex, this is transpose(conj(B)). Otherwise it is transpose(B) . \n *@par Third-party framework compatibility * Compatible with the TensorFlow operator SparseTensorDenseMatMul. @@ -400,9 +400,13 @@ REG_OP(SparseTensorDenseMatMul) * @li indices: A 0D, 1D, or 2D Tensor of type int32 or int64. * @li output_shape: A 1D Tensor of the same type as "sparse_indices". The shape of the dense output tensor. * @li values: A 1D Tensor. Values corresponding to each row of "sparse_indices", -* @li or a scalar value to be used for all sparse indices. +or a scalar value to be used for all sparse indices. * @li default_value: A Tensor of the same type as "sparse_values" . \n +*@par Attributes: +*validate_indices: If true, indices are checked to make sure they are sorted in +lexicographic order and that there are no repeats. \n + *@par Outputs: *y: A Tensor. Has the same type as "values" . \n @@ -427,7 +431,6 @@ REG_OP(SparseToDense) *Concatenation is with respect to the dense versions of these sparse tensors . \n *@par Inputs: -*3 or 5 inputs,contains: * @li indices:A list of at least 2 `Tensor` objects with type `int64`.2-D. *Indices of each input `SparseTensor`.It's a dynamic input. * @li values:A list with the same length as `indices` of `Tensor` objects with the same type. @@ -700,7 +703,6 @@ REG_OP(SparseReduceMaxSparse) *@brief Computes the sum of elements across dimensions of a SparseTensor . \n *@par Inputs: -*4 or 5 inputs, including: * @li x_indices: A 2D Tensor of type int64. *"N x R" matrix with the indices of non-empty values in a *SparseTensor, possibly not in canonical ordering. @@ -711,13 +713,11 @@ REG_OP(SparseReduceMaxSparse) *A length-"K" vector containing the reduction axes . \n *@par Attributes: -* keep_dims: An optional bool. Defaults to "False". +*keep_dims: An optional bool. Defaults to "False". *If true, retains reduced dimensions with length 1 . \n *@par Outputs: -* @li y_indices: A Tensor of type int64. -* @li y_values: A Tensor. Has the same type as "input_values". -* @li y_shape: A Tensor of type int64 . \n +*y: A Tensor. Has the same type as "x_values". \n *@par Third-party framework compatibility * Compatible with the TensorFlow operator SparseReduceSum. @@ -818,7 +818,6 @@ REG_OP(SparseSplit) *@brief Generates sparse cross from a list of sparse and dense tensors . \n *@par Inputs: -*8 or 10 inputs, including: * @li indices: A list of 2D Tensor objects of type int64. * Indices of each input SparseTensor.It's a dynamic input. * @li values: A list of 1D Tensor objects of type int64 or string. @@ -899,9 +898,8 @@ REG_OP(AddManySparseToTensorsMap) *@brief Reads SparseTensors from a "SparseTensorsMap" and concatenate them . \n *@par Inputs: -*2 or 4 inputs, including: * handles: A 1D Tensor of type int64. -* The "N" serialized SparseTensor objects . \n +*The "N" serialized SparseTensor objects . \n *@par Attributes: * @li dtype: A tf.DType. The "dtype" of the SparseTensor objects stored in the "SparseTensorsMap". @@ -911,9 +909,9 @@ REG_OP(AddManySparseToTensorsMap) *The shared name for the "SparseTensorsMap" read by this op . \n *@par Outputs: -* @li indices: A Tensor of type int64. -* @li values: A Tensor of type "dtype". -* @li shape: A Tensor of type int64 . \n +* @li indices: A Tensor of type int64.2-D. The `indices` of the minibatch `SparseTensor`. +* @li values: A Tensor of type "dtype". 1-D. The `values` of the minibatch `SparseTensor`. +* @li shape: A Tensor of type int64 . 1-D. The `shape` of the minibatch `SparseTensor`. \n *@par Third-party framework compatibility * Compatible with the TensorFlow operator TakeManySparseFromTensorsMap. @@ -989,8 +987,7 @@ REG_OP(SerializeManySparse) *@brief Deserializes SparseTensor objects . \n *@par Inputs: -*Two inputs, including: -* serialized_sparse: A Tensor. The serialized SparseTensor objects. +*serialized_sparse: A Tensor. The serialized SparseTensor objects. *The last dimension must have 3 columns . \n *@par Attributes: diff --git a/third_party/fwkacllib/inc/ops/spectral_ops.h b/third_party/fwkacllib/inc/ops/spectral_ops.h index 34ccb398..ab9e1dec 100644 --- a/third_party/fwkacllib/inc/ops/spectral_ops.h +++ b/third_party/fwkacllib/inc/ops/spectral_ops.h @@ -31,10 +31,10 @@ namespace ge { inner-most dimension of `x`. \n *@par Inputs: -*@li x: A Tensor. Must be the following types: complex64, complex128. \n +*x: A Tensor. Must be the following types: complex64, complex128. \n *@par Outputs: -*@li y: A complex tensor of the same rank as `x`. \n +*y: A complex tensor of the same rank as `x`. \n *@par Third-party framework compatibility * Compatible with TensorFlow IFFT operator. @@ -52,7 +52,7 @@ REG_OP(IFFT) *@li fft_length: An int32 tensor of shape [1]. The FFT length . \n *@par Outputs: -*@li y: A complex64 tensor of the same rank as `input`. The inner-most +*y: A complex64 tensor of the same rank as `input`. The inner-most dimension of `input` is replaced with the `fft_length / 2 + 1` unique frequency components of its 1D Fourier transform . \n @@ -73,7 +73,7 @@ REG_OP(RFFT) *@li fft_length: An int32 tensor of shape [1]. The FFT length. \n *@par Outputs: -*@li y: A float32 tensor of the same rank as `input`. The inner-most +* y: A float32 tensor of the same rank as `input`. The inner-most dimension of `input` is replaced with the `fft_length` samples of its inverse 1D Fourier transform. \n @@ -91,10 +91,10 @@ REG_OP(IRFFT) *@brief 2D fast Fourier transform. \n *@par Inputs: -*@li x: A complex64 tensor. +*x: A complex64 tensor. *@par Outputs: -*@li y: A complex64 tensor of the same shape as `input`. The inner-most 2 +*y: A complex64 tensor of the same shape as `input`. The inner-most 2 dimensions of `input` are replaced with their 2D Fourier transform. \n *@par Third-party framework compatibility @@ -110,10 +110,10 @@ REG_OP(FFT2D) innermost dimension of the input. \n *@par Inputs: -*@li x: A Tensor. Must be the following types: complex64, complex128. \n +*x: A Tensor. Must be the following types: complex64, complex128. \n *@par Outputs: -*@li y: A complex tensor with the same shape as input. The innermost dimension +*y: A complex tensor with the same shape as input. The innermost dimension of the input is replaced by its 1-dimensional Fourier transform. \n *@par Third-party framework compatibility @@ -129,10 +129,10 @@ REG_OP(FFT) innermost dimension of the input. \n *@par Inputs: -*@li x: A Tensor. Must be the following types: complex64, complex128. \n +*x: A Tensor. Must be the following types: complex64, complex128. \n *@par Outputs: -*@li y: A complex tensor with the same shape as input. The innermost dimension +*y: A complex tensor with the same shape as input. The innermost dimension of the input is replaced by its inverse two-dimensional Fourier transform. \n *@par Third-party framework compatibility diff --git a/third_party/fwkacllib/inc/ops/split_combination_ops.h b/third_party/fwkacllib/inc/ops/split_combination_ops.h index fe25a46f..98d4d111 100644 --- a/third_party/fwkacllib/inc/ops/split_combination_ops.h +++ b/third_party/fwkacllib/inc/ops/split_combination_ops.h @@ -161,14 +161,11 @@ REG_OP(SplitVD) /** *@brief Concatenates a list of N tensors along the first dimension. *@par Inputs: -* Two inputs, including: -* @li values: A list of Tensors. Must be one of the following types: int8, int16, int32, +* One input, including: +* values: A list of Tensors. Must be one of the following types: int8, int16, int32, * int64, uint8, uint16, uint32, uint64, float16, float32. * Tensors to be concatenated. All must have size 1 in the first dimension and same shape. -* It's a dynamic input. -* @li shape: A Tensor of the same type as "x". -* The final shape of the result. Should be equal to the shapes of any input -* but with the number of input values in the first dimension . \n +* It's a dynamic input. \n *@par Attributes: * @li shape: A required list of ints. diff --git a/third_party/fwkacllib/inc/ops/state_ops.h b/third_party/fwkacllib/inc/ops/state_ops.h index 3c8e32b6..d1ec00b5 100644 --- a/third_party/fwkacllib/inc/ops/state_ops.h +++ b/third_party/fwkacllib/inc/ops/state_ops.h @@ -104,7 +104,7 @@ REG_OP(DestroyTemporaryVariable) *@brief Checks whether a tensor has been initialized. Outputs boolean scalar indicating whether the tensor has been initialized . \n *@par Inputs: -*x: A tensor . \n +*x: A Tensor of type float16, float32, double, bool, int8, uint8, uint16, int16, int32, uint32, uint64, int64. *@par Outputs: *y: A tensor, indicating whether "x" has been initialized . \n diff --git a/third_party/fwkacllib/inc/ops/stateful_random_ops.h b/third_party/fwkacllib/inc/ops/stateful_random_ops.h index c2f65c6a..f4eb763c 100644 --- a/third_party/fwkacllib/inc/ops/stateful_random_ops.h +++ b/third_party/fwkacllib/inc/ops/stateful_random_ops.h @@ -32,7 +32,10 @@ namespace ge { *@par Inputs: *This op may use some OS-provided source of non-determinism (e.g. an RNG), *so each execution will give different results. Inputs included: -*@li shape: The shape of the output tensor . \n +*shape: The shape of the output tensor . \n + +*@par Attributes: +*dtype: required, type. \n *@par Outputs: *y:A Returns Non-deterministic integer values with specified shape . \n @@ -54,13 +57,10 @@ REG_OP(NonDeterministicInts) *counter is an unspecified implementation detail . \n *@par Inputs: -*@li resource: The handle of the resource variable that stores the state of the RNG. +*@li x: The handle of the resource variable that stores the state of the RNG. *@li algorithm: The RNG algorithm. *@li delta: The amount of advancement . \n -*@par Outputs: -*y:A Returns the created operation . \n - *@par Third-party framework compatibility * Compatible with tensorflow RngSkip operator. */ @@ -81,11 +81,16 @@ power of two. The bias is small for values of `maxval - minval` significantly smaller than the range of the output (either `2^32` or `2^64`) . \n *@par Inputs: -*@li resource: The handle of the resource variable that stores the state of the RNG. +*@li x: The handle of the resource variable that stores the state of the RNG. *@li algorithm: The RNG algorithm. *@li shape: The shape of the output tensor. -*@li minval: Minimum value (inclusive, scalar). -*@li maxval: Maximum value (exclusive, scalar) . \n +*@li counts: A 0/1-D Tensor or Python value. The counts of the binomial +distribution. Must be broadcastable with the leftmost dimension defined by `shape`. +*@li probs: A 0/1-D Tensor or Python value. The probability of success for the +binomial distribution. Must be broadcastable with the leftmost dimension defined by `shape`.\n + +*@par Attributes: +*dtype: required, type. \n *@par Outputs: *y:A Returns Random values with specified shape . \n @@ -109,7 +114,7 @@ REG_OP(StatefulRandomBinomial) *The generated values will have mean 0 and standard deviation 1 . \n *@par Inputs: -*@li resource: The handle of the resource variable that stores the state of the RNG. +*@li x: The handle of the resource variable that stores the state of the RNG. *@li algorithm: The RNG algorithm. *@li shape: The shape of the output tensor . \n @@ -134,7 +139,7 @@ REG_OP(StatefulStandardNormalV2) *deviations from the mean are dropped and re-picked . \n *@par Inputs: -*@li resource: The handle of the resource variable that stores the state of the RNG. +*@li x: The handle of the resource variable that stores the state of the RNG. *@li algorithm: The RNG algorithm. *@li shape: The shape of the output tensor . \n @@ -158,7 +163,7 @@ The generated values follow a uniform distribution in the range `[0, 1)`. The lower bound 0 is included in the range, while the upper bound 1 is excluded. *@par Inputs: -*@li resource: The handle of the resource variable that stores the state of the RNG. +*@li x: The handle of the resource variable that stores the state of the RNG. *@li algorithm: The RNG algorithm. *@li shape: The shape of the output tensor . \n @@ -181,7 +186,7 @@ REG_OP(StatefulUniform) The generated values are uniform integers covering the whole range of `dtype` . \n *@par Inputs: -*@li resource: The handle of the resource variable that stores the state of the RNG. +*@li x: The handle of the resource variable that stores the state of the RNG. *@li algorithm: The RNG algorithm. *@li shape: The shape of the output tensor . \n @@ -209,7 +214,7 @@ power of two. The bias is small for values of `maxval - minval` significantly smaller than the range of the output (either `2^32` or `2^64`) . \n *@par Inputs: -*@li resource: The handle of the resource variable that stores the state of the RNG. +*@li x: The handle of the resource variable that stores the state of the RNG. *@li algorithm: The RNG algorithm. *@li shape: The shape of the output tensor. *@li minval: Minimum value (inclusive, scalar). diff --git a/third_party/fwkacllib/inc/ops/string_ops.h b/third_party/fwkacllib/inc/ops/string_ops.h index f9cc2549..a78d63a1 100644 --- a/third_party/fwkacllib/inc/ops/string_ops.h +++ b/third_party/fwkacllib/inc/ops/string_ops.h @@ -295,7 +295,7 @@ REG_OP(StringSplit) *@par Inputs: include: -*@li input:A Tensor of type string. The text to be processed. \n +*input:A Tensor of type string. The text to be processed. \n *@par Attributes: *@li pattern:A string. The regular expression to match the input. @@ -303,8 +303,8 @@ include: *@li replace_global:An optional bool. Defaults to True. If True, the replacement is global, otherwise the replacement is done only on the first match. -*@par output: -*@li output::A Tensor of type string. +*@par Outputs: +*output::A Tensor of type string. */ REG_OP(StaticRegexReplace) .INPUT(input, TensorType({DT_STRING})) @@ -322,13 +322,13 @@ REG_OP(StaticRegexReplace) *@par Inputs: include: -*@li input:A Tensor of type string. The text to be processed. \n +*input:A Tensor of type string. The text to be processed. \n *@par Attributes: -*@li pattern:A string. The regular expression to match the input. +*pattern:A string. The regular expression to match the input. -*@par output: -*@li output::A bool tensor with the same shape as `input`. +*@par Outputs: +*output::A bool tensor with the same shape as `input`. */ REG_OP(StaticRegexFullMatch) .INPUT(input, TensorType({DT_STRING})) @@ -347,10 +347,10 @@ include: *@li num_segments:A Tensor. Must be one of the following types: int32, int64. A scalar. *@par Attributes: -*@li separator:An optional string. Defaults to "". The separator to use when joining. +*separator:An optional string. Defaults to "". The separator to use when joining. -*@par output: -*@li output::A Tensor of type string.. +*@par Outputs: +*output::A Tensor of type string.. */ REG_OP(UnsortedSegmentJoin) .INPUT(input, TensorType({DT_STRING})) @@ -366,13 +366,13 @@ REG_OP(UnsortedSegmentJoin) *@par Inputs: include: -*@li input:A Tensor of type string. The text to be processed. +*input:A Tensor of type string. The text to be processed. *@par Attributes: -*@li encoding:An optional string. Defaults to "". +*encoding:An optional string. Defaults to "". -*@par output: -*@li output::A Tensor of type string.. +*@par Outputs: +*output::A Tensor of type string.. */ REG_OP(StringLower) .INPUT(input, TensorType({DT_STRING})) @@ -386,13 +386,13 @@ REG_OP(StringLower) *@par Inputs: include: -*@li input:A Tensor of type string. The text to be processed. +*input:A Tensor of type string. The text to be processed. *@par Attributes: -*@li encoding:An optional string. Defaults to "". +*encoding:An optional string. Defaults to "". -*@par output: -*@li output::A Tensor of type string.. +*@par Outputs: +*output::A Tensor of type string.. */ REG_OP(StringUpper) .INPUT(input, TensorType({DT_STRING})) @@ -901,10 +901,10 @@ REG_OP(DecodeBase64) *@brief StringNormalization performs string operations for basic cleaning . \n *@par Inputs: -*@li input: only accepts [C] or [1, C] UTF-8 strings tensor . \n +*input: only accepts [C] or [1, C] UTF-8 strings tensor . \n *@par Outputs: -*@li output: UTF-8 strings tensor after cleaning . \n +*output: UTF-8 strings tensor after cleaning . \n *@par Attributes: *@li stopwords : list of strings (default is empty). @@ -919,13 +919,13 @@ case-sensitive. Default is false. *string enum that cases output to be lowercased/uppercases/unchanged. Valid values are "LOWER", "UPPER", "NONE". Default is "NONE". -*@li local : string (default is "en_US"). +*@li locale : string (default is "C"). *Environment dependent string that denotes the locale according to which output -strings needs to be upper/lowercased.Default en_US or platform specific equivalent -as decided by the implementation . \n +strings needs to be upper/lowercased.Default C or platform specific equivalent +as decided by the implementation. \n *@attention Constraints: -*@li input can be either a 1-D or 2-D tensor, the shape of 2-D tensor must be [1, C]. +*input can be either a 1-D or 2-D tensor, the shape of 2-D tensor must be [1, C]. */ REG_OP(StringNormalizer) .INPUT(input, TensorType({DT_STRING})) @@ -933,7 +933,7 @@ REG_OP(StringNormalizer) .ATTR(stopwords, ListString, {}) .ATTR(is_case_sensitive, Bool, false) .ATTR(case_change_action, String, "NONE") - .ATTR(local, String, "en_US") + .ATTR(locale, String, "C") .OP_END_FACTORY_REG(StringNormalizer) } // namespace ge diff --git a/third_party/fwkacllib/inc/ops/transformation_ops.h b/third_party/fwkacllib/inc/ops/transformation_ops.h index 4a46e35f..f403fe12 100644 --- a/third_party/fwkacllib/inc/ops/transformation_ops.h +++ b/third_party/fwkacllib/inc/ops/transformation_ops.h @@ -29,15 +29,15 @@ namespace ge { *@par Inputs: *The input handle must have the resource type. Inputs include: -*@li x:A list of Tensor objects. One or more tensors from which +*x:A list of Tensor objects. One or more tensors from which the enqueued tensors should be taken . \n *@par Outputs: -*@li y:A list of Tensor objects. One or more tensors from which +*y:A list of Tensor objects. One or more tensors from which the enqueued tensors should be taken . \n *@par Attributes: -*@li type: An optional ge::DataType. It refers to the target data type of outputs . \n +*type: An optional ge::DataType. It refers to the target data type of outputs . \n *@par Third-party framework compatibility *Compatible with tensorflow QueueIsClosed operator. @@ -723,11 +723,12 @@ REG_OP(CompressFcOp) *@brief Performs Col2im for each batch entry. \n *@par Inputs: -*@li input_x: The Col Tensor. 5-D, shape: `(n, c1, kernel_h*kernel_w, ho*wo, c0)`. -where ho/wo is do = (output_d + 2*padding_d - dilation_d*(kernel_d - 1) - 1)//stride_d + 1 \n +*@li x: The Col Tensor. 4-D, shape: `(n, c, kernel_h*kernel_w, ho*wo)`. +where ho/wo is do = (output_d + 2*padding_d - dilation_d*(kernel_d - 1) - 1)//stride_d + 1. +*@li output_size: The img shape Tensor. 1-D, shape:`(2)`, value: (output_h, output_w). \n *@par Outputs: -*@li output_y: The img Tensor. 5-D, shape: `(n, c1, output_h, output_w, c0)`. \n +*y: The img Tensor. 4-D, shape: `(n, c, output_h, output_w)`. \n *@par Attributes: *@li kernel_shape: ListInt, value: `(kernel_h, kernel_w)`, the shape of kernel in convolution. @@ -909,7 +910,7 @@ output shape would be [max(ngram_indexes) + 1]. If input shape is [N, C], this o *@li either pool_strings or pool_int64s attributes must be present but not both. */ -REG_OP(TfidVectorizer) +REG_OP(TfIdfVectorizer) .INPUT(input, TensorType({DT_INT32, DT_INT64, DT_STRING})) .OUTPUT(output, TensorType({DT_FLOAT})) .REQUIRED_ATTR(max_gram_length, Int) @@ -921,7 +922,7 @@ REG_OP(TfidVectorizer) .ATTR(pool_int64s, ListInt, {}) .ATTR(pool_strings, ListString, {}) .ATTR(weights, ListFloat, {}) - .OP_END_FACTORY_REG(TfidVectorizer) + .OP_END_FACTORY_REG(TfIdfVectorizer) } // namespace ge #endif // OPS_BUILT_IN_OP_PROTO_INC_TRANSFORMATION_OPS_H_ diff --git a/third_party/fwkacllib/inc/opt_info/opt_info.h b/third_party/fwkacllib/inc/opt_info/opt_info.h new file mode 100644 index 00000000..4dff695b --- /dev/null +++ b/third_party/fwkacllib/inc/opt_info/opt_info.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace gelc { +using Status = uint32_t; +using WorkMode = uint32_t; +const Status SUCCESS = 0x0; +const Status FAILED = 0xFFFFFFFF; +const WorkMode kOffline = 0x0; +const WorkMode kInline = 0x01; + +__attribute__((visibility ("default"))) +Status GetOptInfo(WorkMode mode, const std::string &soc_ver, + std::map &opt_info_map); +} // namespace gelc + diff --git a/third_party/fwkacllib/inc/runtime/base.h b/third_party/fwkacllib/inc/runtime/base.h index 7fc1cdea..70e42dc9 100644 --- a/third_party/fwkacllib/inc/runtime/base.h +++ b/third_party/fwkacllib/inc/runtime/base.h @@ -20,7 +20,7 @@ #include #include "toolchain/prof_callback.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif @@ -357,7 +357,7 @@ RTS_API rtError_t rtLabelCreateExV2(rtLabel_t *label, rtModel_t model, rtStream_ */ RTS_API rtError_t rtGetTaskIdAndStreamID(uint32_t *taskId, uint32_t *streamId); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) } #endif diff --git a/third_party/fwkacllib/inc/runtime/config.h b/third_party/fwkacllib/inc/runtime/config.h index 66ae36ef..76836e7b 100644 --- a/third_party/fwkacllib/inc/runtime/config.h +++ b/third_party/fwkacllib/inc/runtime/config.h @@ -19,7 +19,7 @@ #include "base.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif @@ -43,6 +43,7 @@ typedef enum tagRtChipType { CHIP_LHISI, CHIP_DC, CHIP_CLOUD_V2, + CHIP_NO_DEVICE, CHIP_END, } rtChipType_t; @@ -52,6 +53,14 @@ typedef enum tagRtAicpuScheType { SCHEDULE_HARDWARE, /* HWTS Schedule */ } rtAicpuScheType; +typedef enum tagRtDeviceCapabilityType { + RT_SCHEDULE_SOFTWARE = 0, // Software Schedule + RT_SCHEDULE_SOFTWARE_OPT, + RT_SCHEDULE_HARDWARE, // HWTS Schedule + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, + RT_AICPU_BLOCKING_OP_SUPPORT, // 1910/1980/1951 ts support AICPU blocking operation +} rtDeviceCapabilityType; + typedef enum tagRtVersion { VER_BEGIN = 0, VER_NA = VER_BEGIN, @@ -71,6 +80,7 @@ typedef enum tagRtPlatformType { PLATFORM_LHISI_CS, PLATFORM_DC, PLATFORM_CLOUD_V2, + PLATFORM_LHISI_SD3403, PLATFORM_END, } rtPlatformType_t; @@ -226,7 +236,7 @@ RTS_API rtError_t rtSetOpWaitTimeOut(uint32_t timeout); */ RTS_API rtError_t rtSetOpExecuteTimeOut(uint32_t timeout); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) } #endif diff --git a/third_party/fwkacllib/inc/runtime/context.h b/third_party/fwkacllib/inc/runtime/context.h index e95d4c89..c597a657 100644 --- a/third_party/fwkacllib/inc/runtime/context.h +++ b/third_party/fwkacllib/inc/runtime/context.h @@ -19,7 +19,7 @@ #include "base.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif @@ -157,7 +157,7 @@ RTS_API rtError_t rtGetGroupCount(uint32_t *count); */ RTS_API rtError_t rtSetCtxINFMode(bool mode); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) } #endif diff --git a/third_party/fwkacllib/inc/runtime/dev.h b/third_party/fwkacllib/inc/runtime/dev.h index 2cf6712f..4a9a5817 100644 --- a/third_party/fwkacllib/inc/runtime/dev.h +++ b/third_party/fwkacllib/inc/runtime/dev.h @@ -19,7 +19,7 @@ #include "base.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif @@ -65,6 +65,7 @@ typedef enum tagRtFeatureType { typedef enum tagRtDeviceFeatureType { FEATURE_TYPE_SCHE, + FEATURE_TYPE_BLOCKING_OPERATOR, FEATURE_TYPE_END, } rtDeviceFeatureType_t; @@ -78,6 +79,17 @@ typedef enum tagMemoryInfo { MEMORY_INFO_RSV } rtMemoryInfo_t; +typedef enum tagRtDeviceModuleType { + RT_MODULE_TYPE_SYSTEM = 0, /**< system info*/ + RT_MODULE_TYPE_AICPU, /** < aicpu info*/ + RT_MODULE_TYPE_CCPU, /**< ccpu_info*/ + RT_MODULE_TYPE_DCPU, /**< dcpu info*/ + RT_MODULE_TYPE_AICORE, /**< AI CORE info*/ + RT_MODULE_TYPE_TSCPU, /**< tscpu info*/ + RT_MODULE_TYPE_PCIE, /**< PCIE info*/ + RT_MODULE_TYPE_VECTOR_CORE, /**< VECTOR CORE info*/ +} rtDeviceModuleType_t; + /** * @ingroup dvrt_dev * @brief get total device number. @@ -368,7 +380,7 @@ RTS_API rtError_t rtSetDeviceWithoutTsd(int32_t device); */ RTS_API rtError_t rtDeviceResetWithoutTsd(int32_t device); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) } #endif diff --git a/third_party/fwkacllib/inc/runtime/dvfsprofile.h b/third_party/fwkacllib/inc/runtime/dvfsprofile.h index 6e451695..33e2f4c1 100644 --- a/third_party/fwkacllib/inc/runtime/dvfsprofile.h +++ b/third_party/fwkacllib/inc/runtime/dvfsprofile.h @@ -19,7 +19,7 @@ #include "base.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif @@ -56,7 +56,7 @@ RTS_API rtError_t rtUnsetDvfsProfile(); */ RTS_API rtError_t rtGetDvfsProfile(DvfsProfileMode *pmode); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) } #endif diff --git a/third_party/fwkacllib/inc/runtime/event.h b/third_party/fwkacllib/inc/runtime/event.h index 9e555230..81b635c3 100644 --- a/third_party/fwkacllib/inc/runtime/event.h +++ b/third_party/fwkacllib/inc/runtime/event.h @@ -19,16 +19,22 @@ #include "base.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif +typedef enum rtEventWaitStatus { + EVENT_STATUS_COMPLETE = 0, + EVENT_STATUS_NOT_READY = 1, + EVENT_STATUS_MAX = 2, +} rtEventWaitStatus_t; + /** * @ingroup event_flags * @brief event op bit flags */ -#define RT_EVENT_DEFAULT (0x00) -#define RT_EVENT_WITH_FLAG (0x01) +#define RT_EVENT_DEFAULT (0x0E) +#define RT_EVENT_WITH_FLAG (0x0B) #define RT_EVENT_DDSYNC_NS 0x01U #define RT_EVENT_STREAM_MARK 0x02U @@ -111,6 +117,16 @@ RTS_API rtError_t rtEventQuery(rtEvent_t event); /** * @ingroup dvrt_event + * @brief Queries an event's wait status + * @param [in] event event to query + * @param [in out] EVENT_WAIT_STATUS status + * @return EVENT_STATUS_COMPLETE for complete + * @return EVENT_STATUS_NOT_READY for not complete + */ +RTS_API rtError_t rtEventQueryWaitStatus(rtEvent_t event, rtEventWaitStatus_t *status); + +/** + * @ingroup dvrt_event * @brief computes the elapsed time between events. * @param [in] time time between start and end in ms * @param [in] start starting event @@ -256,7 +272,7 @@ RTS_API rtError_t rtNotifyGetAddrOffset(rtNotify_t notify, uint64_t *devAddrOffs */ RTS_API rtError_t rtSetIpcNotifyPid(const char *name, int32_t pid[], int num); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) } #endif diff --git a/third_party/fwkacllib/inc/runtime/kernel.h b/third_party/fwkacllib/inc/runtime/kernel.h index c79ee7a5..c1b9bd6d 100644 --- a/third_party/fwkacllib/inc/runtime/kernel.h +++ b/third_party/fwkacllib/inc/runtime/kernel.h @@ -20,7 +20,7 @@ #include "base.h" #include "stream.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif @@ -112,6 +112,16 @@ typedef struct rtKernelInfo { } *rtKernelInfo_t; /** + * @ingroup rt_kernel + * @brief op name + */ +typedef struct rtKernelLaunchNames { + const char *soName; // defined for so name + const char *kernelName; // defined for kernel type name + const char *opName; // defined for operator name +} rtKernelLaunchNames_t; + +/** * @ingroup rt_KernelConfigDump * @brief device dump type */ @@ -365,7 +375,7 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); /** - * @ingroup rt_kernel + * @ingroup rt_kernel(abandoned) * @brief launch kernel to device * @param [in] args argments address for kernel function * @param [in] argsSize argements size @@ -377,7 +387,21 @@ RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim RTS_API rtError_t rtKernelLaunchEx(void *args, uint32_t argsSize, uint32_t flags, rtStream_t stream); /** - * @ingroup rt_kernel + * @ingroup rt_kernel(in use) + * @brief launch kernel to device + * @param [in] opName opkernel name + * @param [in] args argments address for kernel function + * @param [in] argsSize argements size + * @param [in] flags launch flags + * @param [in] stream associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunchFwk(const char *opName, void *args, uint32_t argsSize, uint32_t flags, + rtStream_t rtStream); + +/** + * @ingroup rt_kernel(abandoned) * @brief launch cpu kernel to device * @param [in] soName so name * @param [in] kernelName kernel name @@ -393,7 +417,22 @@ RTS_API rtError_t rtCpuKernelLaunch(const void *soName, const void *kernelName, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream); /** - * @ingroup rt_kernel + * @ingroup rt_kernel(in use) + * @brief launch cpu kernel to device + * @param [in] launchNames names for kernel launch + * @param [in] blockDim block dimentions + * @param [in] args argments address for kernel function + * @param [in] argsSize argments size + * @param [in] smDesc shared memory description + * @param [in] stream associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtAicpuKernelLaunch(const rtKernelLaunchNames_t *launchNames, + uint32_t blockDim, const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream); + +/** + * @ingroup rt_kernel(abandoned) * @brief launch cpu kernel to device with dump identifier * @param [in] soName so name * @param [in] kernelName kernel name @@ -411,6 +450,22 @@ RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kern uint32_t flags); /** + * @ingroup rt_kernel(in use) + * @brief launch cpu kernel to device with dump identifier + * @param [in] launchNames names for kernel launch + * @param [in] blockDim block dimentions + * @param [in] args argments address for kernel function + * @param [in] argsSize argments size + * @param [in] smDesc shared memory description + * @param [in] stream associated stream + * @param [in] flag dump flag or others function flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtAicpuKernelLaunchWithFlag(const rtKernelLaunchNames_t *launchNames, uint32_t blockDim, + const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); + +/** * @ingroup rt_kernel * @brief L1 fusion dump addr transfered to device * @param [in] model handle info @@ -592,7 +647,7 @@ RTS_API rtError_t rtStartMDCProfiler(void **addr, uint32_t length); */ RTS_API rtError_t rtStopMDCProfiler(void *addr); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) } #endif diff --git a/third_party/fwkacllib/inc/runtime/mem.h b/third_party/fwkacllib/inc/runtime/mem.h index bace4bc6..b049e762 100644 --- a/third_party/fwkacllib/inc/runtime/mem.h +++ b/third_party/fwkacllib/inc/runtime/mem.h @@ -24,7 +24,7 @@ #include "config.h" #include "stream.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif @@ -547,7 +547,7 @@ RTS_API rtError_t rtSetIpcMemPid(const char *name, int32_t pid[], int num); */ RTS_API rtError_t rtRDMADBSend(uint32_t dbIndex, uint64_t dbInfo, rtStream_t stream); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) } #endif diff --git a/third_party/fwkacllib/inc/runtime/rt_ffts.h b/third_party/fwkacllib/inc/runtime/rt_ffts.h index bae5a54d..f2809218 100644 --- a/third_party/fwkacllib/inc/runtime/rt_ffts.h +++ b/third_party/fwkacllib/inc/runtime/rt_ffts.h @@ -8,7 +8,7 @@ #include "base.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif @@ -178,7 +178,7 @@ typedef struct tagFftsTaskInfo { RTS_API rtError_t rtFftsTaskLaunch(rtFftsTaskInfo_t *fftsTaskInfo, rtStream_t stream); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) } #endif #endif // __CCE_RUNTIME_FFTS_H diff --git a/third_party/fwkacllib/inc/runtime/rt_model.h b/third_party/fwkacllib/inc/runtime/rt_model.h index a7618b45..d4af72c5 100644 --- a/third_party/fwkacllib/inc/runtime/rt_model.h +++ b/third_party/fwkacllib/inc/runtime/rt_model.h @@ -19,7 +19,7 @@ #include "base.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif @@ -490,7 +490,7 @@ RTS_API rtError_t rtDebugRegister(rtModel_t model, uint32_t flag, const void *ad */ RTS_API rtError_t rtDebugUnRegister(rtModel_t model); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) } #endif diff --git a/third_party/fwkacllib/inc/runtime/rt_stars.h b/third_party/fwkacllib/inc/runtime/rt_stars.h index 188656b1..016c352a 100644 --- a/third_party/fwkacllib/inc/runtime/rt_stars.h +++ b/third_party/fwkacllib/inc/runtime/rt_stars.h @@ -8,7 +8,7 @@ #include "base.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif @@ -23,6 +23,7 @@ extern "C" { */ RTS_API rtError_t rtStarsTaskLaunch(const void *taskSqe, uint32_t sqeLen, rtStream_t stream); + /** * @ingroup rt_stars * @brief create cdq instance. @@ -76,10 +77,11 @@ RTS_API rtError_t rtCdqEnQueue(const char *queName, uint32_t cdqeIndex, void *da * @param [in] stream launch task on the stream * @return RT_ERROR_NONE for ok, others failed */ -RTS_API rtError_t rtCdqEnQueuePtrMode(const char *queName, uint32_t cdqeIndex, const void *prtAddr, +RTS_API rtError_t rtCdqEnQueuePtrMode(const char *queName, uint32_t cdqeIndex, const void *ptrAddr, rtStream_t stream); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) + } #endif #endif // __CCE_RUNTIME_STARS_H diff --git a/third_party/fwkacllib/inc/runtime/stream.h b/third_party/fwkacllib/inc/runtime/stream.h index f9981514..3a078e99 100644 --- a/third_party/fwkacllib/inc/runtime/stream.h +++ b/third_party/fwkacllib/inc/runtime/stream.h @@ -20,7 +20,7 @@ #include "base.h" #include "event.h" -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) extern "C" { #endif @@ -211,7 +211,7 @@ RTS_API rtError_t rtDebugRegisterForStream(rtStream_t stream, uint32_t flag, con */ RTS_API rtError_t rtDebugUnRegisterForStream(rtStream_t stream); -#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +#if defined(__cplusplus) } #endif diff --git a/third_party/fwkacllib/inc/toolchain/prof_acl_api.h b/third_party/fwkacllib/inc/toolchain/prof_acl_api.h index 07b32149..9350f9d4 100644 --- a/third_party/fwkacllib/inc/toolchain/prof_acl_api.h +++ b/third_party/fwkacllib/inc/toolchain/prof_acl_api.h @@ -84,6 +84,7 @@ #endif #include +#include namespace Msprofiler { namespace Api { @@ -105,6 +106,37 @@ extern "C" { MSVP_PROF_API uint64_t ProfGetOpExecutionTime(const void *data, uint32_t len, uint32_t index); +typedef int Status; +typedef struct aclprofSubscribeConfig aclprofSubscribeConfig1; +/// +/// @ingroup AscendCL +/// @brief subscribe profiling data of graph +/// @param [in] graphId: the graph id subscribed +/// @param [in] profSubscribeConfig: pointer to config of model subscribe +/// @return Status result of function +/// +Status aclgrphProfGraphSubscribe(const uint32_t graphId, + const aclprofSubscribeConfig1 *profSubscribeConfig); + +/// +/// @ingroup AscendCL +/// @brief unsubscribe profiling data of graph +/// @param [in] graphId: the graph id subscribed +/// @return Status result of function +/// +Status aclgrphProfGraphUnSubscribe(const uint32_t graphId); + +/** + * @ingroup AscendCL + * @brief get graph id from subscription data + * + * @param opInfo [IN] pointer to subscription data + * @param opInfoLen [IN] memory size of subscription data + * + * @retval graph id of subscription data + * @retval 0 for failed + */ +size_t aclprofGetGraphId(const void *opInfo, size_t opInfoLen, uint32_t index); #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/toolchain/prof_callback.h b/third_party/fwkacllib/inc/toolchain/prof_callback.h index 3666c36c..36b55216 100644 --- a/third_party/fwkacllib/inc/toolchain/prof_callback.h +++ b/third_party/fwkacllib/inc/toolchain/prof_callback.h @@ -55,6 +55,17 @@ struct ReporterData { }; /** + * @name HashData + * @brief struct of data to hash + */ +struct HashData { + int deviceId; // the index of device + size_t dataLen; // the length of data + unsigned char *data; // the data content + uint64_t hashId; // the id of hashed data +}; + +/** * @name MsprofReporterModuleId * @brief module id of data to report */ @@ -75,6 +86,7 @@ enum MsprofReporterCallbackType { MSPROF_REPORTER_INIT, // init reporter MSPROF_REPORTER_UNINIT, // uninit reporter MSPROF_REPORTER_DATA_MAX_LEN, // data max length for calling report callback + MSPROF_REPORTER_HASH // hash data to id }; /** @@ -110,7 +122,8 @@ enum MsprofCtrlCallbackType { MSPROF_CTRL_INIT_GE_OPTIONS, // start profiling with ge env and options MSPROF_CTRL_FINALIZE, // stop profiling MSPROF_CTRL_REPORT_FUN_P, // for report callback - MSPROF_CTRL_PROF_SWITCH // for prof switch + MSPROF_CTRL_PROF_SWITCH_ON, // for prof switch on + MSPROF_CTRL_PROF_SWITCH_OFF // for prof switch off }; #define MSPROF_MAX_DEV_NUM (64)