Browse Source

!707 update c76 code

Merge pull request !707 from 王涛/r1.2.0
pull/707/MERGE
计晨 Gitee 4 years ago
parent
commit
ca855a5bf7
100 changed files with 1342 additions and 1744 deletions
  1. +2
    -2
      .gitmodules
  2. +6
    -8
      build.sh
  3. +1
    -1
      cmake/FindModule.cmake
  4. +0
    -1
      cmake/external_libs/gflags.cmake
  5. +2
    -6
      cmake/external_libs/gtest.cmake
  6. +7
    -12
      cmake/external_libs/json.cmake
  7. +1
    -5
      cmake/external_libs/onnx.cmake
  8. +0
    -1
      cmake/external_libs/protobuf_shared.cmake
  9. +0
    -1
      cmake/external_libs/protobuf_static.cmake
  10. +0
    -1
      cmake/external_libs/protoc.cmake
  11. +2
    -11
      cmake/external_libs/securec.cmake
  12. +0
    -4
      ge/CMakeLists.txt
  13. +1
    -14
      ge/common/dump/dump_op.cc
  14. +2
    -2
      ge/common/ge/tbe_plugin_manager.cc
  15. +1
    -1
      ge/common/profiling/ge_profiling.cc
  16. +8
    -14
      ge/common/profiling/profiling_manager.cc
  17. +0
    -2
      ge/common/proto/op_mapping_info.proto
  18. +81
    -36
      ge/executor/ge_executor.cc
  19. +0
    -2
      ge/executor/proto/op_mapping_info.proto
  20. +0
    -2
      ge/ge_inference.mk
  21. +3
    -2
      ge/ge_local_engine/engine/host_cpu_engine.cc
  22. +0
    -2
      ge/ge_runner.mk
  23. +21
    -19
      ge/generator/ge_generator.cc
  24. +1
    -1
      ge/graph/build/memory/graph_mem_assigner.cc
  25. +0
    -1
      ge/graph/build/model_builder.cc
  26. +8
    -6
      ge/graph/build/stream_graph_optimizer.cc
  27. +67
    -10
      ge/graph/load/graph_loader.cc
  28. +6
    -0
      ge/graph/load/graph_loader.h
  29. +0
    -6
      ge/graph/load/new_model_manager/data_dumper.cc
  30. +343
    -152
      ge/graph/load/new_model_manager/davinci_model.cc
  31. +53
    -33
      ge/graph/load/new_model_manager/davinci_model.h
  32. +35
    -40
      ge/graph/load/new_model_manager/model_manager.cc
  33. +2
    -1
      ge/graph/load/new_model_manager/model_manager.h
  34. +74
    -58
      ge/graph/load/new_model_manager/task_info/kernel_task_info.cc
  35. +0
    -2
      ge/graph/load/new_model_manager/task_info/kernel_task_info.h
  36. +7
    -3
      ge/graph/load/new_model_manager/zero_copy_offset.cc
  37. +1
    -1
      ge/graph/load/new_model_manager/zero_copy_offset.h
  38. +45
    -2
      ge/graph/load/new_model_manager/zero_copy_task.cc
  39. +7
    -1
      ge/graph/load/new_model_manager/zero_copy_task.h
  40. +22
    -17
      ge/graph/manager/graph_manager.cc
  41. +1
    -2
      ge/graph/manager/graph_manager.h
  42. +3
    -0
      ge/graph/manager/graph_mem_allocator.cc
  43. +6
    -1
      ge/graph/optimize/graph_optimize.cc
  44. +2
    -1
      ge/graph/optimize/graph_optimize.h
  45. +23
    -5
      ge/graph/passes/attach_stream_label_pass.cc
  46. +3
    -1
      ge/graph/passes/attach_stream_label_pass.h
  47. +1
    -1
      ge/graph/passes/base_pass.cc
  48. +0
    -64
      ge/graph/passes/dimension_adjust_pass.cc
  49. +0
    -4
      ge/graph/passes/dimension_adjust_pass.h
  50. +7
    -57
      ge/graph/passes/enter_pass.cc
  51. +1
    -2
      ge/graph/passes/enter_pass.h
  52. +4
    -1
      ge/graph/passes/folding_pass.cc
  53. +10
    -0
      ge/graph/passes/merge_to_stream_merge_pass.cc
  54. +173
    -89
      ge/graph/passes/next_iteration_pass.cc
  55. +13
    -3
      ge/graph/passes/next_iteration_pass.h
  56. +0
    -106
      ge/graph/passes/remove_same_const_pass.cc
  57. +0
    -28
      ge/graph/passes/remove_same_const_pass.h
  58. +8
    -4
      ge/graph/passes/switch_to_stream_switch_pass.cc
  59. +0
    -51
      ge/graph/passes/useless_control_out_remove_pass.cc
  60. +0
    -29
      ge/graph/passes/useless_control_out_remove_pass.h
  61. +59
    -343
      ge/graph/preprocess/multi_batch_copy_graph.cc
  62. +1
    -15
      ge/graph/preprocess/multi_batch_copy_graph.h
  63. +0
    -2
      ge/hybrid/executor/hybrid_execution_context.h
  64. +0
    -1
      ge/hybrid/executor/hybrid_model_executor.cc
  65. +22
    -34
      ge/hybrid/executor/node_state.cc
  66. +1
    -2
      ge/hybrid/executor/node_state.h
  67. +8
    -1
      ge/hybrid/executor/subgraph_executor.cc
  68. +12
    -16
      ge/hybrid/executor/worker/execution_engine.cc
  69. +18
    -103
      ge/hybrid/executor/worker/shape_inference_engine.cc
  70. +0
    -4
      ge/hybrid/executor/worker/shape_inference_engine.h
  71. +0
    -3
      ge/hybrid/executor/worker/task_compile_engine.cc
  72. +1
    -4
      ge/hybrid/model/hybrid_model_builder.cc
  73. +34
    -57
      ge/hybrid/model/node_item.cc
  74. +0
    -6
      ge/hybrid/model/node_item.h
  75. +0
    -10
      ge/hybrid/node_executor/aicore/aicore_node_executor.cc
  76. +0
    -11
      ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc
  77. +0
    -38
      ge/hybrid/node_executor/task_context.cc
  78. +0
    -10
      ge/hybrid/node_executor/task_context.h
  79. +3
    -4
      ge/ir_build/atc_ir_common.cc
  80. +1
    -1
      ge/ir_build/atc_ir_common.h
  81. +1
    -1
      ge/offline/main.cc
  82. +0
    -2
      ge/proto/op_mapping_info.proto
  83. +4
    -8
      ge/single_op/single_op.cc
  84. +4
    -4
      ge/single_op/task/op_task.cc
  85. +14
    -7
      inc/external/ge/ge_api_types.h
  86. +0
    -4
      inc/framework/common/ge_types.h
  87. +13
    -0
      inc/framework/executor/ge_executor.h
  88. +32
    -34
      inc/framework/omg/parser/model_parser.h
  89. +1
    -1
      metadef
  90. +1
    -1
      parser
  91. +0
    -5
      tests/depends/runtime/src/runtime_stub.cc
  92. +35
    -44
      tests/ut/common/graph/CMakeLists.txt
  93. +7
    -29
      tests/ut/ge/CMakeLists.txt
  94. +0
    -2
      tests/ut/ge/graph/build/mem_assigner_unittest.cc
  95. +1
    -0
      tests/ut/ge/graph/passes/folding_kernel/broadcast_args_kernel_unittest.cc
  96. +1
    -0
      tests/ut/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel_unittest.cc
  97. +1
    -0
      tests/ut/ge/graph/passes/folding_kernel/empty_kernel_unittest.cc
  98. +0
    -1
      tests/ut/ge/graph/passes/variable_op_pass_unittest.cc
  99. +2
    -3
      tests/ut/ge/graph_ir/ge_operator_factory_unittest.cc
  100. +1
    -1
      tests/ut/ge/single_op/single_op_model_unittest.cc

+ 2
- 2
.gitmodules View File

@@ -1,8 +1,8 @@
[submodule "parser"] [submodule "parser"]
path = parser path = parser
url = https://gitee.com/ascend/parser.git url = https://gitee.com/ascend/parser.git
branch = development
branch = r1.2.0
[submodule "metadef"] [submodule "metadef"]
path = metadef path = metadef
url = https://gitee.com/ascend/metadef.git url = https://gitee.com/ascend/metadef.git
branch = development
branch = r1.2.0

+ 6
- 8
build.sh View File

@@ -224,14 +224,12 @@ if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then
# fi # fi


# if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then # if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then
echo "Generating coverage statistics, please wait..."
cd ${BASEPATH}
rm -rf ${BASEPATH}/cov
mkdir ${BASEPATH}/cov
lcov -c -d build/tests/ut/ge -d build/tests/ut/common/graph/ -o cov/tmp.info
lcov --remove cov/tmp.info '*/output/*' '*/build/opensrc/*' '*/build/proto/*' '*/third_party/*' '*/tests/*' '/usr/local/*' -o cov/coverage.info
cd ${BASEPATH}/cov
genhtml coverage.info
# echo "Generating coverage statistics, please wait..."
# cd ${BASEPATH}
# rm -rf ${BASEPATH}/cov
# mkdir ${BASEPATH}/cov
# gcovr -r ./ --exclude 'third_party' --exclude 'build' --exclude 'tests' --exclude 'prebuild' --exclude 'inc' --print-summary --html --html-details -d -o cov/index.html
# fi
fi fi


# generate output package in tar form, including ut/st libraries/executables # generate output package in tar form, including ut/st libraries/executables


+ 1
- 1
cmake/FindModule.cmake View File

@@ -21,7 +21,7 @@ function(find_module module name)
if ("${${module}_LIBRARY_DIR}" STREQUAL "${module}_LIBRARY_DIR-NOTFOUND") if ("${${module}_LIBRARY_DIR}" STREQUAL "${module}_LIBRARY_DIR-NOTFOUND")
message(FATAL_ERROR "${name} not found in ${path}") message(FATAL_ERROR "${name} not found in ${path}")
endif() endif()
add_library(${module} SHARED IMPORTED) add_library(${module} SHARED IMPORTED)
set_target_properties(${module} PROPERTIES set_target_properties(${module} PROPERTIES
IMPORTED_LOCATION ${${module}_LIBRARY_DIR} IMPORTED_LOCATION ${${module}_LIBRARY_DIR}


+ 0
- 1
cmake/external_libs/gflags.cmake View File

@@ -23,7 +23,6 @@ ExternalProject_Add(gflags_build
URL ${REQ_URL} URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${GE_CODE_DIR}/../../third_party/gflags/src/gflags-2.2.2 #SOURCE_DIR ${GE_CODE_DIR}/../../third_party/gflags/src/gflags-2.2.2
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gflags_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR> CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gflags_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR>
BUILD_COMMAND $(MAKE) BUILD_COMMAND $(MAKE)
INSTALL_COMMAND $(MAKE) install INSTALL_COMMAND $(MAKE) install


+ 2
- 6
cmake/external_libs/gtest.cmake View File

@@ -10,10 +10,7 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif() endif()


if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/ge_gtest/release-1.8.0.tar.gz")
set(MD5 "")
elseif (ENABLE_GITEE)
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz") set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz")
set(MD5 "") set(MD5 "")
else() else()
@@ -25,9 +22,8 @@ set (gtest_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-
set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(gtest_build ExternalProject_Add(gtest_build
URL ${REQ_URL} URL ${REQ_URL}
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest <SOURCE_DIR> CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest <SOURCE_DIR>
-DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON
-DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON
BUILD_COMMAND $(MAKE) BUILD_COMMAND $(MAKE)
INSTALL_COMMAND $(MAKE) install INSTALL_COMMAND $(MAKE) install
EXCLUDE_FROM_ALL TRUE EXCLUDE_FROM_ALL TRUE


+ 7
- 12
cmake/external_libs/json.cmake View File

@@ -5,24 +5,19 @@ endif()
include(ExternalProject) include(ExternalProject)


set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include)
if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/ge_nlohmann_json/include.zip")
set(MD5 "0dc903888211db3a0f170304cd9f3a89")
set(JSON_INCLUDE_DIR ${JSON_SRC_DIR})
#elseif (ENABLE_GITEE)
#if (ENABLE_GITEE)
# set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") # set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip")
# set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") # set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7")
#set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include")
else()
set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip")
set(MD5 "0dc903888211db3a0f170304cd9f3a89")
set(JSON_INCLUDE_DIR ${JSON_SRC_DIR})
endif ()
# set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include")
#else()
set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip")
set(MD5 "0dc903888211db3a0f170304cd9f3a89")
set(JSON_INCLUDE_DIR ${JSON_SRC_DIR})
#endif ()
ExternalProject_Add(json_build ExternalProject_Add(json_build
URL ${REQ_URL} URL ${REQ_URL}
#URL /home/txd/workspace/cloud_code/pkg/include.zip #URL /home/txd/workspace/cloud_code/pkg/include.zip
SOURCE_DIR ${JSON_SRC_DIR} SOURCE_DIR ${JSON_SRC_DIR}
TLS_VERIFY OFF
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
INSTALL_COMMAND "" INSTALL_COMMAND ""


+ 1
- 5
cmake/external_libs/onnx.cmake View File

@@ -6,10 +6,7 @@ set(ONNX_PROTO_DIR ${CMAKE_BINARY_DIR}/onnx)
set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto)
file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) file(MAKE_DIRECTORY ${ONNX_PROTO_DIR})


if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/onnx/onnx-1.6.0.tar.gz")
set(MD5 "512f2779d6215d4a36f366b6b9acdf1e")
elseif (ENABLE_GITEE)
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz")
set(MD5 "1bdbcecdd68ea8392630467646776e02") set(MD5 "1bdbcecdd68ea8392630467646776e02")
else() else()
@@ -22,7 +19,6 @@ ExternalProject_Add(onnx
#URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz
#URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345
#SOURCE_DIR ${ONNX_SRC_DIR} #SOURCE_DIR ${ONNX_SRC_DIR}
TLS_VERIFY OFF
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
#INSTALL_COMMAND "" #INSTALL_COMMAND ""


+ 0
- 1
cmake/external_libs/protobuf_shared.cmake View File

@@ -26,7 +26,6 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protobuf_build ExternalProject_Add(protobuf_build
URL ${REQ_URL} URL ${REQ_URL}
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} CONFIGURE_COMMAND ${CMAKE_COMMAND}
-Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_WITH_ZLIB=OFF
-DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR}


+ 0
- 1
cmake/external_libs/protobuf_static.cmake View File

@@ -27,7 +27,6 @@ ExternalProject_Add(protobuf_static_build
URL ${REQ_URL} URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} CONFIGURE_COMMAND ${CMAKE_COMMAND}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}


+ 0
- 1
cmake/external_libs/protoc.cmake View File

@@ -30,7 +30,6 @@ ExternalProject_Add(protoc_build
URL ${REQ_URL} URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake
BUILD_COMMAND $(MAKE) BUILD_COMMAND $(MAKE)
INSTALL_COMMAND $(MAKE) install INSTALL_COMMAND $(MAKE) install


+ 2
- 11
cmake/external_libs/securec.cmake View File

@@ -10,20 +10,11 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif() endif()


if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/securec/v1.1.10.tar.gz")
set(MD5 "")
else()
set(REQ_URL "https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz")
set(MD5 "")
endif ()

ExternalProject_Add(c_sec_build ExternalProject_Add(c_sec_build
URL ${REQ_URL}
#URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz
URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${GE_CODE_DIR}/../libc_sec #SOURCE_DIR ${GE_CODE_DIR}/../libc_sec
PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} CONFIGURE_COMMAND ${CMAKE_COMMAND}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}


+ 0
- 4
ge/CMakeLists.txt View File

@@ -157,8 +157,6 @@ set(TRAIN_SRC_LIST
"graph/passes/compile_nodes_pass.cc" "graph/passes/compile_nodes_pass.cc"
"graph/passes/constant_folding_pass.cc" "graph/passes/constant_folding_pass.cc"
"graph/passes/constant_fuse_same_pass.cc" "graph/passes/constant_fuse_same_pass.cc"
"graph/passes/remove_same_const_pass.cc"
"graph/passes/useless_control_out_remove_pass.cc"
"graph/passes/control_trigger_pass.cc" "graph/passes/control_trigger_pass.cc"
"graph/passes/dimension_adjust_pass.cc" "graph/passes/dimension_adjust_pass.cc"
"graph/passes/dimension_compute_pass.cc" "graph/passes/dimension_compute_pass.cc"
@@ -524,8 +522,6 @@ set(INFER_SRC_LIST
"graph/passes/assign_pass.cc" "graph/passes/assign_pass.cc"
"graph/passes/addn_pass.cc" "graph/passes/addn_pass.cc"
"graph/passes/common_subexpression_elimination_pass.cc" "graph/passes/common_subexpression_elimination_pass.cc"
"graph/passes/remove_same_const_pass.cc"
"graph/passes/useless_control_out_remove_pass.cc"
"graph/passes/transop_symmetry_elimination_pass.cc" "graph/passes/transop_symmetry_elimination_pass.cc"
"graph/passes/save_pass.cc" "graph/passes/save_pass.cc"
"graph/passes/switch_dead_branch_elimination.cc" "graph/passes/switch_dead_branch_elimination.cc"


+ 1
- 14
ge/common/dump/dump_op.cc View File

@@ -94,9 +94,6 @@ Status DumpOp::DumpOutput(aicpu::dump::Task &task) {
for (auto dim : output_descs.at(i).GetShape().GetDims()) { for (auto dim : output_descs.at(i).GetShape().GetDims()) {
output.mutable_shape()->add_dim(dim); output.mutable_shape()->add_dim(dim);
} }
for (auto dim : output_descs.at(i).GetOriginShape().GetDims()) {
output.mutable_origin_shape()->add_dim(dim);
}
int64_t output_size = 0; int64_t output_size = 0;
if (TensorUtils::GetTensorSizeInBytes(output_descs.at(i), output_size) != SUCCESS) { if (TensorUtils::GetTensorSizeInBytes(output_descs.at(i), output_size) != SUCCESS) {
GELOGE(PARAM_INVALID, "Get output size filed"); GELOGE(PARAM_INVALID, "Get output size filed");
@@ -121,9 +118,6 @@ Status DumpOp::DumpInput(aicpu::dump::Task &task) {
for (auto dim : input_descs.at(i).GetShape().GetDims()) { for (auto dim : input_descs.at(i).GetShape().GetDims()) {
input.mutable_shape()->add_dim(dim); input.mutable_shape()->add_dim(dim);
} }
for (auto dim : input_descs.at(i).GetOriginShape().GetDims()) {
input.mutable_origin_shape()->add_dim(dim);
}
int64_t input_size = 0; int64_t input_size = 0;
if (TensorUtils::GetTensorSizeInBytes(input_descs.at(i), input_size) != SUCCESS) { if (TensorUtils::GetTensorSizeInBytes(input_descs.at(i), input_size) != SUCCESS) {
GELOGE(PARAM_INVALID, "Get output size filed"); GELOGE(PARAM_INVALID, "Get output size filed");
@@ -220,15 +214,8 @@ Status DumpOp::LaunchDumpOp() {
SetOpMappingLoopAddr(global_step_, loop_per_iter_, loop_cond_, op_mapping_info); SetOpMappingLoopAddr(global_step_, loop_per_iter_, loop_cond_, op_mapping_info);
GELOGI("Dump step is %s ,dump path is %s ,in Launch dump op", dump_properties_.GetDumpStep().c_str(), GELOGI("Dump step is %s ,dump path is %s ,in Launch dump op", dump_properties_.GetDumpStep().c_str(),
dump_path.c_str()); dump_path.c_str());
uint32_t task_id = 0;
uint32_t stream_id = 0;
rt_ret = rtGetTaskIdAndStreamID(&task_id, &stream_id);
if (rt_ret != RT_ERROR_NONE) {
GELOGW("call rtGetTaskIdAndStreamID failed, ret = 0x%X", rt_ret);
}

aicpu::dump::Task task; aicpu::dump::Task task;
task.set_task_id(task_id);
task.set_stream_id(stream_id);
task.mutable_op()->set_op_name(op_desc_->GetName()); task.mutable_op()->set_op_name(op_desc_->GetName());
task.mutable_op()->set_op_type(op_desc_->GetType()); task.mutable_op()->set_op_type(op_desc_->GetType());
if (dump_properties_.GetDumpMode() == kDumpOutput) { if (dump_properties_.GetDumpMode() == kDumpOutput) {


+ 2
- 2
ge/common/ge/tbe_plugin_manager.cc View File

@@ -184,7 +184,7 @@ void TBEPluginManager::LoadCustomOpLib() {
std::string fmk_type = std::to_string(domi::TENSORFLOW); std::string fmk_type = std::to_string(domi::TENSORFLOW);
auto it = options_.find(ge::FRAMEWORK_TYPE); auto it = options_.find(ge::FRAMEWORK_TYPE);
if (it != options_.end()) { if (it != options_.end()) {
fmk_type = it->second;
fmk_type = it->second;
} }
std::vector<OpRegistrationData> registration_datas = domi::OpRegistry::Instance()->registrationDatas; std::vector<OpRegistrationData> registration_datas = domi::OpRegistry::Instance()->registrationDatas;
GELOGI("The size of registration_datas is: %zu", registration_datas.size()); GELOGI("The size of registration_datas is: %zu", registration_datas.size());
@@ -192,7 +192,7 @@ void TBEPluginManager::LoadCustomOpLib() {
if (std::to_string(reg_data.GetFrameworkType()) == fmk_type) { if (std::to_string(reg_data.GetFrameworkType()) == fmk_type) {
GELOGD("Begin to register optype: %s, imply_type: %s", reg_data.GetOmOptype().c_str(), GELOGD("Begin to register optype: %s, imply_type: %s", reg_data.GetOmOptype().c_str(),
TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str()); TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str());
(void)domi::OpRegistry::Instance()->Register(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
} }
} }
} }


+ 1
- 1
ge/common/profiling/ge_profiling.cc View File

@@ -182,7 +182,7 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le
command.module_index = prof_config_param->profSwitch; command.module_index = prof_config_param->profSwitch;
} }
GELOGI("GE commandhandle execute, Command Type: %s, data type config: 0x%llx", iter->second.c_str(), GELOGI("GE commandhandle execute, Command Type: %s, data type config: 0x%llx", iter->second.c_str(),
command.module_index);
command.module_index);
if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) { if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) {
GELOGI("Profiling device nums:%s , deviceID:[%s]", prof_params[0].c_str(), prof_params[kDeviceListIndex].c_str()); GELOGI("Profiling device nums:%s , deviceID:[%s]", prof_params[0].c_str(), prof_params[kDeviceListIndex].c_str());
} }


+ 8
- 14
ge/common/profiling/profiling_manager.cc View File

@@ -38,8 +38,10 @@ const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe";
} // namespace } // namespace


namespace ge { namespace ge {
ProfilingManager::ProfilingManager()
: is_load_profiling_(false), is_execute_profiling_(false), is_training_trace_(false), subscribe_count_(0) {
ProfilingManager::ProfilingManager() : is_load_profiling_(false),
is_execute_profiling_(false),
is_training_trace_(false),
subscribe_count_(0) {
prof_cb_.msprofCtrlCallback = nullptr; prof_cb_.msprofCtrlCallback = nullptr;
prof_cb_.msprofReporterCallback = nullptr; prof_cb_.msprofReporterCallback = nullptr;
} }
@@ -100,8 +102,8 @@ ge::Status ProfilingManager::InitFromOptions(const Options &options, MsprofGeOpt
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
is_execute_profiling_ = true; is_execute_profiling_ = true;
GELOGI("The profiling in options is %s, %s. origin option: %s", options.profiling_mode.c_str(), prof_conf.options,
options.profiling_options.c_str());
GELOGI("The profiling in options is %s, %s. origin option: %s", options.profiling_mode.c_str(),
prof_conf.options, options.profiling_options.c_str());
} else { } else {
(void)mmGetEnv("PROFILING_MODE", env_profiling_mode, MMPA_MAX_PATH); (void)mmGetEnv("PROFILING_MODE", env_profiling_mode, MMPA_MAX_PATH);
(void)mmGetEnv("PROFILING_OPTIONS", prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX); (void)mmGetEnv("PROFILING_OPTIONS", prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX);
@@ -141,9 +143,6 @@ ge::Status ProfilingManager::ParseOptions(const std::string &options) {
} }
try { try {
Json prof_options = Json::parse(options); Json prof_options = Json::parse(options);
if (options.find(kTrainingTrace) == std::string::npos) {
return ge::SUCCESS;
}
const std::string training_trace = prof_options[kTrainingTrace]; const std::string training_trace = prof_options[kTrainingTrace];
if (training_trace.empty()) { if (training_trace.empty()) {
GELOGI("Training trace will not take effect."); GELOGI("Training trace will not take effect.");
@@ -212,16 +211,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin
uint32_t block_dim = task.block_dim; uint32_t block_dim = task.block_dim;
uint32_t task_id = task.task_id; uint32_t task_id = task.task_id;
uint32_t stream_id = task.stream_id; uint32_t stream_id = task.stream_id;
std::string shape_type = task.shape_type;
int64_t cur_iter_num = task.cur_iter_num;
data = model_name.append(" ") data = model_name.append(" ")
.append(op_name).append(" ") .append(op_name).append(" ")
.append(std::to_string(block_dim)).append(" ")
.append(std::to_string(block_dim).append(" ")
.append(std::to_string(task_id)).append(" ") .append(std::to_string(task_id)).append(" ")
.append(std::to_string(stream_id)).append(" ") .append(std::to_string(stream_id)).append(" ")
.append(std::to_string(model_id)).append(" ")
.append(shape_type).append(" ")
.append(std::to_string(cur_iter_num)).append("\n");
.append(std::to_string(model_id)).append("\n"));


ReporterData reporter_data{}; ReporterData reporter_data{};
reporter_data.deviceId = device_id; reporter_data.deviceId = device_id;
@@ -846,7 +841,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpP
return; return;
} }
} }
return; return;
} }




+ 0
- 2
ge/common/proto/op_mapping_info.proto View File

@@ -15,7 +15,6 @@ message Output {
int32 original_output_data_type = 7; int32 original_output_data_type = 7;
int32 original_output_format = 8; int32 original_output_format = 8;
uint64 size = 9; uint64 size = 9;
Shape origin_shape = 10;
} }


message Input { message Input {
@@ -24,7 +23,6 @@ message Input {
Shape shape = 3; Shape shape = 3;
uint64 address = 4; uint64 address = 4;
uint64 size = 5; uint64 size = 5;
Shape origin_shape = 6;
} }


enum BufferType { enum BufferType {


+ 81
- 36
ge/executor/ge_executor.cc View File

@@ -209,33 +209,19 @@ bool IsDynmaicDimsSizeMatchModel(const vector<uint64_t> cur_dynamic_dims,


namespace ge { namespace ge {
bool GeExecutor::isInit_ = false; bool GeExecutor::isInit_ = false;

static void InitOpsProtoManger() {
string opsproto_path;
const char *path_env = std::getenv("ASCEND_OPP_PATH");
if (path_env != nullptr) {
string path = path_env;
string file_path = RealPath(path.c_str());
if (file_path.empty()) {
GELOGE(FAILED, "File path %s is invalid.", path.c_str());
return;
class ModelListenerAdapter : public ModelListener {
public:
domi::Status OnComputeDone(uint32_t model_id, uint32_t dataIndex, uint32_t resultCode,
std::vector<ge::OutputTensorInfo> &outputs) {
if (listener == nullptr) {
GELOGE(ge::FAILED, "listener is null.");
return FAILED;
} }
opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
GELOGI("Get opsproto so path from env : %s", path.c_str());
} else {
string path_base = PluginManager::GetPath();
GELOGI("path_base is %s", path_base.c_str());
path_base = path_base.substr(0, path_base.rfind('/'));
path_base = path_base.substr(0, path_base.rfind('/') + 1);
opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
}

GELOGI("Get opsproto path is %s", opsproto_path.c_str());
OpsProtoManager *manager = OpsProtoManager::Instance();
map<string, string> option_tmp;
option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
(void)manager->Initialize(option_tmp);
}
return listener->OnComputeDone(model_id, dataIndex, resultCode, outputs);
}

std::shared_ptr<ge::ModelListener> listener;
};


GeExecutor::GeExecutor() {} GeExecutor::GeExecutor() {}


@@ -246,16 +232,6 @@ Status GeExecutor::Initialize() {
return ge::SUCCESS; return ge::SUCCESS;
} }


OpTilingManager::GetInstance().LoadSo();

Status initHostCpuEngineStatus = HostCpuEngine::GetInstance().Initialize();
if (initHostCpuEngineStatus != SUCCESS) {
GELOGE(initHostCpuEngineStatus, "Failed to initialize HostCpuEngine");
return initHostCpuEngineStatus;
}

InitOpsProtoManger();

std::vector<rtMemType_t> mem_type(1, RT_MEMORY_HBM); std::vector<rtMemType_t> mem_type(1, RT_MEMORY_HBM);
mem_type.push_back(RT_MEMORY_P2P_DDR); mem_type.push_back(RT_MEMORY_P2P_DDR);
auto ret = MemManager::Instance().Initialize(mem_type); auto ret = MemManager::Instance().Initialize(mem_type);
@@ -560,6 +536,60 @@ Status GeExecutor::SetDynamicAippData(uint32_t model_id, void *dynamic_input_add
return SUCCESS; return SUCCESS;
} }


// Load model
Status GeExecutor::LoadModelOffline(uint32_t &model_id, const std::string &path, const std::string &key,
int32_t priority, std::shared_ptr<ge::ModelListener> listener) {
GELOGI("load model offline begin.");
if (!isInit_) {
GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!");
return ACL_ERROR_GE_EXEC_NOT_INIT;
}

string filePath = RealPath(path.c_str());
if (filePath.empty()) {
GELOGE(ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID,
"File path is invalid. please check your text file '%s'.", path.c_str());
return ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID;
}

std::shared_ptr<ModelListenerAdapter> listener_adapter = MakeShared<ModelListenerAdapter>();
if (listener_adapter == nullptr) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "ModelListenerAdapter make shared failed!");
return ACL_ERROR_GE_MEMORY_ALLOCATION;
}
listener_adapter->listener = listener;

Status ret = GraphLoader::LoadModelFromFile(path, key, priority, listener_adapter, model_id);
if (ret != SUCCESS) {
GELOGE(ret, "[GeExecutor] LoadModelFromFile failed");
return ACL_ERROR_GE_LOAD_MODEL;
}
return SUCCESS;
}

Status GeExecutor::LoadModel(uint32_t &model_id, const ModelData &model_data,
std::shared_ptr<ge::ModelListener> listener) {
GELOGI("Load model begin.");
if (!isInit_) {
GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!");
return ACL_ERROR_GE_EXEC_NOT_INIT;
}

std::shared_ptr<ModelListenerAdapter> listener_adapter = MakeShared<ModelListenerAdapter>();
if (listener_adapter == nullptr) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "ModelListenerAdapter make shared failed!");
return ACL_ERROR_GE_MEMORY_ALLOCATION;
}
listener_adapter->listener = listener;

Status ret = GraphLoader::LoadModel(model_data, listener_adapter, model_id);
if (ret != SUCCESS) {
GELOGE(ret, "[GeExecutor] LoadModel failed.");
return ACL_ERROR_GE_LOAD_MODEL;
}
return ret;
}

Status GeExecutor::UnloadModel(uint32_t model_id) { Status GeExecutor::UnloadModel(uint32_t model_id) {
GELOGD("unload model %u begin.", model_id); GELOGD("unload model %u begin.", model_id);
if (!isInit_) { if (!isInit_) {
@@ -592,6 +622,21 @@ Status GeExecutor::UnloadModel(uint32_t model_id) {
return SUCCESS; return SUCCESS;
} }


Status GeExecutor::RunModel(const ge::RunModelData &input_data, ge::RunModelData &output_data) {
GELOGI("run model begin.");
if (!isInit_) {
GELOGE(ACL_ERROR_GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!");
return ACL_ERROR_GE_EXEC_NOT_INIT;
}

InputData inputs;
GetDomiInputData(input_data, inputs);
OutputData outputs;
GetDomiOutputData(output_data, outputs);

return GraphExecutor::DataInput(inputs, outputs);
}

// Get input and output descriptor // Get input and output descriptor
Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc,
std::vector<ge::TensorDesc> &output_desc, bool new_model_desc) { std::vector<ge::TensorDesc> &output_desc, bool new_model_desc) {


+ 0
- 2
ge/executor/proto/op_mapping_info.proto View File

@@ -15,7 +15,6 @@ message Output {
int32 original_output_data_type = 7; int32 original_output_data_type = 7;
int32 original_output_format = 8; int32 original_output_format = 8;
uint64 size = 9; uint64 size = 9;
Shape origin_shape = 10;
} }


message Input { message Input {
@@ -24,7 +23,6 @@ message Input {
Shape shape = 3; Shape shape = 3;
uint64 address = 4; uint64 address = 4;
uint64 size = 5; uint64 size = 5;
Shape origin_shape = 6;
} }


enum BufferType { enum BufferType {


+ 0
- 2
ge/ge_inference.mk View File

@@ -191,8 +191,6 @@ OMG_HOST_SRC_FILES := \
graph/passes/control_trigger_pass.cc \ graph/passes/control_trigger_pass.cc \
graph/passes/cond_pass.cc \ graph/passes/cond_pass.cc \
graph/passes/cond_remove_pass.cc \ graph/passes/cond_remove_pass.cc \
graph/passes/remove_same_const_pass.cc \
graph/passes/useless_control_out_remove_pass.cc \
graph/passes/for_pass.cc \ graph/passes/for_pass.cc \
graph/passes/enter_pass.cc \ graph/passes/enter_pass.cc \
graph/passes/assign_pass.cc \ graph/passes/assign_pass.cc \


+ 3
- 2
ge/ge_local_engine/engine/host_cpu_engine.cc View File

@@ -39,7 +39,7 @@ namespace {
} \ } \
ge_tensor = MakeShared<GeTensor>(out_desc); \ ge_tensor = MakeShared<GeTensor>(out_desc); \
GE_CHECK_NOTNULL(ge_tensor); \ GE_CHECK_NOTNULL(ge_tensor); \
GELOGD("node:%s allocate output %zu success, size=%lld", op_desc->GetName().c_str(), i, data_num * sizeof(TYPE));\
GELOGI("node:%s allocate output %zu success, size=%lld", op_desc->GetName().c_str(), i, data_num * sizeof(TYPE));\
if (ge_tensor->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_num * sizeof(TYPE)) != GRAPH_SUCCESS) { \ if (ge_tensor->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_num * sizeof(TYPE)) != GRAPH_SUCCESS) { \
GELOGE(MEMALLOC_FAILED, "Set data for output %zu of node %s failed.", i, op_desc->GetName().c_str()); \ GELOGE(MEMALLOC_FAILED, "Set data for output %zu of node %s failed.", i, op_desc->GetName().c_str()); \
return MEMALLOC_FAILED; \ return MEMALLOC_FAILED; \
@@ -50,7 +50,8 @@ namespace {
} else { \ } else { \
ge_tensor = outputs[i]; \ ge_tensor = outputs[i]; \
GE_CHECK_NOTNULL(ge_tensor); \ GE_CHECK_NOTNULL(ge_tensor); \
GELOGD("node:%s existed output %zu", op_desc->GetName().c_str(), i); \
GELOGI("node:%s existed output %zu, addr=%p, size=%lld", op_desc->GetName().c_str(), i, \
reinterpret_cast<const uint8_t *>(ge_tensor->GetData().data()), ge_tensor->GetData().size()); \
} \ } \
auto tensor = TensorAdapter::AsTensor(*ge_tensor); \ auto tensor = TensorAdapter::AsTensor(*ge_tensor); \
auto tensor_name = op_desc->GetOutputNameByIndex(i); \ auto tensor_name = op_desc->GetOutputNameByIndex(i); \


+ 0
- 2
ge/ge_runner.mk View File

@@ -126,8 +126,6 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/compile_nodes_pass.cc \ graph/passes/compile_nodes_pass.cc \
graph/passes/constant_folding_pass.cc \ graph/passes/constant_folding_pass.cc \
graph/passes/constant_fuse_same_pass.cc \ graph/passes/constant_fuse_same_pass.cc \
graph/passes/remove_same_const_pass.cc \
graph/passes/useless_control_out_remove_pass.cc \
graph/passes/control_trigger_pass.cc \ graph/passes/control_trigger_pass.cc \
graph/passes/dimension_adjust_pass.cc \ graph/passes/dimension_adjust_pass.cc \
graph/passes/dimension_compute_pass.cc \ graph/passes/dimension_compute_pass.cc \


+ 21
- 19
ge/generator/ge_generator.cc View File

@@ -272,7 +272,6 @@ static void ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor>


std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue}; std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue};
GeShape dynamic_shape(dynamic_shape_dims); GeShape dynamic_shape(dynamic_shape_dims);
std::vector<std::pair<int64_t, int64_t>> dynamic_shape_range;


ge::GeTensor inputTensor; ge::GeTensor inputTensor;
ge::GeTensorDesc desc(input_desc); ge::GeTensorDesc desc(input_desc);
@@ -281,7 +280,6 @@ static void ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor>
(void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const); (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
if (!is_const && shape_ori.GetDims().size() > 0) { if (!is_const && shape_ori.GetDims().size() > 0) {
desc.SetShape(dynamic_shape); desc.SetShape(dynamic_shape);
desc.SetShapeRange(dynamic_shape_range);
} }


inputTensor.SetTensorDesc(desc); inputTensor.SetTensorDesc(desc);
@@ -530,6 +528,24 @@ bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) {
return true; return true;
} }


static Status SetModelNameForDump(GeRootModelPtr ge_root_model) {
ModelHelper model_helper;
string model_name = "";
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
model_name);
if (name_ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
GELOGE(FAILED, "Get model_name failed. Param --output is invalid.");
return PARAM_INVALID;
}
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model cannot be null");
ge_model->SetName(model_name);
return SUCCESS;
}

Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
ModelBufferData &model, bool is_offline) { ModelBufferData &model, bool is_offline) {
rtContext_t ctx = nullptr; rtContext_t ctx = nullptr;
@@ -538,7 +554,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
GELOGD("Current ctx is null."); GELOGD("Current ctx is null.");
ctx = nullptr; ctx = nullptr;
} }

GeRootModelPtr ge_root_model = nullptr; GeRootModelPtr ge_root_model = nullptr;
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
impl_->is_offline_ = is_offline; impl_->is_offline_ = is_offline;
@@ -562,22 +577,11 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
impl_->build_step_.c_str()); impl_->build_step_.c_str());
return SUCCESS; return SUCCESS;
} }

GE_CHECK_NOTNULL(ge_root_model); GE_CHECK_NOTNULL(ge_root_model);
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
ModelHelper model_helper;
string model_name = "";
Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
model_name);
if (name_ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
GELOGE(FAILED, "Get model_name failed. Param --output is invalid.");
return PARAM_INVALID;
ret = SetModelNameForDump(ge_root_model);
if (ret != SUCCESS) {
return ret;
} }
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model cannot be null");
ge_model->SetName(model_name);
ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model); ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "Save model failed"); GELOGE(ret, "Save model failed");
@@ -586,11 +590,9 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
} }
return ret; return ret;
} }

