Browse Source

!1880 code_sync_0626

Merge pull request !1880 from mindspore_ding/code_sync_0626
tags/v1.3.0
i-robot Gitee 4 years ago
parent
commit
f565c42679
100 changed files with 1291 additions and 7414 deletions
  1. +2
    -2
      .gitmodules
  2. +1
    -0
      CMakeLists.txt
  3. +4
    -4
      build.sh
  4. +6
    -6
      cmake/external_libs/protobuf_shared.cmake
  5. +5
    -7
      cmake/external_libs/protobuf_static.cmake
  6. +5
    -7
      cmake/external_libs/protoc.cmake
  7. +10
    -0
      ge/CMakeLists.txt
  8. +0
    -1
      ge/client/proto/ge_api.proto
  9. +0
    -193
      ge/client/proto/ge_ir.proto
  10. +0
    -140
      ge/client/proto/insert_op.proto
  11. +0
    -396
      ge/client/proto/om.proto
  12. +0
    -179
      ge/client/proto/task.proto
  13. +1
    -0
      ge/common/CMakeLists.txt
  14. +2
    -2
      ge/common/dump/dump_manager.cc
  15. +10
    -2
      ge/common/ge/tbe_plugin_manager.cc
  16. +2
    -1
      ge/common/ge/tbe_plugin_manager.h
  17. +0
    -193
      ge/common/proto/ge_ir.proto
  18. +0
    -140
      ge/common/proto/insert_op.proto
  19. +0
    -396
      ge/common/proto/om.proto
  20. +0
    -75
      ge/common/proto/op_mapping.proto
  21. +0
    -179
      ge/common/proto/task.proto
  22. +0
    -70
      ge/common/proto/tensorflow/attr_value.proto
  23. +0
    -108
      ge/common/proto/tensorflow/function.proto
  24. +0
    -64
      ge/common/proto/tensorflow/graph.proto
  25. +0
    -22
      ge/common/proto/tensorflow/graph_library.proto
  26. +0
    -71
      ge/common/proto/tensorflow/node_def.proto
  27. +0
    -172
      ge/common/proto/tensorflow/op_def.proto
  28. +0
    -37
      ge/common/proto/tensorflow/resource_handle.proto
  29. +0
    -102
      ge/common/proto/tensorflow/tensor.proto
  30. +0
    -53
      ge/common/proto/tensorflow/tensor_shape.proto
  31. +0
    -82
      ge/common/proto/tensorflow/types.proto
  32. +0
    -39
      ge/common/proto/tensorflow/versions.proto
  33. +26
    -56
      ge/common/util.cc
  34. +1
    -0
      ge/executor/CMakeLists.txt
  35. +0
    -113
      ge/executor/proto/dump_task.proto
  36. +0
    -193
      ge/executor/proto/ge_ir.proto
  37. +0
    -140
      ge/executor/proto/insert_op.proto
  38. +0
    -396
      ge/executor/proto/om.proto
  39. +0
    -75
      ge/executor/proto/op_mapping.proto
  40. +0
    -179
      ge/executor/proto/task.proto
  41. +0
    -179
      ge/ge_local_engine/proto/task.proto
  42. +58
    -0
      ge/ge_opt_info/ge_opt_info.cc
  43. +31
    -0
      ge/ge_opt_info/ge_opt_info.h
  44. +6
    -5
      ge/generator/ge_generator.cc
  45. +5
    -0
      ge/graph/build/label_allocator.cc
  46. +5
    -0
      ge/graph/build/logical_stream_allocator.cc
  47. +10
    -1
      ge/graph/build/stream_allocator.cc
  48. +27
    -5
      ge/graph/build/task_generator.cc
  49. +1
    -0
      ge/graph/build/task_generator.h
  50. +0
    -15
      ge/graph/common/omg_util.cc
  51. +0
    -9
      ge/graph/common/omg_util.h
  52. +160
    -51
      ge/graph/load/model_manager/davinci_model.cc
  53. +6
    -0
      ge/graph/load/model_manager/davinci_model.h
  54. +393
    -0
      ge/graph/load/model_manager/task_info/ffts_task_info.cc
  55. +66
    -0
      ge/graph/load/model_manager/task_info/ffts_task_info.h
  56. +8
    -1
      ge/graph/manager/graph_manager.cc
  57. +0
    -2
      ge/graph/optimize/graph_optimize.cc
  58. +22
    -1
      ge/graph/partition/dynamic_shape_partition.cc
  59. +1
    -1
      ge/graph/partition/dynamic_shape_partition.h
  60. +14
    -7
      ge/graph/partition/graph_partition.cc
  61. +8
    -30
      ge/graph/passes/mark_force_unknown_for_cond_pass.cc
  62. +6
    -0
      ge/graph/passes/mark_graph_unknown_status_pass.cc
  63. +2
    -3
      ge/graph/passes/merge_to_stream_merge_pass.cc
  64. +9
    -1
      ge/graph/passes/next_iteration_pass.cc
  65. +20
    -0
      ge/graph/passes/replace_with_empty_const_pass.cc
  66. +8
    -8
      ge/graph/passes/switch_to_stream_switch_pass.cc
  67. +66
    -53
      ge/graph/preprocess/graph_preprocess.cc
  68. +2
    -1
      ge/graph/preprocess/graph_preprocess.h
  69. +1
    -1
      ge/graph/preprocess/insert_op/ge_aipp_op.cc
  70. +8
    -7
      ge/hybrid/executor/hybrid_model_executor.cc
  71. +2
    -1
      ge/hybrid/executor/hybrid_model_executor.h
  72. +58
    -8
      ge/hybrid/executor/node_state.cc
  73. +11
    -8
      ge/hybrid/executor/node_state.h
  74. +19
    -8
      ge/hybrid/executor/subgraph_context.cc
  75. +3
    -2
      ge/hybrid/executor/subgraph_context.h
  76. +11
    -15
      ge/hybrid/executor/subgraph_executor.cc
  77. +2
    -1
      ge/hybrid/executor/subgraph_executor.h
  78. +2
    -1
      ge/hybrid/executor/worker/shape_inference_engine.cc
  79. +22
    -0
      ge/hybrid/model/hybrid_model_builder.cc
  80. +1
    -0
      ge/hybrid/model/hybrid_model_builder.h
  81. +2
    -3
      ge/hybrid/model/node_item.cc
  82. +2
    -3
      ge/hybrid/model/node_item.h
  83. +0
    -4
      ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc
  84. +2
    -1
      ge/hybrid/node_executor/hccl/hccl_node_executor.cc
  85. +40
    -40
      ge/hybrid/node_executor/node_executor.cc
  86. +5
    -4
      ge/hybrid/node_executor/node_executor.h
  87. +7
    -10
      ge/hybrid/node_executor/task_context.cc
  88. +1
    -3
      ge/hybrid/node_executor/task_context.h
  89. +49
    -22
      ge/ir_build/ge_ir_build.cc
  90. +1
    -1
      ge/ir_build/option_utils.cc
  91. +33
    -7
      ge/offline/main.cc
  92. +0
    -193
      ge/offline/proto/ge_ir.proto
  93. +0
    -140
      ge/offline/proto/insert_op.proto
  94. +0
    -396
      ge/offline/proto/om.proto
  95. +0
    -179
      ge/offline/proto/task.proto
  96. +0
    -1829
      ge/proto/caffe/caffe.proto
  97. +0
    -113
      ge/proto/dump_task.proto
  98. +0
    -21
      ge/proto/fusion_model.proto
  99. +0
    -37
      ge/proto/fwk_adapter.proto
  100. +0
    -88
      ge/proto/ge_api.proto

+ 2
- 2
.gitmodules View File

@@ -1,8 +1,8 @@
[submodule "parser"] [submodule "parser"]
path = parser path = parser
url = https://gitee.com/ascend/parser.git url = https://gitee.com/ascend/parser.git
branch = master
branch = r1.5.0
[submodule "metadef"] [submodule "metadef"]
path = metadef path = metadef
url = https://gitee.com/ascend/metadef.git url = https://gitee.com/ascend/metadef.git
branch = master
branch = r1.5.0

+ 1
- 0
CMakeLists.txt View File

@@ -95,6 +95,7 @@ else ()
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH})
else() else()
find_module(slog libalog.so ${ASCEND_ATC_DIR}) find_module(slog libalog.so ${ASCEND_ATC_DIR})
find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR})
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR})
if(PLATFORM STREQUAL "train") if(PLATFORM STREQUAL "train")
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR})


+ 4
- 4
build.sh View File

@@ -355,13 +355,13 @@ generate_package()


if [ "x${PLATFORM}" = "xtrain" ] if [ "x${PLATFORM}" = "xtrain" ]
then then
tar -cf graphengine_lib.tar fwkacllib
tar -zcf graphengine_lib.tar fwkacllib
elif [ "x${PLATFORM}" = "xinference" ] elif [ "x${PLATFORM}" = "xinference" ]
then then
tar -cf graphengine_lib.tar acllib atc
tar -zcf graphengine_lib.tar acllib atc
elif [ "x${PLATFORM}" = "xall" ] elif [ "x${PLATFORM}" = "xall" ]
then then
tar -cf graphengine_lib.tar fwkacllib acllib atc
tar -zcf graphengine_lib.tar fwkacllib acllib atc
fi fi
} }


@@ -371,6 +371,6 @@ elif [ "X$MINDSPORE_MODE" = "Xon" ]
then then
cd "${OUTPUT_PATH}" cd "${OUTPUT_PATH}"
find ./ -name graphengine_lib.tar -exec rm {} \; find ./ -name graphengine_lib.tar -exec rm {} \;
tar -cf graphengine_lib.tar lib
tar -zcf graphengine_lib.tar lib
fi fi
echo "---------------- GraphEngine package archive generated ----------------" echo "---------------- GraphEngine package archive generated ----------------"

+ 6
- 6
cmake/external_libs/protobuf_shared.cmake View File

@@ -11,14 +11,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif() endif()
if (GE_PB_PKG) if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz")
else() else()
if (ENABLE_GITEE) if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else() else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif () endif ()
endif() endif()


@@ -58,7 +58,7 @@ target_include_directories(ascend_protobuf INTERFACE ${PROTOBUF_SHARED_PKG_DIR}/
set(INSTALL_BASE_DIR "") set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib) set(INSTALL_LIBRARY_DIR lib)


install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.8.0.0 OPTIONAL
install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.13.0.0 OPTIONAL
DESTINATION ${INSTALL_LIBRARY_DIR}) DESTINATION ${INSTALL_LIBRARY_DIR})
install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL
DESTINATION ${INSTALL_LIBRARY_DIR}) DESTINATION ${INSTALL_LIBRARY_DIR})


+ 5
- 7
cmake/external_libs/protobuf_static.cmake View File

@@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
endif() endif()


if(GE_PB_PKG) if(GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz")
else() else()
if (ENABLE_GITEE) if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else() else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif () endif ()
endif() endif()


@@ -29,8 +29,6 @@ set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static)
ExternalProject_Add(protobuf_static_build ExternalProject_Add(protobuf_static_build
URL ${REQ_URL} URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} CONFIGURE_COMMAND ${CMAKE_COMMAND}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}


+ 5
- 7
cmake/external_libs/protoc.cmake View File

@@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
endif() endif()


if(GE_PB_PKG) if(GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz")
else() else()
if (ENABLE_GITEE) if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else() else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif () endif ()
endif() endif()


@@ -28,8 +28,6 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protoc_build ExternalProject_Add(protoc_build
URL ${REQ_URL} URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake
BUILD_COMMAND $(MAKE) BUILD_COMMAND $(MAKE)


+ 10
- 0
ge/CMakeLists.txt View File

@@ -174,6 +174,7 @@ set(TRAIN_SRC_LIST
"graph/load/model_manager/task_info/model_exit_task_info.cc" "graph/load/model_manager/task_info/model_exit_task_info.cc"
"graph/load/model_manager/task_info/event_record_task_info.cc" "graph/load/model_manager/task_info/event_record_task_info.cc"
"graph/load/model_manager/task_info/event_wait_task_info.cc" "graph/load/model_manager/task_info/event_wait_task_info.cc"
"graph/load/model_manager/task_info/ffts_task_info.cc"
"graph/load/model_manager/task_info/fusion_start_task_info.cc" "graph/load/model_manager/task_info/fusion_start_task_info.cc"
"graph/load/model_manager/task_info/fusion_stop_task_info.cc" "graph/load/model_manager/task_info/fusion_stop_task_info.cc"
"graph/load/model_manager/task_info/hccl_task_info.cc" "graph/load/model_manager/task_info/hccl_task_info.cc"
@@ -433,6 +434,7 @@ set(TRAIN_SRC_LIST
"graph/build/memory/max_block_mem_assigner.cc" "graph/build/memory/max_block_mem_assigner.cc"
"graph/build/memory/var_mem_assign_util.cc" "graph/build/memory/var_mem_assign_util.cc"
"graph/build/memory/buffer_pool_mem_assigner.cc" "graph/build/memory/buffer_pool_mem_assigner.cc"
"ge_opt_info/ge_opt_info.cc"
) )


set(INFER_SRC_LIST set(INFER_SRC_LIST
@@ -662,6 +664,7 @@ set(INFER_SRC_LIST
"graph/load/model_manager/task_info/task_info.cc" "graph/load/model_manager/task_info/task_info.cc"
"graph/load/model_manager/task_info/event_record_task_info.cc" "graph/load/model_manager/task_info/event_record_task_info.cc"
"graph/load/model_manager/task_info/event_wait_task_info.cc" "graph/load/model_manager/task_info/event_wait_task_info.cc"
"graph/load/model_manager/task_info/ffts_task_info.cc"
"graph/load/model_manager/task_info/fusion_start_task_info.cc" "graph/load/model_manager/task_info/fusion_start_task_info.cc"
"graph/load/model_manager/task_info/fusion_stop_task_info.cc" "graph/load/model_manager/task_info/fusion_stop_task_info.cc"
"graph/load/model_manager/task_info/kernel_ex_task_info.cc" "graph/load/model_manager/task_info/kernel_ex_task_info.cc"
@@ -709,6 +712,7 @@ set(INFER_SRC_LIST
"graph/build/memory/max_block_mem_assigner.cc" "graph/build/memory/max_block_mem_assigner.cc"
"graph/build/memory/var_mem_assign_util.cc" "graph/build/memory/var_mem_assign_util.cc"
"graph/build/memory/buffer_pool_mem_assigner.cc" "graph/build/memory/buffer_pool_mem_assigner.cc"
"ge_opt_info/ge_opt_info.cc"
) )


if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES)
@@ -770,11 +774,13 @@ target_include_directories(ge_runner SYSTEM PRIVATE
${GE_CODE_DIR}/../inc/cce ${GE_CODE_DIR}/../inc/cce
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external
${GE_CODE_DIR}/../abl/adump/external ${GE_CODE_DIR}/../abl/adump/external
${GE_CODE_DIR}/../abl/licctrl
#### blue zone #### blue zone
${ASCEND_DIR}/driver/include ${ASCEND_DIR}/driver/include
${ASCEND_DIR}/fwkacllib/include ${ASCEND_DIR}/fwkacllib/include
${GE_CODE_DIR}/third_party/fwkacllib/inc ${GE_CODE_DIR}/third_party/fwkacllib/inc
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain
${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info
) )


target_link_options(ge_runner PRIVATE target_link_options(ge_runner PRIVATE
@@ -797,6 +803,7 @@ target_link_libraries(ge_runner PRIVATE
runtime runtime
error_manager error_manager
ascend_hal_stub ascend_hal_stub
opt_feature
-Wl,--as-needed -Wl,--as-needed
json json
-lrt -lrt
@@ -851,11 +858,13 @@ target_include_directories(ge_compiler SYSTEM PRIVATE
${GE_CODE_DIR}/../inc/cce ${GE_CODE_DIR}/../inc/cce
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external
${GE_CODE_DIR}/../abl/adump/external ${GE_CODE_DIR}/../abl/adump/external
${GE_CODE_DIR}/../abl/licctrl
#### blue zone #### #### blue zone ####
${ASCEND_DIR}/driver/include ${ASCEND_DIR}/driver/include
${ASCEND_DIR}/fwkacllib/include ${ASCEND_DIR}/fwkacllib/include
${GE_CODE_DIR}/third_party/fwkacllib/inc ${GE_CODE_DIR}/third_party/fwkacllib/inc
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain
${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info
) )


target_link_options(ge_compiler PRIVATE target_link_options(ge_compiler PRIVATE
@@ -875,6 +884,7 @@ target_link_libraries(ge_compiler PRIVATE
error_manager error_manager
slog slog
runtime_compile runtime_compile
opt_feature
-Wl,--as-needed -Wl,--as-needed
json json
-lrt -lrt


+ 0
- 1
ge/client/proto/ge_api.proto View File

@@ -1 +0,0 @@
../../proto/ge_api.proto

+ 0
- 193
ge/client/proto/ge_ir.proto View File

@@ -1,193 +0,0 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
DT_VARIANT = 26; // variant type
DT_BF16 = 27; // bf16 type
DT_INT4 = 28; // int4 type
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 0
- 140
ge/client/proto/insert_op.proto View File

@@ -1,140 +0,0 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
float padding_value = 72;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 0
- 396
ge/client/proto/om.proto View File

@@ -1,396 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 0
- 179
ge/client/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 1
- 0
ge/common/CMakeLists.txt View File

@@ -106,6 +106,7 @@ target_link_libraries(ge_common PRIVATE
c_sec c_sec
error_manager error_manager
slog slog
opt_feature
-Wl,--as-needed -Wl,--as-needed
json json
$<$<NOT:$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-lrt> $<$<NOT:$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-lrt>


+ 2
- 2
ge/common/dump/dump_manager.cc View File

@@ -33,7 +33,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpManager &DumpManager::GetIn


bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump_properties) { bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump_properties) {
if (dump_config.dump_status.empty() && dump_config.dump_debug.empty()) { if (dump_config.dump_status.empty() && dump_config.dump_debug.empty()) {
dump_properties_map_.emplace(kInferSessionId, dump_properties);
dump_properties_map_[kInferSessionId] = dump_properties;
GELOGI("Dump does not open"); GELOGI("Dump does not open");
return false; return false;
} }
@@ -41,7 +41,7 @@ bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump
if ((dump_config.dump_status == kDumpoff || dump_config.dump_status == kDumpOFF) && if ((dump_config.dump_status == kDumpoff || dump_config.dump_status == kDumpOFF) &&
dump_config.dump_debug == kDumpoff) { dump_config.dump_debug == kDumpoff) {
dump_properties.ClearDumpPropertyValue(); dump_properties.ClearDumpPropertyValue();
dump_properties_map_.emplace(kInferSessionId, dump_properties);
dump_properties_map_[kInferSessionId] = dump_properties;
return false; return false;
} }
if (dump_config.dump_status == kDumpOn && dump_config.dump_debug == kDumpOn) { if (dump_config.dump_status == kDumpOn && dump_config.dump_debug == kDumpOn) {


+ 10
- 2
ge/common/ge/tbe_plugin_manager.cc View File

@@ -104,7 +104,15 @@ void TBEPluginManager::ProcessSoFullName(vector<string> &file_list, string &caff
} }
} }


void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path) {
void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_list,
string &caffe_parser_path, int recursive_depth) {
static const int kMaxRecursiveDepth = 20; // For recursive depth protection

if (recursive_depth >= kMaxRecursiveDepth) {
GELOGW("Recursive depth is become %d, Please check input!", recursive_depth);
return;
}

// Path, change to absolute path // Path, change to absolute path
string real_path = RealPath(path.c_str()); string real_path = RealPath(path.c_str());
// Plugin path does not exist // Plugin path does not exist
@@ -138,7 +146,7 @@ void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_lis
ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff, ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff,
aicpu_host_so_suff); aicpu_host_so_suff);
} else { } else {
FindParserSo(full_name, file_list, caffe_parser_path);
FindParserSo(full_name, file_list, caffe_parser_path, recursive_depth + 1);
} }
} }
mmScandirFree(entries, ret); mmScandirFree(entries, ret);


+ 2
- 1
ge/common/ge/tbe_plugin_manager.h View File

@@ -57,7 +57,8 @@ class TBEPluginManager {
static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name,
const string &caffe_parser_so_suff, const string &aicpu_so_suff, const string &caffe_parser_so_suff, const string &aicpu_so_suff,
const string &aicpu_host_so_suff); const string &aicpu_host_so_suff);
static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path);
static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path,
int recursive_depth = 0);
static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path); static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path);
static void GetCustomOpPath(std::string &customop_path); static void GetCustomOpPath(std::string &customop_path);
void LoadCustomOpLib(); void LoadCustomOpLib();


+ 0
- 193
ge/common/proto/ge_ir.proto View File

@@ -1,193 +0,0 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
DT_VARIANT = 26; // variant type
DT_BF16 = 27; // bf16 type
DT_INT4 = 28; // int4 type
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 0
- 140
ge/common/proto/insert_op.proto View File

@@ -1,140 +0,0 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
float padding_value = 72;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 0
- 396
ge/common/proto/om.proto View File

@@ -1,396 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 0
- 75
ge/common/proto/op_mapping.proto View File

@@ -1,75 +0,0 @@
syntax = "proto3";
package toolkit.aicpu.dump;

message Shape {
repeated uint64 dim = 1;
}

message Output {
int32 data_type = 1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
string original_name = 5;
int32 original_output_index = 6;
int32 original_output_data_type = 7;
int32 original_output_format = 8;
uint64 size = 9;
Shape origin_shape = 10;
}