if (ctx != nullptr) { if (ctx != nullptr) {
(void)rtCtxSetCurrent(ctx); (void)rtCtxSetCurrent(ctx);
} }

return SUCCESS; return SUCCESS;
} }




+ 1
- 1
ge/graph/build/memory/graph_mem_assigner.cc View File

@@ -99,7 +99,7 @@ Status GraphMemoryAssigner::AssignMemory() {
MemoryOffset memory_offset(RT_MEMORY_HBM, mem_assigner->GetMemOffset()); MemoryOffset memory_offset(RT_MEMORY_HBM, mem_assigner->GetMemOffset());
memory_offset_.emplace(RT_MEMORY_HBM, memory_offset); memory_offset_.emplace(RT_MEMORY_HBM, memory_offset);


if (mem_assigner->GetP2PMemOffset() >= 0) {
if (mem_assigner->GetP2PMemOffset() > 0) {
MemoryOffset p2p_memory_offset(RT_MEMORY_P2P_DDR, mem_assigner->GetP2PMemOffset()); MemoryOffset p2p_memory_offset(RT_MEMORY_P2P_DDR, mem_assigner->GetP2PMemOffset());
memory_offset_.emplace(RT_MEMORY_P2P_DDR, p2p_memory_offset); memory_offset_.emplace(RT_MEMORY_P2P_DDR, p2p_memory_offset);
} }


+ 0
- 1
ge/graph/build/model_builder.cc View File

@@ -224,7 +224,6 @@ Status ModelBuilder::AdjustConstWeightSize(const ge::NodePtr &node, size_t &mem_
GeTensorDesc &tensor_desc = weight->MutableTensorDesc(); GeTensorDesc &tensor_desc = weight->MutableTensorDesc();
size_t output_size = weight->GetData().size(); size_t output_size = weight->GetData().size();
TensorUtils::SetDataOffset(tensor_desc, mem_offset); TensorUtils::SetDataOffset(tensor_desc, mem_offset);
GELOGD("Node: %s, weight size: %zu.", node->GetName().c_str(), output_size);
mem_offset += output_size; mem_offset += output_size;
} }
return SUCCESS; return SUCCESS;


+ 8
- 6
ge/graph/build/stream_graph_optimizer.cc View File

@@ -66,13 +66,13 @@ bool StreamGraphOptimizer::IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &com
if (AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { if (AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) {
label_set.insert(batch_label); label_set.insert(batch_label);
} else { } else {
GELOGD("Node %s[%s] has no batch label, subgraph %s, stream id: %ld", cur_node->GetName().c_str(),
GELOGD("Node %s[%s] has no batch_label, subgraph %s, stream id: %ld ", cur_node->GetName().c_str(),
cur_node->GetType().c_str(), comp_graph->GetName().c_str(), stream_id); cur_node->GetType().c_str(), comp_graph->GetName().c_str(), stream_id);
continue; continue;
} }


GELOGD("Node %s in subgraph %s stream id: %ld, node num: %zu", cur_node->GetName().c_str(),
comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize());
GELOGD("Node %s in subgraph %s stream id: %ld, batch_label: %s, node num: %zu", cur_node->GetName().c_str(),
comp_graph->GetName().c_str(), stream_id, batch_label.c_str(), comp_graph->GetDirectNodesSize());
} }
if (stream_set.size() > 1 || label_set.size() > 1) { if (stream_set.size() > 1 || label_set.size() > 1) {
GELOGI("Nodes of graph: %s have different stream id or batch_label, node num: %zu, different stream num: %zu.", GELOGI("Nodes of graph: %s have different stream id or batch_label, node num: %zu, different stream num: %zu.",
@@ -126,12 +126,14 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com
run_context.graphStreamList.size()); run_context.graphStreamList.size());
return FAILED; return FAILED;
} }

run_context.stream = run_context.graphStreamList[stream_id]; run_context.stream = run_context.graphStreamList[stream_id];
std::string batch_label;
(void)AttrUtils::GetStr(subgraph, ATTR_NAME_BATCH_LABEL, batch_label);
std::string batch_label;
(void)AttrUtils::GetStr(subgraph, ATTR_NAME_BATCH_LABEL, batch_label);
GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu, " GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu, "
"batch_label: %s", subgraph->GetName().c_str(), engine_name.c_str(), stream_id,
"batch_label: %s", subgraph->GetName().c_str(), engine_name.c_str(), stream_id,
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream)), batch_label.c_str()); static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream)), batch_label.c_str());

for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) {
GE_CHECK_NOTNULL(*iter); GE_CHECK_NOTNULL(*iter);
Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context); Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context);


+ 67
- 10
ge/graph/load/graph_loader.cc View File

@@ -122,14 +122,14 @@ Status GraphLoader::LoadDataFromFile(const std::string &path, const std::string
ModelData &model_data) { ModelData &model_data) {
Status ret; Status ret;
if (!CheckInputPathValid(path)) { if (!CheckInputPathValid(path)) {
GELOGE(ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID, "model path is invalid: %s", path.c_str());
return ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID;
GELOGE(GE_EXEC_MODEL_PATH_INVALID, "model path is invalid: %s", path.c_str());
return GE_EXEC_MODEL_PATH_INVALID;
} }


GELOGI("Load model begin, model path is: %s", path.c_str()); GELOGI("Load model begin, model path is: %s", path.c_str());
if (!key_path.empty() && !CheckInputPathValid(key_path)) { if (!key_path.empty() && !CheckInputPathValid(key_path)) {
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "decrypt_key path is invalid: %s", key_path.c_str());
return ACL_ERROR_GE_PARAM_INVALID;
GELOGE(GE_EXEC_MODEL_KEY_PATH_INVALID, "decrypt_key path is invalid: %s", key_path.c_str());
return GE_EXEC_MODEL_KEY_PATH_INVALID;
} }


ret = DavinciModelParser::LoadFromFile(path.c_str(), key_path.c_str(), priority, model_data); ret = DavinciModelParser::LoadFromFile(path.c_str(), key_path.c_str(), priority, model_data);
@@ -144,6 +144,63 @@ Status GraphLoader::LoadDataFromFile(const std::string &path, const std::string
return SUCCESS; return SUCCESS;
} }


Status GraphLoader::LoadModelFromFile(const std::string &path, const std::string &key_path, int32_t priority,
const std::shared_ptr<ModelListener> &listener, uint32_t &model_id) {
Status ret;
ModelData model_data;
ret = LoadDataFromFile(path, key_path, priority, model_data);
if (ret != SUCCESS) {
GELOGE(ret, "LoadModelFromFile: Load failed. ret = %u", ret);
if (model_data.model_data != nullptr) {
delete[] static_cast<char *>(model_data.model_data);
model_data.model_data = nullptr;
}
return ret;
}

ret = LoadModel(model_data, listener, model_id);
if (ret != SUCCESS) {
GELOGE(ret, "LoadModel: Load failed. ret = %u", ret);
if (model_data.model_data != nullptr) {
delete[] static_cast<char *>(model_data.model_data);
model_data.model_data = nullptr;
}
}

if (model_data.model_data != nullptr) {
delete[] static_cast<char *>(model_data.model_data);
model_data.model_data = nullptr;
}

return ret;
}

Status GraphLoader::LoadModel(const ModelData &model_data, const std::shared_ptr<ModelListener> &listener,
uint32_t &model_id) {
GELOGI("Load model begin, model_id:%u.", model_id);

// For GeOp, Open Device 0 here.
GE_CHK_RT_RET(rtSetDevice(0));
auto model_manager = ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager);
Status ret = model_manager->LoadModelOffline(model_id, model_data, listener);
if (ret != SUCCESS) {
GE_CHK_RT(rtDeviceReset(0));
GELOGE(ret, "LoadModel: Load failed.");
return ret;
}
ret = model_manager->Start(model_id);
if (ret != SUCCESS) {
if (model_manager->Unload(model_id) != SUCCESS) {
GELOGE(FAILED, "LoadModel: Unload failed while trying to unload after a failed start.");
}
GELOGE(ret, "LoadModel: Start failed.");
return ret;
}
GELOGI("LoadModel: Start model success, model_id:%u.", model_id);
return SUCCESS;
}

Status GraphLoader::CommandHandle(const Command &command) { Status GraphLoader::CommandHandle(const Command &command) {
try { try {
auto model_manager = ModelManager::GetInstance(); auto model_manager = ModelManager::GetInstance();
@@ -168,16 +225,16 @@ Status GraphLoader::CommandHandle(const Command &command) {
} }


Status GraphLoader::LoadModelFromData(uint32_t &model_id, const ModelData &model_data, void *dev_ptr, Status GraphLoader::LoadModelFromData(uint32_t &model_id, const ModelData &model_data, void *dev_ptr,
size_t mem_size, void *weight_ptr, size_t weight_size) {
size_t memsize, void *weight_ptr, size_t weightsize) {
GELOGI("Load model begin, model_id:%u.", model_id); GELOGI("Load model begin, model_id:%u.", model_id);
// For ACL, Open Device from App. // For ACL, Open Device from App.
auto model_manager = ModelManager::GetInstance(); auto model_manager = ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager); GE_CHECK_NOTNULL(model_manager);
Status ret = model_manager->LoadModelOffline( Status ret = model_manager->LoadModelOffline(
model_id, model_data, nullptr, dev_ptr, mem_size, weight_ptr, weight_size);
model_id, model_data, nullptr, dev_ptr, memsize, weight_ptr, weightsize);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ACL_ERROR_GE_LOAD_MODEL, "Load model failed, model_id:%u.", model_id);
return ACL_ERROR_GE_LOAD_MODEL;
GELOGE(ret, "Load model failed, model_id:%u.", model_id);
return ret;
} }
GELOGI("Load model success, model_id:%u.", model_id); GELOGI("Load model success, model_id:%u.", model_id);
return SUCCESS; return SUCCESS;
@@ -202,8 +259,8 @@ Status GraphLoader::LoadModelWithQ(uint32_t &model_id, const ModelData &model_da
GE_CHECK_NOTNULL(model_manager); GE_CHECK_NOTNULL(model_manager);
Status ret = model_manager->LoadModelWithQ(model_id, model_data, input_queue_ids, output_queue_ids); Status ret = model_manager->LoadModelWithQ(model_id, model_data, input_queue_ids, output_queue_ids);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ACL_ERROR_GE_LOAD_MODEL, "Load model with queue failed, model_id:%u.", model_id);
return ACL_ERROR_GE_LOAD_MODEL;
GELOGE(ret, "Load model with queue failed, model_id:%u.", model_id);
return ret;
} }


GELOGI("Load model with queue success, model_id:%u.", model_id); GELOGI("Load model with queue success, model_id:%u.", model_id);


+ 6
- 0
ge/graph/load/graph_loader.h View File

@@ -44,6 +44,12 @@ class GraphLoader {


static Status GetMaxUsedMemory(uint32_t model_id, uint64_t &max_size); static Status GetMaxUsedMemory(uint32_t model_id, uint64_t &max_size);


static Status LoadModel(const ModelData &model_data, const std::shared_ptr<ModelListener> &listener,
uint32_t &model_id);

static Status LoadModelFromFile(const std::string &path, const std::string &key_path, int32_t priority,
const std::shared_ptr<ModelListener> &listener, uint32_t &model_id);

static Status CommandHandle(const Command &command); static Status CommandHandle(const Command &command);


static Status GetMemoryInfo(int64_t &free); static Status GetMemoryInfo(int64_t &free);


+ 0
- 6
ge/graph/load/new_model_manager/data_dumper.cc View File

@@ -319,9 +319,6 @@ Status DataDumper::GenerateOutput(aicpu::dump::Output &output, const OpDesc::Vis
for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { for (auto dim : tensor_descs.at(index).GetShape().GetDims()) {
output.mutable_shape()->add_dim(dim); output.mutable_shape()->add_dim(dim);
} }
for (auto dim : tensor_descs.at(index).GetOriginShape().GetDims()) {
output.mutable_origin_shape()->add_dim(dim);
}
int64_t output_size = 0; int64_t output_size = 0;
if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(index), output_size) != SUCCESS) { if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(index), output_size) != SUCCESS) {
GELOGE(PARAM_INVALID, "Get output size filed"); GELOGE(PARAM_INVALID, "Get output size filed");
@@ -479,9 +476,6 @@ Status DataDumper::GenerateInput(aicpu::dump::Input &input, const OpDesc::Vistor
for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { for (auto dim : tensor_descs.at(index).GetShape().GetDims()) {
input.mutable_shape()->add_dim(dim); input.mutable_shape()->add_dim(dim);
} }
for (auto dim : tensor_descs.at(index).GetOriginShape().GetDims()) {
input.mutable_origin_shape()->add_dim(dim);
}
int64_t input_size = 0; int64_t input_size = 0;
if (AttrUtils::GetInt(tensor_descs.at(index), ATTR_NAME_INPUT_ORIGIN_SIZE, input_size)) { if (AttrUtils::GetInt(tensor_descs.at(index), ATTR_NAME_INPUT_ORIGIN_SIZE, input_size)) {
GELOGI("Get aipp input size according to attr is %ld", input_size); GELOGI("Get aipp input size according to attr is %ld", input_size);


+ 343
- 152
ge/graph/load/new_model_manager/davinci_model.cc View File

@@ -289,8 +289,8 @@ Status DavinciModel::InitWeightMem(void *dev_ptr, void *weight_ptr, size_t weigh
if (weight_ptr == nullptr) { if (weight_ptr == nullptr) {
weights_mem_base_ = MallocWeightsMem(weights_size); weights_mem_base_ = MallocWeightsMem(weights_size);
if (weights_mem_base_ == nullptr) { if (weights_mem_base_ == nullptr) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc weight memory failed. size: %zu", weights_size);
return ACL_ERROR_GE_MEMORY_ALLOCATION;
GELOGE(GE_EXEC_ALLOC_WEIGHT_MEM_FAILED, "Alloc weight memory failed. size: %zu", weights_size);
return GE_EXEC_ALLOC_WEIGHT_MEM_FAILED;
} }
is_inner_weight_base_ = true; is_inner_weight_base_ = true;
} }
@@ -307,8 +307,8 @@ Status DavinciModel::InitWeightMem(void *dev_ptr, void *weight_ptr, size_t weigh


Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) {
if (is_feature_map_mem_has_inited_) { if (is_feature_map_mem_has_inited_) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "call InitFeatureMapMem more than once .");
return ACL_ERROR_GE_MEMORY_ALLOCATION;
GELOGE(FAILED, "call InitFeatureMapMem more than once .");
return FAILED;
} }
is_feature_map_mem_has_inited_ = true; is_feature_map_mem_has_inited_ = true;


@@ -316,8 +316,8 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) {
std::size_t p2p_data_size = P2PMemInfos().at(RT_MEMORY_P2P_DDR).memory_size; std::size_t p2p_data_size = P2PMemInfos().at(RT_MEMORY_P2P_DDR).memory_size;


if ((dev_ptr != nullptr) && (mem_size < TotalMemSize())) { if ((dev_ptr != nullptr) && (mem_size < TotalMemSize())) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Invalid mem param: mem_size=%zu totalsize=%zu.", mem_size, TotalMemSize());
return ACL_ERROR_GE_MEMORY_ALLOCATION;
GELOGE(FAILED, "Invalid mem param: mem_size=%zu totalsize=%zu.", mem_size, TotalMemSize());
return FAILED;
} }


mem_base_ = static_cast<uint8_t *>(dev_ptr); mem_base_ = static_cast<uint8_t *>(dev_ptr);
@@ -327,8 +327,8 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) {
if (TotalMemSize() && mem_base_ == nullptr) { if (TotalMemSize() && mem_base_ == nullptr) {
mem_base_ = MallocFeatureMapMem(data_size); mem_base_ = MallocFeatureMapMem(data_size);
if (mem_base_ == nullptr) { if (mem_base_ == nullptr) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc feature map memory failed. size: %zu", data_size);
return ACL_ERROR_GE_MEMORY_ALLOCATION;
GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size);
return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED;
} }
GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]",
runtime_param_.graph_id, mem_base_, data_size); runtime_param_.graph_id, mem_base_, data_size);
@@ -343,8 +343,8 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) {
if (p2p_data_size != 0) { if (p2p_data_size != 0) {
p2p_mem_base_ = MallocP2PMem(p2p_data_size); p2p_mem_base_ = MallocP2PMem(p2p_data_size);
if (p2p_mem_base_ == nullptr) { if (p2p_mem_base_ == nullptr) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc p2p memory failed,size: %zu", p2p_data_size);
return ACL_ERROR_GE_MEMORY_ALLOCATION;
GELOGE(GE_EXEC_ALLOC_P2P_MEM_FAILED, "Alloc p2p memory failed,size: %zu", p2p_data_size);
return GE_EXEC_ALLOC_P2P_MEM_FAILED;
} }
GELOGI("InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, GELOGI("InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id,
p2p_mem_base_, p2p_data_size); p2p_mem_base_, p2p_data_size);
@@ -710,7 +710,6 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size
} }


// collect profiling for ge // collect profiling for ge
GE_CHK_STATUS_RET(InitModelProfile(), "Init model profile failed");
auto &profiling_manager = ProfilingManager::Instance(); auto &profiling_manager = ProfilingManager::Instance();
if (profiling_manager.ProfilingModelLoadOn()) { if (profiling_manager.ProfilingModelLoadOn()) {
Status p_ret = ReportProfilingData(); Status p_ret = ReportProfilingData();
@@ -971,7 +970,7 @@ Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index, ma
uint32_t parent_index = 0; // Ignore subgraph Data Node. uint32_t parent_index = 0; // Ignore subgraph Data Node.
if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
GELOGI("Init zero copy by subgraph Data node: %s.", op_desc->GetName().c_str()); GELOGI("Init zero copy by subgraph Data node: %s.", op_desc->GetName().c_str());
return SUCCESS;
return InitInputBatchLabel(node);
} }


data_op_list_.push_back(op_desc); data_op_list_.push_back(op_desc);
@@ -1012,6 +1011,10 @@ Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index, ma
} }


data_op_index++; data_op_index++;
if (InitInputZeroCopyNodes(node) != SUCCESS) {
GELOGE(PARAM_INVALID, "Input zero copy nodes init failed!");
return PARAM_INVALID;
}
return SUCCESS; return SUCCESS;
} }


@@ -1033,6 +1036,39 @@ void DavinciModel::AdjustDataOpList(const map<uint32_t, OpDescPtr> &data_by_inde
} }
} }


///
/// @ingroup ge
/// @brief input zero copy node Initialize.
/// @param [in] NodePtr: Data Op.
/// @return Status
///
Status DavinciModel::InitInputZeroCopyNodes(const NodePtr &node) {
auto out_data_anchor = node->GetOutDataAnchor(kDataIndex);
if (out_data_anchor == nullptr) {
GELOGE(FAILED, "Out data anchor is nullptr");
return FAILED;
}
for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
auto node = peer_in_data_anchor->GetOwnerNode();
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(FAILED, "Op desc is nullptr");
return FAILED;
}
string batch_label;
(void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (batch_label.empty()) {
batch_label = kDefaultBatchLable;
}
if (zero_copy_op_id_batch_label_.find(op_desc->GetId()) == zero_copy_op_id_batch_label_.end()) {
zero_copy_op_id_batch_label_.emplace(pair<int64_t, string>(op_desc->GetId(), batch_label));
GELOGD("Init input zero copy nodes success, op name:%s, op id: %ld, batch label: %s.", op_desc->GetName().c_str(),
op_desc->GetId(), batch_label.c_str());
}
}
return SUCCESS;
}

bool DavinciModel::IsGetNextSinkDynamic(const OpDescPtr &op_desc) { bool DavinciModel::IsGetNextSinkDynamic(const OpDescPtr &op_desc) {
bool getnext_sink_dynamic = false; bool getnext_sink_dynamic = false;
if (ge::AttrUtils::GetBool(op_desc, ATTR_GETNEXT_SINK_DYNMAIC, getnext_sink_dynamic) && getnext_sink_dynamic) { if (ge::AttrUtils::GetBool(op_desc, ATTR_GETNEXT_SINK_DYNMAIC, getnext_sink_dynamic) && getnext_sink_dynamic) {
@@ -1058,7 +1094,7 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) {
if (owner_graph->GetParentGraph() != nullptr) { if (owner_graph->GetParentGraph() != nullptr) {
GELOGI("Init zero copy by subgraph NetOutput node: %s.", op_desc->GetName().c_str()); GELOGI("Init zero copy by subgraph NetOutput node: %s.", op_desc->GetName().c_str());
op_list_.erase(op_desc->GetId()); op_list_.erase(op_desc->GetId());
return SUCCESS;
return InitOutputBatchLabel(node);
} }


output_op_list_.push_back(op_desc); output_op_list_.push_back(op_desc);
@@ -1110,6 +1146,8 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) {
} }
} }


GE_IF_BOOL_EXEC(InitOutputZeroCopyNodes(node) != SUCCESS,
GELOGE(PARAM_INVALID, "Output zero copy nodes init failed!"); return PARAM_INVALID;);
GetAllGearsInfo(node); GetAllGearsInfo(node);
if (is_getnext_sink_dynamic_) { if (is_getnext_sink_dynamic_) {
GE_IF_BOOL_EXEC(GetGetDynamicDimsNodeInfo(node) != SUCCESS, GE_IF_BOOL_EXEC(GetGetDynamicDimsNodeInfo(node) != SUCCESS,
@@ -1305,6 +1343,121 @@ void DavinciModel::ParseDynamicOutShape(const std::vector<std::string> &str_info
} }
} }


///
/// @ingroup ge
/// @brief output zero copy node Initialize.
/// @param [in] NodePtr: netoutput Op.
/// @return Status
///
Status DavinciModel::InitOutputZeroCopyNodes(const NodePtr &node) {
set<NodePtr> nodes_need_record;
for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_data_anchor == nullptr) {
continue;
}
auto peer_node = peer_out_data_anchor->GetOwnerNode();
nodes_need_record.emplace(peer_node);

// Merge node output multiplexed input, upstream nodes need to be considered in multiple batch scenarios
if (peer_node->GetType() == MERGE) {
for (const auto &merge_peer_in_data_anchor : peer_node->GetAllInDataAnchors()) {
auto merge_peer_out_data_anchor = merge_peer_in_data_anchor->GetPeerOutAnchor();
if (merge_peer_out_data_anchor == nullptr) {
continue;
}
auto merge_peer_node = merge_peer_out_data_anchor->GetOwnerNode();
nodes_need_record.emplace(merge_peer_node);
}
} else {
for (const auto &other_in_data_anchor : peer_out_data_anchor->GetPeerInDataAnchors()) {
auto other_in_node = other_in_data_anchor->GetOwnerNode();
if (other_in_node->GetType() != NETOUTPUT) {
nodes_need_record.emplace(other_in_node);
}
}
}
}

for (const auto &node_need_record : nodes_need_record) {
auto op_desc = node_need_record->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
string batch_label;
(void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (batch_label.empty()) {
batch_label = kDefaultBatchLable;
}
if (zero_copy_op_id_batch_label_.find(op_desc->GetId()) == zero_copy_op_id_batch_label_.end()) {
zero_copy_op_id_batch_label_.emplace(pair<int64_t, string>(op_desc->GetId(), batch_label));
GELOGD("Init Output zero copy nodes success, op name:%s, op id: %ld, batch label: %s.",
op_desc->GetName().c_str(), op_desc->GetId(), batch_label.c_str());
}
}
return SUCCESS;
}

///
/// @ingroup ge
/// @brief input zero copy node Initialize.
/// @param [in] NodePtr: Data Op.
/// @return Status
///
Status DavinciModel::InitInputBatchLabel(const NodePtr &node) {
string batch_label;
if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) {
return SUCCESS; // Not Multi-batch.
}

const auto &out_data_anchor = node->GetOutDataAnchor(kDataIndex);
GE_CHECK_NOTNULL(out_data_anchor);

for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
const auto &node = peer_in_data_anchor->GetOwnerNode();
const auto &op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);

if (zero_copy_op_id_batch_label_.find(op_desc->GetId()) == zero_copy_op_id_batch_label_.end()) {
zero_copy_op_id_batch_label_[op_desc->GetId()] = batch_label;
GELOGD("Init input zero copy nodes success, op name: %s, op id: %ld, batch label: %s", op_desc->GetName().c_str(),
op_desc->GetId(), batch_label.c_str());
}
}

return SUCCESS;
}

///
/// @ingroup ge
/// @brief output zero copy node Initialize for Case.
/// @param [in] NodePtr: netoutput Op.
/// @return Status
///
Status DavinciModel::InitOutputBatchLabel(const NodePtr &node) {
string batch_label;
if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) {
return SUCCESS; // Not Multi-batch.
}

for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
const auto &peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_data_anchor == nullptr) {
continue;
}

const auto &peer_node = peer_out_data_anchor->GetOwnerNode();
const auto &op_desc = peer_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);

if (zero_copy_op_id_batch_label_.find(op_desc->GetId()) == zero_copy_op_id_batch_label_.end()) {
zero_copy_op_id_batch_label_[op_desc->GetId()] = batch_label;
GELOGD("Init Output zero copy nodes success, op name: %s, op id: %ld, batch label: %s",
op_desc->GetName().c_str(), op_desc->GetId(), batch_label.c_str());
}
}

return SUCCESS;
}

/// @ingroup ge /// @ingroup ge
/// @brief LabelSet Op Initialize. /// @brief LabelSet Op Initialize.
/// @param [in] op_desc: LabelSet Op descriptor. /// @param [in] op_desc: LabelSet Op descriptor.
@@ -2087,61 +2240,12 @@ Status DavinciModel::SyncVarData() {
return ret; return ret;
} }


Status DavinciModel::InitModelProfile() {
for (const auto &task : task_list_) {
GE_CHECK_NOTNULL(task);
const FusionOpInfo *fusion_op_info = task->GetFusionOpInfo();
// when type is RT_MODEL_TASK_KERNEL, ctx is not null
if ((fusion_op_info == nullptr) || fusion_op_info->original_op_names.empty()) {
continue;
}

GELOGI("task.id = %u, opNum = %zu", task->GetTaskID(), fusion_op_info->original_op_names.size());
op_id_map_.insert(std::make_pair(fusion_op_info->op_index, task->GetTaskID()));
}

std::set<uint32_t> task_id_set;
using CIT = std::multimap<uint32_t, uint32_t>::const_iterator;
using Range = std::pair<CIT, CIT>;
for (const auto &task : task_list_) {
GE_CHECK_NOTNULL(task);
const FusionOpInfo *fusion_op_info = task->GetFusionOpInfo();
if ((fusion_op_info == nullptr) || fusion_op_info->original_op_names.empty()) {
continue;
}

if (task_id_set.count(task->GetTaskID()) > 0) {
continue;
}

const auto &op_desc = GetOpByIndex(fusion_op_info->op_index);
GE_CHK_BOOL_EXEC(op_desc != nullptr, return FAILED, "index: %u out of range", fusion_op_info->op_index);

ProfileInfo profile;
profile.fusion_info = *fusion_op_info;
Range range = op_id_map_.equal_range(fusion_op_info->op_index);
for (CIT range_idx = range.first; range_idx != range.second; ++range_idx) {
profile.task_count++;
task_id_set.insert(range_idx->second);
}

// memory info
TaskMemInfo &mem_info = profile.memory_info;
const auto input_size = ModelUtils::GetInputSize(op_desc);
const auto output_size = ModelUtils::GetOutputSize(op_desc);
const auto workspace_size = ModelUtils::GetWorkspaceSize(op_desc);
const auto weight_size = ModelUtils::GetWeightSize(op_desc);
mem_info.input_size = std::accumulate(input_size.begin(), input_size.end(), 0);
mem_info.output_size = std::accumulate(output_size.begin(), output_size.end(), 0);
mem_info.workspace_size = std::accumulate(workspace_size.begin(), workspace_size.end(), 0);
mem_info.weight_size = std::accumulate(weight_size.begin(), weight_size.end(), 0);
mem_info.total_size = mem_info.weight_size + mem_info.input_size + mem_info.output_size + mem_info.workspace_size;

profile_list_.emplace_back(profile);
inline int64_t SumSize(const vector<int64_t> &size_list) {
int64_t sum_size = 0;
for (const int64_t &size : size_list) {
sum_size += size;
} }

GELOGI("fusion task size: %zu, profile info size: %zu", op_id_map_.size(), profile_list_.size());
return SUCCESS;
return sum_size;
} }


Status DavinciModel::SinkModelProfile() { Status DavinciModel::SinkModelProfile() {
@@ -2149,12 +2253,18 @@ Status DavinciModel::SinkModelProfile() {
auto &prof_mgr = ProfilingManager::Instance(); auto &prof_mgr = ProfilingManager::Instance();
ReporterData reporter_data{}; ReporterData reporter_data{};
// report model data tag name // report model data tag name
std::string tag_name("model_load_info_" + std::to_string(this->Id()));
std::string tag_name;
tag_name.append("model_load_info_").append(std::to_string(this->Id()));
GE_CHK_BOOL_EXEC(memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN, tag_name.c_str(), tag_name.size()) == EOK, GE_CHK_BOOL_EXEC(memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN, tag_name.c_str(), tag_name.size()) == EOK,
return FAILED, "Sink model tag memcpy error."); return FAILED, "Sink model tag memcpy error.");


// Model Header // Model Header
std::string name = om_name_.empty() ? name_ : om_name_;
string name;
if (!om_name_.empty()) {
name = om_name_;
} else {
name = name_;
}
size_t name_len = name.size(); size_t name_len = name.size();
reporter_data.deviceId = device_id_; reporter_data.deviceId = device_id_;
reporter_data.data = (unsigned char *)&name_len; reporter_data.data = (unsigned char *)&name_len;
@@ -2186,71 +2296,128 @@ Status DavinciModel::SinkModelProfile() {
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id()); "Reporter data fail, model id:%u.", this->Id());


int32_t task_num = task_list_.size();
std::multimap<uint32_t, uint32_t> op_id_map;
std::set<uint32_t> task_id_set;
for (int32_t i = 0; i < task_num; i++) {
auto task = task_list_[i];
GE_CHECK_NOTNULL(task);
auto fusion_op_info = task->GetFusionOpInfo();
// when type is RT_MODEL_TASK_KERNEL, ctx is not null
if (fusion_op_info != nullptr) {
uint32_t op_num = fusion_op_info->original_op_names.size();
uint32_t task_id = task->GetTaskID();
if (op_num > 0) {
GELOGI("task.id = %u, opNum = %u", task_id, op_num);
op_id_map.insert(std::make_pair(fusion_op_info->op_index, task_id));
}
}
}

struct memoryInfo {
int64_t input_size;
int64_t output_size;
int64_t weight_size;
int64_t workspace_size;
int64_t total_size;

memoryInfo() : input_size(0), output_size(0), weight_size(0), workspace_size(0), total_size(0) {}
};

using CIT = std::multimap<uint32_t, uint32_t>::const_iterator; using CIT = std::multimap<uint32_t, uint32_t>::const_iterator;
using Range = std::pair<CIT, CIT>; using Range = std::pair<CIT, CIT>;
for (const ProfileInfo &profile : profile_list_) {
// op name after fusion
string fusion_op_name = profile.fusion_info.op_name;
int32_t fusion_op_name_len = fusion_op_name.size() == 0 ? 1 : fusion_op_name.size();
reporter_data.data = (unsigned char *)&fusion_op_name_len;
reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());

reporter_data.data = (unsigned char *)fusion_op_name.c_str();
reporter_data.dataLen = fusion_op_name_len;
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());

// original op name before fusion
uint32_t op_num = profile.fusion_info.original_op_names.size();
reporter_data.data = (unsigned char *)&op_num;
reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());

for (uint32_t k = 0; k < op_num; k++) {
std::string op_name = profile.fusion_info.original_op_names[k];
int32_t op_name_len = op_name.size() == 0 ? 1 : op_name.size();
reporter_data.data = (unsigned char *)&op_name_len;
for (int32_t i = 0; i < task_num; i++) {
auto task = task_list_[i];
GE_CHECK_NOTNULL(task);
auto fusion_op_info = task->GetFusionOpInfo();
if (fusion_op_info != nullptr && fusion_op_info->original_op_names.size() > 0) {
uint32_t task_id = task->GetTaskID();
uint32_t op_num = fusion_op_info->original_op_names.size();
uint32_t task_count = 0;
if (task_id_set.count(task_id) != 0) {
continue;
}

uint32_t op_id = fusion_op_info->op_index;
Range range = op_id_map.equal_range(op_id);
for (CIT range_idx = range.first; range_idx != range.second; ++range_idx) {
task_count++;
uint32_t task_id = range_idx->second;
task_id_set.insert(task_id);
}

// op name after fusion
string fusion_op_name = fusion_op_info->op_name;
int32_t fusion_op_name_len = fusion_op_name.size() == 0 ? 1 : fusion_op_name.size();
reporter_data.data = (unsigned char *)&fusion_op_name_len;
reporter_data.dataLen = sizeof(int32_t); reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id()); "Reporter data fail, model id:%u.", this->Id());
reporter_data.data = (unsigned char *)op_name.c_str();
reporter_data.dataLen = op_name_len;

reporter_data.data = (unsigned char *)fusion_op_name.c_str();
reporter_data.dataLen = fusion_op_name_len;
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());

// original op name before fusion
reporter_data.data = (unsigned char *)&op_num;
reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id()); "Reporter data fail, model id:%u.", this->Id());
}


// stream id info
uint32_t streamId = profile.fusion_info.stream_id;
reporter_data.data = (unsigned char *)&streamId;
reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());

// memory info
reporter_data.data = (unsigned char *)&profile.memory_info;
reporter_data.dataLen = sizeof(profile.memory_info);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());

// task info
reporter_data.data = (unsigned char *)&profile.task_count;
reporter_data.dataLen = sizeof(uint32_t);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());

Range task_range = op_id_map_.equal_range(profile.fusion_info.op_index);
for (CIT idx = task_range.first; idx != task_range.second; ++idx) {
uint32_t task_id = idx->second;
reporter_data.data = (unsigned char *)&task_id;
for (uint32_t k = 0; k < op_num; k++) {
std::string op_name = fusion_op_info->original_op_names[k];
int32_t op_name_len = op_name.size() == 0 ? 1 : op_name.size();
reporter_data.data = (unsigned char *)&op_name_len;
reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());
reporter_data.data = (unsigned char *)op_name.c_str();
reporter_data.dataLen = op_name_len;
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());
}

// stream id info
uint32_t streamId = task->GetStreamId();
reporter_data.data = (unsigned char *)&streamId;
reporter_data.dataLen = sizeof(int32_t);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());

// memory info
struct memoryInfo memory_info;
uint32_t op_index = fusion_op_info->op_index;
auto iter = op_list_.find(op_index);
GE_CHK_BOOL_EXEC(iter != op_list_.end(), return FAILED, "index is out of range, index: %u", op_index);
auto op_desc = iter->second;
memory_info.input_size = SumSize(ModelUtils::GetInputSize(op_desc));
memory_info.output_size = SumSize(ModelUtils::GetOutputSize(op_desc));
memory_info.workspace_size = SumSize(ModelUtils::GetWorkspaceSize(op_desc));
memory_info.weight_size = SumSize(ModelUtils::GetWeightSize(op_desc));
memory_info.total_size =
memory_info.weight_size + memory_info.input_size + memory_info.output_size + memory_info.workspace_size;
reporter_data.data = (unsigned char *)&memory_info;
reporter_data.dataLen = sizeof(struct memoryInfo);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());

// task info
reporter_data.data = (unsigned char *)&task_count;
reporter_data.dataLen = sizeof(uint32_t); reporter_data.dataLen = sizeof(uint32_t);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id()); "Reporter data fail, model id:%u.", this->Id());

Range task_range = op_id_map.equal_range(op_id);
for (CIT idx = task_range.first; idx != task_range.second; ++idx) {
uint32_t task_id = idx->second;
reporter_data.data = (unsigned char *)&task_id;
reporter_data.dataLen = sizeof(uint32_t);
GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED,
"Reporter data fail, model id:%u.", this->Id());
}
} }
} }

return SUCCESS; return SUCCESS;
} }


@@ -2824,19 +2991,19 @@ Status DavinciModel::CreateKnownZeroCopyMap(const vector<void *> &inputs, const
return SUCCESS; return SUCCESS;
} }


Status DavinciModel::UpdateKnownZeroCopyAddr(vector<void *> &total_io_addrs) {
for (size_t i = 0; i < total_io_addrs.size(); ++i) {
auto it_in = knonw_input_data_info_.find(total_io_addrs[i]);
Status DavinciModel::UpdateKnownZeroCopyAddr() {
for (size_t i = 0; i < total_io_addrs_.size(); ++i) {
auto it_in = knonw_input_data_info_.find(total_io_addrs_[i]);
if (it_in != knonw_input_data_info_.end()) { if (it_in != knonw_input_data_info_.end()) {
GELOGI("DavinciModel::UpdateKnownZeroCopyAddr input %zu,v addr %p,p addr %p .", i, total_io_addrs[i],
knonw_input_data_info_.at(total_io_addrs[i]));
total_io_addrs[i] = knonw_input_data_info_.at(total_io_addrs[i]);
GELOGI("DavinciModel::UpdateKnownZeroCopyAddr input %zu,v addr %p,p addr %p .", i, total_io_addrs_[i],
knonw_input_data_info_.at(total_io_addrs_[i]));
total_io_addrs_[i] = knonw_input_data_info_.at(total_io_addrs_[i]);
} }
auto it_out = knonw_output_data_info_.find(total_io_addrs[i]);
auto it_out = knonw_output_data_info_.find(total_io_addrs_[i]);
if (it_out != knonw_output_data_info_.end()) { if (it_out != knonw_output_data_info_.end()) {
GELOGI("DavinciModel::UpdateKnownZeroCopyAddr output %zu,v addr %p,p addr %p .", i, total_io_addrs[i],
knonw_output_data_info_.at(total_io_addrs[i]));
total_io_addrs[i] = knonw_output_data_info_.at(total_io_addrs[i]);
GELOGI("DavinciModel::UpdateKnownZeroCopyAddr output %zu,v addr %p,p addr %p .", i, total_io_addrs_[i],
knonw_output_data_info_.at(total_io_addrs_[i]));
total_io_addrs_[i] = knonw_output_data_info_.at(total_io_addrs_[i]);
} }
} }
GELOGI("DavinciModel::UpdateKnownZeroCopyAddr success."); GELOGI("DavinciModel::UpdateKnownZeroCopyAddr success.");
@@ -2865,7 +3032,7 @@ Status DavinciModel::UpdateKnownNodeArgs(const vector<void *> &inputs, const vec
} else { } else {
total_io_addrs_ = orig_total_io_addrs_; total_io_addrs_ = orig_total_io_addrs_;
} }
GE_CHK_STATUS_RET(UpdateKnownZeroCopyAddr(total_io_addrs_), "DavinciModel::UpdateKnownZeroCopyAddr failed.");
GE_CHK_STATUS_RET(UpdateKnownZeroCopyAddr(), "DavinciModel::UpdateKnownZeroCopyAddr failed.");


if (total_args_size_ == 0) { if (total_args_size_ == 0) {
GELOGW("DavinciModel::UpdateKnownNodeArgs device args %p, dst size %u, pass rtMemcpy.", args_, total_args_size_); GELOGW("DavinciModel::UpdateKnownNodeArgs device args %p, dst size %u, pass rtMemcpy.", args_, total_args_size_);
@@ -2932,14 +3099,7 @@ Status DavinciModel::MallocKnownArgs() {
GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret);
} }
// malloc dynamic and static hybrid memory
if (total_hybrid_args_size_ != 0) {
rt_ret = rtMalloc(&hybrid_addrs_, total_hybrid_args_size_, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}
}

// malloc fixed addr memory, eg: rts op // malloc fixed addr memory, eg: rts op
if (total_fixed_addr_size_ != 0) { if (total_fixed_addr_size_ != 0) {
GELOGI("Begin to allocate fixed addr."); GELOGI("Begin to allocate fixed addr.");
@@ -2993,7 +3153,9 @@ Status DavinciModel::DistributeTask() {
} }


auto task_type = static_cast<rtModelTaskType_t>(task_def.type()); auto task_type = static_cast<rtModelTaskType_t>(task_def.type());
bool no_need_profiling = (task_type != RT_MODEL_TASK_KERNEL) && (task_type != RT_MODEL_TASK_KERNEL_EX);
bool no_need_profiling = (task_type != RT_MODEL_TASK_KERNEL)
&& (task_type != RT_MODEL_TASK_KERNEL_EX)
&& (task_type != RT_MODEL_TASK_HCCL);
GE_IF_BOOL_EXEC(no_need_profiling, continue); GE_IF_BOOL_EXEC(no_need_profiling, continue);


SaveDumpOpInfo(runtime_param_, op, task->GetTaskID(), task->GetStreamId()); SaveDumpOpInfo(runtime_param_, op, task->GetTaskID(), task->GetStreamId());
@@ -3008,8 +3170,6 @@ Status DavinciModel::DistributeTask() {
task_desc_info.block_dim = task_def.kernel().block_dim(); task_desc_info.block_dim = task_def.kernel().block_dim();
task_desc_info.task_id = task->GetTaskID(); task_desc_info.task_id = task->GetTaskID();
task_desc_info.stream_id = task->GetStreamId(); task_desc_info.stream_id = task->GetStreamId();
task_desc_info.shape_type = "static";
task_desc_info.cur_iter_num = 0;
task_desc_info_.emplace_back(task_desc_info); task_desc_info_.emplace_back(task_desc_info);
if (flag) { if (flag) {
if (task->GetSktTaskID() != 0xFFFFFFFF) { if (task->GetSktTaskID() != 0xFFFFFFFF) {
@@ -3097,20 +3257,27 @@ void DavinciModel::SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector<v


for (auto &input_outside_addrs : new_input_outside_addrs_) { for (auto &input_outside_addrs : new_input_outside_addrs_) {
ZeroCopyOffset &input_outside = input_outside_addrs.second; ZeroCopyOffset &input_outside = input_outside_addrs.second;
input_outside.SetOutsideAddrsValue(zero_copy_task, outside_addrs[i], args, offset + i * kAddrLen);
bool ret = input_outside.SetOutsideAddrsValue(zero_copy_task, outside_addrs[i], args, offset + i * kAddrLen);
if (ret) {
void *args_val = static_cast<uint8_t *>(args) + offset + i * kAddrLen;
SetBatchLabelAddr(op_desc, reinterpret_cast<uintptr_t>(args_val));
}
} }


for (auto &output_outside_addrs : new_output_outside_addrs_) { for (auto &output_outside_addrs : new_output_outside_addrs_) {
ZeroCopyOffset &output_outside = output_outside_addrs.second; ZeroCopyOffset &output_outside = output_outside_addrs.second;
output_outside.SetOutsideAddrsValue(zero_copy_task, outside_addrs[i], args, offset + i * kAddrLen);
bool ret = output_outside.SetOutsideAddrsValue(zero_copy_task, outside_addrs[i], args, offset + i * kAddrLen);
if (ret) {
void *args_val = static_cast<uint8_t *>(args) + offset + i * kAddrLen;
SetBatchLabelAddr(op_desc, reinterpret_cast<uintptr_t>(args_val));
}
} }
} }

string batch_label;
if (!AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label) || batch_label.empty()) {
auto it = zero_copy_op_id_batch_label_.find(op_desc->GetId());
if (it == zero_copy_op_id_batch_label_.end()) {
zero_copy_task.SetBatchLabel(kDefaultBatchLable); zero_copy_task.SetBatchLabel(kDefaultBatchLable);
} else { } else {
zero_copy_task.SetBatchLabel(batch_label);
zero_copy_task.SetBatchLabel(it->second);
} }


std::lock_guard<std::mutex> lock(outside_addrs_mutex_); std::lock_guard<std::mutex> lock(outside_addrs_mutex_);
@@ -3120,6 +3287,27 @@ void DavinciModel::SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector<v
} }
} }


void DavinciModel::SetBatchLabelAddr(const OpDescPtr &op_desc, uintptr_t addr) {
// Establish a mapping between batch label and zero copy address for multi-batch scenes
auto it = zero_copy_op_id_batch_label_.find(op_desc->GetId());
if (it == zero_copy_op_id_batch_label_.end()) {
return;
}

const string &batch_label = it->second;
auto iter = zero_copy_batch_label_addrs_.find(batch_label);
if (iter != zero_copy_batch_label_addrs_.end()) {
iter->second.insert(addr);
GELOGD("[ZCPY] Set zero copy batch label and addrs success, batch label: %s, op name:%s.", batch_label.c_str(),
op_desc->GetName().c_str());
} else {
set<uintptr_t> addrs = {addr};
zero_copy_batch_label_addrs_.emplace(pair<string, set<uintptr_t>>(batch_label, addrs));
GELOGD("[ZCPY] New added zero copy batch label and addrs success, batch label: %s, op name:%s.",
batch_label.c_str(), op_desc->GetName().c_str());
}
}

/// ///
/// @ingroup ge /// @ingroup ge
/// @brief Copy Check input size and model op size. /// @brief Copy Check input size and model op size.
@@ -3253,15 +3441,15 @@ Status DavinciModel::UpdateIoTaskArgs(const std::map<uint32_t, ZeroCopyOffset> &
void *addr = data.second.GetDataInfo().at(count).second; void *addr = data.second.GetDataInfo().at(count).second;
void *buffer_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(buffer.data) + void *buffer_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(buffer.data) +
data.second.GetRelativeOffset().at(count)); data.second.GetRelativeOffset().at(count));
GELOGI("[ZCPY] Copy %s blobs_index %u, virtual_addr: %p, size: %ld, user_data_addr: %p, batch_label: %s",
input_or_output.c_str(), data.first, addr, size, buffer_addr, batch_label.c_str());
GELOGI("[ZCPY] Copy %s blobs_index %u, virtual_addr: %p, size: %ld, user_data_addr: %p", input_or_output.c_str(),
data.first, addr, size, buffer_addr);
// For input data, just copy for rts task. // For input data, just copy for rts task.
for (ZeroCopyTask &task : zero_copy_tasks_) { for (ZeroCopyTask &task : zero_copy_tasks_) {
if (task.GetBatchLabel() != kDefaultBatchLable && task.GetBatchLabel() != batch_label) { if (task.GetBatchLabel() != kDefaultBatchLable && task.GetBatchLabel() != batch_label) {
continue; continue;
} }
uintptr_t addr_val = reinterpret_cast<uintptr_t>(addr); uintptr_t addr_val = reinterpret_cast<uintptr_t>(addr);
if (task.UpdateTaskParam(addr_val, buffer_addr) != SUCCESS) {
if (task.UpdateTaskParam(addr_val, buffer_addr, zero_copy_batch_label_addrs_, batch_label) != SUCCESS) {
return FAILED; return FAILED;
} }
} }
@@ -3623,6 +3811,9 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa
GELOGD("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_); GELOGD("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_);
GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed."); GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed.");
is_dynamic_ = input_data.is_dynamic_batch; is_dynamic_ = input_data.is_dynamic_batch;
if (!is_dynamic_) {
zero_copy_batch_label_addrs_.clear();
}


GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_PRE_PROC_START)); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_PRE_PROC_START));
Status ret = CopyModelData(input_data, output_data, is_dynamic_); Status ret = CopyModelData(input_data, output_data, is_dynamic_);


+ 53
- 33
ge/graph/load/new_model_manager/davinci_model.h View File

@@ -76,20 +76,6 @@ struct timeInfo {
int64_t dumpEndTime; int64_t dumpEndTime;
}; };


struct TaskMemInfo {
int64_t input_size{0};
int64_t output_size{0};
int64_t weight_size{0};
int64_t workspace_size{0};
int64_t total_size{0};
};

struct ProfileInfo {
FusionOpInfo fusion_info;
TaskMemInfo memory_info;
uint32_t task_count{0};
};

enum ExecuteMode { enum ExecuteMode {
INITIALIZATION, INITIALIZATION,
SYNCHRONIZATION, SYNCHRONIZATION,
@@ -240,6 +226,8 @@ class DavinciModel {
const vector<OpDescPtr> &GetDataList() const { return data_op_list_; } const vector<OpDescPtr> &GetDataList() const { return data_op_list_; }


// get Op // get Op
const map<uint32_t, OpDescPtr> &GetOpList() const { return op_list_; }

OpDescPtr GetOpByIndex(uint32_t index) const { OpDescPtr GetOpByIndex(uint32_t index) const {
if (op_list_.find(index) == op_list_.end()) { if (op_list_.find(index) == op_list_.end()) {
return nullptr; return nullptr;
@@ -448,6 +436,10 @@ class DavinciModel {


int64_t GetLoadEndTime() { return load_end_time_; } int64_t GetLoadEndTime() { return load_end_time_; }


Status SinkModelProfile();

Status SinkTimeProfile(const InputData &current_data);

Status ReportProfilingData(); Status ReportProfilingData();


void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) { void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) {
@@ -484,14 +476,6 @@ class DavinciModel {
void SetTotalIOAddrs(vector<void *> &io_addrs) { void SetTotalIOAddrs(vector<void *> &io_addrs) {
total_io_addrs_.insert(total_io_addrs_.end(), io_addrs.begin(), io_addrs.end()); total_io_addrs_.insert(total_io_addrs_.end(), io_addrs.begin(), io_addrs.end());
} }
void SetHybridArgsSize(uint32_t args_size) { total_hybrid_args_size_ += args_size; }
uint32_t GetHybridArgsSize() {
return total_hybrid_args_size_;
}
void *GetCurrentHybridArgsAddr(uint32_t offset) {
void *cur_args = static_cast<char *>(hybrid_addrs_) + offset;
return cur_args;
}
void SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size); void SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size);
int64_t GetFixedAddrsSize(string tensor_name); int64_t GetFixedAddrsSize(string tensor_name);
void *GetCurrentFixedAddr(int64_t offset) const { void *GetCurrentFixedAddr(int64_t offset) const {
@@ -510,7 +494,7 @@ class DavinciModel {
Status MallocKnownArgs(); Status MallocKnownArgs();
Status UpdateKnownNodeArgs(const vector<void *> &inputs, const vector<void *> &outputs); Status UpdateKnownNodeArgs(const vector<void *> &inputs, const vector<void *> &outputs);
Status CreateKnownZeroCopyMap(const vector<void *> &inputs, const vector<void *> &outputs); Status CreateKnownZeroCopyMap(const vector<void *> &inputs, const vector<void *> &outputs);
Status UpdateKnownZeroCopyAddr(vector<void *> &total_io_addrs);
Status UpdateKnownZeroCopyAddr();
void SetKnownNodeAddrNotChanged(bool base_addr_not_changed) { base_addr_not_changed_ = base_addr_not_changed; } void SetKnownNodeAddrNotChanged(bool base_addr_not_changed) { base_addr_not_changed_ = base_addr_not_changed; }


Status GetOrigInputInfo(uint32_t index, OriginInputInfo &orig_input_info); Status GetOrigInputInfo(uint32_t index, OriginInputInfo &orig_input_info);
@@ -547,6 +531,15 @@ class DavinciModel {


/// ///
/// @ingroup ge /// @ingroup ge
/// @brief Save Batch label Info.
/// @param [in] const OpDescPtr &op_desc
/// @param [in] uintptr_t addr: address value in args block.
/// @return None.
///
void SetBatchLabelAddr(const OpDescPtr &op_desc, uintptr_t addr);

///
/// @ingroup ge
/// @brief Copy Check input size and model op size. /// @brief Copy Check input size and model op size.
/// @param [in] const int64_t &input_size: input size. /// @param [in] const int64_t &input_size: input size.
/// @param [in] const int64_t &op_size: model op size. /// @param [in] const int64_t &op_size: model op size.
@@ -658,6 +651,14 @@ class DavinciModel {


/// ///
/// @ingroup ge /// @ingroup ge
/// @brief input zero copy node Initialize.
/// @param [in] NodePtr: Data Op.
/// @return Status
///
Status InitInputZeroCopyNodes(const NodePtr &node);

///
/// @ingroup ge
/// @brief NetOutput Op Initialize. /// @brief NetOutput Op Initialize.
/// @param [in] NodePtr: NetOutput Op. /// @param [in] NodePtr: NetOutput Op.
/// @return Status /// @return Status
@@ -666,6 +667,30 @@ class DavinciModel {


/// ///
/// @ingroup ge /// @ingroup ge
/// @brief output zero copy node Initialize.
/// @param [in] NodePtr: Data Op.
/// @return Status
///
Status InitOutputZeroCopyNodes(const NodePtr &node);

///
/// @ingroup ge
/// @brief input zero copy node Initialize for Case.
/// @param [in] NodePtr: Data Op.
/// @return Status
///
Status InitInputBatchLabel(const NodePtr &node);

///
/// @ingroup ge
/// @brief output zero copy node Initialize for Case.
/// @param [in] NodePtr: netoutput Op.
/// @return Status
///
Status InitOutputBatchLabel(const NodePtr &node);

///
/// @ingroup ge
/// @brief Constant Op Init. /// @brief Constant Op Init.
/// @return Status /// @return Status
/// ///
@@ -812,11 +837,6 @@ class DavinciModel {


void SetDataDumperArgs(const ComputeGraphPtr &compute_graph); void SetDataDumperArgs(const ComputeGraphPtr &compute_graph);


Status InitModelProfile();
Status SinkModelProfile();

Status SinkTimeProfile(const InputData &current_data);

Status GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data, Status GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data,
std::vector<ge::OutputTensorInfo> &outputs); std::vector<ge::OutputTensorInfo> &outputs);


@@ -894,6 +914,11 @@ class DavinciModel {
std::vector<ZeroCopyTask> zero_copy_tasks_; // Task used Data or NetOutput addr. std::vector<ZeroCopyTask> zero_copy_tasks_; // Task used Data or NetOutput addr.
std::set<const void *> copy_only_addrs_; // Address need copy to original place. std::set<const void *> copy_only_addrs_; // Address need copy to original place.


// {op_id, batch_label}
std::map<int64_t, std::string> zero_copy_op_id_batch_label_;
// {batch_label, addrs}
std::map<std::string, std::set<uintptr_t>> zero_copy_batch_label_addrs_;

std::vector<TaskInfoPtr> task_list_; std::vector<TaskInfoPtr> task_list_;
// rt_moodel_handle // rt_moodel_handle
rtModel_t rt_model_handle_; rtModel_t rt_model_handle_;
@@ -952,8 +977,6 @@ class DavinciModel {
void *args_ = nullptr; void *args_ = nullptr;
void *args_host_ = nullptr; void *args_host_ = nullptr;
void *fixed_addrs_ = nullptr; void *fixed_addrs_ = nullptr;
void *hybrid_addrs_ = nullptr;
uint32_t total_hybrid_args_size_ = 0;
int64_t total_fixed_addr_size_ = 0; int64_t total_fixed_addr_size_ = 0;
std::map<const void *, void *> knonw_input_data_info_; std::map<const void *, void *> knonw_input_data_info_;
std::map<const void *, void *> knonw_output_data_info_; std::map<const void *, void *> knonw_output_data_info_;
@@ -993,9 +1016,6 @@ class DavinciModel {
// key: input_index: input is merge node; value: each gear info and each output shape // key: input_index: input is merge node; value: each gear info and each output shape
std::map<size_t, std::map<vector<int64_t>, vector<int64_t>>> merge_nodes_gear_and_real_out_shape_info_; std::map<size_t, std::map<vector<int64_t>, vector<int64_t>>> merge_nodes_gear_and_real_out_shape_info_;
std::vector<std::vector<int64_t>> all_gears_info_; std::vector<std::vector<int64_t>> all_gears_info_;

std::multimap<uint32_t, uint32_t> op_id_map_;
std::vector<ProfileInfo> profile_list_;
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_

+ 35
- 40
ge/graph/load/new_model_manager/model_manager.cc View File

@@ -89,7 +89,6 @@ Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, u
if (op_type == aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY) { if (op_type == aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY) {
std::vector<uint64_t> v_aicpu_kernel; std::vector<uint64_t> v_aicpu_kernel;
std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id);
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
auto iter = model_aicpu_kernel_.find(model_key); auto iter = model_aicpu_kernel_.find(model_key);
if (iter != model_aicpu_kernel_.end()) { if (iter != model_aicpu_kernel_.end()) {
GELOGD("kernel destroy session_id %lu, model_id %u.", session_id, model_id); GELOGD("kernel destroy session_id %lu, model_id %u.", session_id, model_id);
@@ -177,7 +176,7 @@ Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, u
} }


void ModelManager::DestroyAicpuSession(uint64_t session_id) { void ModelManager::DestroyAicpuSession(uint64_t session_id) {
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
std::lock_guard<std::mutex> lock(sess_ids_mutex_);
auto it = sess_ids_.find(session_id); auto it = sess_ids_.find(session_id);
if (it == sess_ids_.end()) { if (it == sess_ids_.end()) {
GELOGI("The session: %lu not created.", session_id); GELOGI("The session: %lu not created.", session_id);
@@ -206,7 +205,7 @@ void ModelManager::DestroyAicpuSession(uint64_t session_id) {
} }


ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) { ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) {
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
std::lock_guard<std::mutex> lock(map_mutex_);
auto hybrid_davinci_model = hybrid_model_map_.find(model_id); auto hybrid_davinci_model = hybrid_model_map_.find(model_id);
if (hybrid_davinci_model != hybrid_model_map_.end()) { if (hybrid_davinci_model != hybrid_model_map_.end()) {
uint64_t session_id = hybrid_davinci_model->second->GetSessionId(); uint64_t session_id = hybrid_davinci_model->second->GetSessionId();
@@ -216,8 +215,8 @@ ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) {


auto it = model_map_.find(model_id); auto it = model_map_.find(model_id);
if (it == model_map_.end()) { if (it == model_map_.end()) {
GELOGE(ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", model_id);
return ACL_ERROR_GE_EXEC_MODEL_ID_INVALID;
GELOGE(GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", model_id);
return GE_EXEC_MODEL_ID_INVALID;
} }
uint64_t session_id = it->second->GetSessionId(); uint64_t session_id = it->second->GetSessionId();
DestroyAicpuSession(session_id); DestroyAicpuSession(session_id);
@@ -226,7 +225,7 @@ ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) {


ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) { ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) {
GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id); GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id);
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
std::lock_guard<std::mutex> lock(map_mutex_);
std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id);
if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) {
Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY, session_id, model_id); Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY, session_id, model_id);
@@ -239,7 +238,7 @@ ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_
} }


ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id) { ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id) {
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
std::lock_guard<std::mutex> lock(map_mutex_);
std::vector<uint64_t> v_aicpu_kernel; std::vector<uint64_t> v_aicpu_kernel;
std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id);
if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) {
@@ -251,7 +250,7 @@ ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_i
} }


ModelManager::~ModelManager() { ModelManager::~ModelManager() {
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
std::lock_guard<std::mutex> lock(map_mutex_);
model_map_.clear(); model_map_.clear();
model_aicpu_kernel_.clear(); model_aicpu_kernel_.clear();
cust_aicpu_so_.clear(); cust_aicpu_so_.clear();
@@ -359,18 +358,18 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge


void ModelManager::InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model) { void ModelManager::InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model) {
GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", id); GE_CHK_BOOL_EXEC(davinci_model != nullptr, return, "davinci_model ptr is null, id: %u", id);
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
std::lock_guard<std::mutex> lock(map_mutex_);
model_map_[id] = davinci_model; model_map_[id] = davinci_model;
} }


void ModelManager::InsertModel(uint32_t id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) { void ModelManager::InsertModel(uint32_t id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) {
GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", id); GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", id);
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
std::lock_guard<std::mutex> lock(map_mutex_);
hybrid_model_map_[id] = hybrid_model; hybrid_model_map_[id] = hybrid_model;
} }


Status ModelManager::DeleteModel(uint32_t id) { Status ModelManager::DeleteModel(uint32_t id) {
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
std::lock_guard<std::mutex> lock(map_mutex_);


auto it = model_map_.find(id); auto it = model_map_.find(id);
auto hybrid_model_it = hybrid_model_map_.find(id); auto hybrid_model_it = hybrid_model_map_.find(id);
@@ -385,22 +384,22 @@ Status ModelManager::DeleteModel(uint32_t id) {
} else if (hybrid_model_it != hybrid_model_map_.end()) { } else if (hybrid_model_it != hybrid_model_map_.end()) {
(void)hybrid_model_map_.erase(hybrid_model_it); (void)hybrid_model_map_.erase(hybrid_model_it);
} else { } else {
GELOGE(ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", id);
return ACL_ERROR_GE_EXEC_MODEL_ID_INVALID;
GELOGE(GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", id);
return GE_EXEC_MODEL_ID_INVALID;
} }


return SUCCESS; return SUCCESS;
} }


std::shared_ptr<DavinciModel> ModelManager::GetModel(uint32_t id) { std::shared_ptr<DavinciModel> ModelManager::GetModel(uint32_t id) {
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
std::lock_guard<std::mutex> lock(map_mutex_);


auto it = model_map_.find(id); auto it = model_map_.find(id);
return (it == model_map_.end()) ? nullptr : it->second; return (it == model_map_.end()) ? nullptr : it->second;
} }


std::shared_ptr<hybrid::HybridDavinciModel> ModelManager::GetHybridModel(uint32_t id) { std::shared_ptr<hybrid::HybridDavinciModel> ModelManager::GetHybridModel(uint32_t id) {
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
std::lock_guard<std::mutex> lock(map_mutex_);


auto it = hybrid_model_map_.find(id); auto it = hybrid_model_map_.find(id);
return (it == hybrid_model_map_.end()) ? nullptr : it->second; return (it == hybrid_model_map_.end()) ? nullptr : it->second;
@@ -903,7 +902,7 @@ Status ModelManager::GetInputOutputDescInfo(const uint32_t model_id, vector<Inpu
} }


std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID,
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, GE_EXEC_MODEL_ID_INVALID,
"GetInputOutputDescInfo Failed, Invalid model id %u!", model_id); "GetInputOutputDescInfo Failed, Invalid model id %u!", model_id);


davinci_model->SetModelDescVersion(new_model_desc); davinci_model->SetModelDescVersion(new_model_desc);
@@ -971,9 +970,8 @@ Status ModelManager::GetUserDesignateShapeOrder(const uint32_t model_id,
} }


Status ModelManager::GetCurShape(const uint32_t model_id, std::vector<int64_t> &batch_info, int32_t &dynamic_type) { Status ModelManager::GetCurShape(const uint32_t model_id, std::vector<int64_t> &batch_info, int32_t &dynamic_type) {
auto davinci_model = GetModel(model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID,
"GetCurShape Failed, Invalid Model ID %u!", model_id);
std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id);
GE_CHECK_NOTNULL(davinci_model);
davinci_model->GetCurShape(batch_info, dynamic_type); davinci_model->GetCurShape(batch_info, dynamic_type);
return SUCCESS; return SUCCESS;
} }
@@ -986,8 +984,7 @@ Status ModelManager::GetModelAttr(uint32_t model_id, std::vector<string> &dynami
} }


std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID,
"GetModelAttr Failed, Invalid Model ID %u!", model_id);
GE_CHECK_NOTNULL(davinci_model);
davinci_model->GetModelAttr(dynamic_output_shape_info); davinci_model->GetModelAttr(dynamic_output_shape_info);
return SUCCESS; return SUCCESS;
} }
@@ -997,8 +994,9 @@ Status ModelManager::GetInputOutputDescInfoForZeroCopy(const uint32_t model_id,
std::vector<uint32_t> &inputFormats, std::vector<uint32_t> &inputFormats,
std::vector<uint32_t> &outputFormats) { std::vector<uint32_t> &outputFormats) {
std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID,
"GetInputOutputDescInfo Failed, Invalid model id %u!", model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetInputOutputDescInfo Failed, Invalid model id %u!",
model_id);

return davinci_model->GetInputOutputDescInfoForZeroCopy(input_desc, output_desc, inputFormats, outputFormats); return davinci_model->GetInputOutputDescInfoForZeroCopy(input_desc, output_desc, inputFormats, outputFormats);
} }


@@ -1013,14 +1011,18 @@ Status ModelManager::GetInputOutputDescInfoForZeroCopy(const uint32_t model_id,
Status ModelManager::GetAIPPInfo(const uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info) { Status ModelManager::GetAIPPInfo(const uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info) {
std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID,
"GetAIPPInfo failed, invalid model_id is %u.", model_id);
"GetAIPPInfo failed, invalid model_id is %u.",
model_id);

return davinci_model->GetAIPPInfo(index, aipp_info); return davinci_model->GetAIPPInfo(index, aipp_info);
} }


Status ModelManager::GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index) { Status ModelManager::GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index) {
std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID,
"GetAIPPInfo failed, invalid model_id is %u.", model_id);
"GetAIPPInfo failed, invalid model_id is %u.",
model_id);

return davinci_model->GetAippType(index, type, aipp_index); return davinci_model->GetAippType(index, type, aipp_index);
} }