message Input {
int32 data_type =1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
uint64 size = 5;
Shape origin_shape = 6;
}

enum BufferType {
L1 = 0;
}

message OpBuffer {
BufferType buffer_type = 1;
uint64 address = 2;
uint64 size = 3;
}

message Op {
string op_name = 1;
string op_type = 2;
}

message Task {
uint32 task_id = 1;
uint32 stream_id = 2;
Op op = 3;
repeated Output output = 4;
bool end_graph = 5;
repeated Input input = 6;
repeated OpBuffer buffer = 7;
}

message OpMappingInfo {
string dump_path = 1;
oneof model_name_param {
string model_name = 2;
}
oneof model_id_param {
uint32 model_id = 3;
}
oneof step_id {
uint64 step_id_addr = 4;
}
oneof iterations_per_loop {
uint64 iterations_per_loop_addr = 5;
}
oneof loop_cond {
uint64 loop_cond_addr = 6;
}
uint32 flag = 7; // 0x01 load, 0x00 unload
repeated Task task = 8;
string dump_step = 9;
}

+ 0
- 179
ge/common/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 0
- 70
ge/common/proto/tensorflow/attr_value.proto View File

@@ -1,70 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "AttrValueProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "tensor.proto";
import "tensor_shape.proto";
import "types.proto";

// Protocol buffer representing the value for an attr used to configure an Op.
// Comment indicates the corresponding attr type. Only the field matching the
// attr type may be filled.
message AttrValue {
// LINT.IfChange
message ListValue {
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated DataType type = 6 [packed = true]; // "list(type)"
repeated TensorShapeProto shape = 7; // "list(shape)"
repeated TensorProto tensor = 8; // "list(tensor)"
repeated NameAttrList func = 9; // "list(attr)"
}
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)

oneof value {
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
DataType type = 6; // "type"
TensorShapeProto shape = 7; // "shape"
TensorProto tensor = 8; // "tensor"
ListValue list = 1; // any "list(...)"

// "func" represents a function. func.name is a function's name or
// a primitive op's name. func.attr.first is the name of an attr
// defined for that function. func.attr.second is the value for
// that attr in the instantiation.
NameAttrList func = 10;

// This is a placeholder only used in nodes defined inside a
// function. It indicates the attr value will be supplied when
// the function is instantiated. For example, let us suppose a
// node "N" in function "FN". "N" has an attr "A" with value
// placeholder = "foo". When FN is instantiated with attr "foo"
// set to "bar", the instantiated node N's attr A will have been
// given the value "bar".
string placeholder = 9;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NameAttrList {
string name = 1;
map<string, AttrValue> attr = 2;
}

+ 0
- 108
ge/common/proto/tensorflow/function.proto View File

@@ -1,108 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "FunctionProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "node_def.proto";
import "op_def.proto";

// A library is a set of named functions.
message FunctionDefLibrary {
repeated FunctionDef function = 1;
repeated GradientDef gradient = 2;
}

// A function can be instantiated when the runtime can bind every attr
// with a value. When a GraphDef has a call to a function, it must
// have binding for every attr defined in the signature.
// * device spec, etc.
message FunctionDef {
// The definition of the function's name, arguments, return values,
// attrs etc.
OpDef signature = 1;

// Attributes specific to this function definition.
map<string, AttrValue> attr = 5;

// NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21.
reserved 2;

// In both of the following fields, there is the need to specify an
// output that is used as either the input to another node (in
// `node_def`) or as a return value of the function (in `ret`).
// Unlike the NodeDefs in GraphDef, we need to be able to specify a
// list in some cases (instead of just single outputs). Also, we
// need to be able to deal with lists of unknown length (so the
// output index may not be known at function definition time). So
// we use the following format instead:
// * "fun_in" where "fun_in" is the name of a function input arg in
// the `signature` field above. This represents that input, whether
// it is a single tensor or a list.
// * "fun_in:0" gives the first element of a function input arg (a
// non-list input is considered a list of length 1 for these
// purposes).
// * "node:out" where "node" is the name of a node in `node_def` and
// "out" is the name one of its op's output arguments (the name
// comes from the OpDef of the node's op). This represents that
// node's output, whether it is a single tensor or a list.
// Note: We enforce that an op's output arguments are never
// renamed in the backwards-compatibility test.
// * "node:out:0" gives the first element of a node output arg (a
// non-list output is considered a list of length 1 for these
// purposes).
//
// NOT CURRENTLY SUPPORTED (but may be in the future):
// * "node:out:-1" gives last element in a node output list
// * "node:out:1:" gives a list with all but the first element in a
// node output list
// * "node:out::-1" gives a list with all but the last element in a
// node output list

// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
// may have values of type `placeholder` and the `input` field uses
// the "output" format above.

// By convention, "op" in node_def is resolved by consulting with a
// user-defined library first. If not resolved, "func" is assumed to
// be a builtin op.
repeated NodeDef node_def = 3;

// A mapping from the output arg names from `signature` to the
// outputs from `node_def` that should be returned by the function.
map<string, string> ret = 4;
}

// GradientDef defines the gradient function of a function defined in
// a function library.
//
// A gradient function g (specified by gradient_func) for a function f
// (specified by function_name) must follow the following:
//
// The function 'f' must be a numerical function which takes N inputs
// and produces M outputs. Its gradient function 'g', which is a
// function taking N + M inputs and produces N outputs.
//
// I.e. if we have
// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
// then, g is
// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
// dL/dy1, dL/dy2, ..., dL/dy_M),
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
// loss function). dL/dx_i is the partial derivative of L with respect
// to x_i.
message GradientDef {
string function_name = 1; // The function name.
string gradient_func = 2; // The gradient function's name.
}

+ 0
- 64
ge/common/proto/tensorflow/graph.proto View File

@@ -1,64 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "GraphProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "node_def.proto";
import "function.proto";
import "versions.proto";

// Represents the graph of operations
message GraphDef {
repeated NodeDef node = 1;

// Compatibility versions of the graph. See core/public/version.h for version
// history. The GraphDef version is distinct from the TensorFlow version, and
// each release of TensorFlow will support a range of GraphDef versions.
VersionDef versions = 4;

// Deprecated single version field; use versions above instead. Since all
// GraphDef changes before "versions" was introduced were forward
// compatible, this field is entirely ignored.
int32 version = 3 [deprecated = true];

// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
//
// "library" provides user-defined functions.
//
// Naming:
// * library.function.name are in a flat namespace.
// NOTE: We may need to change it to be hierarchical to support
// different orgs. E.g.,
// { "/google/nn", { ... }},
// { "/google/vision", { ... }}
// { "/org_foo/module_bar", { ... }}
// map<string, FunctionDefLib> named_lib;
// * If node[i].op is the name of one function in "library",
// node[i] is deemed as a function call. Otherwise, node[i].op
// must be a primitive operation supported by the runtime.
//
//
// Function call semantics:
//
// * The callee may start execution as soon as some of its inputs
// are ready. The caller may want to use Tuple() mechanism to
// ensure all inputs are ready in the same time.
//
// * The consumer of return values may start executing as soon as
// the return values the consumer depends on are ready. The
// consumer may want to use Tuple() mechanism to ensure the
// consumer does not start until all return values of the callee
// function are ready.
FunctionDefLibrary library = 2;
};

+ 0
- 22
ge/common/proto/tensorflow/graph_library.proto View File

@@ -1,22 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;

import "graph.proto";

message GeGraphDef {
string name = 1;
GraphDef graph = 2;
}

message GraphDefLibrary {
repeated GeGraphDef graph_def = 1;
};

+ 0
- 71
ge/common/proto/tensorflow/node_def.proto View File

@@ -1,71 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "NodeProto";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";

message NodeDef {
// The name given to this operator. Used for naming inputs,
// logging, visualization, etc. Unique within a single GraphDef.
// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
string name = 1;

// The operation name. There may be custom parameters in attrs.
// Op names starting with an underscore are reserved for internal use.
string op = 2;

// Each input is "node:src_output" with "node" being a string name and
// "src_output" indicating which output tensor to use from "node". If
// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
// may optionally be followed by control inputs that have the format
// "^node".
repeated string input = 3;

// A (possibly partial) specification for the device on which this
// node should be placed.
// The expected syntax for this string is as follows:
//
// DEVICE_SPEC ::= PARTIAL_SPEC
//
// PARTIAL_SPEC ::= ("/" CONSTRAINT) *
// CONSTRAINT ::= ("job:" JOB_NAME)
// | ("replica:" [1-9][0-9]*)
// | ("task:" [1-9][0-9]*)
// | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") )
//
// Valid values for this string include:
// * "/job:worker/replica:0/task:1/device:GPU:3" (full specification)
// * "/job:worker/device:GPU:3" (partial specification)
// * "" (no specification)
//
// If the constraints do not resolve to a single device (or if this
// field is empty or not present), the runtime will attempt to
// choose a device automatically.
string device = 4;

// Operation-specific graph-construction-time configuration.
// Note that this should include all attrs defined in the
// corresponding OpDef, including those with a value matching
// the default -- this allows the default to change and makes
// NodeDefs easier to interpret on their own. However, if
// an attr with a default is not specified in this list, the
// default will be used.
// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
// one of the names from the corresponding OpDef's attr field).
// The values must have a type matching the corresponding OpDef
// attr's type field.
// Add some examples here showing best practices.
map<string, AttrValue> attr = 5;
};

+ 0
- 172
ge/common/proto/tensorflow/op_def.proto View File

@@ -1,172 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "OpDefProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "types.proto";

// Defines an operation. A NodeDef in a GraphDef specifies an Op by
// using the "op" field which should match the name of a OpDef.
// LINT.IfChange
message OpDef {
// Op names starting with an underscore are reserved for internal use.
// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
string name = 1;

// For describing inputs and outputs.
message ArgDef {
// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*".
string name = 1;

// Human readable description.
string description = 2;

// Describes the type of one or more tensors that are accepted/produced
// by this input/output arg. The only legal combinations are:
// * For a single tensor: either the "type" field is set or the
// "type_attr" field is set to the name of an attr with type "type".
// * For a sequence of tensors with the same type: the "number_attr"
// field will be set to the name of an attr with type "int", and
// either the "type" or "type_attr" field will be set as for
// single tensors.
// * For a sequence of tensors, the "type_list_attr" field will be set
// to the name of an attr with type "list(type)".
DataType type = 3;
string type_attr = 4; // if specified, attr must have type "type"
string number_attr = 5; // if specified, attr must have type "int"
// If specified, attr must have type "list(type)", and none of
// type, type_attr, and number_attr may be specified.
string type_list_attr = 6;

// For inputs: if true, the inputs are required to be refs.
// By default, inputs can be either refs or non-refs.
// For outputs: if true, outputs are refs, otherwise they are not.
bool is_ref = 16;
};

// Description of the input(s).
repeated ArgDef input_arg = 2;

// Description of the output(s).
repeated ArgDef output_arg = 3;

// Description of the graph-construction-time configuration of this
// Op. That is to say, this describes the attr fields that will
// be specified in the NodeDef.
message AttrDef {
// A descriptive name for the argument. May be used, e.g. by the
// Python client, as a keyword argument name, and so should match
// the regexp "[a-z][a-z0-9_]+".
string name = 1;

// One of the type names from attr_value.proto ("string", "list(string)",
// "int", etc.).
string type = 2;

// A reasonable default for this attribute if the user does not supply
// a value. If not specified, the user must supply a value.
AttrValue default_value = 3;

// Human-readable description.
string description = 4;


// --- Constraints ---
// These constraints are only in effect if specified. Default is no
// constraints.

// For type == "int", this is a minimum value. For "list(___)"
// types, this is the minimum length.
bool has_minimum = 5;
int64 minimum = 6;

// The set of allowed values. Has type that is the "list" version
// of the "type" field above (uses the "list" field of AttrValue).
// If type == "type" or "list(type)" above, then the "type" field
// of "allowed_values.list" has the set of allowed DataTypes.
// If type == "string" or "list(string)", then the "s" field of
// "allowed_values.list" has the set of allowed strings.
AttrValue allowed_values = 7;
}
repeated AttrDef attr = 4;

// Optional deprecation based on GraphDef versions.
OpDeprecation deprecation = 8;

// One-line human-readable description of what the Op does.
string summary = 5;

// Additional, longer human-readable description of what the Op does.
string description = 6;

// -------------------------------------------------------------------------
// Which optimizations this operation can participate in.

// True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs)
bool is_commutative = 18;

// If is_aggregate is true, then this operation accepts N >= 2
// inputs and produces 1 output all of the same type. Should be
// associative and commutative, and produce output with the same
// shape as the input. The optimizer may replace an aggregate op
// taking input from multiple devices with a tree of aggregate ops
// that aggregate locally within each device (and possibly within
// groups of nearby devices) before communicating.
bool is_aggregate = 16; // for things like add

// Other optimizations go here, like
// can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc.

// -------------------------------------------------------------------------
// Optimization constraints.

// Ops are marked as stateful if their behavior depends on some state beyond
// their input tensors (e.g. variable reading op) or if they have
// a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
// must always produce the same output for the same input and have
// no side-effects.
//
// By default Ops may be moved between devices. Stateful ops should
// either not be moved, or should only be moved if that state can also
// be moved (e.g. via some sort of save / restore).
// Stateful ops are guaranteed to never be optimized away by Common
// Subexpression Elimination (CSE).
bool is_stateful = 17; // for things like variables, queue

// -------------------------------------------------------------------------
// Non-standard options.

// By default, all inputs to an Op must be initialized Tensors. Ops
// that may initialize tensors for the first time should set this
// field to true, to allow the Op to take an uninitialized Tensor as
// input.
bool allows_uninitialized_input = 19; // for Assign, etc.
};
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc)

// Information about version-dependent deprecation of an op
message OpDeprecation {
// First GraphDef version at which the op is disallowed.
int32 version = 1;

// Explanation of why it was deprecated and what to use instead.
string explanation = 2;
};

// A collection of OpDefs
message OpList {
repeated OpDef op = 1;
};

+ 0
- 37
ge/common/proto/tensorflow/resource_handle.proto View File

@@ -1,37 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "ResourceHandle";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// Protocol buffer representing a handle to a tensorflow resource. Handles are
// not valid across executions, but can be serialized back and forth from within
// a single run.
message ResourceHandleProto {
// Unique name for the device containing the resource.
string device = 1;

// Container in which this resource is placed.
string container = 2;

// Unique name of this resource.
string name = 3;

// Hash code for the type of the resource. Is only valid in the same device
// and in the same execution.
uint64 hash_code = 4;

// For debug-only, the name of the type pointed to by this handle, if
// available.
string maybe_type_name = 5;
};

+ 0
- 102
ge/common/proto/tensorflow/tensor.proto View File

@@ -1,102 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TensorProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "resource_handle.proto";
import "tensor_shape.proto";
import "types.proto";

// Protocol buffer representing a tensor.
message TensorProto {
DataType dtype = 1;

// Shape of the tensor.
TensorShapeProto tensor_shape = 2;

// Only one of the representations below is set, one of "tensor_contents" and
// the "xxx_val" attributes. We are not using oneof because as oneofs cannot
// contain repeated fields it would require another extra set of messages.

// Version number.
//
// In version 0, if the "repeated xxx" representations contain only one
// element, that element is repeated to fill the shape. This makes it easy
// to represent a constant Tensor with a single value.
int32 version_number = 3;

// Serialized raw tensor content from either Tensor::AsProtoTensorContent or
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
// can be used for all tensor types. The purpose of this representation is to
// reduce serialization overhead during RPC call by avoiding serialization of
// many repeated small items.
bytes tensor_content = 4;

// Type specific representations that make it easy to create tensor protos in
// all languages. Only the representation corresponding to "dtype" can
// be set. The values hold the flattened representation of the tensor in
// row major order.

// DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll
// have some pointless zero padding for each value here.
repeated int32 half_val = 13 [packed = true];

// DT_FLOAT.
repeated float float_val = 5 [packed = true];

// DT_DOUBLE.
repeated double double_val = 6 [packed = true];

// DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
repeated int32 int_val = 7 [packed = true];

// DT_STRING
repeated bytes string_val = 8;

// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
// and imaginary parts of i-th single precision complex.
repeated float scomplex_val = 9 [packed = true];

// DT_INT64
repeated int64 int64_val = 10 [packed = true];

// DT_BOOL
repeated bool bool_val = 11 [packed = true];

// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
// and imaginary parts of i-th double precision complex.
repeated double dcomplex_val = 12 [packed = true];

// DT_RESOURCE
repeated ResourceHandleProto resource_handle_val = 14;

// DT_VARIANT
repeated VariantTensorDataProto variant_val = 15;

// DT_UINT32
repeated uint32 uint32_val = 16 [packed = true];

// DT_UINT64
repeated uint64 uint64_val = 17 [packed = true];
};

// Protocol buffer representing the serialization format of DT_VARIANT tensors.
message VariantTensorDataProto {
// Name of the type of objects being serialized.
string type_name = 1;
// Portions of the object that are not Tensors.
bytes metadata = 2;
// Tensors contained within objects being serialized.
repeated TensorProto tensors = 3;
}

+ 0
- 53
ge/common/proto/tensorflow/tensor_shape.proto View File

@@ -1,53 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

// Protocol buffer representing the shape of tensors.

syntax = "proto3";
option cc_enable_arenas = true;
option java_outer_classname = "TensorShapeProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

package domi.tensorflow;

// Dimensions of a tensor.
message TensorShapeProto {
// One dimension of the tensor.
message Dim {
// Size of the tensor in that dimension.
// This value must be >= -1, but values of -1 are reserved for "unknown"
// shapes (values of -1 mean "unknown" dimension). Certain wrappers
// that work with TensorShapeProto may fail at runtime when deserializing
// a TensorShapeProto containing a dim value of -1.
int64 size = 1;

// Optional name of the tensor dimension.
string name = 2;
};

// Dimensions of the tensor, such as {"input", 30}, {"output", 40}
// for a 30 x 40 2D tensor. If an entry has size -1, this
// corresponds to a dimension of unknown size. The names are
// optional.
//
// The order of entries in "dim" matters: It indicates the layout of the
// values in the tensor in-memory representation.
//
// The first entry in "dim" is the outermost dimension used to layout the
// values, the last entry is the innermost dimension. This matches the
// in-memory layout of RowMajor Eigen tensors.
//
// If "dim.size()" > 0, "unknown_rank" must be false.
repeated Dim dim = 2;

// If true, the number of dimensions in the shape is unknown.
//
// If true, "dim.size()" must be 0.
bool unknown_rank = 3;
};

+ 0
- 82
ge/common/proto/tensorflow/types.proto View File

@@ -1,82 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TypesProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// LINT.IfChange
enum DataType {
// Not a legal value for DataType. Used to indicate a DataType field
// has not been set.
DT_INVALID = 0;

// Data types that all computation devices are expected to be
// capable to support.
DT_FLOAT = 1;
DT_DOUBLE = 2;
DT_INT32 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_INT8 = 6;
DT_STRING = 7;
DT_COMPLEX64 = 8; // Single-precision complex
DT_INT64 = 9;
DT_BOOL = 10;
DT_QINT8 = 11; // Quantized int8
DT_QUINT8 = 12; // Quantized uint8
DT_QINT32 = 13; // Quantized int32
DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
DT_QINT16 = 15; // Quantized int16
DT_QUINT16 = 16; // Quantized uint16
DT_UINT16 = 17;
DT_COMPLEX128 = 18; // Double-precision complex
DT_HALF = 19;
DT_RESOURCE = 20;
DT_VARIANT = 21; // Arbitrary C++ data types
DT_UINT32 = 22;
DT_UINT64 = 23;

// Do not use! These are only for parameters. Every enum above
// should have a corresponding value below (verified by types_test).
DT_FLOAT_REF = 101;
DT_DOUBLE_REF = 102;
DT_INT32_REF = 103;
DT_UINT8_REF = 104;
DT_INT16_REF = 105;
DT_INT8_REF = 106;
DT_STRING_REF = 107;
DT_COMPLEX64_REF = 108;
DT_INT64_REF = 109;
DT_BOOL_REF = 110;
DT_QINT8_REF = 111;
DT_QUINT8_REF = 112;
DT_QINT32_REF = 113;
DT_BFLOAT16_REF = 114;
DT_QINT16_REF = 115;
DT_QUINT16_REF = 116;
DT_UINT16_REF = 117;
DT_COMPLEX128_REF = 118;
DT_HALF_REF = 119;
DT_RESOURCE_REF = 120;
DT_VARIANT_REF = 121;
DT_UINT32_REF = 122;
DT_UINT64_REF = 123;
}
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/c/c_api.h,
// https://www.tensorflow.org/code/tensorflow/go/tensor.go,
// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.h,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py)

+ 0
- 39
ge/common/proto/tensorflow/versions.proto View File

@@ -1,39 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "VersionsProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// Version information for a piece of serialized data
//
// There are different types of versions for each type of data
// (GraphDef, etc.), but they all have the same common shape
// described here.
//
// Each consumer has "consumer" and "min_producer" versions (specified
// elsewhere). A consumer is allowed to consume this data if
//
// producer >= min_producer
// consumer >= min_consumer
// consumer not in bad_consumers
//
message VersionDef {
// The version of the code that produced this data.
int32 producer = 1;

// Any consumer below this version is not allowed to consume this data.
int32 min_consumer = 2;

// Specific consumer versions which are disallowed (e.g. due to bugs).
repeated int32 bad_consumers = 3;
};

+ 26
- 56
ge/common/util.cc View File

@@ -340,15 +340,24 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char
return res; return res;
} }


void PathValidErrReport(const std::string &file_path, const std::string &atc_param, const std::string &reason) {
if (!atc_param.empty()) {
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}),
std::vector<std::string>({atc_param, file_path, reason}));
} else {
REPORT_INNER_ERROR("E19999", "Path[%s] invalid, reason:%s", file_path.c_str(), reason.c_str());
}
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path, FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path,
const std::string &atc_param) { const std::string &atc_param) {
// The specified path is empty // The specified path is empty
std::map<std::string, std::string> args_map; std::map<std::string, std::string> args_map;
if (file_path.empty()) { if (file_path.empty()) {
if (atc_param != "") {
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param});
if (!atc_param.empty()) {
REPORT_INPUT_ERROR("E10004", std::vector<std::string>({"parameter"}), std::vector<std::string>({atc_param}));
} else { } else {
REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid");
REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid.");
} }
GELOGW("Input parameter %s is empty.", file_path.c_str()); GELOGW("Input parameter %s is empty.", file_path.c_str());
return false; return false;
@@ -356,13 +365,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const
std::string real_path = RealPath(file_path.c_str()); std::string real_path = RealPath(file_path.c_str());
// Unable to get absolute path (does not exist or does not have permission to access) // Unable to get absolute path (does not exist or does not have permission to access)
if (real_path.empty()) { if (real_path.empty()) {
if (atc_param != "") {
std::string reason = "realpath error, errmsg:" + std::string(strerror(errno));
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, file_path, reason});
} else {
REPORT_INNER_ERROR("E19999", "Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno));
}
std::string reason = "realpath error, errmsg:" + std::string(strerror(errno));
PathValidErrReport(file_path, atc_param, reason);
GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno));
return false; return false;
} }
@@ -378,23 +382,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(real_path, mode), !ValidateStr(real_path, mode),
if (atc_param != "") {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, real_path, kPathValidReason});
} else {
REPORT_INNER_ERROR("E19999", "Path[%s] has invalid char, %s", file_path.c_str(), kPathValidReason);
}
PathValidErrReport(file_path, atc_param, kPathValidReason);
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason);


// The absolute path points to a file that is not readable // The absolute path points to a file that is not readable
if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) { if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) {
if (atc_param != "") {
std::string reason = "cat not access, errmsg:" + std::string(strerror(errno));
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, file_path, reason});
} else {
REPORT_INNER_ERROR("E19999", "Path[%s] can't acccess, errmsg:%s", file_path.c_str(), strerror(errno));
}
PathValidErrReport(file_path, atc_param, "cat not access, errmsg:" + std::string(strerror(errno)));
GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno));
return false; return false;
} }
@@ -406,10 +399,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const
const std::string &atc_param) { const std::string &atc_param) {
// The specified path is empty // The specified path is empty
if (file_path.empty()) { if (file_path.empty()) {
if (atc_param != "") {
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param});
if (!atc_param.empty()) {
REPORT_INPUT_ERROR("E10004", std::vector<std::string>({"parameter"}), std::vector<std::string>({atc_param}));
} else { } else {
REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid");
REPORT_INNER_ERROR("E19999", "Param file_path is empty, check invalid.");
} }
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param});
GELOGW("Input parameter's value is empty."); GELOGW("Input parameter's value is empty.");
@@ -417,17 +410,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const
} }


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH, GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH,
if (atc_param != "") {
std::string reason = "len is too long, it must be less than " +
std::to_string(MMPA_MAX_PATH);
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{atc_param, file_path, reason});
} else {
REPORT_INNER_ERROR("E19999", "Path[%s] len is too long, it must be less than %d",
file_path.c_str(), MMPA_MAX_PATH);
}
return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(),
std::string reason = "len is too long, it must be less than " +
std::to_string(MMPA_MAX_PATH);
PathValidErrReport(file_path, atc_param, reason);
return false, "Path[%s] len is too long, it must be less than %d", file_path.c_str(),
MMPA_MAX_PATH); MMPA_MAX_PATH);


// A regular matching expression to verify the validity of the input file path // A regular matching expression to verify the validity of the input file path
@@ -441,12 +427,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const


GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
!ValidateStr(file_path, mode), !ValidateStr(file_path, mode),
if (atc_param != "") {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, file_path, kPathValidReason});
} else {
REPORT_INNER_ERROR("E19999", "Path[%s] has invalid char, %s", file_path.c_str(), kPathValidReason);
}
PathValidErrReport(file_path, atc_param, kPathValidReason);
return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason); return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason);


std::string real_path = RealPath(file_path.c_str()); std::string real_path = RealPath(file_path.c_str());
@@ -454,13 +435,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const
if (!real_path.empty()) { if (!real_path.empty()) {
// File is not readable or writable // File is not readable or writable
if (mmAccess2(real_path.c_str(), M_W_OK | M_F_OK) != EN_OK) { if (mmAccess2(real_path.c_str(), M_W_OK | M_F_OK) != EN_OK) {
if (atc_param != "") {
std::string reason = "cat not access, errmsg:" + std::string(strerror(errno));
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, file_path, reason});
} else {
REPORT_INNER_ERROR("E19999", "Path[%s] can't acccess, errmsg:%s", file_path.c_str(), strerror(errno));
}
PathValidErrReport(file_path, atc_param, "cat not access, errmsg:" + std::string(strerror(errno)));
GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno)); GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno));
return false; return false;
} }
@@ -479,12 +454,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const
std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos)); std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos));
// Determine whether the specified path is valid by creating the path // Determine whether the specified path is valid by creating the path
if (CreateDirectory(prefix_path) != 0) { if (CreateDirectory(prefix_path) != 0) {
if (atc_param != "") {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{atc_param, file_path, "Can not create directory"});
} else {
REPORT_INNER_ERROR("E19999", "Path[%s] Can not create directory", file_path.c_str());
}
PathValidErrReport(file_path, atc_param, "Can not create directory");
GELOGW("Can not create directory[%s].", file_path.c_str()); GELOGW("Can not create directory[%s].", file_path.c_str());
return false; return false;
} }


+ 1
- 0
ge/executor/CMakeLists.txt View File

@@ -37,6 +37,7 @@ set(SRC_LIST
"../graph/load/model_manager/task_info/task_info.cc" "../graph/load/model_manager/task_info/task_info.cc"
"../graph/load/model_manager/task_info/event_record_task_info.cc" "../graph/load/model_manager/task_info/event_record_task_info.cc"
"../graph/load/model_manager/task_info/event_wait_task_info.cc" "../graph/load/model_manager/task_info/event_wait_task_info.cc"
"../graph/load/model_manager/task_info/ffts_task_info.cc"
"../graph/load/model_manager/task_info/fusion_start_task_info.cc" "../graph/load/model_manager/task_info/fusion_start_task_info.cc"
"../graph/load/model_manager/task_info/fusion_stop_task_info.cc" "../graph/load/model_manager/task_info/fusion_stop_task_info.cc"
"../graph/load/model_manager/task_info/kernel_ex_task_info.cc" "../graph/load/model_manager/task_info/kernel_ex_task_info.cc"


+ 0
- 113
ge/executor/proto/dump_task.proto View File

@@ -1,113 +0,0 @@
syntax = "proto3";
package toolkit.dump;

enum OutputDataType {
DT_UNDEFINED = 0;
DT_FLOAT = 1;
DT_FLOAT16 = 2;
DT_INT8 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_UINT16 = 6;
DT_INT32 = 7;
DT_INT64 = 8;
DT_UINT32 = 9;
DT_UINT64 = 10;
DT_BOOL = 11;
DT_DOUBLE = 12;
DT_STRING = 13;
DT_DUAL_SUB_INT8 = 14;
DT_DUAL_SUB_UINT8 = 15;
DT_COMPLEX64 = 16;
DT_COMPLEX128 = 17;
DT_QINT8 = 18;
DT_QINT16 = 19;
DT_QINT32 = 20;
DT_QUINT8 = 21;
DT_QUINT16 = 22;
DT_RESOURCE = 23;
DT_STRING_REF = 24;
DT_DUAL = 25;
DT_VARIANT = 26;
}

enum OutputFormat {
FORMAT_NCHW = 0;
FORMAT_NHWC = 1;
FORMAT_ND = 2;
FORMAT_NC1HWC0 = 3;
FORMAT_FRACTAL_Z = 4;
FORMAT_NC1C0HWPAD = 5;
FORMAT_NHWC1C0 = 6;
FORMAT_FSR_NCHW = 7;
FORMAT_FRACTAL_DECONV = 8;
FORMAT_C1HWNC0 = 9;
FORMAT_FRACTAL_DECONV_TRANSPOSE = 10;
FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11;
FORMAT_NC1HWC0_C04 = 12;
FORMAT_FRACTAL_Z_C04 = 13;
FORMAT_CHWN = 14;
FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15;
FORMAT_HWCN = 16;
FORMAT_NC1KHKWHWC0 = 17;
FORMAT_BN_WEIGHT = 18;
FORMAT_FILTER_HWCK = 19;
FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20;
FORMAT_HASHTABLE_LOOKUP_KEYS = 21;
FORMAT_HASHTABLE_LOOKUP_VALUE = 22;
FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23;
FORMAT_HASHTABLE_LOOKUP_HITS=24;
FORMAT_C1HWNCoC0 = 25;
FORMAT_MD = 26;
FORMAT_NDHWC = 27;
FORMAT_FRACTAL_ZZ = 28;
FORMAT_FRACTAL_NZ = 29;
FORMAT_RESERVED = 30;
}

message OriginalOp {
string name = 1;
uint32 output_index = 2;
OutputDataType data_type = 3;
OutputFormat format = 4;
}

message Shape {
repeated uint64 dim = 1;
}

message OpOutput {
OutputDataType data_type = 1;
OutputFormat format = 2;
Shape shape = 3;
OriginalOp original_op = 4; // the original op corresponding to the output
bytes data = 5;
uint64 size = 6;
}

message OpInput {
OutputDataType data_type = 1;
OutputFormat format = 2;
Shape shape = 3;
bytes data = 4;
uint64 size = 5;
}

enum BufferType {
L1 = 0;
}

message OpBuffer {
BufferType buffer_type = 1;
bytes data = 2;
uint64 size = 3;
}

message DumpData{
string version = 1;
uint64 dump_time = 2;
repeated OpOutput output = 3;
repeated OpInput input = 4;
repeated OpBuffer buffer = 5;
string op_name = 6;
}

+ 0
- 193
ge/executor/proto/ge_ir.proto View File

@@ -1,193 +0,0 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
DT_VARIANT = 26; // variant type
DT_BF16 = 27; // bf16 type
DT_INT4 = 28; // int4 type
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 0
- 140
ge/executor/proto/insert_op.proto View File

@@ -1,140 +0,0 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
float padding_value = 72;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 0
- 396
ge/executor/proto/om.proto View File

@@ -1,396 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 0
- 75
ge/executor/proto/op_mapping.proto View File

@@ -1,75 +0,0 @@
syntax = "proto3";
package toolkit.aicpu.dump;

message Shape {
repeated uint64 dim = 1;
}

message Output {
int32 data_type = 1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
string original_name = 5;
int32 original_output_index = 6;
int32 original_output_data_type = 7;
int32 original_output_format = 8;
uint64 size = 9;
Shape origin_shape = 10;
}

message Input {
int32 data_type =1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
uint64 size = 5;
Shape origin_shape = 6;
}

enum BufferType {
L1 = 0;
}

message OpBuffer {
BufferType buffer_type = 1;
uint64 address = 2;
uint64 size = 3;
}

message Op {
string op_name = 1;
string op_type = 2;
}

message Task {
uint32 task_id = 1;
uint32 stream_id = 2;
Op op = 3;
repeated Output output = 4;
bool end_graph = 5;
repeated Input input = 6;
repeated OpBuffer buffer = 7;
}

message OpMappingInfo {
string dump_path = 1;
oneof model_name_param {
string model_name = 2;
}
oneof model_id_param {
uint32 model_id = 3;
}
oneof step_id {
uint64 step_id_addr = 4;
}
oneof iterations_per_loop {
uint64 iterations_per_loop_addr = 5;
}
oneof loop_cond {
uint64 loop_cond_addr = 6;
}
uint32 flag = 7; // 0x01 load, 0x00 unload
repeated Task task = 8;
string dump_step = 9;
}

+ 0
- 179
ge/executor/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 0
- 179
ge/ge_local_engine/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 58
- 0
ge/ge_opt_info/ge_opt_info.cc View File

@@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_opt_info/ge_opt_info.h"

#include <string>
#include <map>
#include "graph/ge_local_context.h"
#include "ge/ge_api_types.h"
#include "common/debug/ge_log.h"
#include "opt_info.h"

namespace ge {
Status GeOptInfo::SetOptInfo() {
std::string soc_ver;
graphStatus ret = GetThreadLocalContext().GetOption(SOC_VERSION, soc_ver);
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Get soc version failed.");
GELOGE(FAILED, "[Get][SocVersion]Get soc version failed.");
return FAILED;
}
GELOGD("Soc version:%s.", soc_ver.c_str());
std::map<std::string, std::string> opt_info;
// the first arg does not work at present.
if (gelc::GetOptInfo(gelc::kOffline, soc_ver, opt_info) != gelc::SUCCESS) {
REPORT_CALL_ERROR("E19999", "Get optional information failed, is_offline:%d, soc version:%s",
gelc::kOffline, soc_ver.c_str());
GELOGE(FAILED, "[Get][OptInfo]Get optional information failed, is_offline:%d, soc version:%s",
gelc::kOffline, soc_ver.c_str());
return FAILED;
}
// do nothing if get empty information
if (opt_info.empty()) {
GELOGI("Optional information is empty.");
return SUCCESS;
}
std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions();
for (const auto &itr : opt_info) {
graph_options.emplace(itr.first, itr.second);
GELOGI("Get optional information success, key:%s, value:%s.", itr.first.c_str(), itr.second.c_str());
}
GetThreadLocalContext().SetGraphOption(graph_options);
return SUCCESS;
}
} // namespace ge

+ 31
- 0
ge/ge_opt_info/ge_opt_info.h View File

@@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_OPT_INFO_GE_OPT_INFO_H_
#define GE_OPT_INFO_GE_OPT_INFO_H_

#include "ge/ge_api_error_codes.h"
#include "register/register_types.h"

namespace ge {
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeOptInfo {
public:
GeOptInfo() = default;
static Status SetOptInfo();
};
} // namespace ge

#endif // GE_OPT_INFO_GE_OPT_INFO_H_

+ 6
- 5
ge/generator/ge_generator.cc View File

@@ -674,6 +674,12 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
GELOGD("Current ctx is null."); GELOGD("Current ctx is null.");
ctx = nullptr; ctx = nullptr;
} }
std::function<void()> callback = [&]() {
if (ctx != nullptr) {
(void)rtCtxSetCurrent(ctx);
}
};
GE_MAKE_GUARD(restore, callback);


GeRootModelPtr ge_root_model = nullptr; GeRootModelPtr ge_root_model = nullptr;
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
@@ -712,11 +718,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
} }
return ret; return ret;
} }

if (ctx != nullptr) {
(void)rtCtxSetCurrent(ctx);
}

return SUCCESS; return SUCCESS;
} }




+ 5
- 0
ge/graph/build/label_allocator.cc View File

@@ -86,6 +86,11 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::set<Node
return false; return false;
} }


if (func_node->GetOpDesc() != nullptr && func_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) {
GELOGD("Graph[%s] is ffts subgraph, skip label allocator.", graph->GetName().c_str());
return true;
}

ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph(); ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph();
if (owner_graph == nullptr) { if (owner_graph == nullptr) {
REPORT_INNER_ERROR("E19999", "ComputeGraph owner not set in node:%s(%s), graph:%s", REPORT_INNER_ERROR("E19999", "ComputeGraph owner not set in node:%s(%s), graph:%s",


+ 5
- 0
ge/graph/build/logical_stream_allocator.cc View File

@@ -474,6 +474,11 @@ Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vector<Subgr
for (ge::NodePtr &node : graph->GetDirectNode()) { for (ge::NodePtr &node : graph->GetDirectNode()) {
auto op_desc = node->GetOpDesc(); auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
if (op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID)) {
op_desc->SetStreamId(kInvalidStream);
GELOGI("Ffts node %s of type %s reassign to invalid stream.", node->GetName().c_str(), node->GetType().c_str());
continue;
}
int64_t stream_id = op_desc->GetStreamId(); int64_t stream_id = op_desc->GetStreamId();
if (ops_without_label.find(op_desc) != ops_without_label.end()) { if (ops_without_label.find(op_desc) != ops_without_label.end()) {
if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) { if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) {


+ 10
- 1
ge/graph/build/stream_allocator.cc View File

@@ -432,7 +432,11 @@ Status StreamAllocator::SetActiveStreamsForSubgraphs() {


// Insert the send/recv event id to the graph // Insert the send/recv event id to the graph
Status StreamAllocator::InsertSyncEvents() { Status StreamAllocator::InsertSyncEvents() {
for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) {
auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) {
return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH);
};

for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag(), nullptr, ffts_filter)) {
// Take the adjacent points, then judge whether need to insert the event // Take the adjacent points, then judge whether need to insert the event
for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) {
for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
@@ -531,6 +535,11 @@ Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const
Status StreamAllocator::InsertEventsForSubgraph() { Status StreamAllocator::InsertEventsForSubgraph() {
for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) {
GE_CHECK_NOTNULL(subgraph); GE_CHECK_NOTNULL(subgraph);
const auto parent_node = subgraph->GetParentNode();
if (parent_node != nullptr && parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) {
GELOGD("Skip ffts subgraph, parent node is %s.", parent_node->GetName().c_str());
continue;
}
for (const auto &node : subgraph->GetDirectNode()) { for (const auto &node : subgraph->GetDirectNode()) {
auto op_desc = node->GetOpDesc(); auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);


+ 27
- 5
ge/graph/build/task_generator.cc View File

@@ -354,7 +354,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra
}; };
GE_MAKE_GUARD(release, callback); GE_MAKE_GUARD(release, callback);


for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) {
return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH);
};
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag(), nullptr, ffts_filter)) {
OpDescPtr op_desc = node->GetOpDesc(); OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
node_index++; node_index++;
@@ -380,10 +383,8 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra
GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str());
continue; continue;
} }
if (op_kernel_lib_name.empty()) {
GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str());
continue;
}
GE_CHK_BOOL_EXEC_INFO(!op_kernel_lib_name.empty(), continue,
"Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str());
auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name);
if (kernel_info_store == nullptr) { if (kernel_info_store == nullptr) {
REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s", REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s",
@@ -394,6 +395,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra
} }
GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(), GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(),
type.c_str()); type.c_str());
if (node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) {
GE_CHK_STATUS_RET(UpdateAnchorStatusForFfts(node), "[Call][UpdateAnchorStatusForFfts] node:%s(%s) failed",
name.c_str(), type.c_str());
}
// Profiling task // Profiling task
size_t task_list_size_before = task_def_list.size(); size_t task_list_size_before = task_def_list.size();
GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list));
@@ -571,7 +576,24 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info
return ret; return ret;
} }


Status TaskGenerator::UpdateAnchorStatusForFfts(const NodePtr &node) {
GELOGD("Start UpdateAnchorStatusForFfts for %s.", node->GetName().c_str());
if (!node->GetOpDesc()->GetSubgraphInstanceNames().empty()) {
for (size_t i = 0; i < node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) {
auto sub_graph = NodeUtils::GetSubgraph(*node, i);
GE_CHECK_NOTNULL(sub_graph);
GELOGD("Start update anchor status for %s.", sub_graph->GetName().c_str());
for (auto &ffts_node : sub_graph->GetDirectNode()) {
GE_CHK_STATUS_RET(UpdateAnchorStatus(ffts_node), "[Call][UpdateAnchorStatus] node:%s(%s) failed",
ffts_node->GetName().c_str(), ffts_node->GetType().c_str());
}
}
}
return SUCCESS;
}

Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) { Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) {
GELOGD("Start UpdateAnchorStatus for %s.", node->GetName().c_str());
if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) { if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "SetAllAnchorStatus fail for op:%s(%s)", REPORT_CALL_ERROR("E19999", "SetAllAnchorStatus fail for op:%s(%s)",
node->GetName().c_str(), node->GetType().c_str()); node->GetName().c_str(), node->GetType().c_str());


+ 1
- 0
ge/graph/build/task_generator.h View File

@@ -80,6 +80,7 @@ class TaskGenerator {
Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
std::vector<uint32_t> &all_reduce_nodes); std::vector<uint32_t> &all_reduce_nodes);
private: private:
Status UpdateAnchorStatusForFfts(const NodePtr &node);
Status UpdateAnchorStatus(const NodePtr &node); Status UpdateAnchorStatus(const NodePtr &node);


Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id); Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id);


+ 0
- 15
ge/graph/common/omg_util.cc View File

@@ -275,21 +275,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) {
} }


/// ///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) {
if (!force_unknown) {
return;
}

SetControlFlowGroup(node, group_index);
}

///
/// @brief Set Op _control_flow_group flag /// @brief Set Op _control_flow_group flag
/// @param [in] node /// @param [in] node
/// @param [in] group, condition group index of node. /// @param [in] group, condition group index of node.


+ 0
- 9
ge/graph/common/omg_util.h View File

@@ -126,15 +126,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size);
bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc);


/// ///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index);

///
/// @brief Set Op _control_flow_group flag /// @brief Set Op _control_flow_group flag
/// @param [in] node /// @param [in] node
/// @param [in] group, condition group index of node. /// @param [in] group, condition group index of node.