@@ -1053,15 +1055,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model
mmTimespec timespec = mmGetTickCount(); mmTimespec timespec = mmGetTickCount();


ModelHelper model_helper; ModelHelper model_helper;
Status ret = model_helper.LoadRootModel(model);
if (model_helper.GetModelType()) {
bool is_shape_unknown = false;
GE_CHK_STATUS_RET(model_helper.GetGeRootModel()->CheckIsUnknownShape(is_shape_unknown),
"CheckIsUnknownShape failed, model id:%u", model_id);
if (is_shape_unknown || GetContext().GetHostExecFlag()) {
return DoLoadHybridModelOnline(model_id, model_helper.GetGeRootModel(), listener);
}
}
Status ret = model_helper.LoadModel(model);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "load model failed."); GELOGE(ret, "load model failed.");
return ret; return ret;
@@ -1075,8 +1069,8 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Make shared failed"); GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Make shared failed");
return ACL_ERROR_GE_MEMORY_ALLOCATION; return ACL_ERROR_GE_MEMORY_ALLOCATION;
} catch (...) { } catch (...) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Make shared failed since other exception raise");
return ACL_ERROR_GE_MEMORY_ALLOCATION;
GELOGE(INTERNAL_ERROR, "Make shared failed since other exception raise");
return INTERNAL_ERROR;
} }
ret = davinci_model->Assign(ge_model); ret = davinci_model->Assign(ge_model);
if (ret != SUCCESS) { if (ret != SUCCESS) {
@@ -1088,7 +1082,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model
int32_t device_id = 0; int32_t device_id = 0;
rtError_t rt_ret = rtGetDevice(&device_id); rtError_t rt_ret = rtGetDevice(&device_id);
if (rt_ret != RT_ERROR_NONE || device_id < 0) { if (rt_ret != RT_ERROR_NONE || device_id < 0) {
GELOGE(rt_ret, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id);
GELOGE(RT_FAILED, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id);
return RT_ERROR_TO_GE_STATUS(rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret);
} }
davinci_model->SetDeviceId(device_id); davinci_model->SetDeviceId(device_id);
@@ -1220,7 +1214,7 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy


std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID,
"Invalid model id %u, check whether model has been loaded or not.", model_id);
"Invalid model id %u, check weather model has been loaded or not.", model_id);


if (davinci_model->NeedDestroyAicpuKernel()) { if (davinci_model->NeedDestroyAicpuKernel()) {
GELOGI("Start to destroy specified aicpu kernel."); GELOGI("Start to destroy specified aicpu kernel.");
@@ -1243,7 +1237,7 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy
} }


Status ModelManager::CreateAicpuSession(uint64_t session_id) { Status ModelManager::CreateAicpuSession(uint64_t session_id) {
std::lock_guard<std::recursive_mutex> lock(map_mutex_);
std::lock_guard<std::mutex> lock(sess_ids_mutex_);
auto it = sess_ids_.find(session_id); auto it = sess_ids_.find(session_id);
// never been created by any model // never been created by any model
if (it == sess_ids_.end()) { if (it == sess_ids_.end()) {
@@ -1462,7 +1456,8 @@ void ModelManager::GenModelId(uint32_t *id) {
if (id == nullptr) { if (id == nullptr) {
return; return;
} }
std::lock_guard<std::recursive_mutex> lock(map_mutex_);

std::lock_guard<std::mutex> lock(map_mutex_);
*id = ++max_model_id_; *id = ++max_model_id_;
} }




+ 2
- 1
ge/graph/load/new_model_manager/model_manager.h View File

@@ -353,7 +353,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager {
std::map<uint32_t, std::shared_ptr<hybrid::HybridDavinciModel>> hybrid_model_map_; std::map<uint32_t, std::shared_ptr<hybrid::HybridDavinciModel>> hybrid_model_map_;
std::map<std::string, std::vector<uint64_t>> model_aicpu_kernel_; std::map<std::string, std::vector<uint64_t>> model_aicpu_kernel_;
uint32_t max_model_id_; uint32_t max_model_id_;
std::recursive_mutex map_mutex_;
std::mutex map_mutex_;
std::mutex sess_ids_mutex_;
std::mutex session_id_create_mutex_; std::mutex session_id_create_mutex_;
static::std::mutex exeception_infos_mutex_; static::std::mutex exeception_infos_mutex_;
uint64_t session_id_bias_; uint64_t session_id_bias_;


+ 74
- 58
ge/graph/load/new_model_manager/task_info/kernel_task_info.cc View File

@@ -90,18 +90,20 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci
fusion_op_info_.op_index = context.op_index(); fusion_op_info_.original_op_names = original_op_names; fusion_op_info_.op_index = context.op_index(); fusion_op_info_.original_op_names = original_op_names;
fusion_op_info_.op_name = op_desc_->GetName()); fusion_op_info_.op_name = op_desc_->GetName());


string session_graph_model_id;
davinci_model_->GetUniqueId(op_desc_, session_graph_model_id);
// get bin_file_key
const char *bin_file_key = davinci_model_->GetRegisterStub(op_desc_->GetName(), session_graph_model_id);
// new aicpu kernel(rtCpuKernelLaunch) no need to check function // new aicpu kernel(rtCpuKernelLaunch) no need to check function
if (kernel_type_ == ccKernelType::CCE_AI_CORE) { if (kernel_type_ == ccKernelType::CCE_AI_CORE) {
rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(kernel_def.stub_func().c_str()), &stub_func_);
rtError_t rt_ret;
rt_ret = rtGetFunctionByName(const_cast<char *>(kernel_def.stub_func().c_str()), &stub_func_);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s", GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s",
kernel_def.stub_func().c_str()); kernel_def.stub_func().c_str());
return RT_ERROR_TO_GE_STATUS(rt_ret);); return RT_ERROR_TO_GE_STATUS(rt_ret););
} else if (kernel_type_ == ccKernelType::TE) { } else if (kernel_type_ == ccKernelType::TE) {
// get bin_file_key
string session_graph_model_id;
davinci_model_->GetUniqueId(op_desc_, session_graph_model_id);
const char *bin_file_key = davinci_model_->GetRegisterStub(op_desc_->GetName(), session_graph_model_id);
rtError_t rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_);
rtError_t rt_ret;
rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE,
GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. bin_file_key: %s", bin_file_key); GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. bin_file_key: %s", bin_file_key);
return RT_ERROR_TO_GE_STATUS(rt_ret);); return RT_ERROR_TO_GE_STATUS(rt_ret););
@@ -370,11 +372,7 @@ Status KernelTaskInfo::SuperKernelDistribute() {
Status KernelTaskInfo::Distribute() { Status KernelTaskInfo::Distribute() {
GELOGD("KernelTaskInfo Distribute Start."); GELOGD("KernelTaskInfo Distribute Start.");
if (davinci_model_->IsKnownNode()) { if (davinci_model_->IsKnownNode()) {
if (kernel_type_ == ccKernelType::TE) {
args_ = davinci_model_->GetCurrentArgsAddr(args_offset_);
} else if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) {
args_ = davinci_model_->GetCurrentHybridArgsAddr(hybrid_args_offset_);
}
args_ = davinci_model_->GetCurrentArgsAddr(args_offset_);
GELOGI("Known node %s args addr %p, offset %u.", op_desc_->GetName().c_str(), args_, args_offset_); GELOGI("Known node %s args addr %p, offset %u.", op_desc_->GetName().c_str(), args_, args_offset_);
} }
rtError_t rt_ret = RT_ERROR_NONE; rtError_t rt_ret = RT_ERROR_NONE;
@@ -430,31 +428,36 @@ Status KernelTaskInfo::UpdateArgs() {
const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam();
vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc_);
vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_);
vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_);


vector<void *> io_addrs; vector<void *> io_addrs;
io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end());
io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end());
if (kernel_type_ == ccKernelType::TE) {
vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_);
if (!op_desc_->HasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR)) {
io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end());
io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end());
io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end());
davinci_model_->SetTotalIOAddrs(io_addrs);
} else if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) {
davinci_model_->UpdateKnownZeroCopyAddr(io_addrs);
uintptr_t io_addr = reinterpret_cast<uintptr_t>(args_addr.get()) + sizeof(aicpu::AicpuParamHead);
auto addrs_size = sizeof(uint64_t) * io_addrs.size();
errno_t sec_ret = memcpy_s(reinterpret_cast<void *>(io_addr), addrs_size, io_addrs.data(), addrs_size);
if (sec_ret != EOK) {
GELOGE(FAILED, "memcpy failed, ret: %d", sec_ret);
return FAILED;
}
// copy args to device
rtError_t rt_ret = rtMemcpy(args_, args_size_, args_addr.get(), args_size_, RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
} else {
string peer_input_name;
if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name)) {
uint32_t output_index = davinci_model_->GetFixedAddrOutputIndex(peer_input_name);
if (output_index > output_data_addrs.size()) {
GELOGE(FAILED, "The output data addr size[%zu] and output index[%u] are inconsistent.",
output_data_addrs.size(), output_index);
return FAILED;
}
io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end());
for (size_t i = 0; i < output_data_addrs.size(); ++i) {
if (i == output_index) {
void *fixed_addr = davinci_model_->GetCurrentFixedAddr(fixed_addr_offset_);
io_addrs.emplace_back(fixed_addr);
continue;
}
io_addrs.emplace_back(output_data_addrs[i]);
}
io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end());
} }
} }


davinci_model_->SetTotalIOAddrs(io_addrs);
GELOGI("KernelTaskInfo::UpdateArgs success."); GELOGI("KernelTaskInfo::UpdateArgs success.");
return SUCCESS; return SUCCESS;
} }
@@ -530,18 +533,33 @@ Status KernelTaskInfo::UpdateL2Data(const domi::KernelDef &kernel_def) {
} }


Status KernelTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { Status KernelTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) {
const domi::KernelDef &kernel_def = task_def.kernel();
domi::KernelDef kernel_def = task_def.kernel();
uint32_t args_size = kernel_def.args_size();
args_offset_ = davinci_model->GetTotalArgsSize();
davinci_model->SetTotalArgsSize(args_size);
GELOGI("kernel task name , args_size %u, args_offset %u", args_size, args_offset_);

// get opcontext stored in model
const domi::KernelContext &context = kernel_def.context(); const domi::KernelContext &context = kernel_def.context();
kernel_type_ = static_cast<ccKernelType>(context.kernel_type());
if (kernel_type_ == ccKernelType::TE) {
uint32_t args_size = kernel_def.args_size();
args_offset_ = davinci_model->GetTotalArgsSize();
davinci_model->SetTotalArgsSize(args_size);
GELOGI("kernel task name , args_size %u, args_offset %u", args_size, args_offset_);
} else if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) {
hybrid_args_offset_ = davinci_model->GetHybridArgsSize();
davinci_model->SetHybridArgsSize(kernel_def.args_size());
GELOGI("aicpu kernel task name , args_size %u, args_offset %u", kernel_def.args_size(), hybrid_args_offset_);
// get opdesc
op_desc_ = davinci_model->GetOpByIndex(context.op_index());
GE_CHECK_NOTNULL(op_desc_);
// alloc fixed addr
string peer_input_name;
if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name) && !peer_input_name.empty()) {
uint32_t output_index = davinci_model->GetFixedAddrOutputIndex(peer_input_name);
if (output_index > op_desc_->GetOutputsSize()) {
GELOGE(FAILED, "The output size[%zu] and output index[%u] are inconsistent.", op_desc_->GetOutputsSize(),
output_index);
return FAILED;
}
fixed_addr_offset_ = davinci_model->GetFixedAddrsSize(peer_input_name);
auto tensor_desc = op_desc_->GetOutputDesc(output_index);
int64_t tensor_size = 0;
GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size));
davinci_model->SetTotalFixedAddrsSize(peer_input_name, tensor_size);
GELOGI("Calculate stream switch task args , tensor size is %ld, fixed addr offset %ld", tensor_size,
fixed_addr_offset_);
} }
return SUCCESS; return SUCCESS;
} }
@@ -870,7 +888,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k
} }


// copy args to new host memory // copy args to new host memory
args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]);
std::unique_ptr<uint8_t[]> args_addr(new (std::nothrow) uint8_t[args_size_]);
GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_)
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_);
if (sec_ret != EOK) { if (sec_ret != EOK) {
@@ -878,23 +896,8 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k
return FAILED; return FAILED;
} }


auto aicpu_param_head = reinterpret_cast<aicpu::AicpuParamHead *>(args_addr.get());
const auto &ext_info = kernel_def.kernel_ext_info();
auto init_ret = InitAicpuTaskExtInfo(ext_info);
if (init_ret != SUCCESS) {
GELOGE(init_ret, "Init aicpu task ext info failed, ext_info size=%zu", ext_info.size());
return init_ret;
}
GELOGI("Node[%s] type[%s] kernel_ext_info size=%zu, aicpu_ext_info_addr_=%p", op_desc_->GetName().c_str(),
op_desc_->GetType().c_str(), ext_info.size(), aicpu_ext_info_addr_);

aicpu_param_head->extInfoAddr = reinterpret_cast<uintptr_t>(aicpu_ext_info_addr_);
aicpu_param_head->extInfoLength = static_cast<uintptr_t>(ext_info.size());

if (davinci_model_->IsKnownNode()) {
return SUCCESS;
}
const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam();

vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc);
vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc);
vector<void *> io_addrs; vector<void *> io_addrs;
@@ -911,6 +914,19 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k
} }
} }


auto aicpu_param_head = reinterpret_cast<aicpu::AicpuParamHead *>(args_addr.get());
const auto &ext_info = kernel_def.kernel_ext_info();
auto init_ret = InitAicpuTaskExtInfo(ext_info);
if (init_ret != SUCCESS) {
GELOGE(init_ret, "Init aicpu task ext info failed, ext_info size=%zu", ext_info.size());
return init_ret;
}
GELOGI("Node[%s] type[%s] kernel_ext_info size=%zu, aicpu_ext_info_addr_=%p", op_desc_->GetName().c_str(),
op_desc_->GetType().c_str(), ext_info.size(), aicpu_ext_info_addr_);

aicpu_param_head->extInfoAddr = reinterpret_cast<uintptr_t>(aicpu_ext_info_addr_);
aicpu_param_head->extInfoLength = static_cast<uintptr_t>(ext_info.size());

// malloc device memory for args // malloc device memory for args
rtError_t rt_ret = rtMalloc(static_cast<void **>(&args_), args_size_, RT_MEMORY_HBM); rtError_t rt_ret = rtMalloc(static_cast<void **>(&args_), args_size_, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {


+ 0
- 2
ge/graph/load/new_model_manager/task_info/kernel_task_info.h View File

@@ -159,9 +159,7 @@ class KernelTaskInfo : public TaskInfo {
OpDescPtr op_desc_; OpDescPtr op_desc_;
DavinciModel *davinci_model_; DavinciModel *davinci_model_;
uint32_t args_offset_ = 0; uint32_t args_offset_ = 0;
uint32_t hybrid_args_offset_ = 0;
int64_t fixed_addr_offset_ = 0; int64_t fixed_addr_offset_ = 0;
std::unique_ptr<uint8_t[]> args_addr = nullptr;
bool call_save_dump_ = false; bool call_save_dump_ = false;


// aicpu ext_info device mem // aicpu ext_info device mem


+ 7
- 3
ge/graph/load/new_model_manager/zero_copy_offset.cc View File

@@ -183,18 +183,22 @@ void ZeroCopyOffset::SetOutputOutsideAddrs(const int64_t &input_offset, const bo
addr_count_ = out_count; addr_count_ = out_count;
} }


void ZeroCopyOffset::SetOutsideAddrsValue(ZeroCopyTask &zero_copy_task, void *outside_addr, void *args, size_t offset) {
bool ZeroCopyOffset::SetOutsideAddrsValue(ZeroCopyTask &zero_copy_task, void *outside_addr, void *args, size_t offset) {
const auto addr_val = reinterpret_cast<uintptr_t>(outside_addr); const auto addr_val = reinterpret_cast<uintptr_t>(outside_addr);
bool set_batch_label_flag = false;
for (uint32_t out_count = 0; out_count < GetAddrCount(); ++out_count) { for (uint32_t out_count = 0; out_count < GetAddrCount(); ++out_count) {
auto args_addrs = outside_addrs_[out_count].find(outside_addr);
if (args_addrs != outside_addrs_[out_count].end()) {
auto &addrs_mapping_list = GetOutsideAddrs();
auto args_addrs = addrs_mapping_list[out_count].find(outside_addr);
if (args_addrs != addrs_mapping_list[out_count].end()) {
GE_CHK_STATUS(zero_copy_task.SetTaskArgsOffset(addr_val, offset), "Input args invalid."); GE_CHK_STATUS(zero_copy_task.SetTaskArgsOffset(addr_val, offset), "Input args invalid.");
void *args_val = static_cast<uint8_t *>(args) + offset; void *args_val = static_cast<uint8_t *>(args) + offset;
args_addrs->second.push_back(args_val); args_addrs->second.push_back(args_val);
GELOGD("[ZCPY] set copy input: virtual_addr: 0x%lx, task_addr: %p, args: %p, offset: %zu.", addr_val, args_val, GELOGD("[ZCPY] set copy input: virtual_addr: 0x%lx, task_addr: %p, args: %p, offset: %zu.", addr_val, args_val,
args, offset); args, offset);
set_batch_label_flag = true;
} }
} }
return set_batch_label_flag;
} }


} // namespace ge } // namespace ge

+ 1
- 1
ge/graph/load/new_model_manager/zero_copy_offset.h View File

@@ -51,7 +51,7 @@ class ZeroCopyOffset {
const OpDescPtr &op_desc, const size_t &idx, bool &fusion_flag); const OpDescPtr &op_desc, const size_t &idx, bool &fusion_flag);
void SetOutputOutsideAddrs(const int64_t &input_offset, const bool &fusion_flag, void *addr, void SetOutputOutsideAddrs(const int64_t &input_offset, const bool &fusion_flag, void *addr,
std::vector<void *> &tensor_addrs); std::vector<void *> &tensor_addrs);
void SetOutsideAddrsValue(ZeroCopyTask &zero_copy_task, void *outside_addr, void *args, size_t offset);
bool SetOutsideAddrsValue(ZeroCopyTask &zero_copy_task, void *outside_addr, void *args, size_t offset);


// basic_addr of l2-fusion // basic_addr of l2-fusion
void *GetBasicAddr() const { return basic_addr_; } void *GetBasicAddr() const { return basic_addr_; }


+ 45
- 2
ge/graph/load/new_model_manager/zero_copy_task.cc View File

@@ -22,6 +22,8 @@
#include "common/ge_compiler_options.h" #include "common/ge_compiler_options.h"


namespace ge { namespace ge {
const char *const kDefaultBatchLable = "Batch_default";

ZeroCopyTask::ZeroCopyTask(const string &name, uint8_t *args, size_t size) ZeroCopyTask::ZeroCopyTask(const string &name, uint8_t *args, size_t size)
: name_(name), args_addr_(args), args_size_(size), is_updated_(false) {} : name_(name), args_addr_(args), args_size_(size), is_updated_(false) {}


@@ -64,18 +66,59 @@ void ZeroCopyTask::SetOriginalArgs(const void *info, size_t size) {
const uint8_t *data = static_cast<const uint8_t *>(info); const uint8_t *data = static_cast<const uint8_t *>(info);
args_info_.assign(data, data + size); args_info_.assign(data, data + size);


GELOGI("[ZCPY] %s set original args info: %p, args_addr: %p, args size: %zu, info size: %zu", name_.c_str(), info,
GELOGI("[ZCPY] %s set info from virtual_addr: %p, args_addr: %p, args size: %zu, info size: %zu", name_.c_str(), info,
args_addr_, args_size_, size); args_addr_, args_size_, size);
} }


/** /**
* @ingroup ge * @ingroup ge
* @brief Check is dynamic batch node.
* @param [in] addr: virtual address value from Op.
* @param [in] data: data buffer from user.
* @param [in] batch_addrs: dynamic batch addr info.
* @param [in] batch_label: batch label.
* @return: true / false
*/
bool ZeroCopyTask::CheckDynamicBatch(const map<string, set<uintptr_t>> &batch_addrs, const string &batch_label,
uintptr_t addr) {
// Used for dynamic batch / resolution scene
set<uintptr_t> dynamic_input_addrs;
auto dynamic_input_iter = batch_addrs.find(batch_label);
if (dynamic_input_iter != batch_addrs.end()) {
dynamic_input_addrs = dynamic_input_iter->second;
}

set<uintptr_t> fix_input_addrs;
auto fix_input_iter = batch_addrs.find(kDefaultBatchLable);
if (fix_input_iter != batch_addrs.end()) {
fix_input_addrs = fix_input_iter->second;
}

if (fix_input_addrs.empty()) {
if (!dynamic_input_addrs.empty() && dynamic_input_addrs.find(addr) == dynamic_input_addrs.end()) {
return false;
}
} else {
if (!dynamic_input_addrs.empty() && dynamic_input_addrs.find(addr) == dynamic_input_addrs.end() &&
fix_input_addrs.find(addr) == fix_input_addrs.end()) {
return false;
}
}

return true;
}

/**
* @ingroup ge
* @brief Set user data addr to Task param. * @brief Set user data addr to Task param.
* @param [in] addr: virtual address value from Op. * @param [in] addr: virtual address value from Op.
* @param [in] buffer_addr: real_data_buffer_addr from user. * @param [in] buffer_addr: real_data_buffer_addr from user.
* @param [in] batch_addrs: dynamic batch addr info.
* @param [in] batch_label: batch label.
* @return: void * @return: void
*/ */
Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, void *buffer_addr) {
Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, void *buffer_addr, const map<string, set<uintptr_t>> &batch_addrs,
const string &batch_label) {
auto iter = task_addr_offset_.find(addr); auto iter = task_addr_offset_.find(addr);
if (iter != task_addr_offset_.end()) { if (iter != task_addr_offset_.end()) {
auto &cur_pair = *iter; auto &cur_pair = *iter;


+ 7
- 1
ge/graph/load/new_model_manager/zero_copy_task.h View File

@@ -67,9 +67,12 @@ class ZeroCopyTask {
* @brief Set user data addr to Task param. * @brief Set user data addr to Task param.
* @param [in] addr: virtual address value from Op. * @param [in] addr: virtual address value from Op.
* @param [in] buffer_addr: data buffer_addr from user. * @param [in] buffer_addr: data buffer_addr from user.
* @param [in] batch_addrs: dynamic batch addr info.
* @param [in] batch_label: batch label.
* @return: 0 SUCCESS / others FAILED * @return: 0 SUCCESS / others FAILED
*/ */
ge::Status UpdateTaskParam(uintptr_t addr, void *buffer_addr);
ge::Status UpdateTaskParam(uintptr_t addr, void *buffer_addr, const map<string, set<uintptr_t>> &batch_addrs,
const string &batch_label);


/** /**
* @ingroup ge * @ingroup ge
@@ -88,6 +91,9 @@ class ZeroCopyTask {
return batch_label_; return batch_label_;
} }


protected:
bool CheckDynamicBatch(const map<string, set<uintptr_t>> &batch_addrs, const string &batch_label, uintptr_t addr);

private: private:
const string name_; const string name_;




+ 22
- 17
ge/graph/manager/graph_manager.cc View File

@@ -23,15 +23,25 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <thread> #include <thread>
#include <utility>


#include "common/ge/ge_util.h"
#include "common/math/math_util.h" #include "common/math/math_util.h"
#include "common/thread_pool.h" #include "common/thread_pool.h"
#include "common/util.h"
#include "external/graph/types.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h"
#include "framework/common/ge_types.h"
#include "analyzer/analyzer.h" #include "analyzer/analyzer.h"
#include "graph/common/ge_call_wrapper.h" #include "graph/common/ge_call_wrapper.h"
#include "graph/common/local_context.h" #include "graph/common/local_context.h"
#include "graph/common/transop_util.h" #include "graph/common/transop_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_context.h" #include "graph/ge_context.h"
#include "graph/ge_global_options.h" #include "graph/ge_global_options.h"
#include "graph/ge_local_context.h"
#include "graph/manager/graph_mem_allocator.h"
#include "graph/manager/util/rt_context_util.h" #include "graph/manager/util/rt_context_util.h"
#include "graph/partition/dynamic_shape_partition.h" #include "graph/partition/dynamic_shape_partition.h"
#include "graph/passes/enter_pass.h" #include "graph/passes/enter_pass.h"
@@ -51,6 +61,8 @@
#include "graph/passes/dimension_adjust_pass.h" #include "graph/passes/dimension_adjust_pass.h"
#include "graph/passes/dimension_compute_pass.h" #include "graph/passes/dimension_compute_pass.h"
#include "graph/passes/flow_ctrl_pass.h" #include "graph/passes/flow_ctrl_pass.h"
#include "graph/passes/hccl_group_pass.h"
#include "graph/passes/hccl_memcpy_pass.h"
#include "graph/passes/identity_pass.h" #include "graph/passes/identity_pass.h"
#include "graph/passes/input_output_connection_identify_pass.h" #include "graph/passes/input_output_connection_identify_pass.h"
#include "graph/passes/iterator_op_pass.h" #include "graph/passes/iterator_op_pass.h"
@@ -65,7 +77,7 @@
#include "graph/passes/permute_pass.h" #include "graph/passes/permute_pass.h"
#include "graph/passes/prune_pass.h" #include "graph/passes/prune_pass.h"
#include "graph/passes/ref_identity_delete_op_pass.h" #include "graph/passes/ref_identity_delete_op_pass.h"
#include "graph/passes/remove_same_const_pass.h"
#include "graph/passes/replace_with_empty_const_pass.h"
#include "graph/passes/reshape_recovery_pass.h" #include "graph/passes/reshape_recovery_pass.h"
#include "graph/passes/reshape_remove_pass.h" #include "graph/passes/reshape_remove_pass.h"
#include "graph/passes/same_transdata_breadth_fusion_pass.h" #include "graph/passes/same_transdata_breadth_fusion_pass.h"
@@ -75,12 +87,13 @@
#include "graph/passes/switch_logic_remove_pass.h" #include "graph/passes/switch_logic_remove_pass.h"
#include "graph/passes/switch_to_stream_switch_pass.h" #include "graph/passes/switch_to_stream_switch_pass.h"
#include "graph/passes/transop_breadth_fusion_pass.h" #include "graph/passes/transop_breadth_fusion_pass.h"
#include "graph/passes/transop_depth_fusion_pass.h"
#include "graph/passes/transop_nearby_allreduce_fusion_pass.h" #include "graph/passes/transop_nearby_allreduce_fusion_pass.h"
#include "graph/passes/transop_symmetry_elimination_pass.h" #include "graph/passes/transop_symmetry_elimination_pass.h"
#include "graph/passes/transop_without_reshape_fusion_pass.h" #include "graph/passes/transop_without_reshape_fusion_pass.h"
#include "graph/passes/transpose_transdata_pass.h" #include "graph/passes/transpose_transdata_pass.h"
#include "graph/passes/useless_control_out_remove_pass.h"
#include "graph/passes/variable_op_pass.h" #include "graph/passes/variable_op_pass.h"
#include "graph/passes/variable_prepare_op_pass.h"
#include "graph/passes/variable_ref_delete_op_pass.h" #include "graph/passes/variable_ref_delete_op_pass.h"
#include "graph/passes/variable_ref_useless_control_out_delete_pass.h" #include "graph/passes/variable_ref_useless_control_out_delete_pass.h"
#include "graph/passes/end_of_sequence_add_control_pass.h" #include "graph/passes/end_of_sequence_add_control_pass.h"
@@ -91,6 +104,9 @@
#include "graph/passes/memcpy_addr_async_pass.h" #include "graph/passes/memcpy_addr_async_pass.h"
#include "graph/build/label_allocator.h" #include "graph/build/label_allocator.h"
#include "graph/utils/tensor_adapter.h" #include "graph/utils/tensor_adapter.h"
#include "graph/utils/type_utils.h"
#include "graph/graph_util.h"
#include "graph/types.h"
#include "inc/pass_manager.h" #include "inc/pass_manager.h"
#include "init/gelib.h" #include "init/gelib.h"
#include "ir_build/atc_ir_common.h" #include "ir_build/atc_ir_common.h"
@@ -534,8 +550,7 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr
(void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy);
} }
std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this,
compute_graph->GetGraphID(), subgraph,
compute_graph->GetName(), session_id,
compute_graph->GetGraphID(), subgraph, compute_graph, session_id,
GetThreadLocalContext()); GetThreadLocalContext());
if (!f.valid()) { if (!f.valid()) {
GELOGE(FAILED, "Future is invalid"); GELOGE(FAILED, "Future is invalid");
@@ -550,8 +565,7 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr
(void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy);
} }
std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this,
compute_graph->GetGraphID(), subgraph,
compute_graph->GetName(), session_id,
compute_graph->GetGraphID(), subgraph, compute_graph, session_id,
GetThreadLocalContext()); GetThreadLocalContext());
if (!f.valid()) { if (!f.valid()) {
GELOGE(FAILED, "Future is invalid"); GELOGE(FAILED, "Future is invalid");
@@ -2134,7 +2148,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
TransposeTransDataPass transpose_transdata_pass; TransposeTransDataPass transpose_transdata_pass;
TransOpSymmetryEliminationPass symmetry_elimination_pass; TransOpSymmetryEliminationPass symmetry_elimination_pass;
DimensionComputePass dimension_compute_pass; DimensionComputePass dimension_compute_pass;
UselessControlOutRemovePass useless_control_out_remove_pass;
names_to_passes.emplace_back("EnterPass", &enter_pass); names_to_passes.emplace_back("EnterPass", &enter_pass);
names_to_passes.emplace_back("AddNPass", &addn_pass); names_to_passes.emplace_back("AddNPass", &addn_pass);
names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination);
@@ -2148,7 +2161,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass);
names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass);
names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass);
names_to_passes.emplace_back("UselessControlOutRemovePass", &useless_control_out_remove_pass);
GE_TIMESTAMP_START(names_to_passes); GE_TIMESTAMP_START(names_to_passes);
ret = GEPass(compute_graph).Run(names_to_passes); ret = GEPass(compute_graph).Run(names_to_passes);
GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2"); GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2");
@@ -2189,8 +2201,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass", GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass",
new (std::nothrow) VariableRefUselessControlOutDeletePass)) new (std::nothrow) VariableRefUselessControlOutDeletePass))
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass))
GE_CHK_STATUS_RET(
graph_pass.AddPass("OptimizeStage1_3::RemoveSameConstPass", new (std::nothrow) RemoveSameConstPass))
if (options_.train_graph_flag) { if (options_.train_graph_flag) {
// Priority: The GlobalStepInsertPass should work before graph partitioner. // Priority: The GlobalStepInsertPass should work before graph partitioner.
// Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory // Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory
@@ -2461,8 +2471,7 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra


Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id,
const SubGraphInfoPtr &sub_graph_info_ptr, const SubGraphInfoPtr &sub_graph_info_ptr,
const std::string &root_graph_name,
uint64_t session_id,
const ComputeGraphPtr &compute_graph, uint64_t session_id,
const GEThreadLocalContext &ge_context) { const GEThreadLocalContext &ge_context) {
if (sub_graph_info_ptr != nullptr && graph_manager != nullptr) { if (sub_graph_info_ptr != nullptr && graph_manager != nullptr) {
GetContext().SetSessionId(session_id); GetContext().SetSessionId(session_id);
@@ -2479,13 +2488,9 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager
GELOGE(FAILED, "Failed to set attr ATTR_NAME_ROOT_GRAPH_ID for subgraph, graph_id: %u.", root_graph_id); GELOGE(FAILED, "Failed to set attr ATTR_NAME_ROOT_GRAPH_ID for subgraph, graph_id: %u.", root_graph_id);
return FAILED; return FAILED;
} }
if (!AttrUtils::SetStr(*compute_graph_tmp, ATTR_NAME_ROOT_GRAPH_NAME, root_graph_name)) {
GELOGE(FAILED, "Failed to set attr ATTR_NAME_ROOT_GRAPH_NAME for subgraph, \
root_graph_name: %s.", root_graph_name.c_str());
return FAILED;
}
compute_graph_tmp->SetSessionID(session_id); compute_graph_tmp->SetSessionID(session_id);
Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp, Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp,
compute_graph,
engine_name); engine_name);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "SubGraph optimize Failed %s", engine_name.c_str()); GELOGE(ret, "SubGraph optimize Failed %s", engine_name.c_str());


+ 1
- 2
ge/graph/manager/graph_manager.h View File

@@ -219,8 +219,7 @@ class GraphManager {


static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id, static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id,
const SubGraphInfoPtr &sub_graph_info_ptr, const SubGraphInfoPtr &sub_graph_info_ptr,
const std::string &root_graph_name,
uint64_t session_id,
const ComputeGraphPtr &compute_graph, uint64_t session_id,
const GEThreadLocalContext &ge_context); const GEThreadLocalContext &ge_context);
Status ParseInputsDims(const std::vector<InputTensorInfo> &input_tensor); Status ParseInputsDims(const std::vector<InputTensorInfo> &input_tensor);
void ParseInputsDimsForData(const std::vector<InputTensorInfo> &input_tensor); void ParseInputsDimsForData(const std::vector<InputTensorInfo> &input_tensor);


+ 3
- 0
ge/graph/manager/graph_mem_allocator.cc View File

@@ -16,7 +16,10 @@


#include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_mem_allocator.h"


#include <set>
#include <string> #include <string>

#include "framework/common/debug/ge_log.h"
#include "graph/manager/graph_caching_allocator.h" #include "graph/manager/graph_caching_allocator.h"
#include "graph/manager/rdma_pool_allocator.h" #include "graph/manager/rdma_pool_allocator.h"




+ 6
- 1
ge/graph/optimize/graph_optimize.cc View File

@@ -76,7 +76,8 @@ void AddNodeInputProperty(ComputeGraphPtr &compute_graph) {
} }
} }


Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std::string &engine_name) {
Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const ComputeGraphPtr &parent_graph,
const std::string &engine_name) {
if (compute_graph == nullptr) { if (compute_graph == nullptr) {
GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeSubGraph]: compute_graph is nullptr."); GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeSubGraph]: compute_graph is nullptr.");
return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL;
@@ -105,6 +106,10 @@ Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std
for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) {
Status ret = (*iter)->OptimizeFusedGraphAfterGraphSlice(*(compute_graph)); Status ret = (*iter)->OptimizeFusedGraphAfterGraphSlice(*(compute_graph));
if (ret != SUCCESS) { if (ret != SUCCESS) {
auto root_graph = ge::GraphUtils::FindRootGraph(parent_graph);
if (root_graph != nullptr) {
ErrorManager::GetInstance().SaveMstuneCompileFailedMsg(root_graph->GetName());
}
GELOGE(ret, "[OptimizeSubGraph][OptimizeFusedGraphAfterGraphSlice]: graph optimize failed, ret:%d", ret); GELOGE(ret, "[OptimizeSubGraph][OptimizeFusedGraphAfterGraphSlice]: graph optimize failed, ret:%d", ret);
return ret; return ret;
} }


+ 2
- 1
ge/graph/optimize/graph_optimize.h View File

@@ -42,7 +42,8 @@ class GraphOptimize {
~GraphOptimize() = default; ~GraphOptimize() = default;


// subgraph optimize // subgraph optimize
Status OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std::string &engine_name);
Status OptimizeSubGraph(ComputeGraphPtr &compute_graph, const ComputeGraphPtr &parent_graph,
const std::string &engine_name);


// original graph optimize // original graph optimize
Status OptimizeOriginalGraph(ComputeGraphPtr &compute_graph); Status OptimizeOriginalGraph(ComputeGraphPtr &compute_graph);


+ 23
- 5
ge/graph/passes/attach_stream_label_pass.cc View File

@@ -18,8 +18,6 @@
#include "ge/ge_api_types.h" #include "ge/ge_api_types.h"
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"


using std::string;

namespace ge { namespace ge {
Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) {
GELOGD("AttachStreamLabelPass Enter."); GELOGD("AttachStreamLabelPass Enter.");
@@ -189,10 +187,21 @@ Status AttachStreamLabelPass::UpdateEnterNode() {
} }


std::stack<NodePtr> enter_nodes; std::stack<NodePtr> enter_nodes;
std::string batch_label;
for (const auto &enter_node : pair.second) { for (const auto &enter_node : pair.second) {
enter_nodes.emplace(enter_node); enter_nodes.emplace(enter_node);
std::string tmp_label;
(void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
if (!tmp_label.empty()) {
if (batch_label.empty()) {
batch_label = tmp_label;
} else if (batch_label != tmp_label) {
GELOGE(FAILED, "multi batch_label exist, label1=%s, label2=%s.", batch_label.c_str(), tmp_label.c_str());
return FAILED;
}
}
} }
if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) {
if (UpdateLoopBranch(enter_nodes, active_label_list[0], batch_label) != SUCCESS) {
GELOGE(FAILED, "Update stream_label for loop_branch failed."); GELOGE(FAILED, "Update stream_label for loop_branch failed.");
return FAILED; return FAILED;
} }
@@ -217,7 +226,10 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
} }


for (const auto &enter_node : enter_nodes) { for (const auto &enter_node : enter_nodes) {
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed.");
GE_CHECK_NOTNULL(enter_node->GetOpDesc());
if (enter_node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL)) {
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed.");
}
} }
return SUCCESS; return SUCCESS;
} }
@@ -229,7 +241,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
/// @param [in] batch_label /// @param [in] batch_label
/// @return Status /// @return Status
/// ///
Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const string &stream_label) {
Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label,
const std::string &batch_label) {
std::stack<NodePtr> nodes(enter_nodes); std::stack<NodePtr> nodes(enter_nodes);
NodePtr cur_node = nullptr; NodePtr cur_node = nullptr;
while (!nodes.empty()) { while (!nodes.empty()) {
@@ -238,6 +251,11 @@ Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_
for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { for (const NodePtr &out_node : cur_node->GetOutAllNodes()) {
OpDescPtr out_desc = out_node->GetOpDesc(); OpDescPtr out_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(out_desc); GE_CHECK_NOTNULL(out_desc);
std::string tmp_label;
(void)AttrUtils::GetStr(out_desc, ATTR_NAME_BATCH_LABEL, tmp_label);
if (!tmp_label.empty() && (tmp_label != batch_label)) {
continue;
}
std::string out_type = out_desc->GetType(); std::string out_type = out_desc->GetType();
bool need_skip = bool need_skip =
out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) || out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) ||


+ 3
- 1
ge/graph/passes/attach_stream_label_pass.h View File

@@ -58,9 +58,11 @@ class AttachStreamLabelPass : public GraphPass {
/// @brief Update stream_label for loop_branch /// @brief Update stream_label for loop_branch
/// @param [in] enter_nodes /// @param [in] enter_nodes
/// @param [in] stream_label /// @param [in] stream_label
/// @param [in] batch_label
/// @return Status /// @return Status
/// ///
static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label);
static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label,
const std::string &batch_label);


/// ///
/// @brief Update stream_label start with enter nodes /// @brief Update stream_label start with enter nodes


+ 1
- 1
ge/graph/passes/base_pass.cc View File

@@ -96,7 +96,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder
node->GetName().c_str(), node->GetType().c_str()); node->GetName().c_str(), node->GetType().c_str());
continue; continue;
} }
if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) {
if (node_to_re_pass->IsAllInNodesSeen(nodes_seen)) {
GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str()); GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str());
nodes_re_pass.insert(node_to_re_pass); nodes_re_pass.insert(node_to_re_pass);
} else { } else {


+ 0
- 64
ge/graph/passes/dimension_adjust_pass.cc View File

@@ -80,71 +80,7 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) {
} }
} }


ret = DealWithInNodes(node);
if (ret != SUCCESS) {
GELOGE(ret, "DealWithInNodes of %s failed.", node->GetName().c_str());
return ret;
}

std::vector<int> data_relink_io_map = {kDataInputIndex}; std::vector<int> data_relink_io_map = {kDataInputIndex};
return IsolateAndDeleteNode(node, data_relink_io_map); return IsolateAndDeleteNode(node, data_relink_io_map);
} }

Status DimensionAdjustPass::DealWithInNodes(NodePtr &node) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
auto graph = node->GetOwnerComputeGraph();
auto in_data_anchors = node->GetAllInDataAnchors();
for (auto &in_data_anchor : in_data_anchors) {
if (in_data_anchor == nullptr) {
continue;
}
auto in_node_anchor = in_data_anchor->GetPeerOutAnchor();
if (in_node_anchor == nullptr) {
continue;
}
auto in_node = in_node_anchor->GetOwnerNode();
if (in_node->GetType() == SWITCHN) {
auto identity_name = node->GetName() + "_ctrl_identity_" + std::to_string(in_data_anchor->GetIdx());
auto identity =
AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph);
GE_CHECK_NOTNULL(identity);
GELOGI("Create new identity node[%s] after node %s[type: %s] success.", identity->GetName().c_str(),
in_node->GetName().c_str(), in_node->GetType().c_str());
GE_CHK_STATUS_RET(GraphUtils::AddEdge(in_node_anchor, identity->GetInDataAnchor(0)))
GE_CHECK_NOTNULL(identity->GetOutControlAnchor());
if (identity->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor())) {
continue;
}
GE_CHK_STATUS_RET(GraphUtils::AddEdge(identity->GetOutControlAnchor(), node->GetInControlAnchor()))
}
}

return SUCCESS;
}

NodePtr DimensionAdjustPass::AddIdentityNodeToGraph(const string &name, const GeTensorDesc &tensor,
ComputeGraphPtr &graph) {
if (graph == nullptr) {
GELOGE(INTERNAL_ERROR, "Comput graph ptr is null in creating identity node.");
return nullptr;
}

OpDescPtr desc = MakeShared<OpDesc>("", "");
if (desc == nullptr) {
GELOGE(MEMALLOC_FAILED, "Failed to create op desc.");
return nullptr;
}

desc->SetName(name);
desc->SetType(IDENTITY);
auto ret = desc->AddInputDesc(tensor);
auto ret2 = desc->AddOutputDesc(tensor);
if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) {
GELOGE(INTERNAL_ERROR, "Failed to add input/output desc in creating identity.");
return nullptr;
}

return graph->AddNodeFront(desc);
}
} // namespace ge } // namespace ge

+ 0
- 4
ge/graph/passes/dimension_adjust_pass.h View File

@@ -34,10 +34,6 @@ namespace ge {
class DimensionAdjustPass : public BaseNodePass { class DimensionAdjustPass : public BaseNodePass {
public: public:
Status Run(ge::NodePtr &node) override; Status Run(ge::NodePtr &node) override;

private:
Status DealWithInNodes(ge::NodePtr &node);
NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tensor, ComputeGraphPtr &graph);
}; };
} // namespace ge } // namespace ge




+ 7
- 57
ge/graph/passes/enter_pass.cc View File

@@ -23,7 +23,6 @@


namespace { namespace {
const size_t kOutNodesNum = 1; const size_t kOutNodesNum = 1;
const size_t kInCtrlNodesNum = 1;
} }


namespace ge { namespace ge {
@@ -56,7 +55,6 @@ Status EnterPass::Run(NodePtr &node) {
if (out_ctrl_node == nullptr) { if (out_ctrl_node == nullptr) {
continue; continue;
} }
GELOGI("Remove control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str());
if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) { if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(), GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(),
out_ctrl_node->GetName().c_str()); out_ctrl_node->GetName().c_str());
@@ -64,12 +62,8 @@ Status EnterPass::Run(NodePtr &node) {
} }
} }
} else { } else {
if (OptimizeEnterWithOnlyDataOut(node, in_node) != SUCCESS) {
GELOGE(FAILED, "Optimize enter node[%s] with only out data node failed.", node->GetName().c_str());
return FAILED;
}
if (UnlinkCtrlEdgeBeforeConst(node) != SUCCESS) {
GELOGE(FAILED, "Unlink control edge before const of node[%s]'s out nodes failed.", node->GetName().c_str());
if (OptimizeEnter(node, in_node) != SUCCESS) {
GELOGE(FAILED, "Optimize enter node[%s] failed.", node->GetName().c_str());
return FAILED; return FAILED;
} }
} }
@@ -78,7 +72,7 @@ Status EnterPass::Run(NodePtr &node) {
return SUCCESS; return SUCCESS;
} }


Status EnterPass::OptimizeEnterWithOnlyDataOut(NodePtr &node, NodePtr &in_node) {
Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) {
if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) {
return SUCCESS; return SUCCESS;
} }
@@ -89,61 +83,17 @@ Status EnterPass::OptimizeEnterWithOnlyDataOut(NodePtr &node, NodePtr &in_node)
} }


GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0)))
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0)));
const auto &out_data_anchor = node->GetOutDataAnchor(0); const auto &out_data_anchor = node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(out_data_anchor); GE_CHECK_NOTNULL(out_data_anchor);
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor))
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor))
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor));
} }
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node))
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node));
AddNodeDeleted(node); AddNodeDeleted(node);
AddRePassNodesWithInOut(in_node); AddRePassNodesWithInOut(in_node);


return SUCCESS; return SUCCESS;
} }

Status EnterPass::UnlinkCtrlEdgeBeforeConst(NodePtr &node) {
auto out_ctrl_nodes = node->GetOutControlNodes();
if (out_ctrl_nodes.empty()) {
return SUCCESS;
}
auto out_ctrl_anchor = node->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_ctrl_anchor);

for (auto &out_ctrl_node : out_ctrl_nodes) {
GE_CHECK_NOTNULL(out_ctrl_node);
if ((out_ctrl_node->GetType() != CONSTANT) && (out_ctrl_node->GetType() != CONSTANTOP)) {
continue;
}
auto in_ctrl_nodes = out_ctrl_node->GetInControlNodes();
if (in_ctrl_nodes.size() != kInCtrlNodesNum) {
continue;
}

// Skip when has merge out
bool has_merge_out = false;
auto out_nodes_of_const = out_ctrl_node->GetOutAllNodes();
for (const auto &out_node_of_const : out_nodes_of_const) {
GE_CHECK_NOTNULL(out_node_of_const);
if (out_node_of_const->GetType() == MERGE || out_node_of_const->GetType() == REFMERGE) {
has_merge_out = true;
break;
}
}
if (has_merge_out) {
continue;
}

GELOGI("Unlink control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str());
GE_CHK_STATUS_RET(out_ctrl_anchor->Unlink(out_ctrl_node->GetInControlAnchor()))
for (auto &out_node_of_const : out_nodes_of_const) {
if (!out_ctrl_anchor->IsLinkedWith(out_node_of_const->GetInControlAnchor())) {
GELOGI("Link control edge from %s to %s.", node->GetName().c_str(), out_node_of_const->GetName().c_str());
GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(out_node_of_const->GetInControlAnchor()))
}
}
}
return SUCCESS;
}
} // namespace ge } // namespace ge

+ 1
- 2
ge/graph/passes/enter_pass.h View File

@@ -25,8 +25,7 @@ class EnterPass : public BaseNodePass {
Status Run(NodePtr &node) override; Status Run(NodePtr &node) override;


private: private:
Status OptimizeEnterWithOnlyDataOut(NodePtr &node, NodePtr &in_node);
Status UnlinkCtrlEdgeBeforeConst(NodePtr &node);
Status OptimizeEnter(NodePtr &node, NodePtr &in_node);
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_PASSES_ENTER_PASS_H_ #endif // GE_GRAPH_PASSES_ENTER_PASS_H_

+ 4
- 1
ge/graph/passes/folding_pass.cc View File

@@ -173,7 +173,10 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) {
continue; continue;
} }
auto in_node = in_node_anchor->GetOwnerNode(); auto in_node = in_node_anchor->GetOwnerNode();
if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH) || (in_node->GetType() == SWITCHN)) {
if (in_node == nullptr) {
continue;
}
if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) {
GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str());
auto ret = in_node_anchor->Unlink(in_data_anchor); auto ret = in_node_anchor->Unlink(in_data_anchor);
if (ret != SUCCESS) { if (ret != SUCCESS) {


+ 10
- 0
ge/graph/passes/merge_to_stream_merge_pass.cc View File

@@ -89,6 +89,16 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co
GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed");
} }


if (merge_op_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) {
string batch_label;
(void)AttrUtils::GetStr(merge_op_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (!batch_label.empty()) {
auto stream_merge_desc = stream_merge->GetOpDesc();
GE_CHECK_NOTNULL(stream_merge_desc);
(void)AttrUtils::SetStr(stream_merge_desc, ATTR_NAME_BATCH_LABEL, batch_label);
}
}

return AddActiveNodes(graph, stream_merge); return AddActiveNodes(graph, stream_merge);
} }




+ 173
- 89
ge/graph/passes/next_iteration_pass.cc View File

@@ -19,8 +19,6 @@
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"


using std::string;

namespace ge { namespace ge {
Status NextIterationPass::Run(ComputeGraphPtr graph) { Status NextIterationPass::Run(ComputeGraphPtr graph) {
GELOGD("NextIterationPass Enter"); GELOGD("NextIterationPass Enter");
@@ -37,6 +35,10 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) {
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
} }
if (GroupWithNoBatch(graph) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Group enter_nodes failed without batch_label attr.");
return INTERNAL_ERROR;
}


if (FindWhileGroups() != SUCCESS) { if (FindWhileGroups() != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Find while groups failed."); GELOGE(INTERNAL_ERROR, "Find while groups failed.");
@@ -71,22 +73,75 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) {
return FAILED; return FAILED;
} }


string batch_label;
if (ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
frame_name += batch_label;
std::string batch_label;
(void)ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (batch_label.empty()) {
auto frame_iter = frame_enter_map_.find(frame_name);
if (frame_iter == frame_enter_map_.end()) {
std::vector<NodePtr> enter_nodes;
enter_nodes.emplace_back(enter_node);
frame_enter_map_[frame_name] = enter_nodes;
} else {
frame_iter->second.emplace_back(enter_node);
}
return SUCCESS;
} }


auto iter = loop_group_map_.find(frame_name);
if (iter == loop_group_map_.end()) {
auto group_iter = loop_group_map_.find(frame_name);
if (group_iter == loop_group_map_.end()) {
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
if (loop_group == nullptr) { if (loop_group == nullptr) {
GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
return FAILED; return FAILED;
} }
loop_group->enter_nodes.emplace_back(enter_node); loop_group->enter_nodes.emplace_back(enter_node);
loop_group_map_[frame_name] = loop_group;
loop_group_map_[frame_name][batch_label] = loop_group;
} else { } else {
iter->second->enter_nodes.emplace_back(enter_node);
auto batch_iter = group_iter->second.find(batch_label);
if (batch_iter == group_iter->second.end()) {
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
if (loop_group == nullptr) {
GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
return FAILED;
}
loop_group->enter_nodes.emplace_back(enter_node);
group_iter->second[batch_label] = loop_group;
} else {
batch_iter->second->enter_nodes.emplace_back(enter_node);
}
}

return SUCCESS;
}

///
/// @brief Group Enter nodes without batch_label attr
/// @param [in] compute_graph
/// @return Status
///
Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) {
if (frame_enter_map_.empty()) {
GELOGI("All enter nodes in graph %s has batch_label attr.", graph->GetName().c_str());
return SUCCESS;
}
for (const auto &item : frame_enter_map_) {
const std::string &frame_name = item.first;
auto iter = loop_group_map_.find(frame_name);
if (iter == loop_group_map_.end()) {
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
if (loop_group == nullptr) {
GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
return FAILED;
}
loop_group->enter_nodes = item.second;
loop_group_map_[frame_name][""] = loop_group;
} else {
for (auto &batch_item : iter->second) {
for (const auto &enter_node : item.second) {
batch_item.second->enter_nodes.emplace_back(enter_node);
}
}
}
} }


return SUCCESS; return SUCCESS;
@@ -99,39 +154,55 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) {
Status NextIterationPass::FindWhileGroups() { Status NextIterationPass::FindWhileGroups() {
for (const auto &loop_group_iter : loop_group_map_) { for (const auto &loop_group_iter : loop_group_map_) {
const std::string &frame_name = loop_group_iter.first; const std::string &frame_name = loop_group_iter.first;
for (const auto &enter_node : loop_group_iter.second->enter_nodes) {
for (const auto &out_node : enter_node->GetOutAllNodes()) {
const string &type = out_node->GetType();
if ((type != MERGE) && (type != REFMERGE)) {
continue;
}

NodePtr next_node = nullptr;
if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s", frame_name.c_str());
return INTERNAL_ERROR;
}
loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node));

NodePtr switch_node = nullptr;
if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str());
return INTERNAL_ERROR;
}
if (switch_node == nullptr) {
continue;
}

NodePtr loop_cond = nullptr;
if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str());
return INTERNAL_ERROR;
}
if (loop_group_iter.second->loop_cond == nullptr) {
loop_group_iter.second->loop_cond = loop_cond;
} else if (loop_group_iter.second->loop_cond != loop_cond) {
GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str());
return FAILED;
for (const auto &batch_iter : loop_group_iter.second) {
const std::string &batch_label = batch_iter.first;
for (const auto &enter_node : batch_iter.second->enter_nodes) {
for (const auto &out_node : enter_node->GetOutAllNodes()) {
GELOGI("Find while_group for enter_node %s, frame_name:%s, batch_label:%s.", enter_node->GetName().c_str(),
frame_name.c_str(), batch_label.c_str());
if ((out_node->GetType() != MERGE) && (out_node->GetType() != REFMERGE)) {
continue;
}
std::string tmp_label;
GE_CHECK_NOTNULL(out_node->GetOpDesc());
(void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label));
if (need_skip) {
continue;
}

NodePtr next_node = nullptr;
if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR,
"Get NextIteration node failed: inputs of Merge should be Enter/NextIteration, current_Merge=%s",
out_node->GetName().c_str());
return INTERNAL_ERROR;
}
batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node));

NodePtr switch_node = nullptr;
if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get Switch node failed: output of Merge should be Switch, current_Merge=%s",
out_node->GetName().c_str());
return INTERNAL_ERROR;
}
if (switch_node == nullptr) {
continue;
}

NodePtr loop_cond = nullptr;
if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) {
GELOGE(INTERNAL_ERROR,
"Get LoopCond node failed: pred input of Switch should be LoopCond, current_Switch=%s",
switch_node->GetName().c_str());
return INTERNAL_ERROR;
}
if (batch_iter.second->loop_cond == nullptr) {
batch_iter.second->loop_cond = loop_cond;
} else if (batch_iter.second->loop_cond != loop_cond) {
GELOGE(FAILED, "Multi LoopCond nodes exist.");
return FAILED;
}
} }
} }
} }
@@ -152,17 +223,19 @@ bool NextIterationPass::VerifyWhileGroup() {
GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty.");
return false; return false;
} }
if (loop_group_iter.second->loop_cond == nullptr) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str());
return false;
}

for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) {
if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.",
frame_name.c_str());
for (const auto &batch_iter : loop_group_iter.second) {
if (batch_iter.second->loop_cond == nullptr) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str());
return false; return false;
} }

for (const auto &pair_iter : batch_iter.second->merge_next_pairs) {
if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.",
frame_name.c_str());
return false;
}
}
} }
} }


@@ -176,53 +249,56 @@ bool NextIterationPass::VerifyWhileGroup() {
/// ///
Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
for (const auto &loop_cond_iter : loop_group_map_) { for (const auto &loop_cond_iter : loop_group_map_) {
const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName();
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());

// Create Active node, Enter->Active->Merge, NextIteration->Active->Merge
NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE);
NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE);
if ((enter_active == nullptr) || (next_active == nullptr)) {
GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str());
return INTERNAL_ERROR;
}

for (const auto &enter_node : loop_cond_iter.second->enter_nodes) {
// Enter --> Active
if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge from %s to %s failed.", enter_node->GetName().c_str(),
enter_active->GetName().c_str());
for (const auto &batch_iter : loop_cond_iter.second) {
const std::string &cond_name = batch_iter.second->loop_cond->GetName();
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());

// Create Active node, Enter->Active->Merge, NextIteration->Active->Merge
NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE);
NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE);
if ((enter_active == nullptr) || (next_active == nullptr)) {
GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
}


for (const auto &pair : loop_cond_iter.second->merge_next_pairs) {
NodePtr merge_node = pair.first;
NodePtr next_node = pair.second;
// Active --> Merge
if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
for (const auto &enter_node : batch_iter.second->enter_nodes) {
// Enter --> Active
if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) !=
GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
}
} }


// NextIteration --> Active
if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
for (const auto &pair : batch_iter.second->merge_next_pairs) {
NodePtr merge_node = pair.first;
NodePtr next_node = pair.second;
// Active --> Merge
if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) !=
GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
}

// NextIteration --> Active
if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
}

// break link between NextIteration and Merge
if (BreakNextIteration(next_node, merge_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Break NextIteration failed");
return INTERNAL_ERROR;
}
} }


// break link between NextIteration and Merge
if (BreakNextIteration(next_node, merge_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Break NextIteration failed");
if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
(SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) {
GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed.");
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
} }

if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
(SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) {
GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed.");
return INTERNAL_ERROR;
}
} }


return SUCCESS; return SUCCESS;
@@ -289,11 +365,12 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr &
/// @param [in] node /// @param [in] node
/// @param [in] target_type /// @param [in] target_type
/// @param [in] is_input /// @param [in] is_input
/// @param [in] batch_label
/// @param [out] target_node /// @param [out] target_node
/// @return Status /// @return Status
/// ///
Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input,
NodePtr &target_node) {
const std::string &batch_label, NodePtr &target_node) {
if (node == nullptr) { if (node == nullptr) {
GELOGE(PARAM_INVALID, "node is null."); GELOGE(PARAM_INVALID, "node is null.");
return PARAM_INVALID; return PARAM_INVALID;
@@ -310,6 +387,12 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string
} }


for (const auto &tmp_node : nodes) { for (const auto &tmp_node : nodes) {
std::string tmp_label;
(void)AttrUtils::GetStr(tmp_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label));
if (need_skip) {
continue;
}
const std::string type = tmp_node->GetType(); const std::string type = tmp_node->GetType();
if ((target_type == LOOPCOND) && (type == target_type)) { if ((target_type == LOOPCOND) && (type == target_type)) {
target_node = tmp_node; target_node = tmp_node;
@@ -332,6 +415,7 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string
/// @return SUCCESS /// @return SUCCESS
/// ///
Status NextIterationPass::ClearStatus() { Status NextIterationPass::ClearStatus() {
frame_enter_map_.clear();
loop_group_map_.clear(); loop_group_map_.clear();
return SUCCESS; return SUCCESS;
} }


+ 13
- 3
ge/graph/passes/next_iteration_pass.h View File

@@ -47,6 +47,13 @@ class NextIterationPass : public GraphPass {
Status GroupEnterNode(const NodePtr &enter_node); Status GroupEnterNode(const NodePtr &enter_node);


/// ///
/// @brief Group Enter nodes without batch_label attr
/// @param [in] compute_graph
/// @return Status
///
Status GroupWithNoBatch(const ComputeGraphPtr &graph);

///
/// @brief Find while groups /// @brief Find while groups
/// @return Status /// @return Status
/// ///
@@ -90,10 +97,13 @@ class NextIterationPass : public GraphPass {
/// @param [out] target_node /// @param [out] target_node
/// @return Status /// @return Status
/// ///
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node);
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input,
const std::string &batch_label, NodePtr &target_node);


// map<frame_name, LoopCondGroup>
std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_;
// map<frame_name, vector<enter_node>>
std::unordered_map<std::string, std::vector<NodePtr>> frame_enter_map_;
// map<frame_name, map<batch_label, LoopCondGroup>>
std::unordered_map<std::string, std::unordered_map<std::string, LoopCondGroupPtr>> loop_group_map_;
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ #endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_

+ 0
- 106
ge/graph/passes/remove_same_const_pass.cc View File

@@ -1,106 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "remove_same_const_pass.h"

#include <sstream>
#include <string>
#include <set>

#include "common/base64.h"
#include "ge_local_engine/engine/host_cpu_engine.h"
#include "graph/utils/node_utils.h"

namespace ge {
namespace {
std::string GetCseKey(const NodePtr &node) {
std::stringstream ss;
ss << node->GetType() << "control-inputs-";
std::set<std::string> control_in_node_names;
for (auto &src_node : node->GetInControlNodes()) {
control_in_node_names.insert(src_node->GetName());
}
for (auto &name : control_in_node_names) {
ss << name << "-";
}

ss << "attrs-" << AttrUtils::GetAllAttrsStr(node->GetOpDesc());

return ss.str();
}

bool IsConstType(const NodePtr &node) { return (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP); }
} // namespace
Status RemoveSameConstPass::Run(ComputeGraphPtr graph) {
GELOGD("Begin to run RemoveSameConstPass on the graph");
GE_CHECK_NOTNULL(graph);
std::map<std::string, NodePtr> keys_to_node;
for (const auto &node : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
if (!IsConstType(node)) {
continue;
}
bool is_unknown = false;
auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown);
if (ret != GRAPH_SUCCESS) {
GELOGW("Get node unknown status failed, node name:%s, type:%s.",
node->GetName().c_str(), node->GetType().c_str());
continue;
}
if (is_unknown) {
GELOGI("Current node %s, type %s is unknown shape which should be skip.",
node->GetName().c_str(), node->GetType().c_str());
continue;
}
auto key = GetCseKey(node);
GELOGD("The const node %s cse key %s", node->GetName().c_str(), ge::base64::EncodeToBase64(key).c_str());
auto iter = keys_to_node.find(key);
if (iter == keys_to_node.end()) {
keys_to_node[key] = node;
continue;
}

if (node->GetAllOutDataAnchorsSize() != iter->second->GetAllOutDataAnchorsSize()) {
GELOGW("The const node %s and %s have the same CSE key, but different output anchor count, skip to fusion them",
iter->second->GetName().c_str(), node->GetName().c_str());
continue;
}

std::vector<int> output_map(node->GetAllOutDataAnchorsSize());
for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
output_map[i] = i;
}

ret = GraphUtils::ReplaceNodeAnchors(iter->second, node, {}, output_map);
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to replace node %s by node %s", node->GetName().c_str(),
iter->second->GetName().c_str(), ret);
return INTERNAL_ERROR;
}

NodeUtils::UnlinkAll(*node);

ret = GraphUtils::RemoveNodeWithoutRelink(graph, node);
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to remove node %s from graph", node->GetName().c_str());
return INTERNAL_ERROR;
}

GELOGI("Remove const node %s by RemoveSameConstPass, replace it with node %s", node->GetName().c_str(),
iter->second->GetName().c_str());
}
return SUCCESS;
}
} // namespace ge

+ 0
- 28
ge/graph/passes/remove_same_const_pass.h View File

@@ -1,28 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_
#define GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_

#include "graph/types.h"
#include "inc/graph_pass.h"

namespace ge {
class RemoveSameConstPass : public GraphPass {
public:
Status Run(ge::ComputeGraphPtr graph) override ;
};
} // namespace ge
#endif //GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_

+ 8
- 4
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -17,8 +17,13 @@
#include "graph/passes/switch_to_stream_switch_pass.h" #include "graph/passes/switch_to_stream_switch_pass.h"
#include <stack> #include <stack>
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "framework/common/ge_inner_error_codes.h"
#include "framework/common/types.h"
#include "ge/ge_api_types.h" #include "ge/ge_api_types.h"
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_context.h" #include "graph/ge_context.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"


@@ -120,13 +125,12 @@ void SwitchToStreamSwitchPass::MarkCycleDependence(
if (visited.count(tmp_node) > 0) { if (visited.count(tmp_node) > 0) {
continue; continue;
} }
GELOGD("MarkCycleDependence: tmp_node=%s.", tmp_node->GetName().c_str());
for (const NodePtr &out_node : tmp_node->GetOutAllNodes()) { for (const NodePtr &out_node : tmp_node->GetOutAllNodes()) {
if (switch_nodes.find(out_node) == switch_nodes.end()) { if (switch_nodes.find(out_node) == switch_nodes.end()) {
out_nodes.push(out_node); out_nodes.push(out_node);
continue; continue;
} }
GELOGD("MarkCycleDependence: tmp_node=%s, switch_node=%s.",
tmp_node->GetName().c_str(), out_node->GetName().c_str());
GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS, GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS,
GELOGW("set cyclic dependence attr failed."); return ); GELOGW("set cyclic dependence attr failed."); return );
auto map_iter = switch_cyclic_map_.find(out_node); auto map_iter = switch_cyclic_map_.find(out_node);
@@ -598,7 +602,7 @@ Status SwitchToStreamSwitchPass::AddConstNode(const ComputeGraphPtr &graph, cons
/// ///
Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node,
const std::set<NodePtr> &same_cond_switch) { const std::set<NodePtr> &same_cond_switch) {
GELOGD("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(),
GELOGI("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(),
cast_node->GetName().c_str()); cast_node->GetName().c_str());
std::string orig_switch_name = switch_node->GetName(); std::string orig_switch_name = switch_node->GetName();
OpDescPtr switch_desc = switch_node->GetOpDesc(); OpDescPtr switch_desc = switch_node->GetOpDesc();
@@ -649,7 +653,7 @@ Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_no
/// ///
Status SwitchToStreamSwitchPass::ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, Status SwitchToStreamSwitchPass::ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch,
const NodePtr &active_node) { const NodePtr &active_node) {
GELOGD("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(),
GELOGI("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(),
stream_switch->GetName().c_str(), active_node->GetName().c_str()); stream_switch->GetName().c_str(), active_node->GetName().c_str());
auto find_res = switch_node_map_.find(switch_node); auto find_res = switch_node_map_.find(switch_node);
GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), { GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), {


+ 0
- 51
ge/graph/passes/useless_control_out_remove_pass.cc View File

@@ -1,51 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/useless_control_out_remove_pass.h"

#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"

namespace ge {
Status UselessControlOutRemovePass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);

if ((node->GetType() != CONSTANT) && (node->GetType() != CONSTANTOP)) {
return SUCCESS;
}
GELOGD("UselessControlOutRemovePass running, node: %s.", node->GetName().c_str());

// const has no control input
if (node->GetInControlNodes().empty()) {
if (node->GetOutDataNodes().empty()) {
// It is an isolated const, just remove it.
GELOGI("Delete isolated const: %s.", node->GetName().c_str());
GE_CHK_STATUS_RET(IsolateAndDeleteNode(node, {}))
AddNodeDeleted(node);
} else {
auto out_ctrl_anchor = node->GetOutControlAnchor();
if (out_ctrl_anchor != nullptr && !out_ctrl_anchor->GetPeerAnchors().empty()) {
GELOGI("Node: %s unlink all out control edge.", node->GetName().c_str());
out_ctrl_anchor->UnlinkAll();
}
}
}

return SUCCESS;
}
} // namespace ge

+ 0
- 29
ge/graph/passes/useless_control_out_remove_pass.h View File

@@ -1,29 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_
#define GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_

#include "graph/passes/base_pass.h"

namespace ge {
class UselessControlOutRemovePass : public BaseNodePass {
public:
Status Run(NodePtr &node) override;
};
} // namespace ge

#endif // GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_

+ 59
- 343
ge/graph/preprocess/multi_batch_copy_graph.cc View File

@@ -44,8 +44,6 @@
using std::set; using std::set;
using std::string; using std::string;
using std::vector; using std::vector;
using std::map;
using std::queue;


namespace ge { namespace ge {
namespace multibatch { namespace multibatch {
@@ -59,15 +57,10 @@ const int kDataInIndex = 0;
const int kMergeDataOutIndex = 0; const int kMergeDataOutIndex = 0;
const int kStaticOutput = -1; const int kStaticOutput = -1;
const int kDivisionConst = 2; const int kDivisionConst = 2;
const int32_t kOneInDataNode = 1;
const int32_t kFindNoMatch = 0;




inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); }


inline bool IsEnterType(const string &node_type) { return (node_type == ENTER) || (node_type == REFENTER); }
const set<string> unchange_types({CONSTANT, CONSTANTOP, ENTER, REFENTER});

inline bool IsGetNextType(const NodePtr &node) { inline bool IsGetNextType(const NodePtr &node) {
std::string original_type; std::string original_type;
GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS,
@@ -225,6 +218,12 @@ Status MultiBatchGraphCopyer::CopyGraph() {
return ret; return ret;
} }


ret = InsertIdentityAfterSwitchN();
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to insert identity nodes after switchn node.");
return INTERNAL_ERROR;
}

GELOGI("Begin to remove useless nodes by prune pass after copy process"); GELOGI("Begin to remove useless nodes by prune pass after copy process");
PrunePass prune_pass; PrunePass prune_pass;
ret = prune_pass.Run(graph_); ret = prune_pass.Run(graph_);
@@ -241,18 +240,6 @@ Status MultiBatchGraphCopyer::Init() {
return ret; return ret;
} }


ret = RelinkConstCtrlEdge();
if (ret != SUCCESS) {
GELOGE(FAILED, "Relink const's control edge failed.");
return FAILED;
}

ret = ExtractUnchangedStructureOutofCycle();
if (ret != SUCCESS) {
GELOGE(FAILED, "Extract unchanged structure out of cycle failed.");
return FAILED;
}

for (auto &node : graph_->GetAllNodes()) { for (auto &node : graph_->GetAllNodes()) {
origin_all_nodes_.emplace_back(node); origin_all_nodes_.emplace_back(node);
if (IsDataLikeType(node->GetType())) { if (IsDataLikeType(node->GetType())) {
@@ -265,281 +252,6 @@ Status MultiBatchGraphCopyer::Init() {
return SUCCESS; return SUCCESS;
} }


Status MultiBatchGraphCopyer::RelinkConstCtrlEdge() {
for (auto &node : graph_->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) {
if (node->GetOutDataNodes().empty()) {
continue;
}
if (!node->GetInControlNodes().empty()) {
auto in_ctrl_nodes = node->GetInControlNodes();
auto out_nodes = node->GetOutAllNodes();
bool has_merge_out = false;
for (const auto &out_node : out_nodes) {
GE_CHECK_NOTNULL(out_node);
if (out_node->GetType() == MERGE || out_node->GetType() == REFMERGE) {
has_merge_out = true;
break;
}
}
if (has_merge_out) {
continue;
}
auto in_ctrl_anchor = node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor);
in_ctrl_anchor->UnlinkAll();
for (auto &in_ctrl_node : in_ctrl_nodes) {
auto out_ctrl_anchor_of_in_ctrl_node = in_ctrl_node->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_ctrl_anchor_of_in_ctrl_node);
for (auto &out_node : out_nodes) {
if (IsEnterType(out_node->GetType())) {
continue;
}
if (!out_ctrl_anchor_of_in_ctrl_node->IsLinkedWith(out_node->GetInControlAnchor())) {
GE_CHK_STATUS_RET(out_ctrl_anchor_of_in_ctrl_node->LinkTo(out_node->GetInControlAnchor()))
}
}
}
}
auto out_ctrl_anchor = node->GetOutControlAnchor();
if (out_ctrl_anchor != nullptr) {
out_ctrl_anchor->UnlinkAll();
}
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::ExtractUnchangedStructureOutofCycle() {
map<string, vector<NodePtr>> frame_enter;
if (GetEnterNodesGroupByFrame(frame_enter) != SUCCESS) {
GELOGE(FAILED, "Get enter nodes grouped by frame_name failed.");
return FAILED;
}

queue<NodePtr> nodes_to_extract;
if (GetNodeNeedExtract(frame_enter, nodes_to_extract) != SUCCESS) {
GELOGE(FAILED, "Get nodes needed to extract failed.");
return FAILED;
}

while (!nodes_to_extract.empty()) {
auto node = nodes_to_extract.front();
nodes_to_extract.pop();
OpDescPtr enter_desc = nullptr;
if (MoveInEntersInDataAnchorDown(node, enter_desc) != SUCCESS) {
GELOGE(FAILED, "Move in enter nodes' in data anchors down of %s failed.", node->GetName().c_str());
return FAILED;
}
set<NodePtr> out_nodes;
if (InsertEnterAfterNode(node, enter_desc, out_nodes) != SUCCESS) {
GELOGE(FAILED, "Insert enter node after %s failed.", node->GetName().c_str());
return FAILED;
}

if (MoveCtrlEdgeToOutNodes(node, out_nodes) != SUCCESS) {
GELOGE(FAILED, "Move %s's control edge to out nodes failed.", node->GetName().c_str());
return FAILED;
}

for (auto &out_node : out_nodes) {
GE_CHECK_NOTNULL(out_node);
if (AllInDataNodesUnchangeAndNoMergeOut(out_node)) {
nodes_to_extract.push(out_node);
}
}
}

if (DeleteEnterWithoutDataOut() != SUCCESS) {
GELOGE(FAILED, "Delete enter node without out data nodes failed.");
return FAILED;
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::GetEnterNodesGroupByFrame(map<string, vector<NodePtr>> &frame_enter) {
for (auto &node : graph_->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
if (IsEnterType(node->GetType())) {
if (!node->GetInControlNodes().empty() || !node->GetOutControlNodes().empty()) {
continue;
}
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
string frame_name;
if (!AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) {
GELOGE(FAILED, "Get attr frame_name of enter[%] failed.", node->GetName().c_str());
return FAILED;
}
frame_enter[frame_name].emplace_back(node);
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::GetNodeNeedExtract(const map<string, vector<NodePtr>> &frame_enter,
queue<NodePtr> &nodes_to_extract) {
for (const auto &one_group : frame_enter) {
auto enters = one_group.second;
for (const auto &enter : enters) {
auto out_data_nodes = enter->GetOutDataNodes();
for (const auto &out_data_node : out_data_nodes) {
GE_CHECK_NOTNULL(out_data_node);
if (AllInDataNodesUnchangeAndNoMergeOut(out_data_node)) {
nodes_to_extract.push(out_data_node);
}
}
}
}

return SUCCESS;
}

bool MultiBatchGraphCopyer::AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node) {
auto out_data_nodes = node->GetOutDataNodes();
for (const auto &out_data_node : out_data_nodes) {
if (out_data_node == nullptr) {
return false;
}

if (out_data_node->GetType() == MERGE || out_data_node->GetType() == REFMERGE) {
return false;
}
}

auto in_data_nodes = node->GetInDataNodes();
if (in_data_nodes.size() == kOneInDataNode) {
return true;
}

for (const auto &in_data_node : in_data_nodes) {
if (in_data_node == nullptr) {
return false;
}
if (unchange_types.count(in_data_node->GetType()) == kFindNoMatch) {
return false;
}
}

return true;
}

Status MultiBatchGraphCopyer::MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc) {
auto in_data_anchors = node->GetAllInDataAnchors();
for (auto &in_data_anchor : in_data_anchors) {
auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_data_anchor);
auto peer_in_data_node = peer_out_data_anchor->GetOwnerNode();
if (IsEnterType(peer_in_data_node->GetType())) {
GE_CHK_STATUS_RET(peer_out_data_anchor->Unlink(in_data_anchor))
GELOGD("Unlink data edge from %s to %s.", peer_in_data_node->GetName().c_str(), node->GetName().c_str());
auto enter_in_data_anchors = peer_in_data_node->GetAllInDataAnchors();
for (auto &enter_in_data_anchor : enter_in_data_anchors) {
auto peer_out_data_anchor_of_enter = enter_in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_data_anchor_of_enter);
if (peer_out_data_anchor_of_enter->IsLinkedWith(in_data_anchor)) {
continue;
}
GE_CHK_STATUS_RET(peer_out_data_anchor_of_enter->LinkTo(in_data_anchor))
GELOGD("Relink data edge from %s to %s.", peer_out_data_anchor_of_enter->GetOwnerNode()->GetName().c_str(),
node->GetName().c_str());
}
enter_desc = peer_in_data_node->GetOpDesc();
GE_CHECK_NOTNULL(enter_desc);
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::InsertEnterAfterNode(NodePtr &node, const OpDescPtr &copy_desc, set<NodePtr> &out_nodes) {
if (copy_desc == nullptr) {
return SUCCESS;
}
map<OutDataAnchorPtr, vector<std::pair<InDataAnchorPtr, NodePtr>>> outanchors_inanchors_nodes;
auto out_data_anchors = node->GetAllOutDataAnchors();
for (auto &out_data_anchor : out_data_anchors) {
auto peer_in_data_anchors = out_data_anchor->GetPeerInDataAnchors();
for (auto peer_in_data_anchor : peer_in_data_anchors) {
GE_CHECK_NOTNULL(peer_in_data_anchor);
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode();
out_nodes.emplace(peer_in_data_node);
outanchors_inanchors_nodes[out_data_anchor].emplace_back(std::make_pair(peer_in_data_anchor, peer_in_data_node));
}
}

int32_t i = 0;
auto node_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(node_desc);
// Insert one enter node after node's per out data anchor
for (auto &outanchor_inanchors_nodes : outanchors_inanchors_nodes) {
string name = node->GetName() + "_" + ENTER + "_" + std::to_string(i++);
GELOGD("Create Enter op %s after %s.", name.c_str(), node->GetName().c_str());
auto enter_desc = AttrUtils::CopyOpDesc(copy_desc);
enter_desc->SetName(name);
GE_CHK_STATUS_RET(
enter_desc->UpdateInputDesc("x", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx())))
GE_CHK_STATUS_RET(
enter_desc->UpdateOutputDesc("y", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx())))
auto enter_node = graph_->AddNode(enter_desc);
GE_CHECK_NOTNULL(enter_node);
GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->LinkTo(enter_node->GetInDataAnchor(kDataInIndex)))
GE_CHECK_NOTNULL(enter_node->GetOutDataAnchor(kDataInIndex));
for (auto &inanchor_node : outanchor_inanchors_nodes.second) {
GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->Unlink(inanchor_node.first))
GE_CHK_STATUS_RET(enter_node->GetOutDataAnchor(kDataInIndex)->LinkTo(inanchor_node.first))
GELOGD("Unlink from %s to %s, link from %s to %s then to %s.", node->GetName().c_str(),
inanchor_node.second->GetName().c_str(), node->GetName().c_str(), enter_node->GetName().c_str(),
inanchor_node.second->GetName().c_str());
}
}

return SUCCESS;
}

// Move node's in control edges to out data nodes
Status MultiBatchGraphCopyer::MoveCtrlEdgeToOutNodes(NodePtr &node, set<NodePtr> &out_nodes) {
auto in_ctrl_anchor = node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor);
auto peer_out_ctrl_anchors = in_ctrl_anchor->GetPeerOutControlAnchors();
for (auto &peer_out_ctrl_anchor : peer_out_ctrl_anchors) {
GE_CHK_STATUS_RET(peer_out_ctrl_anchor->Unlink(in_ctrl_anchor))
GELOGD("Unlink control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
node->GetName().c_str());
for (auto &out_node : out_nodes) {
auto in_ctrl_anchor_of_out_node = out_node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor_of_out_node);
if (!peer_out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor_of_out_node)) {
GE_CHK_STATUS_RET(peer_out_ctrl_anchor->LinkTo(in_ctrl_anchor_of_out_node))
GELOGD("Link control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
out_node->GetName().c_str());
}
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::DeleteEnterWithoutDataOut() {
for (auto &node : graph_->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
if (IsEnterType(node->GetType())) {
auto out_nodes = node->GetOutAllNodes();
if (out_nodes.empty()) {
GELOGD("Delete enter node: %s which has no output.", node->GetName().c_str());
GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {}))
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node))
}
}
}

return SUCCESS;
}

void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) { void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) {
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(), GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(),
@@ -585,9 +297,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() {
LabelStatusForGetNextSink(data); LabelStatusForGetNextSink(data);
} }
} }

map<string, vector<NodePtr>> frame_enters;
InitStatus(frame_enters);
bool changed = true; bool changed = true;
// If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch
while (changed) { while (changed) {
@@ -597,13 +306,12 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() {
if (iter != origin_nodes_status_.end()) { if (iter != origin_nodes_status_.end()) {
continue; continue;
} }
for (auto &in_node : node->GetInDataNodes()) {
if (origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end()) {
if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) {
origin_nodes_status_[node.get()] == kNodeInBatchBranch;
ResetEnterStatus(frame_enters, node);
changed = true;
}
for (auto &in_node : node->GetInAllNodes()) {
bool is_in_batch = origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() &&
origin_nodes_status_[in_node.get()] == kNodeInBatchBranch;
if (is_in_batch) {
origin_nodes_status_[node.get()] = kNodeInBatchBranch;
changed = true;
break; break;
} }
} }
@@ -612,45 +320,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() {
return SUCCESS; return SUCCESS;
} }


void MultiBatchGraphCopyer::InitStatus(map<string, vector<NodePtr>> &frame_enters) {
for (const auto &node : origin_all_nodes_) {
if (!IsEnterType(node->GetType())) {
continue;
}
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
continue;
}
string frame_name;
if (AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) {
frame_enters[frame_name].emplace_back(node);
}
}

for (const auto &data : origin_data_nodes_) {
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
if (!IsAllDimsPositive(data_shape.GetDims())) {
origin_nodes_status_[data.get()] = kNodeInBatchBranch;
}
}
}

void MultiBatchGraphCopyer::ResetEnterStatus(map<string, vector<NodePtr>> &frame_enters, const NodePtr &node) {
if (!IsEnterType(node->GetType())) {
return;
}

for (const auto &frame_enter : frame_enters) {
auto &enters = frame_enter.second;
if (std::find(enters.begin(), enters.end(), node) != enters.end()) {
for (const auto &enter : enters) {
origin_nodes_status_[enter.get()] = kNodeInBatchBranch;
}
break;
}
}
}

Status MultiBatchGraphCopyer::LabelStatus() { Status MultiBatchGraphCopyer::LabelStatus() {
if (LabelInBatchBranchStatus() != SUCCESS) { if (LabelInBatchBranchStatus() != SUCCESS) {
GELOGE(PARAM_INVALID, "Failed to label no in batch branch"); GELOGE(PARAM_INVALID, "Failed to label no in batch branch");
@@ -1691,6 +1360,52 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) {
return SUCCESS; return SUCCESS;
} }


Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() {
for (auto &node : graph_->GetAllNodes()) {
if (node->GetType() != SWITCHN) {
continue;
}
auto switchn_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(switchn_desc);
size_t i = 0;
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
for (auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
auto out_node = in_data_anchor->GetOwnerNode();
auto op_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if ((out_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) {
GELOGD("No need to insert identity between %s and %s.", node->GetName().c_str(), out_node->GetName().c_str());
continue;
}

auto identity_desc = MakeShared<OpDesc>(node->GetName() + "_identity_" + std::to_string(i), IDENTITY);
GE_CHECK_NOTNULL(identity_desc);

string batch_label;
if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
if (!AttrUtils::SetStr(identity_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", identity_desc->GetName().c_str());
return FAILED;
}
}

auto data_desc = switchn_desc->GetOutputDesc(i);
i++;
GE_CHK_STATUS_RET(identity_desc->AddInputDesc("x", data_desc));
GE_CHK_STATUS_RET(identity_desc->AddOutputDesc("y", data_desc));

auto identity_node = graph_->AddNode(identity_desc);
GE_CHECK_NOTNULL(identity_node);
GE_CHK_STATUS_RET(out_data_anchor->LinkTo(identity_node->GetInDataAnchor(0)));
GE_CHECK_NOTNULL(identity_node->GetOutControlAnchor());
GE_CHK_STATUS_RET(identity_node->GetOutControlAnchor()->LinkTo(out_node->GetInControlAnchor()));
}
}
}

return SUCCESS;
}

Status ProcessMultiBatch(ComputeGraphPtr &graph) { Status ProcessMultiBatch(ComputeGraphPtr &graph) {
if (GetLocalOmgContext().dynamic_node_type.empty()) { if (GetLocalOmgContext().dynamic_node_type.empty()) {
const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN");
@@ -1700,6 +1415,7 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) {
return pass_manager.Run(graph); return pass_manager.Run(graph);
} }
} }

if (!GetLocalOmgContext().need_multi_batch) { if (!GetLocalOmgContext().need_multi_batch) {
GELOGI("No need to process_multi for no_train graph."); GELOGI("No need to process_multi for no_train graph.");
return SUCCESS; return SUCCESS;


+ 1
- 15
ge/graph/preprocess/multi_batch_copy_graph.h View File

@@ -18,7 +18,6 @@
#include <map> #include <map>
#include <queue> #include <queue>
#include <vector> #include <vector>
#include <set>


#include "external/ge/ge_api_error_codes.h" #include "external/ge/ge_api_error_codes.h"


@@ -65,26 +64,12 @@ class MultiBatchGraphCopyer {
private: private:
Status Init(); Status Init();
Status CheckArguments(); Status CheckArguments();
Status RelinkConstCtrlEdge();

Status ExtractUnchangedStructureOutofCycle();
Status GetEnterNodesGroupByFrame(std::map<std::string, std::vector<NodePtr>> &frame_enter);
Status GetNodeNeedExtract(const std::map<std::string, std::vector<NodePtr>> &frame_enter,
std::queue<NodePtr> &nodes_to_extract);
bool AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node);
Status MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc);
Status InsertEnterAfterNode(NodePtr &node, const OpDescPtr &enter_desc, std::set<NodePtr> &out_nodes);
Status MoveCtrlEdgeToOutNodes(NodePtr &node, std::set<NodePtr> &out_nodes);
Status DeleteEnterWithoutDataOut();


// label status for origin_all_nodes_ // label status for origin_all_nodes_
Status LabelStatus(); Status LabelStatus();
Status LabelInBatchBranchStatus(); Status LabelInBatchBranchStatus();
void LabelStatusForData(const NodePtr &data); void LabelStatusForData(const NodePtr &data);
void LabelStatusForGetNextSink(const NodePtr &data); void LabelStatusForGetNextSink(const NodePtr &data);
void InitStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters);
void ResetEnterStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters, const NodePtr &node);

// add nodes functions // add nodes functions
Status CreateNewNodes(); Status CreateNewNodes();


@@ -96,6 +81,7 @@ class MultiBatchGraphCopyer {
Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index, Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index,
std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn); std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn);


Status InsertIdentityAfterSwitchN();
Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index); Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index);
Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index); Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index);




+ 0
- 2
ge/hybrid/executor/hybrid_execution_context.h View File

@@ -22,7 +22,6 @@
#include "common/blocking_queue.h" #include "common/blocking_queue.h"
#include "common/properties_manager.h" #include "common/properties_manager.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "graph/ge_local_context.h"
#include "hybrid/common/npu_memory_allocator.h" #include "hybrid/common/npu_memory_allocator.h"
#include "hybrid/common/tensor_value.h" #include "hybrid/common/tensor_value.h"
#include "hybrid/executor/hybrid_profiler.h" #include "hybrid/executor/hybrid_profiler.h"
@@ -39,7 +38,6 @@ struct GraphExecutionContext {


uint64_t session_id = 0; uint64_t session_id = 0;
const HybridModel *model = nullptr; const HybridModel *model = nullptr;
const GEThreadLocalContext *ge_context = nullptr;
rtStream_t stream = nullptr; rtStream_t stream = nullptr;
rtContext_t rt_context = nullptr; rtContext_t rt_context = nullptr;
rtContext_t rt_gen_context = nullptr; rtContext_t rt_gen_context = nullptr;


+ 0
- 1
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -95,7 +95,6 @@ Status HybridModelExecutor::InitExecutionContext() {
context_.stream = stream_; context_.stream = stream_;
context_.model = model_; context_.model = model_;
context_.session_id = ::ge::GetContext().SessionId(); context_.session_id = ::ge::GetContext().SessionId();
context_.ge_context = &GetThreadLocalContext();
GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id);
context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_); context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_);
GE_CHECK_NOTNULL(context_.allocator); GE_CHECK_NOTNULL(context_.allocator);


+ 22
- 34
ge/hybrid/executor/node_state.cc View File

@@ -18,7 +18,6 @@
#include <chrono> #include <chrono>
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
#include "graph/compute_graph.h" #include "graph/compute_graph.h"
#include "graph/utils/tensor_utils.h"
#include "hybrid_execution_context.h" #include "hybrid_execution_context.h"
#include "subgraph_context.h" #include "subgraph_context.h"


@@ -36,31 +35,29 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(
this->num_pending_shapes_); this->num_pending_shapes_);
} }


Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) {
Status ShapeInferenceState::UpdateInputShape(int idx,
const GeShape &ori_shape,
const GeShape &shape) {
if (node_item.IsInputShapeStatic(idx)) { if (node_item.IsInputShapeStatic(idx)) {
GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]", GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]",
node_item.NodeName().c_str(), node_item.NodeName().c_str(),
idx, idx,
node_item.MutableInputDesc(idx)->GetShape().ToString().c_str(), node_item.MutableInputDesc(idx)->GetShape().ToString().c_str(),
target.GetShape().ToString().c_str());
shape.ToString().c_str());
return SUCCESS; return SUCCESS;
} }


int64_t tensor_size = -1;
(void) TensorUtils::GetSize(target, tensor_size);
GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s], size = %ld",
GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s]",
node_item.NodeName().c_str(), node_item.NodeName().c_str(),
idx, idx,
target.GetShape().ToString().c_str(),
target.GetOriginShape().ToString().c_str(),
tensor_size);
shape.ToString().c_str(),
ori_shape.ToString().c_str());


std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
auto tensor_desc = node_item.MutableInputDesc(idx); auto tensor_desc = node_item.MutableInputDesc(idx);
GE_CHECK_NOTNULL(tensor_desc); GE_CHECK_NOTNULL(tensor_desc);
tensor_desc->SetShape(target.GetShape());
tensor_desc->SetOriginShape(target.GetOriginShape());
(void) TensorUtils::SetSize(*tensor_desc, tensor_size);
tensor_desc->SetShape(shape);
tensor_desc->SetOriginShape(ori_shape);
if (--num_pending_shapes_ == 0) { if (--num_pending_shapes_ == 0) {
ready_cv_.notify_all(); ready_cv_.notify_all();
} }
@@ -113,24 +110,24 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex
for (auto &p : shape_futures) { for (auto &p : shape_futures) {
auto idx = p.first; auto idx = p.first;
auto &future = p.second; auto &future = p.second;
GeShape shape;
GeShape ori_shape;
RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx);
auto src_tensor_desc = future.GetTensorDesc();
GE_CHECK_NOTNULL(src_tensor_desc);
GE_CHK_STATUS_RET(future.Get(ori_shape, shape),
"[%s] Get shape failed. index = %u",
node_item.NodeName().c_str(),
idx);
RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx);


auto input_desc = node_item.MutableInputDesc(idx);
GE_CHECK_NOTNULL(input_desc);
int64_t tensor_size = -1;
(void) TensorUtils::GetSize(*src_tensor_desc, tensor_size);
GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu",
GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s]",
node_item.NodeName().c_str(), node_item.NodeName().c_str(),
idx, idx,
src_tensor_desc->GetShape().ToString().c_str(),
src_tensor_desc->GetOriginShape().ToString().c_str(),
tensor_size);
input_desc->SetShape(src_tensor_desc->GetShape());
input_desc->SetOriginShape(src_tensor_desc->GetOriginShape());
(void) TensorUtils::SetSize(*input_desc, tensor_size);
shape.ToString().c_str(),
ori_shape.ToString().c_str());
auto input_desc = node_item.MutableInputDesc(idx);
GE_CHECK_NOTNULL(input_desc);
input_desc->SetShape(std::move(shape));
input_desc->SetOriginShape(ori_shape);
} }


return SUCCESS; return SUCCESS;
@@ -193,14 +190,5 @@ Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) {
GELOGD("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); GELOGD("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str());
return SUCCESS; return SUCCESS;
} }

GeTensorDescPtr ShapeFuture::GetTensorDesc() {
GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str());
if (!subgraph_context_->Await(src_node_)) {
GELOGE(INTERNAL_ERROR, "cancelled");
return nullptr;
}
return src_node_->GetOpDesc()->MutableOutputDesc(src_index_);
}
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge

+ 1
- 2
ge/hybrid/executor/node_state.h View File

@@ -35,7 +35,6 @@ class ShapeFuture {
ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context); ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context);
~ShapeFuture() = default; ~ShapeFuture() = default;
Status Get(GeShape &ori_shape, GeShape &shape); Status Get(GeShape &ori_shape, GeShape &shape);
GeTensorDescPtr GetTensorDesc();


private: private:
NodePtr src_node_; NodePtr src_node_;
@@ -46,7 +45,7 @@ class ShapeFuture {
struct ShapeInferenceState { struct ShapeInferenceState {
explicit ShapeInferenceState(const NodeItem &node_item); explicit ShapeInferenceState(const NodeItem &node_item);


Status UpdateInputShape(int idx, const GeTensorDesc &tensor_desc);
Status UpdateInputShape(int idx, const GeShape &ori_shape, const GeShape &shape);


void UpdateInputShapeFuture(int idx, ShapeFuture &&future); void UpdateInputShapeFuture(int idx, ShapeFuture &&future);




+ 8
- 1
ge/hybrid/executor/subgraph_executor.cc View File

@@ -96,7 +96,7 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue
GE_CHECK_NOTNULL(tensor_desc); GE_CHECK_NOTNULL(tensor_desc);
auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); auto node_state = subgraph_context_->GetOrCreateNodeState(input_node);
GE_CHECK_NOTNULL(node_state); GE_CHECK_NOTNULL(node_state);
node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc);
node_state->GetShapeInferenceState().UpdateInputShape(0, tensor_desc->GetOriginShape(), tensor_desc->GetShape());
} }
} }


@@ -268,6 +268,13 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta
} else { } else {
node_state.SetKernelTask(node_item.kernel_task); node_state.SetKernelTask(node_item.kernel_task);
} }

GELOGD("[%s] Start to invoke CalcOpRunningParam.", node_item.NodeName().c_str());
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start");
GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().CalcOpRunningParam(*node_item.node),
"[%s] Failed to invoke CalcOpRunningParam.", node_item.NodeName().c_str());
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] End");
GELOGD("[%s] Done invoking CalcOpRunningParam successfully.", node_item.NodeName().c_str());
return SUCCESS; return SUCCESS;
} }




+ 12
- 16
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -20,9 +20,12 @@
#include "graph/utils/tensor_adapter.h" #include "graph/utils/tensor_adapter.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "hybrid/node_executor/node_executor.h" #include "hybrid/node_executor/node_executor.h"
#include "hybrid/executor//worker//shape_inference_engine.h"
#include "common/dump/dump_manager.h"
#include "common/dump/dump_op.h" #include "common/dump/dump_op.h"
#include "common/types.h"
#include "common/ge_types.h"
#include "common/profiling/profiling_manager.h" #include "common/profiling/profiling_manager.h"
#include "runtime/base.h"


namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
@@ -151,19 +154,18 @@ Status NodeDoneCallback::GetTaskDescInfo(const NodePtr node, const HybridModel *
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(model); GE_CHECK_NOTNULL(model);


// only report aicpu and aicore node
bool is_profiling_report = context_->GetNodeItem().is_profiling_report;
if (!is_profiling_report) {
GELOGD("Node[%s] is not aicore or aicpu, and no need to report data.", node->GetName().c_str());
return SUCCESS;
}

GELOGD("GetTaskDescInfo of node [%s] start.", node->GetName().c_str()); GELOGD("GetTaskDescInfo of node [%s] start.", node->GetName().c_str());
auto op_desc = node->GetOpDesc(); auto op_desc = node->GetOpDesc();
std::string op_name = op_desc->GetName(); std::string op_name = op_desc->GetName();
std::string dynamic_model_name = model->GetModelName(); std::string dynamic_model_name = model->GetModelName();
uint32_t task_id = context_->GetTaskId();
uint32_t stream_id = context_->GetStreamId();

uint32_t task_id = 0;
uint32_t stream_id = 0;
if (rtGetTaskIdAndStreamID(&task_id, &stream_id) != RT_ERROR_NONE) {
GELOGE(PARAM_INVALID, "Get task_id and stream_id failed.");
return PARAM_INVALID;
}

TaskDescInfo tmp_task_desc_info; TaskDescInfo tmp_task_desc_info;
tmp_task_desc_info.model_name = dynamic_model_name; tmp_task_desc_info.model_name = dynamic_model_name;
tmp_task_desc_info.op_name = op_name; tmp_task_desc_info.op_name = op_name;
@@ -175,8 +177,6 @@ Status NodeDoneCallback::GetTaskDescInfo(const NodePtr node, const HybridModel *
} }
tmp_task_desc_info.task_id = task_id; tmp_task_desc_info.task_id = task_id;
tmp_task_desc_info.stream_id = stream_id; tmp_task_desc_info.stream_id = stream_id;
tmp_task_desc_info.shape_type = "dynamic";
tmp_task_desc_info.cur_iter_num = graph_context_->iteration;
GELOGD("GetTaskDescInfo of node [%s] end, task_id[%u], stream_id[%u]", GELOGD("GetTaskDescInfo of node [%s] end, task_id[%u], stream_id[%u]",
node->GetName().c_str(), task_id, stream_id); node->GetName().c_str(), task_id, stream_id);
task_desc_info.emplace_back(tmp_task_desc_info); task_desc_info.emplace_back(tmp_task_desc_info);
@@ -348,10 +348,6 @@ Status NodeDoneCallback::OnNodeDone() {
} }


GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item)); GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item));
if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE) {
// update output tensor sizes
GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(node_item));
}
// PropagateOutputs for type == DEPEND_COMPUTE // PropagateOutputs for type == DEPEND_COMPUTE
if (node_item.shape_inference_type == DEPEND_COMPUTE) { if (node_item.shape_inference_type == DEPEND_COMPUTE) {
if (graph_context_->trace_enabled) { if (graph_context_->trace_enabled) {


+ 18
- 103
ge/hybrid/executor/worker/shape_inference_engine.cc View File

@@ -17,15 +17,9 @@
#include "hybrid/executor/worker/shape_inference_engine.h" #include "hybrid/executor/worker/shape_inference_engine.h"
#include "graph/shape_refiner.h" #include "graph/shape_refiner.h"
#include "graph/utils/node_utils.h" #include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "common/math/math_util.h"
#include "hybrid/node_executor/node_executor.h" #include "hybrid/node_executor/node_executor.h"


namespace ge { namespace ge {
namespace {
const int kAlignment = 32;
}
namespace hybrid { namespace hybrid {
ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *execution_context, SubgraphContext *subgraph_context) ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *execution_context, SubgraphContext *subgraph_context)
: execution_context_(execution_context), : execution_context_(execution_context),
@@ -46,9 +40,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) {
} }


if (node_item.fused_subgraph != nullptr) { if (node_item.fused_subgraph != nullptr) {
GE_CHK_STATUS_RET_NOLOG(InferShapeForSubgraph(node_item, *node_item.fused_subgraph));
GE_CHK_STATUS_RET_NOLOG(CalcOutputTensorSizes(node_item));
return SUCCESS;
return InferShapeForSubgraph(node_item, *node_item.fused_subgraph);
} }


// Skip shape inference for node of type DEPEND_COMPUTE // Skip shape inference for node of type DEPEND_COMPUTE
@@ -71,15 +63,21 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) {
std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start");
GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true),
"Invoke InferShapeAndType failed.");
"Invoke InferShapeAndType failed.");
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End");
} }
// Check again to make sure shape is valid after shape inference
if (node_item.shape_inference_type != DEPEND_SHAPE_RANGE) {
bool is_unknown_shape = false;
GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node_item.node, is_unknown_shape),
"Failed to get shape status. node = %s",
node_item.NodeName().c_str());


// update output tensor sizes after shape inference
// error if shape is still unknown and not of type DEPEND_SHAPE_RANGE
RECORD_COMPILE_EVENT(execution_context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start");
GE_CHK_STATUS_RET_NOLOG(CalcOutputTensorSizes(node_item, node_item.shape_inference_type == DEPEND_SHAPE_RANGE));
RECORD_COMPILE_EVENT(execution_context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] End");
GE_CHK_BOOL_RET_STATUS(!is_unknown_shape,
INTERNAL_ERROR,
"[%s] Shape is still unknown after shape inference.",
node_item.NodeName().c_str());
}


GELOGD("[%s] [HybridTrace] After shape inference. Node = %s", GELOGD("[%s] [HybridTrace] After shape inference. Node = %s",
node_item.NodeName().c_str(), node_item.NodeName().c_str(),
@@ -129,6 +127,8 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) {
// propagate each output // propagate each output
for (int i = 0; i < node_item.num_outputs; ++i) { for (int i = 0; i < node_item.num_outputs; ++i) {
auto output_desc = node_item.op_desc->MutableOutputDesc(i); auto output_desc = node_item.op_desc->MutableOutputDesc(i);
const auto &shape = output_desc->MutableShape();
const auto &ori_shape = output_desc->GetOriginShape();
auto &output_nodes = node_item.outputs[i]; auto &output_nodes = node_item.outputs[i];


// propagate output to all sub-inputs // propagate output to all sub-inputs
@@ -149,7 +149,9 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) {
infer_state.UpdateInputShapeFuture(dst_input_index_and_node.first, infer_state.UpdateInputShapeFuture(dst_input_index_and_node.first,
std::move(future)); std::move(future));
} else { } else {
GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first, *output_desc));
GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first,
ori_shape,
shape));
} }
} }
} }
@@ -228,92 +230,5 @@ Status ShapeInferenceEngine::UpdatePeerNodeShape(const Node &node) {
} }
return SUCCESS; return SUCCESS;
} }

Status ShapeInferenceEngine::CanonicalizeShape(GeTensorDesc &tensor_desc,
std::vector<int64_t> &shape,
bool fallback_with_range) {
const auto &tensor_shape = tensor_desc.MutableShape();
if (tensor_shape.IsUnknownShape()) {
if (!fallback_with_range) {
GELOGE(INTERNAL_ERROR, "Output shape is still unknown after shape inference. shape = [%s]",
tensor_shape.ToString().c_str());
return INTERNAL_ERROR;
}

GELOGD("Calc output size by range");
std::vector<std::pair<int64_t, int64_t>> shape_range;
GE_CHK_GRAPH_STATUS_RET(tensor_desc.GetShapeRange(shape_range), "Failed to get shape range");
if (shape_range.size() != shape.size()) {
GELOGE(INTERNAL_ERROR, "Number of shape ranges (%zu) mismatches that of dims (%zu)",
shape_range.size(),
shape.size());
return INTERNAL_ERROR;
}

for (size_t dim_index = 0; dim_index < shape.size(); ++dim_index) {
if (shape[dim_index] == ge::UNKNOWN_DIM) {
shape[dim_index] = shape_range[dim_index].second;
}
}

GELOGD("After canonicalization, shape = [%s], before = [%s]",
GeShape(shape).ToString().c_str(),
tensor_shape.ToString().c_str());
}

return SUCCESS;
}

Status ShapeInferenceEngine::CalcTensorSize(DataType data_type,
const std::vector<int64_t> &shape,
int64_t &tensor_size) {
GELOGD("To calc tensor size by shape = [%s]", GeShape(shape).ToString().c_str());
uint32_t type_size;
if (!TypeUtils::GetDataTypeLength(data_type, type_size)) {
GELOGE(INTERNAL_ERROR, "Failed to get data type size");
return INTERNAL_ERROR;
}

tensor_size = type_size;
for (const auto &dim : shape) {
GE_CHECK_GE(dim, 0);
GE_CHK_STATUS_RET(Int64MulCheckOverflow(tensor_size, dim),
"Shape size overflow, shape = [%s]",
GeShape(shape).ToString().c_str());
tensor_size *= dim;
}

GE_CHK_STATUS_RET(CheckInt64AddOverflow(tensor_size, kAlignment - 1),
"Tensor size is too large: %ld, shape = [%s]",
tensor_size,
GeShape(shape).ToString().c_str());
tensor_size = (tensor_size + kAlignment - 1) / kAlignment * kAlignment;
return SUCCESS;
}

Status ShapeInferenceEngine::CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range) {
auto op_desc = node_item.GetOpDesc();
for (size_t output_index = 0; output_index < op_desc->GetOutputsSize(); ++output_index) {
auto tensor_desc = op_desc->MutableOutputDesc(output_index);
GE_CHECK_NOTNULL(tensor_desc);
const auto &shape = tensor_desc->MutableShape();
// modify on copy
auto dims = shape.GetDims();
GE_CHK_STATUS_RET(CanonicalizeShape(*tensor_desc, dims, fallback_with_range),
"[%s] Failed to canonicalize shape for output %zu",
node_item.NodeName().c_str(),
output_index);

int64_t tensor_size;
GE_CHK_STATUS_RET(CalcTensorSize(tensor_desc->GetDataType(), dims, tensor_size),
"[%s] Failed to calc tensor size for output %zu",
node_item.NodeName().c_str(),
output_index);
GELOGD("[%s] Tensor size of output %zu = %ld", node_item.NodeName().c_str(), output_index, tensor_size);
(void) TensorUtils::SetSize(*tensor_desc, tensor_size);
}

return SUCCESS;
}
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge

+ 0
- 4
ge/hybrid/executor/worker/shape_inference_engine.h View File

@@ -34,11 +34,7 @@ class ShapeInferenceEngine {


Status PropagateOutputShapes(const NodeItem &node_item); Status PropagateOutputShapes(const NodeItem &node_item);


static Status CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range = false);

private: private:
static Status CanonicalizeShape(GeTensorDesc &tensor_desc, std::vector<int64_t> &shape, bool fallback_with_range);
static Status CalcTensorSize(DataType data_type, const std::vector<int64_t> &shape, int64_t &tensor_size);
static Status UpdatePeerNodeShape(const Node &node); static Status UpdatePeerNodeShape(const Node &node);
Status AwaitDependentNodes(NodeState &node_state); Status AwaitDependentNodes(NodeState &node_state);




+ 0
- 3
ge/hybrid/executor/worker/task_compile_engine.cc View File

@@ -26,9 +26,6 @@ Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *
RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start");
GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context));


if (context->ge_context != nullptr) {
GetThreadLocalContext() = *context->ge_context;
}
shared_ptr<NodeTask> kernel_task; shared_ptr<NodeTask> kernel_task;
auto ret = node_item.node_executor->CompileTask(*context->model, node_item.node, kernel_task); auto ret = node_item.node_executor->CompileTask(*context->model, node_item.node, kernel_task);
RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "[Compile] End"); RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "[Compile] End");


+ 1
- 4
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -226,10 +226,7 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n
new_node->node_id = node_index; new_node->node_id = node_index;
new_node->op_desc->SetId(node_index); new_node->op_desc->SetId(node_index);
node_index += 1; node_index += 1;
NodeExecutorManager::ExecutorType executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node);
new_node->is_profiling_report = (executor_type == NodeExecutorManager::ExecutorType::AICORE) ||
(executor_type == NodeExecutorManager::ExecutorType::AICPU_TF) ||
(executor_type == NodeExecutorManager::ExecutorType::AICPU_CUSTOM);

*node_item = new_node.get(); *node_item = new_node.get();
node_items[node] = std::move(new_node); node_items[node] = std::move(new_node);
return SUCCESS; return SUCCESS;


+ 34
- 57
ge/hybrid/model/node_item.cc View File

@@ -22,7 +22,6 @@
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/utils/node_utils.h" #include "graph/utils/node_utils.h"
#include "hybrid/node_executor/node_executor.h" #include "hybrid/node_executor/node_executor.h"
#include "hybrid/executor/worker/shape_inference_engine.h"


namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
@@ -48,7 +47,7 @@ Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgr
GE_CHECK_NOTNULL(dst_op_desc); GE_CHECK_NOTNULL(dst_op_desc);
auto in_idx = node_and_anchor.second->GetIdx(); auto in_idx = node_and_anchor.second->GetIdx();
auto tensor_desc = dst_op_desc->MutableInputDesc(in_idx); auto tensor_desc = dst_op_desc->MutableInputDesc(in_idx);
fused_subgraph.input_mapping[static_cast<int>(parent_index)].emplace_back(tensor_desc);
fused_subgraph.input_mapping[parent_index].emplace_back(tensor_desc);
GELOGD("Input[%u] mapped to [%s:%u]", parent_index, dst_op_desc->GetName().c_str(), in_idx); GELOGD("Input[%u] mapped to [%s:%u]", parent_index, dst_op_desc->GetName().c_str(), in_idx);
} }


@@ -65,7 +64,7 @@ Status ParseOutputMapping(const OpDescPtr &op_desc, FusedSubgraph &fused_subgrap
return FAILED; return FAILED;
} }


fused_subgraph.output_mapping.emplace(static_cast<int>(parent_index), op_desc);
fused_subgraph.output_mapping.emplace(parent_index, op_desc);
return SUCCESS; return SUCCESS;
} }


@@ -127,7 +126,12 @@ Status NodeItem::Create(const NodePtr &node, std::unique_ptr<NodeItem> &node_ite
return SUCCESS; return SUCCESS;
} }


void NodeItem::ResolveOptionalInputs() {
Status NodeItem::Init() {
GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX);
GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX);
num_inputs = static_cast<int>(op_desc->GetInputsSize());
num_outputs = static_cast<int>(op_desc->GetOutputsSize());

if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) { if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) {
has_optional_inputs = true; has_optional_inputs = true;
for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) {
@@ -139,18 +143,7 @@ void NodeItem::ResolveOptionalInputs() {
} }
} }
} }
}


Status NodeItem::InitInputsAndOutputs() {
GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX);
GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX);
num_inputs = static_cast<int>(op_desc->GetInputsSize());
num_outputs = static_cast<int>(op_desc->GetOutputsSize());
ResolveOptionalInputs();
return SUCCESS;
}

Status NodeItem::ResolveDynamicState() {
(void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic); (void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic);
GELOGD("node name = %s, is_dynamic = %d.", this->node_name.c_str(), is_dynamic); GELOGD("node name = %s, is_dynamic = %d.", this->node_name.c_str(), is_dynamic);
if (!is_dynamic) { if (!is_dynamic) {
@@ -158,54 +151,38 @@ Status NodeItem::ResolveDynamicState() {
"[%s] Failed to get shape status.", "[%s] Failed to get shape status.",
node->GetName().c_str()); node->GetName().c_str());
} }
return SUCCESS;
}


Status NodeItem::ResolveStaticInputsAndOutputs() {
for (int i = 0; i < num_inputs; ++i) {
const auto &input_desc = MutableInputDesc(i);
GE_CHECK_NOTNULL(input_desc);
if (input_desc->MutableShape().IsUnknownShape()) {
is_input_shape_static_.push_back(false);
} else {
num_static_input_shapes++;
is_input_shape_static_.push_back(true);
GELOGD("[%s] The shape of input[%d] is static. shape = [%s]",
NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str());
if (is_dynamic) {
for (int i = 0; i < num_inputs; ++i) {
const auto &input_desc = MutableInputDesc(i);
GE_CHECK_NOTNULL(input_desc);
if (input_desc->MutableShape().IsUnknownShape()) {
is_input_shape_static_.push_back(false);
} else {
num_static_input_shapes++;
is_input_shape_static_.push_back(true);
GELOGD("[%s] The shape of input[%d] is static. shape = [%s]",
NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str());
}
} }
}


for (int i = 0; i < num_outputs; ++i) {
const auto &output_desc = op_desc->MutableOutputDesc(i);
GE_CHECK_NOTNULL(output_desc);
if (output_desc->MutableShape().IsUnknownShape()) {
is_output_shape_static = false;
break;
for (int i = 0; i < num_outputs; ++i) {
const auto &output_desc = op_desc->MutableOutputDesc(i);
GE_CHECK_NOTNULL(output_desc);
if (output_desc->MutableShape().IsUnknownShape()) {
is_output_shape_static = false;
break;
}
} }
}

if (is_output_shape_static) {
GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(*this));
}
return SUCCESS;
}


void NodeItem::ResolveUnknownShapeType() {
if (IsControlOp() || node_type == PARTITIONEDCALL) {
shape_inference_type = DEPEND_COMPUTE;
} else {
int32_t unknown_shape_type_val = 0;
(void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val);
shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val);
}
}
if (IsControlOp() || node_type == PARTITIONEDCALL) {
shape_inference_type = DEPEND_COMPUTE;
} else {
int32_t unknown_shape_type_val = 0;
(void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val);
shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val);
}


Status NodeItem::Init() {
GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs());
GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState());
if (is_dynamic) {
ResolveUnknownShapeType();
GE_CHK_STATUS_RET_NOLOG(ResolveStaticInputsAndOutputs());
GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str()); GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str());
} }




+ 0
- 6
ge/hybrid/model/node_item.h View File

@@ -99,16 +99,10 @@ struct NodeItem {
std::map<int, int> reuse_inputs; std::map<int, int> reuse_inputs;
std::map<int, int> reuse_outputs; std::map<int, int> reuse_outputs;
int num_static_input_shapes = 0; int num_static_input_shapes = 0;
bool is_profiling_report = false;


private: private:
explicit NodeItem(NodePtr node); explicit NodeItem(NodePtr node);
Status Init(); Status Init();
Status InitInputsAndOutputs();
void ResolveOptionalInputs();
Status ResolveDynamicState();
Status ResolveStaticInputsAndOutputs();
void ResolveUnknownShapeType();


std::vector<bool> is_input_shape_static_; std::vector<bool> is_input_shape_static_;
std::vector<uint32_t> input_desc_indices_; std::vector<uint32_t> input_desc_indices_;


+ 0
- 10
ge/hybrid/node_executor/aicore/aicore_node_executor.cc View File

@@ -165,16 +165,6 @@ Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function<void()>
} }
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] Start"); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] Start");
GE_CHK_STATUS_RET_NOLOG((*it)->LaunchKernel(context.GetStream())); GE_CHK_STATUS_RET_NOLOG((*it)->LaunchKernel(context.GetStream()));
uint32_t task_id = 0;
uint32_t stream_id = 0;
rtError_t rt_ret = rtGetTaskIdAndStreamID(&task_id, &stream_id);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(rt_ret, "Get task_id and stream_id failed.");
return rt_ret;
}
context.SetTaskId(task_id);
context.SetStreamId(stream_id);
GELOGD("AiCore node[%s] task_id: %u, stream_id: %u.", context.GetNodeName(), task_id, stream_id);
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End");
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End");
} }


+ 0
- 11
ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc View File

@@ -189,17 +189,6 @@ Status AicpuNodeTaskBase::ExecuteAsync(TaskContext &context, std::function<void(


GE_CHK_STATUS_RET(LaunchTask(context)); GE_CHK_STATUS_RET(LaunchTask(context));


uint32_t task_id = 0;
uint32_t stream_id = 0;
rtError_t rt_ret = rtGetTaskIdAndStreamID(&task_id, &stream_id);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(rt_ret, "Get task_id and stream_id failed.");
return rt_ret;
}
context.SetTaskId(task_id);
context.SetStreamId(stream_id);
GELOGD("AiCpu node[%s] task_id: %u, stream_id: %u.", context.GetNodeName(), task_id, stream_id);

auto callback = [=, &context]() { auto callback = [=, &context]() {
GELOGD("Node[%s] callback start.", node_name_.c_str()); GELOGD("Node[%s] callback start.", node_name_.c_str());
RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[TaskCallback] Start"); RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[TaskCallback] Start");


+ 0
- 38
ge/hybrid/node_executor/task_context.cc View File

@@ -148,10 +148,6 @@ Status TaskContext::AllocateWorkspaces() {
} }


Status TaskContext::RegisterCallback(const std::function<void()> &callback_fun) const { Status TaskContext::RegisterCallback(const std::function<void()> &callback_fun) const {
if (callback_fun == nullptr) {
GELOGW("[%s] Callback is NULL", GetNodeName());
return SUCCESS;
}
auto ret = execution_context_->callback_manager->RegisterCallback(callback_fun); auto ret = execution_context_->callback_manager->RegisterCallback(callback_fun);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[%s] Failed to register callback", GetNodeName()); GELOGE(ret, "[%s] Failed to register callback", GetNodeName());
@@ -319,22 +315,6 @@ void TaskContext::SetStatus(Status status) {
} }
} }


uint32_t TaskContext::GetTaskId() const {
return task_id_;
}

void TaskContext::SetTaskId(uint32_t task_id) {
task_id_ = task_id;
}

uint32_t TaskContext::GetStreamId() const {
return stream_id_;
}

void TaskContext::SetStreamId(uint32_t stream_id) {
stream_id_ = stream_id;
}

Status TaskContext::AllocateWorkspace(size_t size, void **buffer, void *ori_addr) { Status TaskContext::AllocateWorkspace(size_t size, void **buffer, void *ori_addr) {
GE_CHECK_NOTNULL(buffer); GE_CHECK_NOTNULL(buffer);
if (ori_addr == nullptr) { if (ori_addr == nullptr) {
@@ -404,20 +384,6 @@ const char *TaskContext::GetNodeName() const {
return node_item_->NodeName().c_str(); return node_item_->NodeName().c_str();
} }


void TaskContext::ReleaseInputsAndOutputs() {
for (int i = 0; i < node_item_->num_inputs; ++i) {
auto tensor = inputs_start_ + i;
tensor->Destroy();
GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), i);
}

for (int i = 0; i < node_item_->num_outputs; ++i) {
auto tensor = outputs_start_ + i;
tensor->Destroy();
GELOGD("[%s] Tensor of output[%d] released", GetNodeName(), i);
}
}

void TaskContext::ReleaseInput(int index) { void TaskContext::ReleaseInput(int index) {
auto input_tensor = MutableInput(index); auto input_tensor = MutableInput(index);
if (input_tensor != nullptr) { if (input_tensor != nullptr) {
@@ -490,9 +456,5 @@ Status TaskContext::TryExecuteCallback(const function<void()> &callback_fun) con
const DumpProperties &TaskContext::GetDumpProperties() const { const DumpProperties &TaskContext::GetDumpProperties() const {
return execution_context_->dump_properties; return execution_context_->dump_properties;
} }

bool TaskContext::NeedCallback() {
return node_item_->has_observer || IsDumpEnabled() || execution_context_->profiling_level > 0;
}
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge

+ 0
- 10
ge/hybrid/node_executor/task_context.h View File

@@ -50,8 +50,6 @@ class TaskContext {
ConstGeTensorDescPtr GetOutputDesc(int index) const; ConstGeTensorDescPtr GetOutputDesc(int index) const;
GeTensorDescPtr MutableInputDesc(int index) const; GeTensorDescPtr MutableInputDesc(int index) const;
GeTensorDescPtr MutableOutputDesc(int index) const; GeTensorDescPtr MutableOutputDesc(int index) const;
void ReleaseInputsAndOutputs();
bool NeedCallback();
void ReleaseInput(int index); void ReleaseInput(int index);
const TensorValue *GetInput(int index) const; const TensorValue *GetInput(int index) const;
const TensorValue *GetOutput(int index) const; const TensorValue *GetOutput(int index) const;
@@ -96,12 +94,6 @@ class TaskContext {


void SetStatus(Status status); void SetStatus(Status status);


uint32_t GetTaskId() const;
void SetTaskId(uint32_t task_id);

uint32_t GetStreamId() const;
void SetStreamId(uint32_t stream_id);

bool IsForceInferShape() const; bool IsForceInferShape() const;
void SetForceInferShape(bool force_infer_shape); void SetForceInferShape(bool force_infer_shape);
void *handle_ = nullptr; void *handle_ = nullptr;
@@ -123,8 +115,6 @@ class TaskContext {
Status status_ = SUCCESS; Status status_ = SUCCESS;
std::vector<void *> workspaces_; std::vector<void *> workspaces_;
uint64_t iteration_ = 0; uint64_t iteration_ = 0;
uint32_t task_id_= 0;
uint32_t stream_id_ = 0;
}; };
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge


+ 3
- 4
ge/ir_build/atc_ir_common.cc View File

@@ -63,19 +63,18 @@ vector<string> SplitInputShape(const std::string &input_shape) {
} }
} // namespace } // namespace


Status CheckInputFormat(const string &input_format) {
Status CheckInputFormat(const std::string &input_format) {
if (input_format.empty()) { if (input_format.empty()) {
return ge::SUCCESS; return ge::SUCCESS;
} }
if (!ge::TypeUtils::IsFormatValid(input_format.c_str())) { if (!ge::TypeUtils::IsFormatValid(input_format.c_str())) {
ErrorManager::GetInstance().ATCReportErrMessage( ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, "input format is invalid!"});
GELOGE(ge::PARAM_INVALID, "input format [%s] is invalid!", input_format.c_str());
"E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, "input format not found"});
GELOGE(ge::PARAM_INVALID, "user input format [%s] is not found!", input_format.c_str());
return ge::PARAM_INVALID; return ge::PARAM_INVALID;
} }
return ge::SUCCESS; return ge::SUCCESS;
} }

bool CheckDynamicBatchSizeInputShapeValid(unordered_map<string, vector<int64_t>> shape_map, bool CheckDynamicBatchSizeInputShapeValid(unordered_map<string, vector<int64_t>> shape_map,
std::string &dynamic_batch_size) { std::string &dynamic_batch_size) {
int32_t size = 0; int32_t size = 0;


+ 1
- 1
ge/ir_build/atc_ir_common.h View File

@@ -75,7 +75,7 @@ Status CheckInsertOpConfParamValid(const std::string insert_op_conf);
Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory); Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory);
Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream); Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream);
Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode); Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode);
Status CheckInputFormat(const string &input_format);
Status CheckInputFormat(const std::string &input_format);
void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips); void PrintOptionMap(std::map<std::string, std::string> &options, std::string tips);
void EraseEndSemicolon(std::string &param); void EraseEndSemicolon(std::string &param);
} }


+ 1
- 1
ge/offline/main.cc View File

@@ -305,7 +305,7 @@ class GFlagUtils {
" --debug_dir Set the save path of operator compilation intermediate files.\n" " --debug_dir Set the save path of operator compilation intermediate files.\n"
"Default value: ./kernel_meta\n" "Default value: ./kernel_meta\n"
" --op_compiler_cache_dir Set the save path of operator compilation cache files.\n" " --op_compiler_cache_dir Set the save path of operator compilation cache files.\n"
"Default value: $HOME/atc_data\n"
"Default value: $HOME/atc_data/kernel_cache\n"
" --op_compiler_cache_mode Set the operator compilation cache mode." " --op_compiler_cache_mode Set the operator compilation cache mode."
"Options are disable(default), enable and force(force to refresh the cache)"); "Options are disable(default), enable and force(force to refresh the cache)");




+ 0
- 2
ge/proto/op_mapping_info.proto View File

@@ -15,7 +15,6 @@ message Output {
int32 original_output_data_type = 7; int32 original_output_data_type = 7;
int32 original_output_format = 8; int32 original_output_format = 8;
uint64 size = 9; uint64 size = 9;
Shape origin_shape = 10;
} }


message Input { message Input {
@@ -24,7 +23,6 @@ message Input {
Shape shape = 3; Shape shape = 3;
uint64 address = 4; uint64 address = 4;
uint64 size = 5; uint64 size = 5;
Shape origin_shape = 6;
} }


enum BufferType { enum BufferType {


+ 4
- 8
ge/single_op/single_op.cc View File

@@ -32,16 +32,14 @@ namespace ge {
namespace { namespace {
const size_t kDataMemAlignSize = 32; const size_t kDataMemAlignSize = 32;
const size_t kDataMemAlignUnit = 2; const size_t kDataMemAlignUnit = 2;
const string kShapeTypeDynamic = "dynamic";
const string kShapeTypeStatic = "static";


size_t GetAlignedSize(size_t size) { size_t GetAlignedSize(size_t size) {
size_t aligned_size = (size + kDataMemAlignUnit * kDataMemAlignSize - 1) / kDataMemAlignSize * kDataMemAlignSize; size_t aligned_size = (size + kDataMemAlignUnit * kDataMemAlignSize - 1) / kDataMemAlignSize * kDataMemAlignSize;
return aligned_size; return aligned_size;
} }


Status ProfilingTaskInfo(OpTask *op_task, const string &shape_type) {
if (!ProfilingManager::Instance().ProfilingModelLoadOn()) {
Status ProfilingTaskInfo(OpTask *op_task) {
if (!ProfilingManager::Instance().ProfilingModelExecuteOn()) {
return SUCCESS; return SUCCESS;
} }


@@ -68,8 +66,6 @@ Status ProfilingTaskInfo(OpTask *op_task, const string &shape_type) {
tmp_task_desc_info.block_dim = block_dim; tmp_task_desc_info.block_dim = block_dim;
tmp_task_desc_info.task_id = task_id; tmp_task_desc_info.task_id = task_id;
tmp_task_desc_info.stream_id = stream_id; tmp_task_desc_info.stream_id = stream_id;
tmp_task_desc_info.shape_type = shape_type;
tmp_task_desc_info.cur_iter_num = 0;
GELOGD("GetTaskDescInfo of op [%s] end, task_id[%u], stream_id[%u]", op_name.c_str(), task_id, stream_id); GELOGD("GetTaskDescInfo of op [%s] end, task_id[%u], stream_id[%u]", op_name.c_str(), task_id, stream_id);
task_desc_info.emplace_back(tmp_task_desc_info); task_desc_info.emplace_back(tmp_task_desc_info);


@@ -197,7 +193,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOp::ExecuteAsync(c
if (ret != SUCCESS) { if (ret != SUCCESS) {
return ret; return ret;
} }
GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(task, kShapeTypeStatic));
GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(task));
} }


return ret; return ret;
@@ -259,7 +255,7 @@ Status DynamicSingleOp::ExecuteAsync(const vector<GeTensorDesc> &input_desc,
std::lock_guard<std::mutex> lk(*stream_mutex_); std::lock_guard<std::mutex> lk(*stream_mutex_);


GE_CHK_STATUS_RET_NOLOG(op_task_->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_)); GE_CHK_STATUS_RET_NOLOG(op_task_->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_));
GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(op_task_.get(), kShapeTypeDynamic));
GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(op_task_.get()));
return SUCCESS; return SUCCESS;
} }
} // namespace ge } // namespace ge

+ 4
- 4
ge/single_op/task/op_task.cc View File

@@ -119,11 +119,11 @@ Status OpTask::DoUpdateArgTable(const SingleOpModelParam &param, bool keep_works
uintptr_t *arg_base = nullptr; uintptr_t *arg_base = nullptr;
size_t arg_num = 0; size_t arg_num = 0;
GetIoAddr(arg_base, arg_num); GetIoAddr(arg_base, arg_num);
if (arg_num < all_addresses.size()) {
GELOGE(INTERNAL_ERROR, "[%s] arg number mismatches, expect at least = %zu, but got = %zu",
if (arg_num != all_addresses.size()) {
GELOGE(INTERNAL_ERROR, "[%s] arg number mismatches, expect = %zu, but got = %zu",
op_desc_->GetName().c_str(), op_desc_->GetName().c_str(),
all_addresses.size(),
arg_num);
arg_num,
all_addresses.size());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }




+ 14
- 7
inc/external/ge/ge_api_types.h View File

@@ -293,7 +293,6 @@ const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path";


// Configure op bank path // Configure op bank path
const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path";
const std::string OP_BANK_UPDATE_FLAG = "ge.op_bank_update";


// Graph run mode // Graph run mode
enum GraphRunMode { PREDICTION = 0, TRAIN }; enum GraphRunMode { PREDICTION = 0, TRAIN };
@@ -367,7 +366,6 @@ static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR;
static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE;
static const char *const MDL_BANK_PATH = ge::MDL_BANK_PATH_FLAG.c_str(); static const char *const MDL_BANK_PATH = ge::MDL_BANK_PATH_FLAG.c_str();
static const char *const OP_BANK_PATH = ge::OP_BANK_PATH_FLAG.c_str(); static const char *const OP_BANK_PATH = ge::OP_BANK_PATH_FLAG.c_str();
static const char *const OP_BANK_UPDATE = ge::OP_BANK_UPDATE_FLAG.c_str();
static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str();


// for interface: aclgrphBuildModel // for interface: aclgrphBuildModel
@@ -391,13 +389,22 @@ const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT,
OP_COMPILER_CACHE_DIR, OP_COMPILER_CACHE_DIR,
OP_COMPILER_CACHE_MODE, OP_COMPILER_CACHE_MODE,
MDL_BANK_PATH, MDL_BANK_PATH,
OP_BANK_PATH,
OP_BANK_UPDATE};
OP_BANK_PATH};


// for interface: aclgrphParse // for interface: aclgrphParse
const std::set<std::string> ir_parser_suppported_options = {
INPUT_FP16_NODES, IS_INPUT_ADJUST_HW_LAYOUT, IS_OUTPUT_ADJUST_HW_LAYOUT, OUTPUT,
OUT_NODES, COMPRESS_WEIGHT_CONF, ENABLE_SCOPE_FUSION_PASSES};
const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT,
INPUT_SHAPE,
OP_NAME_MAP,
IS_DYNAMIC_INPUT,
INPUT_FP16_NODES,
IS_INPUT_ADJUST_HW_LAYOUT,
IS_OUTPUT_ADJUST_HW_LAYOUT,
OUTPUT,
OUTPUT_TYPE,
OUT_NODES,
COMPRESS_WEIGHT_CONF,
ENABLE_SCOPE_FUSION_PASSES,
LOG_LEVEL};


// for interface: aclgrphBuildInitialize // for interface: aclgrphBuildInitialize
const std::set<std::string> global_options = {CORE_TYPE, const std::set<std::string> global_options = {CORE_TYPE,


+ 0
- 4
inc/framework/common/ge_types.h View File

@@ -37,9 +37,7 @@ enum FrameworkType {
MINDSPORE = 1, MINDSPORE = 1,
TENSORFLOW = 3, TENSORFLOW = 3,
ANDROID_NN, ANDROID_NN,
#ifndef ONLY_COMPILE_OPEN_SRC
ONNX, ONNX,
#endif
FRAMEWORK_RESERVED, FRAMEWORK_RESERVED,
}; };


@@ -248,8 +246,6 @@ struct TaskDescInfo {
uint32_t block_dim; uint32_t block_dim;
uint32_t task_id; uint32_t task_id;
uint32_t stream_id; uint32_t stream_id;
std::string shape_type;
int64_t cur_iter_num;
}; };


// Profiling info of graph // Profiling info of graph


+ 13
- 0
inc/framework/executor/ge_executor.h View File

@@ -30,6 +30,8 @@
#include "runtime/base.h" #include "runtime/base.h"


namespace ge { namespace ge {
class ModelListenerAdapter;

class SingleOp; class SingleOp;
class DynamicSingleOp; class DynamicSingleOp;


@@ -53,8 +55,14 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
ge::Status Initialize(); ge::Status Initialize();
ge::Status Finalize(); ge::Status Finalize();


// Load model
ge::Status LoadModelOffline(uint32_t &model_id, const std::string &path, const std::string &key, int32_t priority,
std::shared_ptr<ge::ModelListener> listener);

ge::Status UnloadModel(uint32_t modelId); ge::Status UnloadModel(uint32_t modelId);


ge::Status RunModel(const ge::RunModelData &input_data, ge::RunModelData &output_data);

// Get input and output descriptor // Get input and output descriptor
ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc,
std::vector<ge::TensorDesc> &output_desc, bool new_model_desc = false); std::vector<ge::TensorDesc> &output_desc, bool new_model_desc = false);
@@ -160,6 +168,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
ge::Status GetModelDescInfoForZeroCopy(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, ge::Status GetModelDescInfoForZeroCopy(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc,
std::vector<ge::TensorDesc> &output_desc); std::vector<ge::TensorDesc> &output_desc);


ge::Status LoadModel(uint32_t &model_id, const ge::ModelData &model_data,
std::shared_ptr<ge::ModelListener> listener);

ge::Status CommandHandle(const ge::Command &command); ge::Status CommandHandle(const ge::Command &command);


ge::Status SetDump(const DumpConfig &dump_config); ge::Status SetDump(const DumpConfig &dump_config);
@@ -286,6 +297,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
private: private:
static bool isInit_; static bool isInit_;
}; };

ge::Status ModelInfoParser(const ge::ModelData &model, ge::ModelInfo &model_info);
} // namespace ge } // namespace ge


#endif // INC_FRAMEWORK_EXECUTOR_GE_EXECUTOR_H_ #endif // INC_FRAMEWORK_EXECUTOR_GE_EXECUTOR_H_

+ 32
- 34
inc/framework/omg/parser/model_parser.h View File

@@ -36,7 +36,7 @@ using Status = domi::Status;


namespace domi { namespace domi {
using GetGraphCallback = std::function<std::unique_ptr<google::protobuf::Message>( using GetGraphCallback = std::function<std::unique_ptr<google::protobuf::Message>(
const google::protobuf::Message *root_proto, const std::string &graph)>;
const google::protobuf::Message *root_proto, const std::string &graph)>;
class ModelParser { class ModelParser {
public: public:
ModelParser() {} ModelParser() {}
@@ -44,20 +44,19 @@ class ModelParser {
virtual ~ModelParser() {} virtual ~ModelParser() {}


/** /**
* @ingroup domi_omg
* @brief Analyze network model data
* @param [in] file Network model file path
* @param [in|out] graph Save the network information after analysis
* @return SUCCESS
* @return Others failed
*/
* @ingroup domi_omg
* @brief Analyze network model data
* @param [in] file Network model file path
* @param [in|out] graph Save the network information after analysis
* @return SUCCESS
* @return Others failed
*/
virtual Status Parse(const char *file, ge::Graph &graph) = 0; virtual Status Parse(const char *file, ge::Graph &graph) = 0;


/** /**
* @ingroup domi_omg * @ingroup domi_omg
* @brief Parse relevant data from memory and save it to graph * @brief Parse relevant data from memory and save it to graph
* @param [in] input Model file memory data * @param [in] input Model file memory data
* @param [in] input Model file memory size
* @param [in|out] graph A graph for saving the model information after analysis * @param [in|out] graph A graph for saving the model information after analysis
* @return SUCCESS * @return SUCCESS
* @return FAILED * @return FAILED
@@ -65,7 +64,6 @@ class ModelParser {
*/ */
virtual Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) = 0; virtual Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) = 0;


#ifndef ONLY_COMPILE_OPEN_SRC
/** /**
* @ingroup domi_omg * @ingroup domi_omg
* @brief Parse relevant data from memory and save it to graph * @brief Parse relevant data from memory and save it to graph
@@ -77,37 +75,37 @@ class ModelParser {
* @author * @author
*/ */
virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) = 0; virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) = 0;
#endif