+ 160
- 51
ge/graph/load/model_manager/davinci_model.cc View File

@@ -99,6 +99,9 @@ const uint32_t kEndOfSequenceNew = 507005;
const int32_t kModelAbortNormal = 0x0704000e; const int32_t kModelAbortNormal = 0x0704000e;
const int32_t kModelAbortNormalNew = 507024; const int32_t kModelAbortNormalNew = 507024;
const uint32_t kInteval = 2; const uint32_t kInteval = 2;
const uint32_t kFftsTbeHandleElementSize = 2;
const uint32_t kNonTailBlock = 0;
const uint32_t kTailBlock = 1;
const char *const kModelName = "model_name"; const char *const kModelName = "model_name";
const char *const kModeleId = "model_id"; const char *const kModeleId = "model_id";
const char *const kLoadStartTime = "load_start_time"; const char *const kLoadStartTime = "load_start_time";
@@ -116,14 +119,15 @@ const char *const kWorkSpaceSize = "workspace_size";
const char *const kTotalSize = "total_size"; const char *const kTotalSize = "total_size";
const char *const kTaskCount = "task_count"; const char *const kTaskCount = "task_count";
const char *const kTaskId = "task_id"; const char *const kTaskId = "task_id";
const char* const kRequestId = "request_id";
const char* const kThreadId = "thread_id";
const char* const kInputBeginTime = "input_begin_time";
const char* const kInputEndTime = "input_end_time";
const char* const kInferBeginTime = "infer_begin_time";
const char* const kInferEndTime = "infer_end_time";
const char* const kOutputBeginTime = "output_start_time";
const char* const kOutputEndTime = "output_end_time";
const char *const kRequestId = "request_id";
const char *const kThreadId = "thread_id";
const char *const kInputBeginTime = "input_begin_time";
const char *const kInputEndTime = "input_end_time";
const char *const kInferBeginTime = "infer_begin_time";
const char *const kInferEndTime = "infer_end_time";
const char *const kOutputBeginTime = "output_start_time";
const char *const kOutputEndTime = "output_end_time";
const char *const kStubFuncName = "_register_stub_func";
const uint32_t kStringHeadElems = 2; const uint32_t kStringHeadElems = 2;
const uint32_t kPlacementHostData = 0; const uint32_t kPlacementHostData = 0;
const size_t kAlignment = 64; const size_t kAlignment = 64;
@@ -902,10 +906,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) {
SetLabelForDynamic(node); SetLabelForDynamic(node);
auto it = op_desc_handle.find(op_desc->GetType()); auto it = op_desc_handle.find(op_desc->GetType());
if (it != op_desc_handle.end()) { if (it != op_desc_handle.end()) {
if ((this->*it->second)(op_desc) != SUCCESS) {
GELOGE(PARAM_INVALID, "[Init][Node] failed, Name:%s", op_desc->GetName().c_str());
return PARAM_INVALID;
}
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((this->*it->second)(op_desc) != SUCCESS, return PARAM_INVALID,
"[Init][Node] failed, Name:%s", op_desc->GetName().c_str());
continue; continue;
} }


@@ -935,7 +937,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) {


GE_TIMESTAMP_RESTART(InitTbeHandle); GE_TIMESTAMP_RESTART(InitTbeHandle);
if (IsTbeTask(op_desc)) { if (IsTbeTask(op_desc)) {
Status status = InitTbeHandle(op_desc);
Status status =
op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID) ? InitTbeHandleWithFfts(op_desc) : InitTbeHandle(op_desc);
if (status != SUCCESS) { if (status != SUCCESS) {
GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str()); GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str());
return status; return status;
@@ -3463,11 +3466,11 @@ bool DavinciModel::CheckUserAndModelSize(const int64_t &size, const int64_t &op_
} }
// The input and model input size can not be exactly equal because user input is not definite. // The input and model input size can not be exactly equal because user input is not definite.
if ((size + kDataMemAlignSizeCompare) < op_size) { if ((size + kDataMemAlignSizeCompare) < op_size) {
REPORT_INNER_ERROR("E19999", "%s size:%ld from user add align:%u < input_op_size:%ld in model, model_id:%u, "
REPORT_INNER_ERROR("E19999", "%s size:%ld from user add align:%u < op_size:%ld in model, model_id:%u, "
"check invalid", "check invalid",
input_or_output.c_str(), size, kDataMemAlignSizeCompare, op_size, model_id_); input_or_output.c_str(), size, kDataMemAlignSizeCompare, op_size, model_id_);
GELOGE(ACL_ERROR_GE_PARAM_INVALID, GELOGE(ACL_ERROR_GE_PARAM_INVALID,
"[Check][Param] %s size:%ld from user add align:%u < input_op_size:%ld in model, model_id:%u",
"[Check][Param] %s size:%ld from user add align:%u < op_size:%ld in model, model_id:%u",
input_or_output.c_str(), size, kDataMemAlignSizeCompare, op_size, model_id_); input_or_output.c_str(), size, kDataMemAlignSizeCompare, op_size, model_id_);
return false; return false;
} }
@@ -3700,6 +3703,7 @@ Status DavinciModel::InitConstant(const OpDescPtr &op_desc) {
/// @return Status /// @return Status
/// ///
Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) {
string bin_file = op_desc->GetName();
auto kernel = ge_model_->GetTBEKernelStore().FindKernel(op_desc->GetName()); auto kernel = ge_model_->GetTBEKernelStore().FindKernel(op_desc->GetName());
auto tbe_kernel = (kernel != nullptr) ? kernel : op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); auto tbe_kernel = (kernel != nullptr) ? kernel : op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
if (tbe_kernel == nullptr) { if (tbe_kernel == nullptr) {
@@ -3708,12 +3712,61 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) {
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str()); GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file, tbe_kernel, false), "Function register of bin file: %s failed",
bin_file.c_str());
return SUCCESS;
}


std::string session_graph_model_id;
GetUniqueId(op_desc, session_graph_model_id);
const char *bin_file_key = GetRegisterStub(op_desc->GetName(), session_graph_model_id); // from set, always valid.
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance();
Status DavinciModel::InitTbeHandleWithFfts(const OpDescPtr &op_desc) {
std::vector<OpKernelBinPtr> tbe_kernel;
tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_THREAD_TBE_KERNEL, tbe_kernel);
GELOGD("Kernel bin ptr vec size is %zu.", tbe_kernel.size());
if (tbe_kernel.size() != kFftsTbeHandleElementSize) {
REPORT_INNER_ERROR("E19999", "Get tbe_kernel for op:%s(%s) fail, model_id:%u",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file, size is %zu when ffts",
op_desc->GetName().c_str(), tbe_kernel.size());
return INTERNAL_ERROR;
}
if (tbe_kernel[0] == nullptr || tbe_kernel[1] == nullptr) {
REPORT_INNER_ERROR("E19999", "Tbe kernel for op:%s is nullptr.", op_desc->GetName().c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: tvm bin file of %s is nullptr when ffts.", op_desc->GetName().c_str());
return INTERNAL_ERROR;
}
vector<string> bin_file_keys;
(void)AttrUtils::GetListStr(op_desc, kStubFuncName, bin_file_keys);
if (bin_file_keys.size() != kFftsTbeHandleElementSize) {
REPORT_INNER_ERROR("E19999", "Get bin_file for op:%s(%s) fail.", op_desc->GetName().c_str(),
op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find bin file keys, size is %zu when ffts",
op_desc->GetName().c_str(), bin_file_keys.size());
return INTERNAL_ERROR;
}
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kNonTailBlock], tbe_kernel[kNonTailBlock], true,
kNonTailBlock),
"Function register of first bin file %s failed.", bin_file_keys[kNonTailBlock].c_str());
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kTailBlock], tbe_kernel[kTailBlock], true, kTailBlock),
"Function register of second bin file %s failed.", bin_file_keys[kTailBlock].c_str());
return SUCCESS;
}


Status DavinciModel::FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel,
bool is_ffts, size_t thread_index) {
if (thread_index > 1) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Thread index: %zu should less than 1.", thread_index);
return INTERNAL_ERROR;
}
const char *bin_file_key;
if (is_ffts) {
bin_file_key = GetRegisterStub(bin_file, "");
GELOGI("Node:%s inherit func name:%s directly.", op_desc->GetName().c_str(), bin_file_key);
} else {
std::string session_graph_model_id;
GetUniqueId(op_desc, session_graph_model_id);
bin_file_key = GetRegisterStub(bin_file, session_graph_model_id); // from set, always valid.
}

TBEHandleStore &kernel_store = TBEHandleStore::GetInstance();
std::lock_guard<std::mutex> lock(tvm_bin_mutex_); std::lock_guard<std::mutex> lock(tvm_bin_mutex_);
if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) { if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) {
void *bin_handle = nullptr; void *bin_handle = nullptr;
@@ -3721,59 +3774,115 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) {
GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key); GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key);


rtDevBinary_t binary; rtDevBinary_t binary;
std::string json_string;
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_MAGIC, json_string),
GELOGD("Get original type of session_graph_id."));
if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICPU") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICPU;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE;
} else {
REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
return PARAM_INVALID;
}

GE_CHK_STATUS_RET(InitBinaryMagic(op_desc, is_ffts, thread_index, binary), "Init binary magic of %s failed.",
op_desc->GetName().c_str());
binary.version = 0; binary.version = 0;
binary.data = tbe_kernel->GetBinData(); binary.data = tbe_kernel->GetBinData();
binary.length = tbe_kernel->GetBinDataSize(); binary.length = tbe_kernel->GetBinDataSize();

GELOGD("TBE: binary.length: %lu", binary.length); GELOGD("TBE: binary.length: %lu", binary.length);
GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle)); GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle));


std::string meta_data;
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_METADATA, meta_data),
GELOGI("Get original type of json_string"));
GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str());
GE_IF_BOOL_EXEC(!meta_data.empty(), GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str())));

GE_CHK_STATUS_RET(InitMetaData(op_desc, is_ffts, thread_index, bin_handle), "Init tvm meta data of %s failed.",
op_desc->GetName().c_str());
kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel); kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel);
} else { } else {
GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key); GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key);
kernel_store.ReferTBEHandle(bin_file_key); kernel_store.ReferTBEHandle(bin_file_key);
} }

std::string kernel_name; std::string kernel_name;
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name),
GELOGD("Get original type of kernel_name"));
GE_CHK_STATUS_RET(InitKernelName(op_desc, is_ffts, thread_index, kernel_name), "Init kernel name of %s failed.",
op_desc->GetName().c_str());
GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0)); GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0));
used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1. used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1.
return SUCCESS; return SUCCESS;
} }

// Kernel registed, Increase used num in store. // Kernel registed, Increase used num in store.
StoreTbeHandle(bin_file_key); StoreTbeHandle(bin_file_key);
return SUCCESS; return SUCCESS;
} }


Status DavinciModel::InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index,
rtDevBinary_t &binary) {
string json_string;
const string &tvm_magic = is_ffts ? TVM_ATTR_NAME_THREAD_MAGIC : TVM_ATTR_NAME_MAGIC;
const static std::map<std::string, uint32_t> binary_magics = {
{"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU},
{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF},
{"RT_DEV_BINARY_MAGIC_ELF_AIVEC", RT_DEV_BINARY_MAGIC_ELF_AIVEC},
{"RT_DEV_BINARY_MAGIC_ELF_AICUBE", RT_DEV_BINARY_MAGIC_ELF_AICUBE}
};
if (is_ffts) {
vector<string> json_list;
(void)AttrUtils::GetListStr(op_desc, tvm_magic, json_list);
if (json_list.size() != kFftsTbeHandleElementSize) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Attr is %s, thread index is %zu, json list size is %zu.",
tvm_magic.c_str(), thread_index, json_list.size());
return INTERNAL_ERROR;
}
json_string = json_list[thread_index];
} else {
(void)AttrUtils::GetStr(op_desc, tvm_magic, json_string);
}
auto iter = binary_magics.find(json_string);
if (iter == binary_magics.end()) {
REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
tvm_magic.c_str(), json_string.c_str(), op_desc->GetName().c_str(),
op_desc->GetType().c_str(), model_id_);
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
return PARAM_INVALID;
}
binary.magic = iter->second;
return SUCCESS;
}

Status DavinciModel::InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle) {
string meta_data;
const string &tvm_metadata = is_ffts ? TVM_ATTR_NAME_THREAD_METADATA : TVM_ATTR_NAME_METADATA;
if (is_ffts) {
vector<string> meta_data_list;
(void)AttrUtils::GetListStr(op_desc, tvm_metadata, meta_data_list);
if (meta_data_list.size() != kFftsTbeHandleElementSize) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, meta data list size is %zu.",
tvm_metadata.c_str(), thread_index, meta_data_list.size());
return INTERNAL_ERROR;
}
meta_data = meta_data_list[thread_index];
} else {
(void)AttrUtils::GetStr(op_desc, tvm_metadata, meta_data);
}
GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str());
if (!meta_data.empty()) {
GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str()));
}
return SUCCESS;
}

Status DavinciModel::InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name) {
if (is_ffts) {
// delete prefix, eg: *sgt_graph_nodes*/loss_scale/gradient/fp32_vals/Mean_grad/Tile
vector<string> kernel_name_list;
auto pos = op_desc->GetName().find("/");
if (pos == std::string::npos) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, subgraph node name: %s.", op_desc->GetName().c_str());
return INTERNAL_ERROR;
}
string attr_kernel_name = op_desc->GetName().substr(pos + 1) + "_thread_kernelname";
(void)AttrUtils::GetListStr(op_desc, attr_kernel_name, kernel_name_list);
if (kernel_name_list.size() != kFftsTbeHandleElementSize) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, kernel name list size is %zu.",
attr_kernel_name.c_str(), thread_index, kernel_name_list.size());
return INTERNAL_ERROR;
}
kernel_name = kernel_name_list[thread_index];
} else {
string attr_kernel_name = op_desc->GetName() + "_kernelname";
(void)AttrUtils::GetStr(op_desc, attr_kernel_name, kernel_name);
}
return SUCCESS;
}

void DavinciModel::StoreTbeHandle(const std::string &handle_key) { void DavinciModel::StoreTbeHandle(const std::string &handle_key) {
// Online mode FE may call rtFunctionRegister. // Online mode FE may call rtFunctionRegister.
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); TBEHandleStore &kernel_store = TBEHandleStore::GetInstance();


+ 6
- 0
ge/graph/load/model_manager/davinci_model.h View File

@@ -771,6 +771,12 @@ class DavinciModel {
/// @return Status /// @return Status
/// ///
Status InitTbeHandle(const OpDescPtr &op_desc); Status InitTbeHandle(const OpDescPtr &op_desc);
Status InitTbeHandleWithFfts(const OpDescPtr &op_desc);
Status FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, bool is_ffts,
size_t thread_index = 0);
Status InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, rtDevBinary_t &binary);
Status InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle);
Status InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name);


void StoreTbeHandle(const string &handle_key); void StoreTbeHandle(const string &handle_key);
void CleanTbeHandle(); void CleanTbeHandle();


+ 393
- 0
ge/graph/load/model_manager/task_info/ffts_task_info.cc View File

@@ -0,0 +1,393 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/load/model_manager/task_info/ffts_task_info.h"

#include <vector>

#include "graph/load/model_manager/davinci_model.h"

namespace {
constexpr uint32_t kAddrLen = sizeof(void *);
}
namespace ge {
FftsTaskInfo::~FftsTaskInfo() {
GE_FREE_RT_LOG(args_);
}

Status FftsTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) {
GELOGI("FftsTaskInfo Init Start.");
GE_CHECK_NOTNULL(davinci_model);
davinci_model_ = davinci_model;
GE_CHK_STATUS_RET_NOLOG(SetStream(task_def.stream_id(), davinci_model_->GetStreamList()));

const domi::FftsTaskDef &ffts_task_def = task_def.ffts_task();
OpDescPtr op_desc = davinci_model_->GetOpByIndex(ffts_task_def.op_index());
GE_CHECK_NOTNULL(op_desc);

if ((ffts_task_def.sub_task_size() > static_cast<int>(RT_FFTS_MAX_SUB_TASK_NUM)) ||
(ffts_task_def.ticket_cache_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_NUM))) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Node: %s, sub task desc size: %d, ticket cache size: %d",
op_desc->GetName().c_str(), ffts_task_def.sub_task_size(), ffts_task_def.ticket_cache_size());
return INTERNAL_ERROR;
}
args_size_ = kAddrLen * ffts_task_def.addr_size();
GE_CHK_RT_RET(rtMalloc(&args_, args_size_, RT_MEMORY_HBM));
InitFftsDescInfo(ffts_task_def.ffts_desc(), sub_task_info_.fftsDesc);

sub_task_info_.fftsType = static_cast<rtFftsType_t>(ffts_task_def.ffts_type());
sub_task_info_.subTaskNum = ffts_task_def.sub_task_size();
for (int idx = 0; idx < ffts_task_def.sub_task_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitSubTaskInfo(ffts_task_def.sub_task(idx), sub_task_info_.subTask[idx]));
}

sub_task_info_.tickCacheNum = ffts_task_def.ticket_cache_size();
for (int idx = 0; idx < ffts_task_def.ticket_cache_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitTicketCache(ffts_task_def.ticket_cache(idx), sub_task_info_.ticketCache[idx]));
}

size_t data_size = kAddrLen * io_addrs_.size();
GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs_.data(), data_size, RT_MEMCPY_HOST_TO_DEVICE));
GELOGI("FftsTaskInfo::Init Success. Node: %s, input/output size: %zu", op_desc->GetName().c_str(), io_addrs_.size());
return SUCCESS;
}

void FftsTaskInfo::InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc) {
ffts_desc.tm = static_cast<uint8_t>(ffts_desc_def.tm());
ffts_desc.di = static_cast<uint8_t>(ffts_desc_def.di());
ffts_desc.dw = static_cast<uint8_t>(ffts_desc_def.dw());
ffts_desc.df = static_cast<uint8_t>(ffts_desc_def.df());
ffts_desc.dataSplitUnit = static_cast<uint8_t>(ffts_desc_def.data_split_unit());
ffts_desc.prefetchOstNum = static_cast<uint8_t>(ffts_desc_def.prefetch_ost_num());
ffts_desc.cacheMaintainOstNum = static_cast<uint8_t>(ffts_desc_def.cache_maintain_ost_num());
ffts_desc.aicPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_upper());
ffts_desc.aicPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_lower());
ffts_desc.aivPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_upper());
ffts_desc.aivPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_lower());
}

Status FftsTaskInfo::InitSubTaskInfo(const domi::FftsSubTaskDef &sub_task_def, rtFftsSubTaskInfo_t &sub_task_desc) {
if ((sub_task_def.dst_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) ||
(sub_task_def.src_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) {
GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, dst tick cache id size: %d, src tick cache id size: %d",
sub_task_def.dst_tick_cache_id_size(), sub_task_def.src_tick_cache_id_size());
return FAILED;
}

if (sub_task_def.has_auto_thread_aic_aiv() == sub_task_def.has_manual_thread_aic_aiv()) {
GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, auto thread aic/aiv: %d, manual thread aic/aiv: %d",
sub_task_def.has_auto_thread_aic_aiv(), sub_task_def.has_manual_thread_aic_aiv());
return FAILED;
}

thread_dim_ = sub_task_def.thread_dim();
GE_CHK_BOOL_RET_STATUS(thread_dim_ != 0, FAILED, "[Get][thread_dim] failed, Invalid thread dim: %u!", thread_dim_);

sub_task_desc.subTaskType = static_cast<rtFftsSubTaskType_t>(sub_task_def.sub_task_type());
sub_task_desc.threadDim = sub_task_def.thread_dim();

sub_task_desc.dstTickCacheVldBitmap = sub_task_def.dst_tick_cache_vld_bitmap();
sub_task_desc.srcTickCacheVldBitmap = sub_task_def.src_tick_cache_vld_bitmap();
sub_task_desc.srcDataOutOfSubGraphBitmap = sub_task_def.src_data_out_of_subgraph_bitmap();

for (int idx = 0; idx < sub_task_def.dst_tick_cache_id_size(); ++idx) {
sub_task_desc.dstTickCacheID[idx] = sub_task_def.dst_tick_cache_id(idx);
}

for (int idx = 0; idx < sub_task_def.src_tick_cache_id_size(); ++idx) {
sub_task_desc.srcTickCacheID[idx] = sub_task_def.src_tick_cache_id(idx);
}

if (sub_task_def.has_auto_thread_aic_aiv()) {
GE_CHK_STATUS_RET_NOLOG(InitAutoAicAiv(sub_task_def.auto_thread_aic_aiv(), sub_task_desc.custom.autoThreadAicAiv));
}

if (sub_task_def.has_manual_thread_aic_aiv()) {
GE_CHK_STATUS_RET_NOLOG(
InitManualAicAiv(sub_task_def.manual_thread_aic_aiv(), sub_task_desc.custom.manualThreadAicAiv));
}

if (sub_task_def.has_manual_thread_nop()) {
GE_CHK_STATUS_RET_NOLOG(InitManualNop(sub_task_def.manual_thread_nop(), sub_task_desc.custom.manualThreadNop));
}

return SUCCESS;
}

Status FftsTaskInfo::InitTicketCache(const domi::TicketCacheDef &ticket_cache_def, rtTicketCache_t &ticket_cache) {
if (ticket_cache_def.has_auto_thread_cache() == ticket_cache_def.has_manual_thread_cache()) {
GELOGE(FAILED, "[Check][Param] Invalid TicketCacheDef, has auto thread cache: %d, has manual thread cache: %d",
ticket_cache_def.has_auto_thread_cache(), ticket_cache_def.has_manual_thread_cache());
return FAILED;
}

ticket_cache.cacheOption = static_cast<rtCacheOp_t>(ticket_cache_def.cache_option());
ticket_cache.ticketCacheWindow = ticket_cache_def.ticket_cache_window();

if (ticket_cache_def.has_auto_thread_cache()) {
InitAutoCacheInfo(ticket_cache_def.auto_thread_cache(), ticket_cache.custom.autoThreadCache);
}
if (ticket_cache_def.has_manual_thread_cache()) {
GE_CHK_STATUS_RET_NOLOG(
InitManualCacheInfo(ticket_cache_def.manual_thread_cache(), ticket_cache.custom.manualThreadCache));
}

return SUCCESS;
}

// task_addr = {0,200,700,1000,2000, 3500}
// task_addr_offset = {20,40,2,100,200}
template <typename T>
Status FftsTaskInfo::InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim,
uint32_t addr_count) {
for (uint32_t i = 0; i < addr_count; ++i) {
uintptr_t logic_addr = aic_aiv_def.task_addr(i) + thread_dim * aic_aiv_def.task_addr_offset(i);
uint8_t *io_addr = nullptr;
if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress]GetRtAddress failed.");
return INTERNAL_ERROR;
}
GELOGD("aic_aiv_def task base addr is %ld, offset is %ld, thread is %d, logic addrs is 0x%lx, io addr is %p",
aic_aiv_def.task_addr(i), aic_aiv_def.task_addr_offset(i), thread_dim, logic_addr, io_addr);
io_addrs_.emplace_back(io_addr);
}
return SUCCESS;
}

Status FftsTaskInfo::InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv) {
if (aic_aiv_def.src_prefetch_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) {
GELOGE(FAILED, "[Check][Param] Invalid AutoThreadAicAivInfo, prefetch size: %d", aic_aiv_def.src_prefetch_size());
return FAILED;
}

aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size();
GELOGD("AutoThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr);
const auto &rts_param = davinci_model_->GetRuntimeParam();
for (uint32_t i = 0; i < thread_dim_ - 1; ++i) {
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i,
static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size())));
}
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count()));
int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size();
for (int k = 0; k < last_thread_workspace_size; ++k) {
uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k);
uint8_t *io_addr = nullptr;
GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr));
GELOGD("logic addr is 0x%lx, io addr is %p.", logic_addr, io_addr);
io_addrs_.emplace_back(io_addr);
}

aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset();
GELOGD("args_: %p, io_addrs size: %zu, task param offset: %u.", args_, io_addrs_.size(), aic_aiv.taskParamOffset);
aic_aiv.satMode = aic_aiv_def.sat_mode();
aic_aiv.scheduleMode = aic_aiv_def.schedule_mode();
aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt();

aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap();
aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap();

aic_aiv.tailBlkDim = aic_aiv_def.tail_blk_dim();
aic_aiv.nonTailBlkDim = aic_aiv_def.non_tail_blk_dim();

aic_aiv.nonTailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.non_tail_task_func_stub(), "");
aic_aiv.tailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.tail_task_func_stub(), "");

GELOGI("Set func name[%s][%s] succ.", aic_aiv.nonTailTaskFuncStub, aic_aiv.tailTaskFuncStub);
for (int idx = 0; idx < aic_aiv_def.src_prefetch_size(); ++idx) {
InitAutoPrefetch(aic_aiv_def.src_prefetch(idx), aic_aiv.srcPrefetch[idx]);
}

return SUCCESS;
}

void FftsTaskInfo::InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache) {
cache.dataAddr = cache_def.data_addr();
cache.dataAddrOffset = cache_def.data_addr_offset();
cache.nonTailDataLen = cache_def.non_tail_data_len();
cache.tailDataLen = cache_def.tail_data_len();
cache.ticketCacheRefCnt = cache_def.ticket_cache_ref_cnt();
}

void FftsTaskInfo::InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch) {
prefetch.dataAddr = prefetch_def.data_addr();
prefetch.dataAddrOffset = prefetch_def.data_addr_offset();
prefetch.nonTailDataLen = prefetch_def.non_tail_data_len();
prefetch.tailDataLen = prefetch_def.tail_data_len();
}

Status FftsTaskInfo::InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def,
rtManualThreadAicAivInfo_t &aic_aiv) {
if ((aic_aiv_def.thread_prefetch_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(aic_aiv_def.thread_blk_dim_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(aic_aiv_def.thread_task_func_stub_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(aic_aiv_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadAicAivInfo, thread prefetch dmu desc size: %d, "
"thread blk dim size: %d, thread task func stub size: %d, src dep tbl size: %d",
aic_aiv_def.thread_prefetch_dmu_idx_size(), aic_aiv_def.thread_blk_dim_size(),
aic_aiv_def.thread_task_func_stub_size(), aic_aiv_def.src_dep_tbl_size());
return FAILED;
}
aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size();
GELOGD("ManualThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr);
const auto &rts_param = davinci_model_->GetRuntimeParam();
for (uint32_t i = 0; i < thread_dim_ - 1; ++i) {
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i,
static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size())));
}
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count()));
int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size();
for (int k = 0; k < last_thread_workspace_size; ++k) {
uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k);
uint8_t *io_addr = nullptr;
GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr));
io_addrs_.emplace_back(io_addr);
}
aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset();

aic_aiv.satMode = aic_aiv_def.sat_mode();
aic_aiv.scheduleMode = aic_aiv_def.schedule_mode();
aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt();

aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap(); // 8 bit bitmap 1 0 1 0
aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap(); // 8 bit bitmap 1 0 1 0
aic_aiv.prefetchOnceDmuNum = aic_aiv_def.prefetch_once_dmu_num();

for (int idx = 0; idx < aic_aiv_def.thread_prefetch_dmu_idx_size(); ++idx) {
aic_aiv.threadPrefetchDmuIdx[idx] = aic_aiv_def.thread_prefetch_dmu_idx(idx);
}
for (int idx = 0; idx < aic_aiv_def.thread_blk_dim_size(); ++idx) {
aic_aiv.threadBlkDim[idx] = aic_aiv_def.thread_blk_dim(idx);
}
for (int idx = 0; idx < aic_aiv_def.thread_task_func_stub_size(); ++idx) {
aic_aiv.threadTaskFuncStub[idx] = aic_aiv_def.thread_task_func_stub(idx).c_str();
}

InitManualDmuInfo(aic_aiv_def, aic_aiv.prefetchList);
for (int idx = 0; idx < aic_aiv_def.src_dep_tbl_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitManualDependency(aic_aiv_def.src_dep_tbl(idx), aic_aiv.srcDepTbl[idx]));
}

return SUCCESS;
}

Status FftsTaskInfo::InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def,
rtManualThreadCacheInfo_t &cache_info) {
if ((cache_def.slice_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(cache_def.ticket_cache_ref_cnt_tbl_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM))) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadCacheInfo slice dum desc index %d, ticket cache ref cnt %d",
cache_def.slice_dmu_idx_size(), cache_def.ticket_cache_ref_cnt_tbl_size());
return FAILED;
}

InitManualDmuInfo(cache_def, cache_info.dmuList);
for (int idx = 0; idx < cache_def.slice_dmu_idx_size(); ++idx) {
cache_info.sliceDmuIdx[idx] = cache_def.slice_dmu_idx(idx);
}

for (int idx = 0; idx < cache_def.ticket_cache_ref_cnt_tbl_size(); ++idx) {
cache_info.ticketCacheRefCntTbl[idx] = cache_def.ticket_cache_ref_cnt_tbl(idx);
}

return SUCCESS;
}

Status FftsTaskInfo::InitManualDependency(const domi::ManualThreadDependencyDef &dependency_def,
rtManualThreadDependency_t &dependency) {
if (dependency_def.dependency_size() > static_cast<int>(RT_FFTS_MANUAL_SRC_DEPEND_TBL_LEN)) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadDependency size: %d", dependency_def.dependency_size());
return FAILED;
}

for (int idx = 0; idx < dependency_def.dependency_size(); ++idx) {
dependency.dependency[idx] = dependency_def.dependency(idx);
}

return SUCCESS;
}

Status FftsTaskInfo::InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop_info) {
if (nop_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadNopInfo, src dep tbl size: %d", nop_def.src_dep_tbl_size());
return FAILED;
}

for (int idx = 0; idx < nop_def.src_dep_tbl_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitManualDependency(nop_def.src_dep_tbl(idx), nop_info.srcDepTbl[idx]));
}

return SUCCESS;
}

void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu) {
if (aic_aiv_def.prefetch_list().empty()) {
return;
}

std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * aic_aiv_def.prefetch_list_size());
dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data());
for (int idx = 0; idx < aic_aiv_def.prefetch_list_size(); ++idx) {
InitManualDmuInfo(aic_aiv_def.prefetch_list(idx), dmu[idx]);
}
}

void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu) {
if (cache_def.dmu_list().empty()) {
return;
}

std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * cache_def.dmu_list_size());
dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data());
for (int idx = 0; idx < cache_def.dmu_list_size(); ++idx) {
InitManualDmuInfo(cache_def.dmu_list(idx), dmu[idx]);
}
}

void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu) {
dmu.dataAddr = dmu_def.data_addr();
dmu.numOuter = dmu_def.num_outer();
dmu.numInner = dmu_def.num_inner();
dmu.strideOuter = dmu_def.stride_outer();
dmu.lenInner = dmu_def.len_inner();
dmu.strideInner = dmu_def.stride_inner();
}

Status FftsTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) {
return SUCCESS;
}

Status FftsTaskInfo::UpdateArgs() {
GE_CHECK_NOTNULL(davinci_model_);
std::vector<void *> io_addrs = io_addrs_;
davinci_model_->UpdateKnownZeroCopyAddr(io_addrs);
auto addr_size = kAddrLen * io_addrs.size();
GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs.data(), addr_size, RT_MEMCPY_HOST_TO_DEVICE));
return SUCCESS;
}

Status FftsTaskInfo::Distribute() {
GELOGI("FftsTaskInfo Distribute Start.");
rtError_t rt_ret = rtFftsTaskLaunch(&sub_task_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "[Check][RT_ret] Call rtFftsTaskLaunch failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

GELOGI("FftsTaskInfo Distribute Success.");
return SUCCESS;
}

REGISTER_TASK_INFO(RT_MODEL_TASK_FFTS_TASK, FftsTaskInfo);
} // namespace ge

+ 66
- 0
ge/graph/load/model_manager/task_info/ffts_task_info.h View File

@@ -0,0 +1,66 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_

#include "graph/load/model_manager/task_info/task_info.h"
#include "graph/op_desc.h"

namespace ge {
class FftsTaskInfo : public TaskInfo {
public:
FftsTaskInfo() = default;
~FftsTaskInfo() override;

Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override;

Status Distribute() override;

Status UpdateArgs() override;

Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override;

private:
void InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc);
Status InitSubTaskInfo(const domi::FftsSubTaskDef &task_def, rtFftsSubTaskInfo_t &task);
Status InitTicketCache(const domi::TicketCacheDef &cache_def, rtTicketCache_t &cache);

Status InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv);
void InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache);
void InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch);

Status InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadAicAivInfo_t &aic_aiv);
Status InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadCacheInfo_t &cache);
Status InitManualDependency(const domi::ManualThreadDependencyDef &depend_def, rtManualThreadDependency_t &depend);
Status InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop);

void InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu);
void InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu);
void InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu);

template<typename T>
Status InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim, uint32_t addr_count);

DavinciModel *davinci_model_{nullptr};
rtFftsTaskInfo_t sub_task_info_;
std::vector<void *> io_addrs_;
uint32_t thread_dim_{0};
void *args_{nullptr}; // runtime args memory
uint32_t args_size_{0}; // runtime args memory length
};
} // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_

+ 8
- 1
ge/graph/manager/graph_manager.cc View File

@@ -27,6 +27,7 @@
#include "common/math/math_util.h" #include "common/math/math_util.h"
#include "common/thread_pool.h" #include "common/thread_pool.h"
#include "common/dump/dump_manager.h" #include "common/dump/dump_manager.h"
#include "ge_opt_info/ge_opt_info.h"
#include "analyzer/analyzer.h" #include "analyzer/analyzer.h"
#include "graph/common/ge_call_wrapper.h" #include "graph/common/ge_call_wrapper.h"
#include "graph/common/local_context.h" #include "graph/common/local_context.h"
@@ -949,7 +950,7 @@ Status GraphManager::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint


rtError_t rt_ret = rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId()); rtError_t rt_ret = rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtCtxCreate faileded, session_id:%lu, graph_id:%u, mode:%d",
REPORT_CALL_ERROR("E19999", "Call rtCtxCreate failed, session_id:%lu, graph_id:%u, mode:%d",
session_id, graph_id, mode); session_id, graph_id, mode);
GELOGE(FAILED, "[Call][RtCtxCreate] faileded, session_id:%lu, graph_id:%u, mode:%d", session_id, graph_id, mode); GELOGE(FAILED, "[Call][RtCtxCreate] faileded, session_id:%lu, graph_id:%u, mode:%d", session_id, graph_id, mode);
return FAILED; return FAILED;
@@ -1001,6 +1002,12 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<Ge
return ret; return ret;
} }


ret = GeOptInfo::SetOptInfo();
if (ret != SUCCESS) {
GELOGE(ret, "[Set][OptInfo] Set optional information failed.");
return ret;
}

/// 1. BUILD_MODE_TUNING with BUILD_STEP_AFTER_UB_MATCH no need PreRunOptimizeOriginalGraph; /// 1. BUILD_MODE_TUNING with BUILD_STEP_AFTER_UB_MATCH no need PreRunOptimizeOriginalGraph;
/// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeOriginalGraph. /// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeOriginalGraph.
/// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunOptimizeOriginalGraph. /// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunOptimizeOriginalGraph.


+ 0
- 2
ge/graph/optimize/graph_optimize.cc View File

@@ -336,10 +336,8 @@ Status GraphOptimize::OptimizeAfterStage1(ComputeGraphPtr &compute_graph) {
GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str()); GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str());
continue; continue;
} }
#ifndef ONLY_COMPILE_OPEN_SRC
GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str()); GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str());
ret = (iter->second)->OptimizeAfterStage1(*compute_graph); ret = (iter->second)->OptimizeAfterStage1(*compute_graph);
#endif
if (ret != SUCCESS) { if (ret != SUCCESS) {
REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, " REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, "
"graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str()); "graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str());


+ 22
- 1
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -364,6 +364,7 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) {
} }


void DynamicShapePartitioner::MergeClustersControlFlow() { void DynamicShapePartitioner::MergeClustersControlFlow() {
std::unordered_set<ClusterPtr> all_merged_clusters;
for (const auto &item : control_clusters_) { for (const auto &item : control_clusters_) {
const auto &control_cluster = item.second; const auto &control_cluster = item.second;
auto rit = control_cluster.rbegin(); auto rit = control_cluster.rbegin();
@@ -373,17 +374,32 @@ void DynamicShapePartitioner::MergeClustersControlFlow() {
} }


const auto &cluster = *rit; const auto &cluster = *rit;
if (all_merged_clusters.count(cluster) > 0) {
continue;
}

bool is_unknown_cluster = cluster->IsUnknownShape();
for (++rit; rit != control_cluster.rend(); ++rit) { for (++rit; rit != control_cluster.rend(); ++rit) {
const auto &cluster_from = *rit; const auto &cluster_from = *rit;
if (all_merged_clusters.count(cluster_from) > 0) {
continue;
}

auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); auto merged_clusters = cluster->MergeAllPathFrom(cluster_from);
GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(),
ToString(merged_clusters).c_str()); ToString(merged_clusters).c_str());
for (const auto &merged_cluster : merged_clusters) { for (const auto &merged_cluster : merged_clusters) {
all_merged_clusters.emplace(merged_cluster);
for (const auto &node : merged_cluster->Nodes()) { for (const auto &node : merged_cluster->Nodes()) {
node_2_cluster_[node] = cluster; node_2_cluster_[node] = cluster;
} }
} }
} }

if (!is_unknown_cluster && cluster->IsUnknownShape()) {
GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str());
ordered_cluster_.push_back(cluster);
}
} }
} }


@@ -703,7 +719,12 @@ void Cluster::Merge(ClusterPtr other) {
if (other->min_ < min_) { if (other->min_ < min_) {
min_ = other->min_; min_ = other->min_;
} }
};

if (!IsUnknownShape() && other->IsUnknownShape()) {
type_ = UNKNOWN_SHAPE;
}
}

bool Cluster::TryMerge(ClusterPtr other) { bool Cluster::TryMerge(ClusterPtr other) {
std::queue<ClusterPtr> forward_reached; std::queue<ClusterPtr> forward_reached;
forward_reached.push(other); forward_reached.push(other);


+ 1
- 1
ge/graph/partition/dynamic_shape_partition.h View File

@@ -161,7 +161,7 @@ class DynamicShapePartitioner {
ge::ComputeGraphPtr root_graph_; // The original graph to partition ge::ComputeGraphPtr root_graph_; // The original graph to partition
std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to
// V1 control flow cluster, need merge to one Graph. // V1 control flow cluster, need merge to one Graph.
std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_;
std::map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_;
// topological sorted clusters, this field will change with the splitting. // topological sorted clusters, this field will change with the splitting.
// When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters
// When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters


+ 14
- 7
ge/graph/partition/graph_partition.cc View File

@@ -179,6 +179,7 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr
GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret);
} }
GE_CHECK_NOTNULL(original_compute_graph); GE_CHECK_NOTNULL(original_compute_graph);
output_merged_compute_graph->SetName(original_compute_graph->GetName());
// partition sub graph // partition sub graph
for (const auto &sub_graph : original_compute_graph->GetAllSubgraphs()) { for (const auto &sub_graph : original_compute_graph->GetAllSubgraphs()) {
ComputeGraphPtr merged_sub_graph = nullptr; ComputeGraphPtr merged_sub_graph = nullptr;
@@ -188,8 +189,16 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr
GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret);
continue; continue;
} }
// this means subgraph added in optimize subgraph and without partitions, so just add to root graph
if (merged_sub_graph == sub_graph) {
GELOGI("Just add subgraph %s (parent node is %s) to root graph %s.", sub_graph->GetName().c_str(),
sub_graph->GetParentNode()->GetName().c_str(), output_merged_compute_graph->GetName().c_str());
sub_graph->SetParentGraph(sub_graph->GetParentNode()->GetOwnerComputeGraph());
GE_IF_BOOL_EXEC(output_merged_compute_graph->AddSubgraph(sub_graph->GetName(), merged_sub_graph) != SUCCESS,
return FAILED;)
continue;
}
// add sub graph // add sub graph
output_merged_compute_graph->SetName(original_compute_graph->GetName());
merged_sub_graph->SetName(sub_graph->GetName()); merged_sub_graph->SetName(sub_graph->GetName());
merged_sub_graph->SetInputSize(sub_graph->GetInputSize()); merged_sub_graph->SetInputSize(sub_graph->GetInputSize());
merged_sub_graph->SetOutputSize(sub_graph->GetOutputSize()); merged_sub_graph->SetOutputSize(sub_graph->GetOutputSize());
@@ -245,12 +254,9 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co
} }
if ((graph_2_graph_partition_info_.find(original_compute_graph) == graph_2_graph_partition_info_.end()) || if ((graph_2_graph_partition_info_.find(original_compute_graph) == graph_2_graph_partition_info_.end()) ||
(graph_2_subgraph_list_.find(original_compute_graph) == graph_2_subgraph_list_.end())) { (graph_2_subgraph_list_.find(original_compute_graph) == graph_2_subgraph_list_.end())) {
REPORT_INNER_ERROR("E19999", "original_compute_graph:%s is not find in graph_2_graph_partition_info_.",
original_compute_graph->GetName().c_str());
GELOGE(GE_GRAPH_NULL_INPUT,
"[Check][Param] original_compute_graph:%s is not find in graph_2_graph_partition_info_.",
original_compute_graph->GetName().c_str());
return FAILED;
GELOGW("[GraphPartition]: compute_graph has not found, just return original.");
output_merged_compute_graph = original_compute_graph;
return SUCCESS;
} }
GraphPartitionInfo &subgraph_info = graph_2_graph_partition_info_[original_compute_graph]; GraphPartitionInfo &subgraph_info = graph_2_graph_partition_info_[original_compute_graph];
const auto &sub_graph_list = graph_2_subgraph_list_[original_compute_graph]; const auto &sub_graph_list = graph_2_subgraph_list_[original_compute_graph];
@@ -708,6 +714,7 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vector<ge::SubGraphInfoPtr
} }
auto &engine_name = graph_info_.partitions_.at(sub_graph); auto &engine_name = graph_info_.partitions_.at(sub_graph);
(void)AttrUtils::SetStr(sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); (void)AttrUtils::SetStr(sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName());
(void)sub_graph->SetExtAttr("part_src_graph", compute_graph);
GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(),
compute_graph->GetName().c_str()); compute_graph->GetName().c_str());
GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]); GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]);


+ 8
- 30
ge/graph/passes/mark_force_unknown_for_cond_pass.cc View File

@@ -132,39 +132,17 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std:
/// @return /// @return
/// ///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) {
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP);
};

for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) {
const auto &op_node1 = it1->first;
const auto &op_desc1 = op_node1->GetOpDesc();
if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) {
const auto &op_node = it->first;
const auto &op_desc = op_node->GetOpDesc();
if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue; continue;
} }