/** /**
* @ingroup domi_omg
* @brief Analyze network model data
* @param [in] proto network model
* @param [in|out] graph Save the network information after analysis
* @return SUCCESS
* @return Others failed
*/
* @ingroup domi_omg
* @brief Analyze network model data
* @param [in] proto network model
* @param [in|out] graph Save the network information after analysis
* @return SUCCESS
* @return Others failed
*/
virtual Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) = 0; virtual Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) = 0;


/** /**
* @ingroup domi_omg
* @brief Analyze callback model data in subgraph
* @param [in] proto network model
* @param [in] callback callback of subgraph
* @param [in|out] graph Save the network information after analysis
* @return SUCCESS
* @return Others failed
*/
virtual Status ParseProtoWithSubgraph(const google::protobuf::Message *proto, GetGraphCallback callback,
* @ingroup domi_omg
* @brief Analyze callback model data in subgraph
* @param [in] proto network model
* @param [in] callback callback of subgraph
* @param [in|out] graph Save the network information after analysis
* @return SUCCESS
* @return Others failed
*/
virtual Status ParseProtoWithSubgraph(const google::protobuf::Message *proto,
GetGraphCallback callback,
ge::ComputeGraphPtr &graph) = 0; ge::ComputeGraphPtr &graph) = 0;
/** /**
* @ingroup domi_omg
* @brief Convert model files to JSON format
* @param [in] model_file Model file path to be converted
* @param [out] json_file Converted JSON file path
* @return SUCCESS
* @return Others failed
*/
* @ingroup domi_omg
* @brief Convert model files to JSON format
* @param [in] model_file Model file path to be converted
* @param [out] json_file Converted JSON file path
* @return SUCCESS
* @return Others failed
*/
virtual Status ToJson(const char *model_file, const char *json_file) { return domi::SUCCESS; } virtual Status ToJson(const char *model_file, const char *json_file) { return domi::SUCCESS; }