if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) {
int64_t group_index = op_desc1->GetId();
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index);
MarkForceUnknownShape(op_node1, true, group_index);
for (const auto &n : it1->second) {
MarkForceUnknownShape(n, true, group_index);
}

for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) {
const auto &op_node2 = it2->first;
const auto &op_desc2 = op_node2->GetOpDesc();
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
}

if (std::any_of(it2->second.begin(), it2->second.end(), callback)) {
MarkForceUnknownShape(op_node2, true, group_index);
for (const auto &n : it2->second) {
MarkForceUnknownShape(n, true, group_index);
}
}
}
int64_t group_index = op_desc->GetId();
SetControlFlowGroup(op_node, group_index);
for (const auto &n : it->second) {
SetControlFlowGroup(n, group_index);
} }
} }
} }


+ 6
- 0
ge/graph/passes/mark_graph_unknown_status_pass.cc View File

@@ -40,6 +40,12 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) {
} }
} }


const auto &node = graph->GetParentNode();
if (!is_unknown_shape && node != nullptr && node->GetType() == PARTITIONEDCALL) {
GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape),
"[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str());
}

for (const auto &node : graph->GetDirectNode()) { for (const auto &node : graph->GetDirectNode()) {
GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str()); GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str());
(void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape); (void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape);


+ 2
- 3
ge/graph/passes/merge_to_stream_merge_pass.cc View File

@@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
return FAILED, "[Check][Param] Param of pre node is nullptr."); return FAILED, "[Check][Param] Param of pre node is nullptr.");
int64_t group_index = -1; int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(node, force_unknown, group_index);
(void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
@@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str());
return FAILED; return FAILED;
} }
MarkForceUnknownShape(active_node, force_unknown, group_index);
SetControlFlowGroup(active_node, group_index);
} }


return SUCCESS; return SUCCESS;


+ 9
- 1
ge/graph/passes/next_iteration_pass.cc View File

@@ -284,13 +284,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
/// @return void /// @return void
/// ///
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) {
std::string node_type;
for (const auto &switch_node : loop_group.switch_nodes) { for (const auto &switch_node : loop_group.switch_nodes) {
SetControlFlowGroup(switch_node, group_index); SetControlFlowGroup(switch_node, group_index);
for (const auto &node : switch_node->GetOutDataNodes()) { for (const auto &node : switch_node->GetOutDataNodes()) {
std::string node_type;
(void)GetOriginalType(node, node_type); (void)GetOriginalType(node, node_type);
if (kExitOpTypes.count(node_type) > 0) { if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(node, group_index); SetControlFlowGroup(node, group_index);
} else {
// For: Switch -> Cast -> Exit
for (const auto &n : node->GetOutDataNodes()) {
(void)GetOriginalType(n, node_type);
if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(n, group_index);
}
}
} }
} }
} }


+ 20
- 0
ge/graph/passes/replace_with_empty_const_pass.cc View File

@@ -21,7 +21,23 @@
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_inner_error_codes.h"
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"


namespace {
const std::unordered_set<std::string> kControlFlowOps = {
ge::SWITCH,
ge::REFSWITCH,
ge::MERGE,
ge::REFMERGE,
ge::ENTER,
ge::REFENTER,
ge::NEXTITERATION,
ge::REFNEXTITERATION,
ge::EXIT,
ge::REFEXIT,
ge::LOOPCOND
};
}
namespace ge { namespace ge {
Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { Status ReplaceWithEmptyConstPass::Run(NodePtr &node) {
GELOGD("ReplaceWithEmptyConstPass in."); GELOGD("ReplaceWithEmptyConstPass in.");
@@ -39,6 +55,10 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) {
GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str()); GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str());
return SUCCESS; return SUCCESS;
} }
if (kControlFlowOps.count(NodeUtils::GetNodeType(node)) != 0) {
GELOGI("Node %s is control flow op. Ignore current pass.", node->GetName().c_str());
return SUCCESS;
}
// Node like no op, it has no output // Node like no op, it has no output
if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) { if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) {
GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str()); GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str());


+ 8
- 8
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -395,8 +395,8 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str());


int64_t group_index = -1; int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(stream_switch, force_unknown, group_index);
(void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
SetControlFlowGroup(stream_switch, group_index);
return stream_switch; return stream_switch;
} }


@@ -491,8 +491,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) {
Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) {
for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) {
for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) {
std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
std::set<NodePtr> same_cond_switch; std::set<NodePtr> same_cond_switch;
same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end());
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end());
@@ -524,13 +524,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) {
return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
}; };
bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
MarkForceUnknownShape(active_node, is_unknown_shape, group_index);
(void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
SetControlFlowGroup(active_node, group_index);


const std::string &cond_group = cond_node->GetName(); const std::string &cond_group = cond_node->GetName();
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) {
bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT);
std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
GE_IF_BOOL_EXEC(switch_list.empty(), continue); GE_IF_BOOL_EXEC(switch_list.empty(), continue);


// select first stream_switch // select first stream_switch
@@ -559,7 +559,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
"[Add][Edge] between %s and %s failed.", "[Add][Edge] between %s and %s failed.",
cast_node->GetName().c_str(), stream_switch->GetName().c_str()); cast_node->GetName().c_str(), stream_switch->GetName().c_str());


MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index);
SetControlFlowGroup(stream_switch, group_index);
for (const NodePtr &node : switch_list) { for (const NodePtr &node : switch_list) {
GE_IF_BOOL_EXEC(node != stream_switch, { GE_IF_BOOL_EXEC(node != stream_switch, {
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)),


+ 66
- 53
ge/graph/preprocess/graph_preprocess.cc View File

@@ -1420,9 +1420,10 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) {
return SUCCESS; return SUCCESS;
} }


Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag) {
Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc) {
auto format = desc.GetFormat(); auto format = desc.GetFormat();
auto origin_format = desc.GetOriginFormat(); auto origin_format = desc.GetOriginFormat();
auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER);
bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag); bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag);
if (need_check_internal_format) { if (need_check_internal_format) {
bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format);
@@ -1439,6 +1440,63 @@ Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTens
return SUCCESS; return SUCCESS;
} }


Status GraphPrepare::UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc) {
auto data_type = desc.GetDataType();
uint32_t length = 1;
bool type_ret = TypeUtils::GetDataTypeLength(data_type, length);
if (!type_ret) {
std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" +
std::to_string(index) + " input tensor is not support";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.",
TypeUtils::DataTypeToSerialString(data_type).c_str());
return FAILED;
}
int64_t desc_shape = desc.GetShape().GetShapeSize();
FMK_INT64_UINT32_MULCHECK(desc_shape, length);
int64_t shape_size = desc_shape * length;
GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast<int64_t>(length));
int64_t size = 0;
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS,
REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index);
GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index); return FAILED);
bool size_check = (size != 0 && shape_size != size);
if (size_check) {
std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) +
"] != shape_size[" + std::to_string(size) + "], check invalid";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size);
return FAILED;
}
ge::TensorUtils::SetSize(desc, shape_size);

auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER);
if (!tune_flag) {
graphStatus graph_ret = op->UpdateInputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
// Size will be recalculated in the build stage
ge::TensorUtils::SetSize(desc, 0);
graph_ret = op->UpdateOutputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
} else {
GELOGI("data %s skip update info in tune mode", op->GetName().c_str());
}

return SUCCESS;
}

Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input, Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input,
const std::map<string, string> &graph_option) { const std::map<string, string> &graph_option) {
// Get shape range of input in dynamic_execute mode // Get shape range of input in dynamic_execute mode
@@ -1471,63 +1529,18 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input,
} }
GeTensorDesc desc(user_input[index].GetTensorDesc()); GeTensorDesc desc(user_input[index].GetTensorDesc());
// data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM.
auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER);
ret = CheckInternalFormat(input_node, desc, tune_flag);
ret = CheckInternalFormat(input_node, desc);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str()); GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str());
return ret; return ret;
} }
auto data_type = desc.GetDataType();
uint32_t length = 1;
bool type_ret = TypeUtils::GetDataTypeLength(data_type, length);
if (!type_ret) {
std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" +
std::to_string(index) + " input tensor is not support";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.",
TypeUtils::DataTypeToSerialString(data_type).c_str());
return FAILED;
}
int64_t desc_shape = desc.GetShape().GetShapeSize();
FMK_INT64_UINT32_MULCHECK(desc_shape, length);
int64_t shape_size = desc_shape * length;
GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast<int64_t>(length));
int64_t size = 0;
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS,
REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index);
GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index);
return FAILED);
bool size_check = (size != 0 && shape_size != size);
if (size_check) {
std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) +
"] != shape_size[" + std::to_string(size) + "], check invalid";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size);
return FAILED;
}
ge::TensorUtils::SetSize(desc, shape_size);
if (!tune_flag) {
graphStatus graph_ret = op->UpdateInputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
// Size will be recalculated in the build stage
ge::TensorUtils::SetSize(desc, 0);
graph_ret = op->UpdateOutputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
} else {
GELOGI("data %s skip update info in tune mode", op->GetName().c_str());

ret = UpdateDataInputOutputDesc(index, op, desc);
if (ret != SUCCESS) {
GELOGE(FAILED, "[Update][DataInputOutputDesc] on %s failed", op->GetName().c_str());
return ret;
} }

if (!dynamic_shape_range_vec.empty()) { if (!dynamic_shape_range_vec.empty()) {
ret = UpdateDynamicInputShapeRange(index, dynamic_shape_range_vec, op, desc); ret = UpdateDynamicInputShapeRange(index, dynamic_shape_range_vec, op, desc);
GE_CHK_STATUS_RET(ret, "[Update][DynamicInputShapeRange] on %s failed.", op->GetName().c_str()); GE_CHK_STATUS_RET(ret, "[Update][DynamicInputShapeRange] on %s failed.", op->GetName().c_str());


+ 2
- 1
ge/graph/preprocess/graph_preprocess.h View File

@@ -63,7 +63,8 @@ class GraphPrepare {
Status CheckRefOp(); Status CheckRefOp();
Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode);
Status AdjustDataOpOutput(const NodePtr &node); Status AdjustDataOpOutput(const NodePtr &node);
Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag);
Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc);
Status UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc);
Status UpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); Status UpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option);
Status CheckAndUpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); Status CheckAndUpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option);
Status CheckConstOp(); Status CheckConstOp();


+ 1
- 1
ge/graph/preprocess/insert_op/ge_aipp_op.cc View File

@@ -114,7 +114,7 @@ Status GetDataDimN(const ge::NodePtr &data_node, ge::Format format, int64_t &bat
std::vector<std::string>({ std::vector<std::string>({
data_node->GetName() + " format", data_node->GetName() + " format",
TypeUtils::FormatToSerialString(format), TypeUtils::FormatToSerialString(format),
"only format " + TypeUtils::FormatToSerialString(FORMAT_NCHW) + " and "+
"only format " + TypeUtils::FormatToSerialString(FORMAT_NCHW) + " and " +
TypeUtils::FormatToSerialString(FORMAT_NHWC) + TypeUtils::FormatToSerialString(FORMAT_NHWC) +
" supported which dynamic aipp is linked"})); " supported which dynamic aipp is linked"}));
GELOGE(PARAM_INVALID, "[Check][Param] Not support data format:%s, node:%s", GELOGE(PARAM_INVALID, "[Check][Param] Not support data format:%s, node:%s",


+ 8
- 7
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -41,6 +41,8 @@ HybridModelExecutor::~HybridModelExecutor() {
Status HybridModelExecutor::Init() { Status HybridModelExecutor::Init() {
GELOGD("Start to init HybridGraphEngine."); GELOGD("Start to init HybridGraphEngine.");
GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); GE_CHK_STATUS_RET_NOLOG(InitExecutionContext());
root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_));
GE_CHECK_NOTNULL(root_graph_executor_);
GELOGD("HybridGraphEngine initialized successfully."); GELOGD("HybridGraphEngine initialized successfully.");
return SUCCESS; return SUCCESS;
} }
@@ -60,8 +62,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration,
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream));
} }
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_);
auto ret = ExecuteGraphInternal(executor, args);
auto ret = ExecuteGraphInternal(args);
Cleanup(); Cleanup();
RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End");
GELOGD("Model executed successfully."); GELOGD("Model executed successfully.");
@@ -69,6 +70,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
context_.profiler->Dump(std::cout); context_.profiler->Dump(std::cout);
context_.profiler->Reset(); context_.profiler->Reset();
} }
root_graph_executor_->ReleaseContext();


context_.iteration += 1; context_.iteration += 1;
if (ret == END_OF_SEQUENCE) { if (ret == END_OF_SEQUENCE) {
@@ -79,8 +81,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
return SUCCESS; return SUCCESS;
} }


Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
HybridModelExecutor::ExecuteArgs &args) {
Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) {
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start");
GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_));
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End");
@@ -94,7 +95,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id));
} }


HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs),
HYBRID_CHK_STATUS_RET(root_graph_executor_->ExecuteAsync(args.inputs, args.input_desc, args.outputs),
"Failed to execute partitioned call."); "Failed to execute partitioned call.");
RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End");


@@ -103,7 +104,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
} }


if (!model_->IsSingleOp()) { if (!model_->IsSingleOp()) {
Status ret = executor.Synchronize();
Status ret = root_graph_executor_->Synchronize();
if (ret != ge::SUCCESS) { if (ret != ge::SUCCESS) {
auto model_manager = ModelManager::GetInstance(); auto model_manager = ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager); GE_CHECK_NOTNULL(model_manager);
@@ -123,7 +124,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
} }


args.outputs.clear(); args.outputs.clear();
HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs");
HYBRID_CHK_STATUS_RET(root_graph_executor_->GetOutputs(args.outputs, args.output_desc), "Failed to get outputs");
RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End");
return SUCCESS; return SUCCESS;
} }


+ 2
- 1
ge/hybrid/executor/hybrid_model_executor.h View File

@@ -48,7 +48,7 @@ class HybridModelExecutor {
Status Execute(ExecuteArgs &args); Status Execute(ExecuteArgs &args);


private: private:
Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args);
Status ExecuteGraphInternal(ExecuteArgs &args);
Status Cleanup(); Status Cleanup();
Status InitExecutionContext(); Status InitExecutionContext();
static Status ResetExecutionContext(GraphExecutionContext &context); static Status ResetExecutionContext(GraphExecutionContext &context);
@@ -58,6 +58,7 @@ class HybridModelExecutor {
uint32_t device_id_; uint32_t device_id_;
rtStream_t stream_; rtStream_t stream_;
GraphExecutionContext context_; GraphExecutionContext context_;
std::unique_ptr<SubgraphExecutor> root_graph_executor_;
}; };
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge


+ 58
- 8
ge/hybrid/executor/node_state.cc View File

@@ -19,8 +19,9 @@
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
#include "graph/compute_graph.h" #include "graph/compute_graph.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "hybrid_execution_context.h"
#include "subgraph_context.h"
#include "hybrid/executor/hybrid_execution_context.h"
#include "hybrid/executor/subgraph_context.h"
#include "hybrid/node_executor/task_context.h"


#define INC_ITERATION_COUNT(iteration) \ #define INC_ITERATION_COUNT(iteration) \
do { \ do { \
@@ -260,6 +261,16 @@ NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_contex
this->op_desc_ = node_item.node->GetOpDesc(); this->op_desc_ = node_item.node->GetOpDesc();
} }


Status NodeState::Init(int group, const shared_ptr<FrameState> &frame_state) {
GE_CHECK_NOTNULL(frame_state);
group_ = group;
frame_state_ = frame_state;
auto unique_task_context = TaskContext::Create(this, subgraph_context_);
GE_CHECK_NOTNULL(unique_task_context);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());
return SUCCESS;
}

Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const {
if (node_item_->IsMergeOp()) { if (node_item_->IsMergeOp()) {
GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size()); GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size());
@@ -314,15 +325,54 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
return task_context_; return task_context_;
} }


void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) {
if (node_item_->root_data_.count(input_idx) > 0) {
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor;
}

if (node_item_->enter_data_.count(input_idx) > 0) {
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor;
}
}

void NodeState::UpdatePersistTensor(int input_idx) {
const auto it = root_tensor_values_.find(input_idx);
if (it == root_tensor_values_.end()) {
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx);
return;
}

auto tensor = task_context_->MutableInput(input_idx);
if (tensor == nullptr) {
GELOGW("[%s] Not found input tensor: %d", GetName().c_str(), input_idx);
return;
}

*tensor = it->second;
GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx);
}

void NodeState::ResetContext(uint64_t iteration) { void NodeState::ResetContext(uint64_t iteration) {
switch_index_ = -1; switch_index_ = -1;
subgraph_context_->ResetContext(node_item_->node); subgraph_context_->ResetContext(node_item_->node);
if (iteration == 0) {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
} else {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size());
auto unique_task_context = TaskContext::Create(this, subgraph_context_);
GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());

data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
for (auto item : node_item_->root_data_) {
UpdatePersistTensor(item.first);
}

if (iteration > 0) {
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size());
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size());
for (auto item : node_item_->enter_data_) {
UpdatePersistTensor(item.first);
}
} }


iteration_count_ = iteration; iteration_count_ = iteration;


+ 11
- 8
ge/hybrid/executor/node_state.h View File

@@ -100,6 +100,8 @@ struct NodeState {
NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context);
~NodeState() = default; ~NodeState() = default;


Status Init(int group, const shared_ptr<FrameState> &frame_state);

OpDesc *GetOpDesc() const { OpDesc *GetOpDesc() const {
return op_desc_.get(); return op_desc_.get();
} }
@@ -129,6 +131,8 @@ struct NodeState {
void RunStreamActive(); void RunStreamActive();
void RunNextIteration(); void RunNextIteration();


void SavePersistTensor(int input_idx, const TensorValue &tensor);

Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const;


void SetScheduleFuture(std::future<Status> &&future); void SetScheduleFuture(std::future<Status> &&future);
@@ -150,18 +154,10 @@ struct NodeState {
return merge_index_; return merge_index_;
} }


void SetGroup(int group) {
group_ = group;
}

int GetGroup() const { int GetGroup() const {
return group_; return group_;
} }


void SetFrameState(const shared_ptr<FrameState> &frame_state) {
frame_state_ = frame_state;
}

const shared_ptr<NodeTask> &GetKernelTask() const { const shared_ptr<NodeTask> &GetKernelTask() const {
return kernel_task_; return kernel_task_;
} }
@@ -181,12 +177,17 @@ struct NodeState {
void SetTaskContext(std::shared_ptr<TaskContext> &task_context); void SetTaskContext(std::shared_ptr<TaskContext> &task_context);
std::shared_ptr<TaskContext> GetTaskContext(); std::shared_ptr<TaskContext> GetTaskContext();


void SetSkipInferShape(bool skip_infershape) { skip_infershape_ = skip_infershape; }

bool MaySkipShapeInference() const { return skip_infershape_; }

private: private:
bool IsScheduleReady() const; bool IsScheduleReady() const;
void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void ResetContext(uint64_t iteration); void ResetContext(uint64_t iteration);
void ScheduleContext(const NodeState &node_state); void ScheduleContext(const NodeState &node_state);
void UpdatePersistTensor(int input_idx);


const NodeItem *node_item_ = nullptr; const NodeItem *node_item_ = nullptr;
std::shared_ptr<NodeTask> kernel_task_ = nullptr; std::shared_ptr<NodeTask> kernel_task_ = nullptr;
@@ -199,6 +200,7 @@ struct NodeState {


std::future<Status> schedule_future_; std::future<Status> schedule_future_;
std::shared_ptr<FrameState> frame_state_; std::shared_ptr<FrameState> frame_state_;
std::map<int, TensorValue> root_tensor_values_;
uint64_t active_count_ = 0; uint64_t active_count_ = 0;
uint64_t iteration_count_ = 0; uint64_t iteration_count_ = 0;
uint32_t ctrl_scheduled_ = 0; uint32_t ctrl_scheduled_ = 0;
@@ -206,6 +208,7 @@ struct NodeState {
int merge_index_ = -1; // Use for Execute (Reset after Executed). int merge_index_ = -1; // Use for Execute (Reset after Executed).
int switch_index_ = -1; // Use for Schedule (Reset after Prepared). int switch_index_ = -1; // Use for Schedule (Reset after Prepared).
int group_ = -1; int group_ = -1;
bool skip_infershape_ = false;
}; };
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge


+ 19
- 8
ge/hybrid/executor/subgraph_context.cc View File

@@ -19,7 +19,7 @@


namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context)
SubgraphContext::SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context)
: graph_item_(graph_item), execution_context_(execution_context) { : graph_item_(graph_item), execution_context_(execution_context) {
} }


@@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
return nullptr; return nullptr;
} }


return CreateNodeState(node_item);
}

NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) {
GELOGD("[%s] lock for write", node_item->NodeName().c_str()); GELOGD("[%s] lock for write", node_item->NodeName().c_str());
if (mmRWLockWRLock(&rw_lock_) != EN_OK) { if (mmRWLockWRLock(&rw_lock_) != EN_OK) {
REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str());
GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str());
return nullptr; return nullptr;
} }

auto &node_state = node_states_[node_item]; auto &node_state = node_states_[node_item];
if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
node_state->SetFrameState(GetOrCreateFrameState(*node_item));
node_state->SetGroup(group_);
(void)guard;
}
do {
if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str());
REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str());
break;
}
(void)guard;
}
} while (0);

GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); GELOGD("[%s] unlock for write", node_item->NodeName().c_str());
if (mmWRLockUnLock(&rw_lock_) != EN_OK) { if (mmWRLockUnLock(&rw_lock_) != EN_OK) {
REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str());


+ 3
- 2
ge/hybrid/executor/subgraph_context.h View File

@@ -30,7 +30,7 @@ namespace ge {
namespace hybrid { namespace hybrid {
class SubgraphContext { class SubgraphContext {
public: public:
explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context);
explicit SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context);
~SubgraphContext(); ~SubgraphContext();


Status Init(); Status Init();
@@ -51,10 +51,11 @@ class SubgraphContext {
void NodeDone(const NodePtr &node); void NodeDone(const NodePtr &node);


private: private:
NodeStatePtr CreateNodeState(const NodeItem *node_item);
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock
friend class TaskContext; friend class TaskContext;
const GraphItem *graph_item_; const GraphItem *graph_item_;
const GraphExecutionContext *execution_context_;
GraphExecutionContext *execution_context_;
mmRWLock_t rw_lock_; mmRWLock_t rw_lock_;
std::vector<TensorValue> all_inputs_; std::vector<TensorValue> all_inputs_;
std::vector<TensorValue> all_outputs_; std::vector<TensorValue> all_outputs_;


+ 11
- 15
ge/hybrid/executor/subgraph_executor.cc View File

@@ -103,6 +103,14 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue
auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); auto node_state = subgraph_context_->GetOrCreateNodeState(input_node);
GE_CHECK_NOTNULL(node_state); GE_CHECK_NOTNULL(node_state);
node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc);
auto op_desc = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex);
GE_CHECK_NOTNULL(output_desc);
output_desc->SetShape(tensor_desc->GetShape());
output_desc->SetOriginShape(tensor_desc->GetOriginShape());
output_desc->SetDataType(tensor_desc->GetDataType());
node_state->SetSkipInferShape(true);
} }
} }


@@ -175,16 +183,12 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue
GE_CHECK_NOTNULL(node_state); GE_CHECK_NOTNULL(node_state);
node_state->SetKernelTask(node_item->kernel_task); node_state->SetKernelTask(node_item->kernel_task);


known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(known_shape_task_context_);
node_state->SetTaskContext(known_shape_task_context_);

std::function<void()> callback; std::function<void()> callback;
GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback));
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback),
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, node_state->GetTaskContext(), *context_, callback),
"[%s] Failed to execute node [%s] for known subgraph.", "[%s] Failed to execute node [%s] for known subgraph.",
graph_item_->GetName().c_str(), graph_item_->GetName().c_str(),
known_shape_task_context_->GetNodeName());
node_state->GetName().c_str());


GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str());
return SUCCESS; return SUCCESS;
@@ -271,16 +275,12 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) {
} else { } else {
node_state->SetKernelTask(node_item.kernel_task); node_state->SetKernelTask(node_item.kernel_task);
} }
auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state->GetKernelTask(); const auto &task = node_state->GetKernelTask();
if (task == nullptr) { if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str());
REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state));
return AfterPrepared(p_node_state); return AfterPrepared(p_node_state);
} }
@@ -480,19 +480,15 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta
} else { } else {
node_state.SetKernelTask(node_item.kernel_task); node_state.SetKernelTask(node_item.kernel_task);
} }
auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state.GetKernelTask(); const auto &task = node_state.GetKernelTask();
if (task == nullptr) { if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str()); GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str());
REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str()); REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state.SetTaskContext(shared_task_context);
GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context));
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start"); RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start");
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*node_state.GetTaskContext())); // update op_desc before alloc ws
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end"); RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end");
return SUCCESS; return SUCCESS;
} }


+ 2
- 1
ge/hybrid/executor/subgraph_executor.h View File

@@ -41,6 +41,8 @@ class SubgraphExecutor {


Status PartialExecuteAsync(int task_group); Status PartialExecuteAsync(int task_group);


void ReleaseContext() { subgraph_context_.reset(nullptr); }

/** /**
* Execute subgraph async, output tensor address(not data) and output tensor descriptions are * Execute subgraph async, output tensor address(not data) and output tensor descriptions are
* valid after this method returned * valid after this method returned
@@ -125,7 +127,6 @@ class SubgraphExecutor {
ThreadPool pre_run_pool_; ThreadPool pre_run_pool_;
BlockingQueue<NodeState *> ready_queue_; BlockingQueue<NodeState *> ready_queue_;
std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_;
std::shared_ptr<TaskContext> known_shape_task_context_;


std::mutex mu_; // Guard for prepare_queues_. std::mutex mu_; // Guard for prepare_queues_.
std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_; std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_;


+ 2
- 1
ge/hybrid/executor/worker/shape_inference_engine.cc View File

@@ -68,8 +68,9 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) {
} }


// Do shape inference // Do shape inference
// Skipping infer shape of input node.
GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str());
{
if (!node_state.MaySkipShapeInference()) {
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start");
GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true),
"[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); "[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str());


+ 22
- 0
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -1227,6 +1227,28 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr
hybrid_model_.known_shape_sub_models_.emplace(parent_node, ge_model); hybrid_model_.known_shape_sub_models_.emplace(parent_node, ge_model);
} }


GE_CHK_STATUS_RET_NOLOG(InitHcclExecutorOnDemand(ge_model));
return SUCCESS;
}

Status HybridModelBuilder::InitHcclExecutorOnDemand(const GeModelPtr &ge_model) {
if (NodeExecutorManager::GetInstance().IsExecutorInitialized(NodeExecutorManager::ExecutorType::HCCL)) {
return SUCCESS;
}

// HCCL tasks in known-shaped subgraph which resides in a dynamic root graph
// still depends on the initialization of the HcclExecutor
auto tasks = ge_model->GetModelTaskDefPtr()->task();
for (int i = 0; i < tasks.size(); ++i) {
const domi::TaskDef &task_def = tasks[i];
auto task_type = static_cast<rtModelTaskType_t>(task_def.type());
if (task_type == RT_MODEL_TASK_HCCL) {
const NodeExecutor *unused = nullptr;
GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance()
.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::HCCL, &unused));
return SUCCESS;
}
}
return SUCCESS; return SUCCESS;
} }




+ 1
- 0
ge/hybrid/model/hybrid_model_builder.h View File

@@ -57,6 +57,7 @@ class HybridModelBuilder {
Status ValidateParams(); Status ValidateParams();
Status LoadGraph(); Status LoadGraph();
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model);
static Status InitHcclExecutorOnDemand(const GeModelPtr &ge_model);
Status LoadTask(NodeItem &node_item); Status LoadTask(NodeItem &node_item);
Status LoadTasks(); Status LoadTasks();
Status IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph); Status IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph);


+ 2
- 3
ge/hybrid/model/node_item.cc View File

@@ -398,12 +398,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
data_send_.emplace(node_item); data_send_.emplace(node_item);
node_item->data_recv_[this] = anchor_index; node_item->data_recv_[this] = anchor_index;
if (is_root_node_) { if (is_root_node_) {
node_item->root_data_.emplace(this);
node_item->root_data_[anchor_index] = this;
} }
// If Enter feed Not Merge, take as root Node. // If Enter feed Not Merge, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) {
node_item->enter_data_.emplace(this);
node_item->enter_inside_.emplace(anchor_index);
node_item->enter_data_[anchor_index] = this;
} }
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
} }


+ 2
- 3
ge/hybrid/model/node_item.h View File

@@ -148,15 +148,14 @@ struct NodeItem {
int64_t frame_index_ = -1; int64_t frame_index_ = -1;
int64_t parent_frame_ = -1; int64_t parent_frame_ = -1;
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node
std::set<const NodeItem *> root_data_; // Recv data from root node
std::map<int, const NodeItem *> root_data_; // Recv data from root node
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node
std::set<const NodeItem *> enter_data_; // Recv data from Enter node
std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node
std::set<const NodeItem *> data_send_; // Send data notify to std::set<const NodeItem *> data_send_; // Send data notify to
std::map<const NodeItem *, int> data_recv_; // Recv data notify from std::map<const NodeItem *, int> data_recv_; // Recv data notify from
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to
std::set<int> enter_inside_; // Enter feed loop inside Node, Not cross Merge.


std::shared_ptr<NodeTask> kernel_task; std::shared_ptr<NodeTask> kernel_task;
std::unique_ptr<FusedSubgraph> fused_subgraph; std::unique_ptr<FusedSubgraph> fused_subgraph;


+ 0
- 4
ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc View File

@@ -64,10 +64,6 @@ Status AicpuNodeTaskBase::InitExtInfo(const std::string &kernel_ext_info, int64_
GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateSessionInfoSessionId(session_id), GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateSessionInfoSessionId(session_id),
"[Update][SessionInfoSessionId] failed, session_id:%ld.", session_id); "[Update][SessionInfoSessionId] failed, session_id:%ld.", session_id);


bool execute_mode = !aicpu_ext_handle_.IsNeedRefreshIOAddr() && !node_item_->is_dynamic;
GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateExecuteMode(execute_mode),
"[Update][ExecuteMode] failed, node:%s.", node_name_.c_str());

// copy task args buf // copy task args buf
GE_CHK_STATUS_RET(AllocTensorBuffer(aicpu_ext_handle_.GetExtInfoLen(), ext_info_addr_dev_), GE_CHK_STATUS_RET(AllocTensorBuffer(aicpu_ext_handle_.GetExtInfoLen(), ext_info_addr_dev_),
"[Invoke][AllocTensorBuffer]Node[%s] alloc kernel_ext_info buf failed, size=%zu", "[Invoke][AllocTensorBuffer]Node[%s] alloc kernel_ext_info buf failed, size=%zu",


+ 2
- 1
ge/hybrid/node_executor/hccl/hccl_node_executor.cc View File

@@ -24,6 +24,7 @@
#include "graph/types.h" #include "graph/types.h"
#include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/executor/hybrid_execution_context.h"
#include "hccl/hcom.h" #include "hccl/hcom.h"
#include "runtime/event.h"


namespace ge { namespace ge {
namespace { namespace {
@@ -325,7 +326,7 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do


rtEvent_t evt = nullptr; rtEvent_t evt = nullptr;
if (context.GetExecutionContext()->hccl_stream != nullptr) { if (context.GetExecutionContext()->hccl_stream != nullptr) {
GE_CHK_RT_RET(rtEventCreateWithFlag(&evt, 0x01));
GE_CHK_RT_RET(rtEventCreateWithFlag(&evt, RT_EVENT_WITH_FLAG));
GE_CHK_RT_RET(rtStreamWaitEvent(context.GetExecutionContext()->hccl_stream, evt)); GE_CHK_RT_RET(rtStreamWaitEvent(context.GetExecutionContext()->hccl_stream, evt));
} }
TaskContext *p_ctx = &context; TaskContext *p_ctx = &context;


+ 40
- 40
ge/hybrid/node_executor/node_executor.cc View File

@@ -58,8 +58,8 @@ Status NodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node,
} }


Status NodeExecutorManager::EnsureInitialized() { Status NodeExecutorManager::EnsureInitialized() {
GE_CHK_STATUS_RET(InitializeExecutors());
std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
++ref_count_;
if (initialized_) { if (initialized_) {
return SUCCESS; return SUCCESS;
} }
@@ -115,17 +115,14 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node
return it->second; return it->second;
} }


Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) const {
Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) {
auto executor_type = ResolveExecutorType(node); auto executor_type = ResolveExecutorType(node);
GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast<int>(executor_type));
const auto it = executors_.find(executor_type); const auto it = executors_.find(executor_type);
if (it == executors_.end()) { if (it == executors_.end()) {
REPORT_INNER_ERROR("E19999", "Failed to get executor by type: %d.", static_cast<int>(executor_type));
GELOGE(INTERNAL_ERROR, "[Check][ExecutorType]Failed to get executor by type: %d.",
static_cast<int>(executor_type));
return INTERNAL_ERROR;
return GetOrCreateExecutor(executor_type, executor);
} }


GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast<int>(executor_type));
*executor = it->second.get(); *executor = it->second.get();
return SUCCESS; return SUCCESS;
} }
@@ -178,51 +175,55 @@ Status NodeExecutorManager::CalcOpRunningParam(Node &node) const {
return OpsKernelBuilderManager::Instance().CalcOpRunningParam(node); return OpsKernelBuilderManager::Instance().CalcOpRunningParam(node);
} }


Status NodeExecutorManager::InitializeExecutors() {
bool NodeExecutorManager::IsExecutorInitialized(NodeExecutorManager::ExecutorType executor_type) {
std::lock_guard<std::mutex> lk(mu_);
return executors_.find(executor_type) != executors_.end();
}

Status NodeExecutorManager::GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **out_executor) {
std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
if (executor_initialized_) {
++ref_count_;
GELOGI("Executor is already initialized. add ref count to [%d]", ref_count_);
const auto executor_it = executors_.find(executor_type);
if (executor_it != executors_.end()) {
*out_executor = executor_it->second.get();
return SUCCESS; return SUCCESS;
} }


GELOGI("Start to Initialize NodeExecutors");
for (auto &it : builders_) {
auto engine_type = it.first;
auto build_fn = it.second;
GE_CHECK_NOTNULL(build_fn);
auto executor = std::unique_ptr<NodeExecutor>(build_fn());
if (executor == nullptr) {
REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for engine type = %d",
static_cast<int>(engine_type));
GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast<int>(engine_type));
return INTERNAL_ERROR;
}
GELOGI("Start to Initialize NodeExecutor, type = %d", static_cast<int>(executor_type));
auto it = builders_.find(executor_type);
if (it == builders_.end()) {
REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d",
static_cast<int>(executor_type));
GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for executor type = %d", static_cast<int>(executor_type));
return INTERNAL_ERROR;
}


GELOGD("Executor of engine type = %d was created successfully", static_cast<int>(engine_type));
auto ret = executor->Initialize();
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast<int>(engine_type));
GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast<int>(engine_type));
for (auto &executor_it : executors_) {
executor_it.second->Finalize();
}
executors_.clear();
return ret;
}
auto build_fn = it->second;
GE_CHECK_NOTNULL(build_fn);
auto executor = std::unique_ptr<NodeExecutor>(build_fn());
if (executor == nullptr) {
REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d",
static_cast<int>(executor_type));
GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast<int>(executor_type));
return INTERNAL_ERROR;
}


executors_.emplace(engine_type, std::move(executor));
GELOGD("Executor of engine type = %d was created successfully", static_cast<int>(executor_type));
auto ret = executor->Initialize();
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast<int>(executor_type));
GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast<int>(executor_type));
return ret;
} }


++ref_count_;
executor_initialized_ = true;
GELOGI("Initializing NodeExecutors successfully.");
*out_executor = executor.get();
executors_.emplace(executor_type, std::move(executor));
GELOGI("Initializing NodeExecutor successfully, type = %d", static_cast<int>(executor_type));
return SUCCESS; return SUCCESS;
} }


void NodeExecutorManager::FinalizeExecutors() { void NodeExecutorManager::FinalizeExecutors() {
std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
if (!executor_initialized_) {
if (ref_count_ <= 0) {
GELOGD("No need for finalizing for not initialized."); GELOGD("No need for finalizing for not initialized.");
return; return;
} }
@@ -237,7 +238,6 @@ void NodeExecutorManager::FinalizeExecutors() {
it.second->Finalize(); it.second->Finalize();
} }
executors_.clear(); executors_.clear();
executor_initialized_ = false;
GELOGD("Done invoking Finalize successfully."); GELOGD("Done invoking Finalize successfully.");
} }




+ 5
- 4
ge/hybrid/node_executor/node_executor.h View File

@@ -179,8 +179,6 @@ class NodeExecutorManager {
*/ */
Status EnsureInitialized(); Status EnsureInitialized();


Status InitializeExecutors();

void FinalizeExecutors(); void FinalizeExecutors();


/** /**
@@ -196,7 +194,7 @@ class NodeExecutorManager {
* @param executor executor * @param executor executor
* @return SUCCESS on success, error code otherwise * @return SUCCESS on success, error code otherwise
*/ */
Status GetExecutor(Node &node, const NodeExecutor **executor) const;
Status GetExecutor(Node &node, const NodeExecutor **executor);


/** /**
* Resolve executor type by node * Resolve executor type by node
@@ -205,13 +203,16 @@ class NodeExecutorManager {
*/ */
ExecutorType ResolveExecutorType(Node &node) const; ExecutorType ResolveExecutorType(Node &node) const;


Status GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **executor);

bool IsExecutorInitialized(ExecutorType executor_type);

private: private:
std::map<ExecutorType, std::unique_ptr<NodeExecutor>> executors_; std::map<ExecutorType, std::unique_ptr<NodeExecutor>> executors_;
std::map<ExecutorType, std::function<NodeExecutor *()>> builders_; std::map<ExecutorType, std::function<NodeExecutor *()>> builders_;
std::map<std::string, NodeExecutorManager::ExecutorType> engine_mapping_; std::map<std::string, NodeExecutorManager::ExecutorType> engine_mapping_;
std::mutex mu_; std::mutex mu_;
bool initialized_ = false; bool initialized_ = false;
bool executor_initialized_ = false;
int ref_count_ = 0; int ref_count_ = 0;
}; };




+ 7
- 10
ge/hybrid/node_executor/task_context.cc View File

@@ -52,9 +52,7 @@ void TaskContext::ReleaseWorkspace() {
} }
} }


std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state,
GraphExecutionContext *execution_context,
SubgraphContext *subgraph_context) {
std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, SubgraphContext *subgraph_context) {
const NodeItem &node_item = *node_state->GetNodeItem(); const NodeItem &node_item = *node_state->GetNodeItem();
GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.",
node_item.NodeName().c_str(), node_item.NodeName().c_str(),
@@ -75,7 +73,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state,
} }


auto task_context = std::unique_ptr<TaskContext>( auto task_context = std::unique_ptr<TaskContext>(
new(std::nothrow)TaskContext(execution_context, node_state, subgraph_context));
new(std::nothrow)TaskContext(subgraph_context->execution_context_, node_state, subgraph_context));
if (task_context == nullptr) { if (task_context == nullptr) {
REPORT_CALL_ERROR("E19999", "Create TaskContext failed for [%s].", node_item.NodeName().c_str()); REPORT_CALL_ERROR("E19999", "Create TaskContext failed for [%s].", node_item.NodeName().c_str());
GELOGE(MEMALLOC_FAILED, "[Create][TaskContext] failed for [%s].", node_item.NodeName().c_str()); GELOGE(MEMALLOC_FAILED, "[Create][TaskContext] failed for [%s].", node_item.NodeName().c_str());
@@ -85,7 +83,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state,
task_context->node_item_ = &node_item; task_context->node_item_ = &node_item;
task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start; task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start;
task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start; task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start;
task_context->iteration_ = execution_context->iteration;
task_context->iteration_ = subgraph_context->execution_context_->iteration;
return task_context; return task_context;
} }


@@ -460,6 +458,10 @@ Status TaskContext::PropagateOutputs() {
subgraph_context_->all_inputs_[input_offset].SetName( subgraph_context_->all_inputs_[input_offset].SetName(
node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx));
} }

auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item);
GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SavePersistTensor(dst_input_idx, *tensor);
} }
} }
(void)guard; (void)guard;
@@ -489,11 +491,6 @@ void TaskContext::ReleaseInputsAndOutputs() {
} }


void TaskContext::ReleaseInput(int index) { void TaskContext::ReleaseInput(int index) {
if (node_item_->enter_inside_.count(index) > 0) {
GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index);
return;
}

auto input_tensor = MutableInput(index); auto input_tensor = MutableInput(index);
if (input_tensor != nullptr) { if (input_tensor != nullptr) {
input_tensor->Destroy(); input_tensor->Destroy();


+ 1
- 3
ge/hybrid/node_executor/task_context.h View File

@@ -36,9 +36,7 @@ class SubgraphContext;


class TaskContext { class TaskContext {
public: public:
static std::unique_ptr<TaskContext> Create(NodeState *node_state,
GraphExecutionContext *execution_context,
SubgraphContext *subgraph_context);
static std::unique_ptr<TaskContext> Create(NodeState *node_state, SubgraphContext *subgraph_context);


~TaskContext(); ~TaskContext();




+ 49
- 22
ge/ir_build/ge_ir_build.cc View File

@@ -263,6 +263,7 @@ class Impl {
omg_context_.user_attr_index_valid = false; omg_context_.user_attr_index_valid = false;
}; };
~Impl() { (void)generator_.Finalize(); }; ~Impl() { (void)generator_.Finalize(); };
graphStatus CheckBuildModeAndBuildStep();
graphStatus GetSupportedOptions(const std::map<std::string, std::string> &in, graphStatus GetSupportedOptions(const std::map<std::string, std::string> &in,
std::map<std::string, std::string> &out); std::map<std::string, std::string> &out);
graphStatus CheckOptions(const std::map<std::string, std::string> &options); graphStatus CheckOptions(const std::map<std::string, std::string> &options);
@@ -451,6 +452,37 @@ graphStatus Impl::UpdateDataOpAttr(const Graph &graph) {
return GRAPH_SUCCESS; return GRAPH_SUCCESS;
} }


graphStatus Impl::CheckBuildModeAndBuildStep() {
std::string build_mode;
auto it = options_.find(BUILD_MODE);
if (it != options_.end() && !(it->second.empty())) {
if (build_mode_options.find(it->second) == build_mode_options.end()) {
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}),
std::vector<std::string>({BUILD_MODE, it->second, "value is unsupported. Please check!"}));
GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode]:%s is unsupported. Please check!", it->second.c_str());
return GRAPH_PARAM_INVALID;
}
build_mode = it->second;
}
it = options_.find(BUILD_STEP);
if (it != options_.end() && !(it->second.empty())) {
if (build_step_options.find(it->second) == build_step_options.end()) {
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}),
std::vector<std::string>({BUILD_STEP, it->second, "value is unsupported. Please check!"}));
GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildStep]:%s is unsupported. Please check!", it->second.c_str());
return GRAPH_PARAM_INVALID;
}
} else {
if (build_mode == BUILD_MODE_TUNING) {
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}),
std::vector<std::string>({BUILD_MODE, it->second, "tuning must specify build step. Please check!"}));
GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode] tuning must specify build step. Please check!");
return GRAPH_PARAM_INVALID;
}
}
return GRAPH_SUCCESS;
}