/* /*


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 8c89c521f5d682327b2f975cf06f7093960eb2f0
Subproject commit 5a1b0ab95e2d205ee9ee578ac4bcde4f4fbed6d8

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit 54ec7731e3a2951191693e02ff3165220975ed0c
Subproject commit 77dc42c383e416ed4a0f606ddc3c02cdaa082ac3

+ 0
- 5
tests/depends/runtime/src/runtime_stub.cc View File

@@ -384,8 +384,3 @@ rtError_t rtModelExit(rtModel_t model, rtStream_t stream)
{ {
return RT_ERROR_NONE; return RT_ERROR_NONE;
} }

rtError_t rtGetTaskIdAndStreamID(uint32_t *taskId, uint32_t *streamId)
{
return RT_ERROR_NONE;
}

+ 35
- 44
tests/ut/common/graph/CMakeLists.txt View File

@@ -61,67 +61,58 @@ set(UT_FILES
) )


set(SRC_FILES set(SRC_FILES
"${GE_CODE_DIR}/metadef/graph/option/ge_local_context.cc"
"${GE_CODE_DIR}/metadef/graph/option/ge_context.cc"
"${GE_CODE_DIR}/metadef/graph/anchor.cc"
"${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc"
"${GE_CODE_DIR}/metadef/graph/attr_value.cc"
"${GE_CODE_DIR}/metadef/graph/buffer.cc"
"${GE_CODE_DIR}/metadef/graph/compute_graph.cc"
"${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc"
"${GE_CODE_DIR}/metadef/graph/graph.cc"
"${GE_CODE_DIR}/metadef/graph/gnode.cc"
"${GE_CODE_DIR}/metadef/graph/ascend_string.cc"
"${GE_CODE_DIR}/metadef/graph/model.cc"
"${GE_CODE_DIR}/metadef/graph/model_serialize.cc"
"${GE_CODE_DIR}/metadef/graph/node.cc"
"${GE_CODE_DIR}/metadef/graph/op_desc.cc"
"${GE_CODE_DIR}/metadef/graph/operator.cc"
"${GE_CODE_DIR}/metadef/graph/operator_factory.cc"
"${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc"
"${GE_CODE_DIR}/metadef/graph/tensor.cc"
"${GE_CODE_DIR}/metadef/graph/ge_tensor.cc"
"${GE_CODE_DIR}/metadef/graph/shape_refiner.cc"
"${GE_CODE_DIR}/metadef/graph/format_refiner.cc"
"${GE_CODE_DIR}/metadef/graph/inference_context.cc"
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc"
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc"
"${GE_CODE_DIR}/metadef/graph/utils/tensor_utils.cc"
#"${GE_CODE_DIR}/metadef/graph/option/ge_local_context.cc"
#"${GE_CODE_DIR}/metadef/graph/option/ge_context.cc"
#"${GE_CODE_DIR}/metadef/graph/anchor.cc"
#"${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc"
#"${GE_CODE_DIR}/metadef/graph/attr_value.cc"
#"${GE_CODE_DIR}/metadef/graph/buffer.cc"
#"${GE_CODE_DIR}/metadef/graph/compute_graph.cc"
#"${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc"
#"${GE_CODE_DIR}/metadef/graph/graph.cc"
#"${GE_CODE_DIR}/metadef/graph/gnode.cc"
#"${GE_CODE_DIR}/metadef/graph/ascend_string.cc"
#"${GE_CODE_DIR}/metadef/graph/model.cc"
#"${GE_CODE_DIR}/metadef/graph/model_serialize.cc"
#"${GE_CODE_DIR}/metadef/graph/node.cc"
#"${GE_CODE_DIR}/metadef/graph/op_desc.cc"
#"${GE_CODE_DIR}/metadef/graph/operator.cc"
#"${GE_CODE_DIR}/metadef/graph/operator_reg.cc"
#"${GE_CODE_DIR}/metadef/graph/operator_factory.cc"
#"${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc"
#"${GE_CODE_DIR}/metadef/graph/range_vistor.cc"
#"${GE_CODE_DIR}/metadef/graph/tensor.cc"
#"${GE_CODE_DIR}/metadef/graph/ge_tensor.cc"
#"${GE_CODE_DIR}/metadef/graph/shape_refiner.cc"
#"${GE_CODE_DIR}/metadef/graph/format_refiner.cc"
#"${GE_CODE_DIR}/metadef/graph/inference_context.cc"
#"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc"
#"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc"
#"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc"
#"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc"
#"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc"
#"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc"
#"${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc"
#"${GE_CODE_DIR}/metadef/graph/utils/tensor_utils.cc"
"${GE_CODE_DIR}/metadef/ops/op_imp.cpp" "${GE_CODE_DIR}/metadef/ops/op_imp.cpp"
"${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc"
"${GE_CODE_DIR}/metadef/graph/utils/transformer_utils.cc"
"${GE_CODE_DIR}/metadef/graph/runtime_inference_context.cc"
"${GE_CODE_DIR}/metadef/graph/ref_relation.cc"
"${GE_CODE_DIR}/metadef/third_party/transformer/src/transfer_shape_according_to_format.cpp"
"${GE_CODE_DIR}/metadef/third_party/transformer/src/axis_util.cpp"
#"${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc"
) )


#add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) #add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS})
add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS})


target_compile_options(ut_libgraph PRIVATE
-g --coverage -fprofile-arcs -ftest-coverage
)

target_compile_definitions(ut_libgraph PRIVATE target_compile_definitions(ut_libgraph PRIVATE
google=ascend_private google=ascend_private
) )


target_link_libraries(ut_libgraph target_link_libraries(ut_libgraph
$<BUILD_INTERFACE:intf_pub> $<BUILD_INTERFACE:intf_pub>
graph
gtest gtest
gtest_main gtest_main
slog_stub slog_stub
ascend_protobuf ascend_protobuf
c_sec c_sec
error_manager_stub
mmpa_stub
-lrt -lrt
-ldl -ldl
-lgcov
) )

+ 7
- 29
tests/ut/ge/CMakeLists.txt View File

@@ -245,8 +245,6 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/model/ge_model.cc" "${GE_CODE_DIR}/ge/model/ge_model.cc"
"${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc"
"${GE_CODE_DIR}/ge/graph/load/new_model_manager/model_utils.cc" "${GE_CODE_DIR}/ge/graph/load/new_model_manager/model_utils.cc"
@@ -477,8 +475,6 @@ set(GRAPH_PASS_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc"
@@ -487,7 +483,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc"
"${GE_CODE_DIR}/ge/graph/common/transop_util.cc" "${GE_CODE_DIR}/ge/graph/common/transop_util.cc"
"${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc"
#"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/folding_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/folding_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/variable_op_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transpose_transdata_pass.cc"
@@ -675,13 +671,13 @@ set(MULTI_PARTS_TEST_FILES
) )


set(SINGLE_OP_TEST_FILES set(SINGLE_OP_TEST_FILES
#"single_op/single_op_model_unittest.cc"
"single_op/single_op_model_unittest.cc"
"single_op/single_op_manager_unittest.cc" "single_op/single_op_manager_unittest.cc"
"single_op/stream_resource_unittest.cc" "single_op/stream_resource_unittest.cc"
) )


set(PROFILING_MNG_TEST_FILES set(PROFILING_MNG_TEST_FILES
#"profiling/ge_profiling_manager_unittest.cc"
"profiling/ge_profiling_manager_unittest.cc"
) )


set(OTHERS_TEST_FILES set(OTHERS_TEST_FILES
@@ -848,17 +844,13 @@ add_executable(ut_libge_multiparts_utest
${MULTI_PARTS_TEST_FILES} ${MULTI_PARTS_TEST_FILES}
) )


target_compile_options(ut_libge_multiparts_utest PRIVATE
-g --coverage -fprofile-arcs -ftest-coverage
)

target_compile_definitions(ut_libge_multiparts_utest PRIVATE target_compile_definitions(ut_libge_multiparts_utest PRIVATE
google=ascend_private google=ascend_private
) )


target_link_libraries(ut_libge_multiparts_utest target_link_libraries(ut_libge_multiparts_utest
$<BUILD_INTERFACE:intf_pub> $<BUILD_INTERFACE:intf_pub>
ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common ge_single_op ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov
ge_build_common ge_load_common ge_execute_common ge_optimize_common ge_partition_common ge_prepare_common ge_single_op ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl
) )


# libge_others_utest # libge_others_utest
@@ -869,14 +861,9 @@ add_executable(ut_libge_others_utest
${EXECUTE_TEST_FILES} ${EXECUTE_TEST_FILES}
${OTHERS_TEST_FILES} ${OTHERS_TEST_FILES}
) )

target_compile_options(ut_libge_others_utest PRIVATE
-g --coverage -fprofile-arcs -ftest-coverage
)

target_link_libraries(ut_libge_others_utest target_link_libraries(ut_libge_others_utest
$<BUILD_INTERFACE:intf_pub> $<BUILD_INTERFACE:intf_pub>
ge_load_common ge_execute_common ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov
ge_load_common ge_execute_common ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl
) )


# libge_kernel_utest # libge_kernel_utest
@@ -886,14 +873,9 @@ add_executable(ut_libge_kernel_utest
${KERNEL_TEST_FILES} ${KERNEL_TEST_FILES}
${KERNEL_SRC_FILES} ${KERNEL_SRC_FILES}
) )

target_compile_options(ut_libge_kernel_utest PRIVATE
-g --coverage -fprofile-arcs -ftest-coverage
)

target_link_libraries(ut_libge_kernel_utest target_link_libraries(ut_libge_kernel_utest
$<BUILD_INTERFACE:intf_pub> $<BUILD_INTERFACE:intf_pub>
ge_load_common ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl -lgcov
ge_load_common ge_ut_common gtest gtest_main ascend_protobuf ${COMMON_SHARED_LIBRARIES} json -lrt -ldl
) )


# libge_distinct_load_utest # libge_distinct_load_utest
@@ -905,10 +887,6 @@ add_executable(ut_libge_distinct_load_utest
${PROFILING_MNG_TEST_FILES} ${PROFILING_MNG_TEST_FILES}
) )


target_compile_options(ut_libge_distinct_load_utest PRIVATE
-g --coverage -fprofile-arcs -ftest-coverage
)

target_compile_definitions(ut_libge_distinct_load_utest PRIVATE target_compile_definitions(ut_libge_distinct_load_utest PRIVATE
google=ascend_private google=ascend_private
) )
@@ -919,5 +897,5 @@ target_link_libraries(ut_libge_distinct_load_utest
ge_execute_common ge_ut_common_format ge_load_common ge_execute_common ge_ut_common_format ge_load_common
ge_single_op ge_prepare_common ge_single_op ge_prepare_common
ge_optimize_common ge_build_common ge_partition_common ge_ut_common ge_optimize_common ge_build_common ge_partition_common ge_ut_common
gtest gtest_main ascend_protobuf json c_sec -lrt -ldl -lpthread -lgcov
gtest gtest_main ascend_protobuf json c_sec -lrt -ldl -lpthread
) )

+ 0
- 2
tests/ut/ge/graph/build/mem_assigner_unittest.cc View File

@@ -147,7 +147,6 @@ class UtestMemoryAssignerTest : public testing::Test {
void TearDown() { GetContext().out_nodes_map.clear(); } void TearDown() { GetContext().out_nodes_map.clear(); }
}; };


/*
TEST_F(UtestMemoryAssignerTest, MemoryBlock_Resize_RealSizeList_is_empty) { TEST_F(UtestMemoryAssignerTest, MemoryBlock_Resize_RealSizeList_is_empty) {
ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>(""); ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("");
ge::OpDescPtr op_def_a = createOpWithWsSize("A", 6000); ge::OpDescPtr op_def_a = createOpWithWsSize("A", 6000);
@@ -161,7 +160,6 @@ TEST_F(UtestMemoryAssignerTest, MemoryBlock_Resize_RealSizeList_is_empty) {


delete memory_block; delete memory_block;
} }
*/


namespace ge { namespace ge {




+ 1
- 0
tests/ut/ge/graph/passes/folding_kernel/broadcast_args_kernel_unittest.cc View File

@@ -52,6 +52,7 @@


using namespace testing; using namespace testing;
using namespace ge; using namespace ge;
using namespace cce;
using namespace ge::test; using namespace ge::test;


#define TEST_OPERATOR(op_, input_shapes, output_shapes) \ #define TEST_OPERATOR(op_, input_shapes, output_shapes) \


+ 1
- 0
tests/ut/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel_unittest.cc View File

@@ -52,6 +52,7 @@


using namespace testing; using namespace testing;
using namespace ge; using namespace ge;
using namespace cce;


class UtestBroadcastGradientArgsKernel : public testing::Test { class UtestBroadcastGradientArgsKernel : public testing::Test {
protected: protected:


+ 1
- 0
tests/ut/ge/graph/passes/folding_kernel/empty_kernel_unittest.cc View File

@@ -53,6 +53,7 @@


using namespace testing; using namespace testing;
using namespace ge; using namespace ge;
using namespace cce;
using namespace ge::test; using namespace ge::test;


class UtestEmptyKernel : public testing::Test { class UtestEmptyKernel : public testing::Test {


+ 0
- 1
tests/ut/ge/graph/passes/variable_op_pass_unittest.cc View File

@@ -38,7 +38,6 @@
#include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_mem_allocator.h"
#include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_var_manager.h"
#include "graph_builder_utils.h" #include "graph_builder_utils.h"
#include "cce/dnn.h"
#include "cce/dnn_struct_base.hpp" #include "cce/dnn_struct_base.hpp"
#include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h"
#include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h"


+ 2
- 3
tests/ut/ge/graph_ir/ge_operator_factory_unittest.cc View File

@@ -84,7 +84,7 @@ TEST(UtestGeOperatorFactory, register_func) {
status = OperatorFactoryImpl::RegisterVerifyFunc("ABC", nullptr); status = OperatorFactoryImpl::RegisterVerifyFunc("ABC", nullptr);
EXPECT_EQ(GRAPH_SUCCESS, status); EXPECT_EQ(GRAPH_SUCCESS, status);
} }
/*
TEST(UtestGeOperatorFactory, get_ops_type_list_fail) { TEST(UtestGeOperatorFactory, get_ops_type_list_fail) {
auto operator_creators_temp = OperatorFactoryImpl::operator_creators_; auto operator_creators_temp = OperatorFactoryImpl::operator_creators_;
OperatorFactoryImpl::operator_creators_ = nullptr; OperatorFactoryImpl::operator_creators_ = nullptr;
@@ -92,5 +92,4 @@ TEST(UtestGeOperatorFactory, get_ops_type_list_fail) {
graphStatus status = OperatorFactoryImpl::GetOpsTypeList(all_ops); graphStatus status = OperatorFactoryImpl::GetOpsTypeList(all_ops);
EXPECT_EQ(GRAPH_FAILED, status); EXPECT_EQ(GRAPH_FAILED, status);
OperatorFactoryImpl::operator_creators_ = operator_creators_temp; OperatorFactoryImpl::operator_creators_ = operator_creators_temp;
}
*/
}

+ 1
- 1
tests/ut/ge/single_op/single_op_model_unittest.cc View File

@@ -17,7 +17,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>


//#include "cce/taskdown_common.hpp"
#include "cce/taskdown_common.hpp"
#include "graph/load/new_model_manager/model_utils.h" #include "graph/load/new_model_manager/model_utils.h"
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"
#include "runtime/rt.h" #include "runtime/rt.h"


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save