graphStatus Impl::GetSupportedOptions(const std::map<std::string, std::string> &in, graphStatus Impl::GetSupportedOptions(const std::map<std::string, std::string> &in,
std::map<std::string, std::string> &out) { std::map<std::string, std::string> &out) {
for (auto &ele : in) { for (auto &ele : in) {
@@ -475,29 +507,12 @@ graphStatus Impl::CheckOptions(const std::map<std::string, std::string> &options
} }


// Check options build_mode and build_step. // Check options build_mode and build_step.
std::string build_mode;
auto it = options_.find(BUILD_MODE);
if (it != options_.end() && !(it->second.empty())) {
if (build_mode_options.find(it->second) == build_mode_options.end()) {
GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode]:%s is unsupported. Please check!", it->second.c_str());
return GRAPH_PARAM_INVALID;
}
build_mode = it->second;
}
it = options_.find(BUILD_STEP);
if (it != options_.end() && !(it->second.empty())) {
if (build_step_options.find(it->second) == build_step_options.end()) {
GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildStep]:%s is unsupported. Please check!", it->second.c_str());
return GRAPH_PARAM_INVALID;
}
} else {
if (build_mode == BUILD_MODE_TUNING) {
GELOGE(GRAPH_PARAM_INVALID, "[Check][BuildMode] tuning must specify build step. Please check!");
return GRAPH_PARAM_INVALID;
}
ret = CheckBuildModeAndBuildStep();
if (ret != GRAPH_SUCCESS) {
return ret;
} }
// Check option EXEC_DISABLE_REUSED_MEMORY // Check option EXEC_DISABLE_REUSED_MEMORY
it = options_.find(ge::ir_option::EXEC_DISABLE_REUSED_MEMORY);
auto it = options_.find(ge::ir_option::EXEC_DISABLE_REUSED_MEMORY);
if (it != options_.end() && (CheckDisableReuseMemoryParamValid(it->second) != GRAPH_SUCCESS)) { if (it != options_.end() && (CheckDisableReuseMemoryParamValid(it->second) != GRAPH_SUCCESS)) {
return GRAPH_PARAM_INVALID; return GRAPH_PARAM_INVALID;
} }
@@ -505,6 +520,18 @@ graphStatus Impl::CheckOptions(const std::map<std::string, std::string> &options
if (ge::CheckModifyMixlistParamValid(options_) != GRAPH_SUCCESS) { if (ge::CheckModifyMixlistParamValid(options_) != GRAPH_SUCCESS) {
return GRAPH_PARAM_INVALID; return GRAPH_PARAM_INVALID;
} }
// Check option OP_PRECISION_MODE
it = options_.find(ge::ir_option::OP_PRECISION_MODE);
if (it != options_.end() && !it->second.empty() && !ge::CheckInputPathValid(it->second)) {
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}),
std::vector<std::string>({ge::ir_option::OP_PRECISION_MODE, it->second, "path is not found"}));
GELOGE(GRAPH_PARAM_INVALID, "[Check][OP_PRECISION_MODE] %s not found", it->second.c_str());
return GRAPH_PARAM_INVALID;
}
if (it != options_.end()) {
GELOGI("Option set successfully, option_key=%s, option_value=%s",
ge::ir_option::OP_PRECISION_MODE, it->second.c_str());
}
// Check Input Format // Check Input Format
if (options_.find(kInputFormat) != options_.end()) { if (options_.find(kInputFormat) != options_.end()) {
return CheckInputFormat(options_[kInputFormat]); return CheckInputFormat(options_[kInputFormat]);
@@ -559,8 +586,8 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri
std::string output_type = GetParam(ge::ir_option::OUTPUT_TYPE); std::string output_type = GetParam(ge::ir_option::OUTPUT_TYPE);
GE_CHK_BOOL_EXEC(ge::CheckOutputTypeParamValid(output_type) == ge::SUCCESS, GE_CHK_BOOL_EXEC(ge::CheckOutputTypeParamValid(output_type) == ge::SUCCESS,
return ge::GRAPH_PARAM_INVALID, "[Check][OutputType] failed!"); return ge::GRAPH_PARAM_INVALID, "[Check][OutputType] failed!");
// check insert_op_conf


// check insert_op_conf
std::string insert_op_conf = GetParam(ge::ir_option::INSERT_OP_FILE); std::string insert_op_conf = GetParam(ge::ir_option::INSERT_OP_FILE);
GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(insert_op_conf)) == ge::SUCCESS, GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(insert_op_conf)) == ge::SUCCESS,
return ge::GRAPH_PARAM_INVALID, "[Check][InsertOpConf] failed!"); return ge::GRAPH_PARAM_INVALID, "[Check][InsertOpConf] failed!");


+ 1
- 1
ge/ir_build/option_utils.cc View File

@@ -204,7 +204,7 @@ bool CheckDynamicImagesizeInputShapeValid(map<string, vector<int64_t>> shape_map
if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(input_format.c_str())) { if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(input_format.c_str())) {
GELOGE(ge::PARAM_INVALID, GELOGE(ge::PARAM_INVALID,
"[Check][DynamicImagesizeInputShape] input_format [%s] invalid, can not support now.", input_format.c_str()); "[Check][DynamicImagesizeInputShape] input_format [%s] invalid, can not support now.", input_format.c_str());
REPORT_INPUT_ERROR("E10003", std::vector<std::string>({"parameter","value","reason"}),
REPORT_INPUT_ERROR("E10003", std::vector<std::string>({"parameter", "value", "reason"}),
std::vector<std::string>({"input_format", input_format, "this format is not support"})); std::vector<std::string>({"input_format", input_format, "this format is not support"}));
return false; return false;
} }


+ 33
- 7
ge/offline/main.cc View File

@@ -106,10 +106,14 @@ DEFINE_string(out_nodes, "",
"Optional; output nodes designated by users." "Optional; output nodes designated by users."
"Format: \"node_name1:0;node_name1:1;node_name2:0\""); "Format: \"node_name1:0;node_name1:1;node_name2:0\"");


DEFINE_string(op_precision_mode, "", "Optional; operator precision mode configuration file path");

DEFINE_string(precision_mode, "force_fp16", DEFINE_string(precision_mode, "force_fp16",
"Optional; precision mode." "Optional; precision mode."
"Support force_fp16, force_fp32, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); "Support force_fp16, force_fp32, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype.");


DEFINE_string(modify_mixlist, "", "Optional; operator mixed precision configuration file path");

DEFINE_string(keep_dtype, "", DEFINE_string(keep_dtype, "",
"Optional; config file to specify the precision used by the operator during compilation."); "Optional; config file to specify the precision used by the operator during compilation.");


@@ -192,8 +196,11 @@ DEFINE_string(log, "null", "Optional; generate atc log. Support debug, info, war


DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0."); DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0.");


DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug;"
"1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler");
DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug; "
"1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler; "
"3: disable debug, and keep generating kernel file (.o and .json); 4: disable debug, "
"keep generation kernel file (.o and .json) and generate the operator CCE file (.cce) "
"and the UB fusion computing description file (.json)");
DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass," DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass,"
"multiple names can be set and separated by ','."); "multiple names can be set and separated by ','.");
DEFINE_string(debug_dir, "", "Optional; the path to save the intermediate files of operator compilation"); DEFINE_string(debug_dir, "", "Optional; the path to save the intermediate files of operator compilation");
@@ -210,8 +217,6 @@ DEFINE_string(display_model_info, "0", "Optional; display model info");


DEFINE_string(device_id, "0", "Optional; user device id"); DEFINE_string(device_id, "0", "Optional; user device id");


DEFINE_string(modify_mixlist, "", "Optional; operator mixed precision configuration file path");

class GFlagUtils { class GFlagUtils {
public: public:
/** /**
@@ -298,8 +303,10 @@ class GFlagUtils {
"\"l1_optimize\", \"off_optimize\"\n" "\"l1_optimize\", \"off_optimize\"\n"
" --mdl_bank_path Set the path of the custom repository generated after model tuning.\n" " --mdl_bank_path Set the path of the custom repository generated after model tuning.\n"
"\n[Operator Tuning]\n" "\n[Operator Tuning]\n"
" --op_precision_mode Set the path of operator precision mode configuration file (.ini)\n"
" --precision_mode precision mode, support force_fp16(default), force_fp32, allow_mix_precision, " " --precision_mode precision mode, support force_fp16(default), force_fp32, allow_mix_precision, "
"allow_fp32_to_fp16, must_keep_origin_dtype.\n" "allow_fp32_to_fp16, must_keep_origin_dtype.\n"
" --modify_mixlist Set the path of operator mixed precision configuration file.\n"
" --keep_dtype Retains the precision of certain operators in inference " " --keep_dtype Retains the precision of certain operators in inference "
"scenarios by using a configuration file.\n" "scenarios by using a configuration file.\n"
" --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n"
@@ -315,7 +322,8 @@ class GFlagUtils {
" 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file " " 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file "
"(.json), and enable the CCE compiler -O0-g.\n" "(.json), and enable the CCE compiler -O0-g.\n"
" 3: Disable debug, and keep generating kernel file (.o and .json)\n" " 3: Disable debug, and keep generating kernel file (.o and .json)\n"
" --modify_mixlist Set the path of operator mixed precision configuration file.\n"
" 4: Disable debug, keep generation kernel file (.o and .json) and generate the "
"operator CCE file (.cce) and the UB fusion computing description file (.json)"
"\n[Debug]\n" "\n[Debug]\n"
" --save_original_model Control whether to output original model. E.g.: true: output original model\n" " --save_original_model Control whether to output original model. E.g.: true: output original model\n"
" --log Generate log with level. Support debug, info, warning, error, null\n" " --log Generate log with level. Support debug, info, warning, error, null\n"
@@ -365,6 +373,14 @@ class GFlagUtils {
FLAGS_op_select_implmode) != ge::SUCCESS, FLAGS_op_select_implmode) != ge::SUCCESS,
ret = ge::FAILED, "[Check][ImplMode]check optypelist_for_implmode and op_select_implmode failed!"); ret = ge::FAILED, "[Check][ImplMode]check optypelist_for_implmode and op_select_implmode failed!");


if (!FLAGS_op_precision_mode.empty() && !ge::CheckInputPathValid(FLAGS_op_precision_mode, "--op_precision_mode")) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"op_precision_mode", FLAGS_op_precision_mode.c_str(),
"path is not found"});
GELOGE(ge::FAILED, "[Check][op_precision_mode] %s not found", FLAGS_op_precision_mode.c_str());
ret = ge::FAILED;
}

if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) { if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"modify_mixlist", FLAGS_modify_mixlist.c_str(), {"modify_mixlist", FLAGS_modify_mixlist.c_str(),
@@ -847,6 +863,7 @@ domi::Status GenerateInfershapeJson() {
ge::Graph graph; ge::Graph graph;
std::map<string, string> atc_params; std::map<string, string> atc_params;
atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format)); atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format));
atc_params.insert(std::pair<string, string>("check_report", FLAGS_check_report));
ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework, ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework,
"", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false); "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false);
if (ret != ge::SUCCESS) { if (ret != ge::SUCCESS) {
@@ -953,8 +970,7 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output
ge::Model load_model = ge::Model("loadmodel", "version2"); ge::Model load_model = ge::Model("loadmodel", "version2");
auto ret1 = load_model.LoadFromFile(FLAGS_model); auto ret1 = load_model.LoadFromFile(FLAGS_model);
if (ret1 != ge::GRAPH_SUCCESS) { if (ret1 != ge::GRAPH_SUCCESS) {
REPORT_INPUT_ERROR("E10041", std::vector<std::string>({"file"}), std::vector<std::string>({FLAGS_model}));
REPORT_CALL_ERROR("E19999", "load from model file:%s failed", FLAGS_model.c_str());
REPORT_INPUT_ERROR("E10041", std::vector<std::string>({"parameter"}), std::vector<std::string>({FLAGS_model}));
DOMI_LOGE("Load model from %s failed, please check model file or " DOMI_LOGE("Load model from %s failed, please check model file or "
"input parameter[--framework] is correct", FLAGS_model.c_str()); "input parameter[--framework] is correct", FLAGS_model.c_str());
(void)ge_generator.Finalize(); (void)ge_generator.Finalize();
@@ -1050,6 +1066,7 @@ static void SetEnvForSingleOp(std::map<string, string> &options) {
options.emplace(ge::RUN_FLAG, flag_off); options.emplace(ge::RUN_FLAG, flag_off);
options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off); options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off);
options.emplace(ge::SINGLE_OP_FLAG, flag_on); options.emplace(ge::SINGLE_OP_FLAG, flag_on);
options.emplace(ge::OP_PRECISION_MODE, FLAGS_op_precision_mode);
options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode); options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode);
options.emplace(ge::SOC_VERSION, FLAGS_soc_version); options.emplace(ge::SOC_VERSION, FLAGS_soc_version);
options.emplace(ge::CORE_TYPE, FLAGS_core_type); options.emplace(ge::CORE_TYPE, FLAGS_core_type);
@@ -1077,6 +1094,14 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) {
ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS, ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS,
return ge::FAILED, "[Check][ImplmodeParam] fail for input optypelist_for_implmode and op_select_implmode."); return ge::FAILED, "[Check][ImplmodeParam] fail for input optypelist_for_implmode and op_select_implmode.");


if (!FLAGS_op_precision_mode.empty() && !ge::CheckInputPathValid(FLAGS_op_precision_mode, "--op_precision_mode")) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"op_precision_mode", FLAGS_op_precision_mode.c_str(),
"path is not found"});
GELOGE(ge::FAILED, "[Check][op_precision_mode] %s not found", FLAGS_op_precision_mode.c_str());
return ge::FAILED;
}

if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) { if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"modify_mixlist", FLAGS_modify_mixlist.c_str(), {"modify_mixlist", FLAGS_modify_mixlist.c_str(),
@@ -1160,6 +1185,7 @@ domi::Status GenerateOmModel() {
options.insert(std::pair<string, string>(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf)); options.insert(std::pair<string, string>(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf));
options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes));
options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf));
options.insert(std::pair<string, string>(string(ge::OP_PRECISION_MODE), FLAGS_op_precision_mode));
options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode)); options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode));
options.insert(std::pair<string, string>(string(ge::TUNE_DEVICE_IDS), FLAGS_device_id)); options.insert(std::pair<string, string>(string(ge::TUNE_DEVICE_IDS), FLAGS_device_id));




+ 0
- 193
ge/offline/proto/ge_ir.proto View File

@@ -1,193 +0,0 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
DT_VARIANT = 26; // variant type
DT_BF16 = 27; // bf16 type
DT_INT4 = 28; // int4 type
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 0
- 140
ge/offline/proto/insert_op.proto View File

@@ -1,140 +0,0 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
float padding_value = 72;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 0
- 396
ge/offline/proto/om.proto View File

@@ -1,396 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 0
- 179
ge/offline/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 0
- 1829
ge/proto/caffe/caffe.proto
File diff suppressed because it is too large
View File


+ 0
- 113
ge/proto/dump_task.proto View File

@@ -1,113 +0,0 @@
syntax = "proto3";
package toolkit.dump;

enum OutputDataType {
DT_UNDEFINED = 0;
DT_FLOAT = 1;
DT_FLOAT16 = 2;
DT_INT8 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_UINT16 = 6;
DT_INT32 = 7;
DT_INT64 = 8;
DT_UINT32 = 9;
DT_UINT64 = 10;
DT_BOOL = 11;
DT_DOUBLE = 12;
DT_STRING = 13;
DT_DUAL_SUB_INT8 = 14;
DT_DUAL_SUB_UINT8 = 15;
DT_COMPLEX64 = 16;
DT_COMPLEX128 = 17;
DT_QINT8 = 18;
DT_QINT16 = 19;
DT_QINT32 = 20;
DT_QUINT8 = 21;
DT_QUINT16 = 22;
DT_RESOURCE = 23;
DT_STRING_REF = 24;
DT_DUAL = 25;
DT_VARIANT = 26;
}

enum OutputFormat {
FORMAT_NCHW = 0;
FORMAT_NHWC = 1;
FORMAT_ND = 2;
FORMAT_NC1HWC0 = 3;
FORMAT_FRACTAL_Z = 4;
FORMAT_NC1C0HWPAD = 5;
FORMAT_NHWC1C0 = 6;
FORMAT_FSR_NCHW = 7;
FORMAT_FRACTAL_DECONV = 8;
FORMAT_C1HWNC0 = 9;
FORMAT_FRACTAL_DECONV_TRANSPOSE = 10;
FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11;
FORMAT_NC1HWC0_C04 = 12;
FORMAT_FRACTAL_Z_C04 = 13;
FORMAT_CHWN = 14;
FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15;
FORMAT_HWCN = 16;
FORMAT_NC1KHKWHWC0 = 17;
FORMAT_BN_WEIGHT = 18;
FORMAT_FILTER_HWCK = 19;
FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20;
FORMAT_HASHTABLE_LOOKUP_KEYS = 21;
FORMAT_HASHTABLE_LOOKUP_VALUE = 22;
FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23;
FORMAT_HASHTABLE_LOOKUP_HITS=24;
FORMAT_C1HWNCoC0 = 25;
FORMAT_MD = 26;
FORMAT_NDHWC = 27;
FORMAT_FRACTAL_ZZ = 28;
FORMAT_FRACTAL_NZ = 29;
FORMAT_RESERVED = 30;
}

message OriginalOp {
string name = 1;
uint32 output_index = 2;
OutputDataType data_type = 3;
OutputFormat format = 4;
}

message Shape {
repeated uint64 dim = 1;
}

message OpOutput {
OutputDataType data_type = 1;
OutputFormat format = 2;
Shape shape = 3;
OriginalOp original_op = 4; // the original op corresponding to the output
bytes data = 5;
uint64 size = 6;
}

message OpInput {
OutputDataType data_type = 1;
OutputFormat format = 2;
Shape shape = 3;
bytes data = 4;
uint64 size = 5;
}

enum BufferType {
L1 = 0;
}

message OpBuffer {
BufferType buffer_type = 1;
bytes data = 2;
uint64 size = 3;
}

message DumpData{
string version = 1;
uint64 dump_time = 2;
repeated OpOutput output = 3;
repeated OpInput input = 4;
repeated OpBuffer buffer = 5;
string op_name = 6;
}

+ 0
- 21
ge/proto/fusion_model.proto View File

@@ -1,21 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

import "om.proto";

package domi;

message FusionModelDef {
string version = 1;
repeated OpDef fusion_op = 2;
}

+ 0
- 37
ge/proto/fwk_adapter.proto View File

@@ -1,37 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package aicpu.FWKAdapter;
option cc_enable_arenas = true;


// Defines an struct for input and output.
message TensorDataInfo {

// value DataType
uint32 dtype = 1;

// shape dim
repeated int64 dim = 2;

// data point addr
int64 data_addr = 3;
}

message KernelRunParam {
// input
repeated TensorDataInfo input = 1;
// output
repeated TensorDataInfo output = 2;
}


+ 0
- 88
ge/proto/ge_api.proto View File

@@ -1,88 +0,0 @@
syntax = "proto3";
package ge.api_pb;

import "ge_ir.proto";

// GE initialize
message GEInitialize {
map<string, string> options = 1;
};

// initialize response
message GEInitializeResponse {
uint32 status = 1;
uint32 clientId = 2;
};

// GE finalize
message GEFinalize {
bool final = 1;
uint32 clientId = 2;
};

message GEFinalizeResponse {
uint32 status = 1;
};

// GE Session
message CreateSession{
map<string, string> options = 1;
};

message CreateSessionResponse {
uint32 status = 1;
uint64 sessionId = 2;
};

//GE AddGraph
//model serialize :: serializegraph
message SessionAddGraph{
uint32 graphId = 1;
uint64 sessionId = 2;
ge.proto.GraphDef graph = 3;
};

message SessionAddGraphResponse {
uint32 status = 1;
};

//GE SessionRemoveGraph
message SessionRemoveGraph{
uint32 graphId = 1;
uint64 sessionId = 2;
};

message SessionRemoveGraphResponse {
uint32 status = 1;
};

message SessionRunGraph{
uint32 graphId = 1;
uint64 sessionId = 2;
repeated ge.proto.TensorDef tensor = 3;
};

message SessionBuildGraph{
uint32 graphId = 1;
uint64 sessionId = 2;
repeated ge.proto.TensorDef tensor = 3;
string savePath = 4;
};

message SessionRunGraphResponse {
uint32 status = 1;
repeated ge.proto.TensorDef tensor = 2;
};

message SessionBuildGraphResponse {
uint32 status = 1;
};

message DestroySession{
bool final = 1;
uint64 sessionId = 2;
};

message DestroySessionResponse {
uint32 status = 1;
};

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save