Compare commits

...

148 Commits

Author SHA1 Message Date
  i-robot 98bb610710 !2067 get hbm and ddr size adaptively 3 years ago
  lichun 5f9d382856 get hbm and ddr size adaptively 3 years ago
  i-robot cc2bb8bd75 !2066 use rtMemGetInfo instead of rtMemGetInfoEx 3 years ago
  lichun 6861b40dda use rtMemGetInfo instead of rtMemGetInfoEx 3 years ago
  i-robot daa46ed304 !2062 remove SetDataType 3 years ago
  lichun b077bd4983 remove SetDataType 3 years ago
  i-robot f2655248fe !2051 bugfix for InferFormatForSingleOp 3 years ago
  y00500818 8f18d82eb8 bugfix for InferFormatForSingleOp 3 years ago
  i-robot d807b2ca32 !2025 assign graph memory max size and variable memory max size adaptively 3 years ago
  i-robot 23cbc29fd3 !2011 fix coverity 3 years ago
  lichun 4d1864634d assign graph memory max size and variable memory max size adaptively 3 years ago
  i-robot 9f0ee55282 !1965 fixed coverity warning 3 years ago
  i-robot 00273f1f42 !1966 fix coverity 3 years ago
  李磊 6cb69109b6 fixed coverity warning 3 years ago
  i-robot 057e9529d8 !1964 Adaptation rectification of op_tiling. 3 years ago
  zhaozhixuan 2d9b53f391 Adaptation rectification of op_tiling. 3 years ago
  zhaozhixuan c22a2d5781 Adaptation rectification of op_tiling. 3 years ago
  i-robot f0468551c7 !1957 fix coverity 3 years ago
  i-robot 8c854d67ef !1951 op_select_implmode support high_precision_for_all and high_performance_for_all 3 years ago
  lianghuikang 69ef9eb70c op_select_implmode support high_precision_for_all and high_performance_for_all 3 years ago
  i-robot 641a5f9cd3 !1942 Fix cross merge for switch 3 years ago
  i-robot dfa12939c0 !1944 add_global_step_info_for_known_subgraph_in_dynamic_model and generate_remained_om_when_some_single_op_case_failed_in_atc_multicase_scene 3 years ago
  lichun cafa837a4f add_global_step_info_for_known_subgraph_in_dynamic_model and generate_remained_om_when_some_single_op_case_failed_in_atc_multicase_scene 3 years ago
  zhangxiaokun 6b0dacc00f Fix lambda expression name 4 years ago
  zhangxiaokun 36a90e0313 Fix cross merge for switch 4 years ago
  i-robot cc7175217c !1928 cherry-pick fix for dynamic shape V1 4 years ago
  i-robot 230010b770 !1933 parallel group 4 years ago
  陈华 3929578dee fix parallel group 4 years ago
  i-robot d0f986ea46 !1903 fix sc 4 years ago
  i-robot 0684bd48cf !1927 fix safe 4 years ago
  i-robot 2444f46b8d !1926 set size for dynamic input 4 years ago
  i-robot ca24b76141 !1922 FindLastBpFromBpNode c78 4 years ago
  zhangxiaokun 2785670745 fix printf format 4 years ago
  lianghao 2d9e3da649 IsEnterFeedNode 4 years ago
  zhangxiaokun ac1f4eb1c2 DSP: Switch -> TransData -> Cast -> Exit 4 years ago
  zhangxiaokun 61c203619c Remove UT dump env 4 years ago
  zhangxiaokun b23075b62f UT for control flow group 4 years ago
  zhangxiaokun dbc989a4c3 Clear UpdatePersistTensor Warning for first run 4 years ago
  zhangxiaokun 796513222a UpdatePersistTensor from ExecutionEngine 4 years ago
  zhangxiaokun 002583e4ef Fix Set Control flow group for -1 4 years ago
  zhangxiaokun 65cafdd034 Replace MemcpyAsyncNodeTask 4 years ago
  zhangxiaokun 64d312ab12 UT for LaunchKernelCustAicpuSo 4 years ago
  zhangxiaokun ded54e73af Fix Guard for variable release 4 years ago
  zhangxiaokun 7b1331770a Fix multi control from one node 4 years ago
  zhangxiaokun c440980918 Fix BuildPartitionFrame failed 4 years ago
  wangzhengjun 5da987eb3a set size for dynamic input 4 years ago
  lianghao 8c55572e12 FindLastBpFromBpNode c78 4 years ago
  i-robot 656cd3f3d5 !1873 add copy graph 4 years ago
  wangxiaotian22 e5457d5949 fix sc + 4 years ago
  wangxiaotian22 062757756c fix sc 4 years ago
  i-robot 333bbb700a !1885 fix mem leak 4 years ago
  i-robot 9088d5f696 !1871 FillKernel c78 4 years ago
  i-robot 96b0db9bd5 !1877 step info 4 years ago
  lianghao 1e0a3c0bca FillKernel c78 4 years ago
  wuweikang bb2c55fac8 add copy graph 4 years ago
  wuweikang e51ffe2f54 fix mem leak 4 years ago
  mindspore-ci-bot e0ee4afa7d !1878 update submodule 4 years ago
  zhupuxu be6b3a176f step info 4 years ago
  mindspore-ci-bot ee744e6bae !1872 Create NodeExecute on-demand 4 years ago
  mindspore-ci-bot 17370647b4 !1870 update protobuf to 3.13.0 4 years ago
  李磊 0cf2a59463 update version of protobuf to v3.13.0 4 years ago
  chuxing e2cad9c2ec fixed ad3e707 from https://gitee.com/mindspore/graphengine/pulls/1821 4 years ago
  mindspore-ci-bot dd2ba23718 !1867 add atc_params: check_report for ConvertModelToJson 4 years ago
  mindspore-ci-bot 4d9fe5505a !1862 add op_precision_mode option and support op_debug_level = 4 4 years ago
  mindspore-ci-bot 5c51e07d61 !1864 Fix mem leak and recursive depth protection. 4 years ago
  lichun 6fd9337505 add atc_params: check_report for ConvertModelToJson 4 years ago
  zhaozhixuan 2a0a6eaf2c Fix bug. 4 years ago
  zhaozhixuan cd9869c99d Fix bug. 4 years ago
  mindspore-ci-bot 122a05cda3 !1863 fix opt info 4 years ago
  zhaozhixuan a562b4b6be Fix mem leak and recursive depth protection. 4 years ago
  zhaozhixuan a431199716 Fix mem leak and recursive depth protection. 4 years ago
  陈华 d3bda362d9 fix opt info 4 years ago
  mindspore-ci-bot 83dffb39f6 !1859 skip control flow op when replace node with empty tensor 4 years ago
  mindspore-ci-bot 26fc202953 !1848 add op_precision_mode option and support op_debug_level = 4 4 years ago
  wangzhengjun 49c2eadb7c skip control flow op when replace node with empty tensor 4 years ago
  mindspore-ci-bot 3aec2fa1c8 !1850 fix cmetric 4 years ago
  mindspore-ci-bot f410cde8f2 !1835 move opt to ge_compile 4 years ago
  陈华 310959e5d9 move to ge_compile 4 years ago
  mindspore-ci-bot 802f671006 !1826 opt info 4 years ago
  陈华 bba62ec5f3 opt_info 4 years ago
  wqtshg 0854139134 update submodule 4 years ago
  mindspore-ci-bot 9084063b48 !1818 回退 'Pull Request !1806 : add header targets for link' 4 years ago
  mindspore-ci-bot 287d6db86f !1817 update submodule file 4 years ago
  王涛 a70ab7c91a 回退 'Pull Request !1806 : add header targets for link' 4 years ago
  王涛 65ba89fb34 update .gitmodules. 4 years ago
  计晨 3e510ffc8d !1815 回退 'Pull Request !1784 : Create NodeExecute on-demand' 4 years ago
  储星 c2a1076a87 回退 'Pull Request !1784 : Create NodeExecute on-demand' 4 years ago
  mindspore-ci-bot 3e57016ffa !1806 add header targets for link 4 years ago
  mindspore-ci-bot 076e4222ed !1807 回退 'Pull Request ls : Adaptation rectification of op_tiling.' 4 years ago
  zhaozhixuan 676ce23b55 回退 'Pull Request ls : Adaptation rectification of op_tiling.' 4 years ago
  wangkai 6f130e2290 add link header targets 4 years ago
  mindspore-ci-bot 7610fa5393 !1712 Adaptation rectification of op_tiling. 4 years ago
  mindspore-ci-bot 1e9558e8dd !1784 Create NodeExecute on-demand 4 years ago
  zhaozhixuan fd51637c46 Fix zip bug. 4 years ago
  zhaozhixuan bd1beee90c Fix zip bug. 4 years ago
  mindspore-ci-bot caba8dbf6b !1796 bugfix for restore context 4 years ago
  mindspore-ci-bot b59fbaca6b !1790 fix sc 4 years ago
  zhaozhixuan 9476853d22 Adaptation rectification of op_tiling. 4 years ago
  mindspore-ci-bot b0a017d406 !1785 Optimize performance of single_op executor. 4 years ago
  mindspore-ci-bot 535b674d1a !1791 modify dump config 4 years ago
  mindspore-ci-bot c173f92bcc !1797 Remove reduplicated useless proto 4 years ago
  zhangxiaokun 1bed26c72e Remove reduplicated useless proto 4 years ago
  mindspore-ci-bot 82d489a5e6 !1794 update submodule metadef 4 years ago
  y00500818 246d7e4fd8 bugfix for restore context 4 years ago
  wq160 5bcb04dfb7 update submodule 4 years ago
  zhaozhixuan 4bc0f6f2af Fix bug. 4 years ago
  zhaozhixuan b17eafe3db Fix bug. 4 years ago
  zhaozhixuan 23c8a0d581 Fix ut. 4 years ago
  mindspore-ci-bot c4610153e6 !1787 ge code for 1981 4 years ago
  zhou_chao1993 6927a8eef3 modif dump config 4 years ago
  zhou_lili 116167dc88 ge code for 1981 4 years ago
  zhaozhixuan 58086ab187 Release mem. 4 years ago
  zhaozhixuan 69da59b6b7 Fix ut. 4 years ago
  zhaozhixuan d0e83c26a7 Merge branch 'my_dev4' of https://gitee.com/zhao_zhixuan/graphengine into my_dev4 4 years ago
  zhaozhixuan d1eba02e1e Fix ut. 4 years ago
  xchu42 b64048a39f Init NodeExecutor on demand 4 years ago
  zhaozhixuan 24eedfa3b4 Fix ut. 4 years ago
  zhaozhixuan 7ce31b2e0e Fix ut. 4 years ago
  zhaozhixuan 0c2d07eb72 Fix ut. 4 years ago
  zhaozhixuan e765135b86 Merge branch 'my_dev4' of https://gitee.com/zhao_zhixuan/graphengine into my_dev4 4 years ago
  zhaozhixuan 492d36b237 Fix ut. 4 years ago
  zhaozhixuan 3a54dc6dd8 Merge https://gitee.com/mindspore/graphengine into my_dev4 4 years ago
  zhaozhixuan ab7334ed78 Release context in execute end. 4 years ago
  mindspore-ci-bot f7370bc074 !1788 enable optimization 4 years ago
  mindspore-ci-bot 78bbbda571 !1783 Fix dynamic shape partition 4 years ago
  zhaozhixuan 181cd5891b Release context in execute end. 4 years ago
  wangzhengjun 367774c5b0 enable optimization 4 years ago
  zhangxiaokun f578e8fff4 Fix NodeState for UT 4 years ago
  zhangxiaokun ab65075326 Add Init to NodeState 4 years ago
  zhangxiaokun 8852766766 Fix hccl_node_executor_unittest 4 years ago
  zhangxiaokun e85bbe2181 Fix dynamic shape partition 4 years ago
  mindspore-ci-bot 5393b67679 !1774 ut fix 4 years ago
  zhaozhixuan b35412f5ea Add ut. 4 years ago
  zhaozhixuan 1ab9ae32dc Add ut. 4 years ago
  zhaozhixuan 13c98395e2 Add ut. 4 years ago
  zhaozhixuan 4c3c819129 Optimize performance of single_op executor. 4 years ago
  mindspore-ci-bot 83421b6c16 !1775 run_flag 4 years ago
  lianghao 1847efba53 run_flag 4 years ago
  wjm 68aabd7872 fix 4 years ago
  mindspore-ci-bot 38bea2e8a6 !1780 set step profiling 4 years ago
  wjm 4e49c16585 fix coverity 4 years ago
  zhupuxu 0be66d0dca set step 4 years ago
  zhengyuanhua d8ba1fb2c0 remove graph ut form ge 4 years ago
  wjm 346dd7d6ff fix 4 years ago
  wjm eee4d7b492 fix safe 4 years ago
  wjm 9e33784220 Merge branch 'master' of gitee.com:jiming6/graphengine 4 years ago
  wjm 2abf8be621 fix sc 4 years ago
  wjm f826880f75 fix sc 4 years ago
100 changed files with 1509 additions and 4442 deletions
Split View
  1. +2
    -2
      .gitmodules
  2. +1
    -0
      CMakeLists.txt
  3. +4
    -4
      build.sh
  4. +6
    -6
      cmake/external_libs/protobuf_shared.cmake
  5. +5
    -7
      cmake/external_libs/protobuf_static.cmake
  6. +5
    -7
      cmake/external_libs/protoc.cmake
  7. +10
    -0
      ge/CMakeLists.txt
  8. +0
    -1
      ge/client/proto/ge_api.proto
  9. +0
    -193
      ge/client/proto/ge_ir.proto
  10. +0
    -140
      ge/client/proto/insert_op.proto
  11. +0
    -396
      ge/client/proto/om.proto
  12. +0
    -179
      ge/client/proto/task.proto
  13. +1
    -0
      ge/common/CMakeLists.txt
  14. +2
    -2
      ge/common/dump/dump_manager.cc
  15. +1
    -0
      ge/common/dump/exception_dumper.cc
  16. +4
    -0
      ge/common/formats/format_transfers/format_transfer_fractal_nz.cc
  17. +4
    -0
      ge/common/formats/format_transfers/format_transfer_fractal_zz.cc
  18. +10
    -2
      ge/common/ge/tbe_plugin_manager.cc
  19. +2
    -1
      ge/common/ge/tbe_plugin_manager.h
  20. +7
    -0
      ge/common/helper/model_cache_helper.cc
  21. +38
    -0
      ge/common/profiling/ge_profiling.cc
  22. +0
    -193
      ge/common/proto/ge_ir.proto
  23. +0
    -140
      ge/common/proto/insert_op.proto
  24. +0
    -396
      ge/common/proto/om.proto
  25. +0
    -75
      ge/common/proto/op_mapping.proto
  26. +0
    -179
      ge/common/proto/task.proto
  27. +0
    -70
      ge/common/proto/tensorflow/attr_value.proto
  28. +0
    -108
      ge/common/proto/tensorflow/function.proto
  29. +0
    -64
      ge/common/proto/tensorflow/graph.proto
  30. +0
    -22
      ge/common/proto/tensorflow/graph_library.proto
  31. +0
    -71
      ge/common/proto/tensorflow/node_def.proto
  32. +0
    -172
      ge/common/proto/tensorflow/op_def.proto
  33. +0
    -37
      ge/common/proto/tensorflow/resource_handle.proto
  34. +0
    -102
      ge/common/proto/tensorflow/tensor.proto
  35. +0
    -53
      ge/common/proto/tensorflow/tensor_shape.proto
  36. +0
    -82
      ge/common/proto/tensorflow/types.proto
  37. +0
    -39
      ge/common/proto/tensorflow/versions.proto
  38. +1
    -0
      ge/executor/CMakeLists.txt
  39. +0
    -113
      ge/executor/proto/dump_task.proto
  40. +0
    -193
      ge/executor/proto/ge_ir.proto
  41. +0
    -140
      ge/executor/proto/insert_op.proto
  42. +0
    -396
      ge/executor/proto/om.proto
  43. +0
    -75
      ge/executor/proto/op_mapping.proto
  44. +0
    -179
      ge/executor/proto/task.proto
  45. +4
    -7
      ge/ge_local_engine/engine/host_cpu_engine.cc
  46. +0
    -179
      ge/ge_local_engine/proto/task.proto
  47. +58
    -0
      ge/ge_opt_info/ge_opt_info.cc
  48. +32
    -0
      ge/ge_opt_info/ge_opt_info.h
  49. +1
    -1
      ge/ge_runtime/task/label_goto_task.cc
  50. +13
    -8
      ge/generator/ge_generator.cc
  51. +5
    -0
      ge/graph/build/label_allocator.cc
  52. +5
    -0
      ge/graph/build/logical_stream_allocator.cc
  53. +1
    -1
      ge/graph/build/model_builder.cc
  54. +10
    -1
      ge/graph/build/stream_allocator.cc
  55. +47
    -29
      ge/graph/build/task_generator.cc
  56. +2
    -1
      ge/graph/build/task_generator.h
  57. +0
    -15
      ge/graph/common/omg_util.cc
  58. +0
    -9
      ge/graph/common/omg_util.h
  59. +175
    -58
      ge/graph/load/model_manager/davinci_model.cc
  60. +7
    -7
      ge/graph/load/model_manager/davinci_model.h
  61. +17
    -10
      ge/graph/load/model_manager/model_manager.cc
  62. +393
    -0
      ge/graph/load/model_manager/task_info/ffts_task_info.cc
  63. +66
    -0
      ge/graph/load/model_manager/task_info/ffts_task_info.h
  64. +1
    -1
      ge/graph/load/model_manager/task_info/hccl_task_info.cc
  65. +2
    -0
      ge/graph/load/model_manager/task_info/kernel_task_info.cc
  66. +1
    -1
      ge/graph/load/model_manager/task_info/memcpy_async_task_info.h
  67. +4
    -2
      ge/graph/load/model_manager/zero_copy_offset.cc
  68. +20
    -26
      ge/graph/manager/graph_manager.cc
  69. +1
    -1
      ge/graph/manager/graph_manager.h
  70. +49
    -11
      ge/graph/manager/graph_var_manager.cc
  71. +3
    -0
      ge/graph/manager/graph_var_manager.h
  72. +0
    -2
      ge/graph/optimize/graph_optimize.cc
  73. +29
    -7
      ge/graph/partition/dynamic_shape_partition.cc
  74. +4
    -2
      ge/graph/partition/dynamic_shape_partition.h
  75. +14
    -7
      ge/graph/partition/graph_partition.cc
  76. +107
    -70
      ge/graph/passes/mark_force_unknown_for_cond_pass.cc
  77. +11
    -0
      ge/graph/passes/mark_force_unknown_for_cond_pass.h
  78. +6
    -0
      ge/graph/passes/mark_graph_unknown_status_pass.cc
  79. +2
    -3
      ge/graph/passes/merge_to_stream_merge_pass.cc
  80. +23
    -6
      ge/graph/passes/next_iteration_pass.cc
  81. +39
    -19
      ge/graph/passes/parallel_group_pass.cc
  82. +1
    -0
      ge/graph/passes/parallel_group_pass.h
  83. +20
    -0
      ge/graph/passes/replace_with_empty_const_pass.cc
  84. +9
    -8
      ge/graph/passes/switch_to_stream_switch_pass.cc
  85. +68
    -55
      ge/graph/preprocess/graph_preprocess.cc
  86. +2
    -1
      ge/graph/preprocess/graph_preprocess.h
  87. +1
    -0
      ge/graph/preprocess/insert_op/util_insert_aipp_op.cc
  88. +1
    -1
      ge/graph/preprocess/multi_batch_copy_graph.cc
  89. +8
    -0
      ge/host_kernels/fill_kernel.cc
  90. +5
    -3
      ge/hybrid/executor/hybrid_model_async_executor.cc
  91. +8
    -7
      ge/hybrid/executor/hybrid_model_executor.cc
  92. +2
    -1
      ge/hybrid/executor/hybrid_model_executor.h
  93. +2
    -0
      ge/hybrid/executor/hybrid_model_pipeline_executor.cc
  94. +79
    -8
      ge/hybrid/executor/node_state.cc
  95. +12
    -8
      ge/hybrid/executor/node_state.h
  96. +19
    -8
      ge/hybrid/executor/subgraph_context.cc
  97. +3
    -2
      ge/hybrid/executor/subgraph_context.h
  98. +10
    -15
      ge/hybrid/executor/subgraph_executor.cc
  99. +2
    -1
      ge/hybrid/executor/subgraph_executor.h
  100. +2
    -1
      ge/hybrid/executor/worker/execution_engine.cc

+ 2
- 2
.gitmodules View File

@@ -1,8 +1,8 @@
[submodule "parser"]
path = parser
url = https://gitee.com/ascend/parser.git
branch = master
branch = r1.5.0
[submodule "metadef"]
path = metadef
url = https://gitee.com/ascend/metadef.git
branch = master
branch = r1.5.0

+ 1
- 0
CMakeLists.txt View File

@@ -95,6 +95,7 @@ else ()
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH})
else()
find_module(slog libalog.so ${ASCEND_ATC_DIR})
find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR})
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR})
if(PLATFORM STREQUAL "train")
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR})


+ 4
- 4
build.sh View File

@@ -355,13 +355,13 @@ generate_package()

if [ "x${PLATFORM}" = "xtrain" ]
then
tar -cf graphengine_lib.tar fwkacllib
tar -zcf graphengine_lib.tar fwkacllib
elif [ "x${PLATFORM}" = "xinference" ]
then
tar -cf graphengine_lib.tar acllib atc
tar -zcf graphengine_lib.tar acllib atc
elif [ "x${PLATFORM}" = "xall" ]
then
tar -cf graphengine_lib.tar fwkacllib acllib atc
tar -zcf graphengine_lib.tar fwkacllib acllib atc
fi
}

@@ -371,6 +371,6 @@ elif [ "X$MINDSPORE_MODE" = "Xon" ]
then
cd "${OUTPUT_PATH}"
find ./ -name graphengine_lib.tar -exec rm {} \;
tar -cf graphengine_lib.tar lib
tar -zcf graphengine_lib.tar lib
fi
echo "---------------- GraphEngine package archive generated ----------------"

+ 6
- 6
cmake/external_libs/protobuf_shared.cmake View File

@@ -11,14 +11,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif()
if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz")
else()
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif ()
endif()

@@ -58,7 +58,7 @@ target_include_directories(ascend_protobuf INTERFACE ${PROTOBUF_SHARED_PKG_DIR}/
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.8.0.0 OPTIONAL
install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.13.0.0 OPTIONAL
DESTINATION ${INSTALL_LIBRARY_DIR})
install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL
DESTINATION ${INSTALL_LIBRARY_DIR})


+ 5
- 7
cmake/external_libs/protobuf_static.cmake View File

@@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
endif()

if(GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz")
else()
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif ()
endif()

@@ -29,8 +29,6 @@ set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static)
ExternalProject_Add(protobuf_static_build
URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}


+ 5
- 7
cmake/external_libs/protoc.cmake View File

@@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
endif()

if(GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz")
else()
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif ()
endif()

@@ -28,8 +28,6 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protoc_build
URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake
BUILD_COMMAND $(MAKE)


+ 10
- 0
ge/CMakeLists.txt View File

@@ -174,6 +174,7 @@ set(TRAIN_SRC_LIST
"graph/load/model_manager/task_info/model_exit_task_info.cc"
"graph/load/model_manager/task_info/event_record_task_info.cc"
"graph/load/model_manager/task_info/event_wait_task_info.cc"
"graph/load/model_manager/task_info/ffts_task_info.cc"
"graph/load/model_manager/task_info/fusion_start_task_info.cc"
"graph/load/model_manager/task_info/fusion_stop_task_info.cc"
"graph/load/model_manager/task_info/hccl_task_info.cc"
@@ -433,6 +434,7 @@ set(TRAIN_SRC_LIST
"graph/build/memory/max_block_mem_assigner.cc"
"graph/build/memory/var_mem_assign_util.cc"
"graph/build/memory/buffer_pool_mem_assigner.cc"
"ge_opt_info/ge_opt_info.cc"
)

set(INFER_SRC_LIST
@@ -662,6 +664,7 @@ set(INFER_SRC_LIST
"graph/load/model_manager/task_info/task_info.cc"
"graph/load/model_manager/task_info/event_record_task_info.cc"
"graph/load/model_manager/task_info/event_wait_task_info.cc"
"graph/load/model_manager/task_info/ffts_task_info.cc"
"graph/load/model_manager/task_info/fusion_start_task_info.cc"
"graph/load/model_manager/task_info/fusion_stop_task_info.cc"
"graph/load/model_manager/task_info/kernel_ex_task_info.cc"
@@ -709,6 +712,7 @@ set(INFER_SRC_LIST
"graph/build/memory/max_block_mem_assigner.cc"
"graph/build/memory/var_mem_assign_util.cc"
"graph/build/memory/buffer_pool_mem_assigner.cc"
"ge_opt_info/ge_opt_info.cc"
)

if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES)
@@ -770,11 +774,13 @@ target_include_directories(ge_runner SYSTEM PRIVATE
${GE_CODE_DIR}/../inc/cce
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external
${GE_CODE_DIR}/../abl/adump/external
${GE_CODE_DIR}/../abl/licctrl
#### blue zone
${ASCEND_DIR}/driver/include
${ASCEND_DIR}/fwkacllib/include
${GE_CODE_DIR}/third_party/fwkacllib/inc
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain
${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info
)

target_link_options(ge_runner PRIVATE
@@ -797,6 +803,7 @@ target_link_libraries(ge_runner PRIVATE
runtime
error_manager
ascend_hal_stub
opt_feature
-Wl,--as-needed
json
-lrt
@@ -851,11 +858,13 @@ target_include_directories(ge_compiler SYSTEM PRIVATE
${GE_CODE_DIR}/../inc/cce
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external
${GE_CODE_DIR}/../abl/adump/external
${GE_CODE_DIR}/../abl/licctrl
#### blue zone ####
${ASCEND_DIR}/driver/include
${ASCEND_DIR}/fwkacllib/include
${GE_CODE_DIR}/third_party/fwkacllib/inc
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain
${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info
)

target_link_options(ge_compiler PRIVATE
@@ -875,6 +884,7 @@ target_link_libraries(ge_compiler PRIVATE
error_manager
slog
runtime_compile
opt_feature
-Wl,--as-needed
json
-lrt


+ 0
- 1
ge/client/proto/ge_api.proto View File

@@ -1 +0,0 @@
../../proto/ge_api.proto

+ 0
- 193
ge/client/proto/ge_ir.proto View File

@@ -1,193 +0,0 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
DT_VARIANT = 26; // variant type
DT_BF16 = 27; // bf16 type
DT_INT4 = 28; // int4 type
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 0
- 140
ge/client/proto/insert_op.proto View File

@@ -1,140 +0,0 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
float padding_value = 72;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 0
- 396
ge/client/proto/om.proto View File

@@ -1,396 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 0
- 179
ge/client/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 1
- 0
ge/common/CMakeLists.txt View File

@@ -106,6 +106,7 @@ target_link_libraries(ge_common PRIVATE
c_sec
error_manager
slog
opt_feature
-Wl,--as-needed
json
$<$<NOT:$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-lrt>


+ 2
- 2
ge/common/dump/dump_manager.cc View File

@@ -33,7 +33,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpManager &DumpManager::GetIn

bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump_properties) {
if (dump_config.dump_status.empty() && dump_config.dump_debug.empty()) {
dump_properties_map_.emplace(kInferSessionId, dump_properties);
dump_properties_map_[kInferSessionId] = dump_properties;
GELOGI("Dump does not open");
return false;
}
@@ -41,7 +41,7 @@ bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump
if ((dump_config.dump_status == kDumpoff || dump_config.dump_status == kDumpOFF) &&
dump_config.dump_debug == kDumpoff) {
dump_properties.ClearDumpPropertyValue();
dump_properties_map_.emplace(kInferSessionId, dump_properties);
dump_properties_map_[kInferSessionId] = dump_properties;
return false;
}
if (dump_config.dump_status == kDumpOn && dump_config.dump_debug == kDumpOn) {


+ 1
- 0
ge/common/dump/exception_dumper.cc View File

@@ -161,6 +161,7 @@ Status ExceptionDumper::DumpExceptionInfo(const std::vector<rtExceptionInfo> &ex

uint64_t proto_size = dump_data.ByteSizeLong();
std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]);
GE_CHECK_NOTNULL(proto_msg);
bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size);
if (!ret || proto_size == 0) {
REPORT_INNER_ERROR("E19999", "Serialize proto to string fail");


+ 4
- 0
ge/common/formats/format_transfers/format_transfer_fractal_nz.cc View File

@@ -185,6 +185,7 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con
auto src_offset = (src_h_head + w1_idx * w0) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ?
dst_size - dst_offset : static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
GE_CHECK_GE(protected_size, 0);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size * w0));
if (ret != EOK) {
@@ -202,6 +203,7 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con
auto src_offset = (src_h_head + src_w_idx) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ?
dst_size - dst_offset : static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
GE_CHECK_GE(protected_size, 0);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size));
if (ret != EOK) {
@@ -267,6 +269,7 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con
auto dst_offset = (dst_h_head + w1_idx * w0) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ?
dst_size - dst_offset : static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
GE_CHECK_GE(protected_size, 0);
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size * w0));
if (ret != EOK) {
@@ -285,6 +288,7 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con
auto dst_offset = (dst_h_head + dst_w_idx) * size;
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ?
dst_size - dst_offset : static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
GE_CHECK_GE(protected_size, 0);
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size));
if (ret != EOK) {


+ 4
- 0
ge/common/formats/format_transfers/format_transfer_fractal_zz.cc View File

@@ -193,6 +193,7 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
GE_CHECK_GE(protected_size, 0);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size * w0));
if (ret != EOK) {
@@ -213,6 +214,7 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
GE_CHECK_GE(protected_size, 0);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size));
if (ret != EOK) {
@@ -284,6 +286,7 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
GE_CHECK_GE(protected_size, 0);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size * w0));
if (ret != EOK) {
@@ -304,6 +307,7 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN)
? dst_size - dst_offset
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
GE_CHECK_GE(protected_size, 0);
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset,
static_cast<size_t>(size));
if (ret != EOK) {


+ 10
- 2
ge/common/ge/tbe_plugin_manager.cc View File

@@ -104,7 +104,15 @@ void TBEPluginManager::ProcessSoFullName(vector<string> &file_list, string &caff
}
}

void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path) {
void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_list,
string &caffe_parser_path, int recursive_depth) {
static const int kMaxRecursiveDepth = 20; // For recursive depth protection

if (recursive_depth >= kMaxRecursiveDepth) {
GELOGW("Recursive depth is become %d, Please check input!", recursive_depth);
return;
}

// Path, change to absolute path
string real_path = RealPath(path.c_str());
// Plugin path does not exist
@@ -138,7 +146,7 @@ void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_lis
ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff,
aicpu_host_so_suff);
} else {
FindParserSo(full_name, file_list, caffe_parser_path);
FindParserSo(full_name, file_list, caffe_parser_path, recursive_depth + 1);
}
}
mmScandirFree(entries, ret);


+ 2
- 1
ge/common/ge/tbe_plugin_manager.h View File

@@ -57,7 +57,8 @@ class TBEPluginManager {
static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name,
const string &caffe_parser_so_suff, const string &aicpu_so_suff,
const string &aicpu_host_so_suff);
static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path);
static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path,
int recursive_depth = 0);
static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path);
static void GetCustomOpPath(std::string &customop_path);
void LoadCustomOpLib();


+ 7
- 0
ge/common/helper/model_cache_helper.cc View File

@@ -1679,6 +1679,13 @@ Status ModelCacheHelper::LoadOmModelFromCache(GeModelPtr &ge_model) const {
GELOGW("LoadOmModelFromCache: Load model from file failed. ret = %u", ret);
return ret;
}
std::function<void()> callback = [&]() {
if (model_data.model_data != nullptr) {
delete[] reinterpret_cast<char *>(model_data.model_data);
model_data.model_data = nullptr;
}
};
GE_MAKE_GUARD(release, callback);

ModelHelper model_helper;
ret = model_helper.LoadModel(model_data);


+ 38
- 0
ge/common/profiling/ge_profiling.cc View File

@@ -22,6 +22,7 @@
#include "graph/load/graph_loader.h"
#include "init/gelib.h"
#include "framework/common/ge_inner_error_codes.h"
#include "model/ge_model.h"

namespace {
const uint32_t kDeviceListIndex = 3;
@@ -42,6 +43,10 @@ const std::map<ProfCommandHandleType, std::string> kProfCommandTypeMap = {
{kProfCommandhandleFinalize, kProfilingFinalize},
{kProfCommandhandleModelSubscribe, kProfModelSubscribe},
{kProfCommandhandleModelUnsubscribe, kProfModelUnsubscribe}};

const uint64_t kModelId = ge::INVALID_MODEL_ID;
const uint16_t kStepStart = 0;
const uint16_t kStepEnd = 1;
} // namespace

bool TransProfConfigToParam(const ProfCommandHandleData &profCommand, vector<string> &prof_config_params) {
@@ -216,3 +221,36 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le
return ge::SUCCESS;
}

ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) {
static bool is_first_run = true;
int32_t device_id = 0;
rtError_t rt_ret = rtGetDevice(&device_id);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(rt_ret, "[Get][LogicDeviceId]Failed, ret 0x%X", rt_ret);
REPORT_CALL_ERROR("E19999", "Get logic device id failed, ret 0x%X", rt_ret);
return ge::FAILED;
}
if (is_first_run && tag_id == kStepStart) {
GE_CHK_STATUS_RET_NOLOG(ge::ProfilingManager::Instance().ProfileStepInfo(index_id,
kModelId,
tag_id,
stream,
device_id));
is_first_run = false;
return ge::SUCCESS;
}
if (!is_first_run && tag_id == kStepEnd) {
GE_CHK_STATUS_RET_NOLOG(ge::ProfilingManager::Instance().ProfileStepInfo(index_id,
kModelId,
tag_id,
stream,
device_id));
is_first_run = true;
return ge::SUCCESS;
}
GELOGE(ge::FAILED, "Param tag_id:%u invalid when is_first_run is %d", tag_id, is_first_run);
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"value", "parameter", "reason"}),
std::vector<std::string>({std::to_string(tag_id), "tag_id",
"tag id must be 0 when first run, must be 1 when second run"}));
return ge::FAILED;
}

+ 0
- 193
ge/common/proto/ge_ir.proto View File

@@ -1,193 +0,0 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
DT_VARIANT = 26; // variant type
DT_BF16 = 27; // bf16 type
DT_INT4 = 28; // int4 type
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 0
- 140
ge/common/proto/insert_op.proto View File

@@ -1,140 +0,0 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
float padding_value = 72;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 0
- 396
ge/common/proto/om.proto View File

@@ -1,396 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 0
- 75
ge/common/proto/op_mapping.proto View File

@@ -1,75 +0,0 @@
syntax = "proto3";
package toolkit.aicpu.dump;

message Shape {
repeated uint64 dim = 1;
}

message Output {
int32 data_type = 1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
string original_name = 5;
int32 original_output_index = 6;
int32 original_output_data_type = 7;
int32 original_output_format = 8;
uint64 size = 9;
Shape origin_shape = 10;
}

message Input {
int32 data_type =1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
uint64 size = 5;
Shape origin_shape = 6;
}

enum BufferType {
L1 = 0;
}

message OpBuffer {
BufferType buffer_type = 1;
uint64 address = 2;
uint64 size = 3;
}

message Op {
string op_name = 1;
string op_type = 2;
}

message Task {
uint32 task_id = 1;
uint32 stream_id = 2;
Op op = 3;
repeated Output output = 4;
bool end_graph = 5;
repeated Input input = 6;
repeated OpBuffer buffer = 7;
}

message OpMappingInfo {
string dump_path = 1;
oneof model_name_param {
string model_name = 2;
}
oneof model_id_param {
uint32 model_id = 3;
}
oneof step_id {
uint64 step_id_addr = 4;
}
oneof iterations_per_loop {
uint64 iterations_per_loop_addr = 5;
}
oneof loop_cond {
uint64 loop_cond_addr = 6;
}
uint32 flag = 7; // 0x01 load, 0x00 unload
repeated Task task = 8;
string dump_step = 9;
}

+ 0
- 179
ge/common/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 0
- 70
ge/common/proto/tensorflow/attr_value.proto View File

@@ -1,70 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "AttrValueProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "tensor.proto";
import "tensor_shape.proto";
import "types.proto";

// Protocol buffer representing the value for an attr used to configure an Op.
// Comment indicates the corresponding attr type. Only the field matching the
// attr type may be filled.
message AttrValue {
// LINT.IfChange
message ListValue {
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated DataType type = 6 [packed = true]; // "list(type)"
repeated TensorShapeProto shape = 7; // "list(shape)"
repeated TensorProto tensor = 8; // "list(tensor)"
repeated NameAttrList func = 9; // "list(attr)"
}
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)

oneof value {
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
DataType type = 6; // "type"
TensorShapeProto shape = 7; // "shape"
TensorProto tensor = 8; // "tensor"
ListValue list = 1; // any "list(...)"

// "func" represents a function. func.name is a function's name or
// a primitive op's name. func.attr.first is the name of an attr
// defined for that function. func.attr.second is the value for
// that attr in the instantiation.
NameAttrList func = 10;

// This is a placeholder only used in nodes defined inside a
// function. It indicates the attr value will be supplied when
// the function is instantiated. For example, let us suppose a
// node "N" in function "FN". "N" has an attr "A" with value
// placeholder = "foo". When FN is instantiated with attr "foo"
// set to "bar", the instantiated node N's attr A will have been
// given the value "bar".
string placeholder = 9;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NameAttrList {
string name = 1;
map<string, AttrValue> attr = 2;
}

+ 0
- 108
ge/common/proto/tensorflow/function.proto View File

@@ -1,108 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "FunctionProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "node_def.proto";
import "op_def.proto";

// A library is a set of named functions.
message FunctionDefLibrary {
repeated FunctionDef function = 1;
repeated GradientDef gradient = 2;
}

// A function can be instantiated when the runtime can bind every attr
// with a value. When a GraphDef has a call to a function, it must
// have binding for every attr defined in the signature.
// * device spec, etc.
message FunctionDef {
// The definition of the function's name, arguments, return values,
// attrs etc.
OpDef signature = 1;

// Attributes specific to this function definition.
map<string, AttrValue> attr = 5;

// NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21.
reserved 2;

// In both of the following fields, there is the need to specify an
// output that is used as either the input to another node (in
// `node_def`) or as a return value of the function (in `ret`).
// Unlike the NodeDefs in GraphDef, we need to be able to specify a
// list in some cases (instead of just single outputs). Also, we
// need to be able to deal with lists of unknown length (so the
// output index may not be known at function definition time). So
// we use the following format instead:
// * "fun_in" where "fun_in" is the name of a function input arg in
// the `signature` field above. This represents that input, whether
// it is a single tensor or a list.
// * "fun_in:0" gives the first element of a function input arg (a
// non-list input is considered a list of length 1 for these
// purposes).
// * "node:out" where "node" is the name of a node in `node_def` and
// "out" is the name one of its op's output arguments (the name
// comes from the OpDef of the node's op). This represents that
// node's output, whether it is a single tensor or a list.
// Note: We enforce that an op's output arguments are never
// renamed in the backwards-compatibility test.
// * "node:out:0" gives the first element of a node output arg (a
// non-list output is considered a list of length 1 for these
// purposes).
//
// NOT CURRENTLY SUPPORTED (but may be in the future):
// * "node:out:-1" gives last element in a node output list
// * "node:out:1:" gives a list with all but the first element in a
// node output list
// * "node:out::-1" gives a list with all but the last element in a
// node output list

// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
// may have values of type `placeholder` and the `input` field uses
// the "output" format above.

// By convention, "op" in node_def is resolved by consulting with a
// user-defined library first. If not resolved, "func" is assumed to
// be a builtin op.
repeated NodeDef node_def = 3;

// A mapping from the output arg names from `signature` to the
// outputs from `node_def` that should be returned by the function.
map<string, string> ret = 4;
}

// GradientDef defines the gradient function of a function defined in
// a function library.
//
// A gradient function g (specified by gradient_func) for a function f
// (specified by function_name) must follow the following:
//
// The function 'f' must be a numerical function which takes N inputs
// and produces M outputs. Its gradient function 'g', which is a
// function taking N + M inputs and produces N outputs.
//
// I.e. if we have
// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
// then, g is
// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
// dL/dy1, dL/dy2, ..., dL/dy_M),
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
// loss function). dL/dx_i is the partial derivative of L with respect
// to x_i.
message GradientDef {
string function_name = 1; // The function name.
string gradient_func = 2; // The gradient function's name.
}

+ 0
- 64
ge/common/proto/tensorflow/graph.proto View File

@@ -1,64 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "GraphProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "node_def.proto";
import "function.proto";
import "versions.proto";

// Represents the graph of operations
message GraphDef {
repeated NodeDef node = 1;

// Compatibility versions of the graph. See core/public/version.h for version
// history. The GraphDef version is distinct from the TensorFlow version, and
// each release of TensorFlow will support a range of GraphDef versions.
VersionDef versions = 4;

// Deprecated single version field; use versions above instead. Since all
// GraphDef changes before "versions" was introduced were forward
// compatible, this field is entirely ignored.
int32 version = 3 [deprecated = true];

// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
//
// "library" provides user-defined functions.
//
// Naming:
// * library.function.name are in a flat namespace.
// NOTE: We may need to change it to be hierarchical to support
// different orgs. E.g.,
// { "/google/nn", { ... }},
// { "/google/vision", { ... }}
// { "/org_foo/module_bar", { ... }}
// map<string, FunctionDefLib> named_lib;
// * If node[i].op is the name of one function in "library",
// node[i] is deemed as a function call. Otherwise, node[i].op
// must be a primitive operation supported by the runtime.
//
//
// Function call semantics:
//
// * The callee may start execution as soon as some of its inputs
// are ready. The caller may want to use Tuple() mechanism to
// ensure all inputs are ready in the same time.
//
// * The consumer of return values may start executing as soon as
// the return values the consumer depends on are ready. The
// consumer may want to use Tuple() mechanism to ensure the
// consumer does not start until all return values of the callee
// function are ready.
FunctionDefLibrary library = 2;
};

+ 0
- 22
ge/common/proto/tensorflow/graph_library.proto View File

@@ -1,22 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;

import "graph.proto";

message GeGraphDef {
string name = 1;
GraphDef graph = 2;
}

message GraphDefLibrary {
repeated GeGraphDef graph_def = 1;
};

+ 0
- 71
ge/common/proto/tensorflow/node_def.proto View File

@@ -1,71 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "NodeProto";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";

message NodeDef {
// The name given to this operator. Used for naming inputs,
// logging, visualization, etc. Unique within a single GraphDef.
// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
string name = 1;

// The operation name. There may be custom parameters in attrs.
// Op names starting with an underscore are reserved for internal use.
string op = 2;

// Each input is "node:src_output" with "node" being a string name and
// "src_output" indicating which output tensor to use from "node". If
// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
// may optionally be followed by control inputs that have the format
// "^node".
repeated string input = 3;

// A (possibly partial) specification for the device on which this
// node should be placed.
// The expected syntax for this string is as follows:
//
// DEVICE_SPEC ::= PARTIAL_SPEC
//
// PARTIAL_SPEC ::= ("/" CONSTRAINT) *
// CONSTRAINT ::= ("job:" JOB_NAME)
// | ("replica:" [1-9][0-9]*)
// | ("task:" [1-9][0-9]*)
// | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") )
//
// Valid values for this string include:
// * "/job:worker/replica:0/task:1/device:GPU:3" (full specification)
// * "/job:worker/device:GPU:3" (partial specification)
// * "" (no specification)
//
// If the constraints do not resolve to a single device (or if this
// field is empty or not present), the runtime will attempt to
// choose a device automatically.
string device = 4;

// Operation-specific graph-construction-time configuration.
// Note that this should include all attrs defined in the
// corresponding OpDef, including those with a value matching
// the default -- this allows the default to change and makes
// NodeDefs easier to interpret on their own. However, if
// an attr with a default is not specified in this list, the
// default will be used.
// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
// one of the names from the corresponding OpDef's attr field).
// The values must have a type matching the corresponding OpDef
// attr's type field.
// Add some examples here showing best practices.
map<string, AttrValue> attr = 5;
};

+ 0
- 172
ge/common/proto/tensorflow/op_def.proto View File

@@ -1,172 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "OpDefProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "types.proto";

// Defines an operation. A NodeDef in a GraphDef specifies an Op by
// using the "op" field which should match the name of a OpDef.
// LINT.IfChange
message OpDef {
// Op names starting with an underscore are reserved for internal use.
// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
string name = 1;

// For describing inputs and outputs.
message ArgDef {
// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*".
string name = 1;

// Human readable description.
string description = 2;

// Describes the type of one or more tensors that are accepted/produced
// by this input/output arg. The only legal combinations are:
// * For a single tensor: either the "type" field is set or the
// "type_attr" field is set to the name of an attr with type "type".
// * For a sequence of tensors with the same type: the "number_attr"
// field will be set to the name of an attr with type "int", and
// either the "type" or "type_attr" field will be set as for
// single tensors.
// * For a sequence of tensors, the "type_list_attr" field will be set
// to the name of an attr with type "list(type)".
DataType type = 3;
string type_attr = 4; // if specified, attr must have type "type"
string number_attr = 5; // if specified, attr must have type "int"
// If specified, attr must have type "list(type)", and none of
// type, type_attr, and number_attr may be specified.
string type_list_attr = 6;

// For inputs: if true, the inputs are required to be refs.
// By default, inputs can be either refs or non-refs.
// For outputs: if true, outputs are refs, otherwise they are not.
bool is_ref = 16;
};

// Description of the input(s).
repeated ArgDef input_arg = 2;

// Description of the output(s).
repeated ArgDef output_arg = 3;

// Description of the graph-construction-time configuration of this
// Op. That is to say, this describes the attr fields that will
// be specified in the NodeDef.
message AttrDef {
// A descriptive name for the argument. May be used, e.g. by the
// Python client, as a keyword argument name, and so should match
// the regexp "[a-z][a-z0-9_]+".
string name = 1;

// One of the type names from attr_value.proto ("string", "list(string)",
// "int", etc.).
string type = 2;

// A reasonable default for this attribute if the user does not supply
// a value. If not specified, the user must supply a value.
AttrValue default_value = 3;

// Human-readable description.
string description = 4;


// --- Constraints ---
// These constraints are only in effect if specified. Default is no
// constraints.

// For type == "int", this is a minimum value. For "list(___)"
// types, this is the minimum length.
bool has_minimum = 5;
int64 minimum = 6;

// The set of allowed values. Has type that is the "list" version
// of the "type" field above (uses the "list" field of AttrValue).
// If type == "type" or "list(type)" above, then the "type" field
// of "allowed_values.list" has the set of allowed DataTypes.
// If type == "string" or "list(string)", then the "s" field of
// "allowed_values.list" has the set of allowed strings.
AttrValue allowed_values = 7;
}
repeated AttrDef attr = 4;

// Optional deprecation based on GraphDef versions.
OpDeprecation deprecation = 8;

// One-line human-readable description of what the Op does.
string summary = 5;

// Additional, longer human-readable description of what the Op does.
string description = 6;

// -------------------------------------------------------------------------
// Which optimizations this operation can participate in.

// True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs)
bool is_commutative = 18;

// If is_aggregate is true, then this operation accepts N >= 2
// inputs and produces 1 output all of the same type. Should be
// associative and commutative, and produce output with the same
// shape as the input. The optimizer may replace an aggregate op
// taking input from multiple devices with a tree of aggregate ops
// that aggregate locally within each device (and possibly within
// groups of nearby devices) before communicating.
bool is_aggregate = 16; // for things like add

// Other optimizations go here, like
// can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc.

// -------------------------------------------------------------------------
// Optimization constraints.

// Ops are marked as stateful if their behavior depends on some state beyond
// their input tensors (e.g. variable reading op) or if they have
// a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
// must always produce the same output for the same input and have
// no side-effects.
//
// By default Ops may be moved between devices. Stateful ops should
// either not be moved, or should only be moved if that state can also
// be moved (e.g. via some sort of save / restore).
// Stateful ops are guaranteed to never be optimized away by Common
// Subexpression Elimination (CSE).
bool is_stateful = 17; // for things like variables, queue

// -------------------------------------------------------------------------
// Non-standard options.

// By default, all inputs to an Op must be initialized Tensors. Ops
// that may initialize tensors for the first time should set this
// field to true, to allow the Op to take an uninitialized Tensor as
// input.
bool allows_uninitialized_input = 19; // for Assign, etc.
};
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc)

// Information about version-dependent deprecation of an op
message OpDeprecation {
// First GraphDef version at which the op is disallowed.
int32 version = 1;

// Explanation of why it was deprecated and what to use instead.
string explanation = 2;
};

// A collection of OpDefs
message OpList {
repeated OpDef op = 1;
};

+ 0
- 37
ge/common/proto/tensorflow/resource_handle.proto View File

@@ -1,37 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "ResourceHandle";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// Protocol buffer representing a handle to a tensorflow resource. Handles are
// not valid across executions, but can be serialized back and forth from within
// a single run.
message ResourceHandleProto {
// Unique name for the device containing the resource.
string device = 1;

// Container in which this resource is placed.
string container = 2;

// Unique name of this resource.
string name = 3;

// Hash code for the type of the resource. Is only valid in the same device
// and in the same execution.
uint64 hash_code = 4;

// For debug-only, the name of the type pointed to by this handle, if
// available.
string maybe_type_name = 5;
};

+ 0
- 102
ge/common/proto/tensorflow/tensor.proto View File

@@ -1,102 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TensorProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "resource_handle.proto";
import "tensor_shape.proto";
import "types.proto";

// Protocol buffer representing a tensor.
message TensorProto {
DataType dtype = 1;

// Shape of the tensor.
TensorShapeProto tensor_shape = 2;

// Only one of the representations below is set, one of "tensor_contents" and
// the "xxx_val" attributes. We are not using oneof because as oneofs cannot
// contain repeated fields it would require another extra set of messages.

// Version number.
//
// In version 0, if the "repeated xxx" representations contain only one
// element, that element is repeated to fill the shape. This makes it easy
// to represent a constant Tensor with a single value.
int32 version_number = 3;

// Serialized raw tensor content from either Tensor::AsProtoTensorContent or
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
// can be used for all tensor types. The purpose of this representation is to
// reduce serialization overhead during RPC call by avoiding serialization of
// many repeated small items.
bytes tensor_content = 4;

// Type specific representations that make it easy to create tensor protos in
// all languages. Only the representation corresponding to "dtype" can
// be set. The values hold the flattened representation of the tensor in
// row major order.

// DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll
// have some pointless zero padding for each value here.
repeated int32 half_val = 13 [packed = true];

// DT_FLOAT.
repeated float float_val = 5 [packed = true];

// DT_DOUBLE.
repeated double double_val = 6 [packed = true];

// DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
repeated int32 int_val = 7 [packed = true];

// DT_STRING
repeated bytes string_val = 8;

// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
// and imaginary parts of i-th single precision complex.
repeated float scomplex_val = 9 [packed = true];

// DT_INT64
repeated int64 int64_val = 10 [packed = true];

// DT_BOOL
repeated bool bool_val = 11 [packed = true];

// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
// and imaginary parts of i-th double precision complex.
repeated double dcomplex_val = 12 [packed = true];

// DT_RESOURCE
repeated ResourceHandleProto resource_handle_val = 14;

// DT_VARIANT
repeated VariantTensorDataProto variant_val = 15;

// DT_UINT32
repeated uint32 uint32_val = 16 [packed = true];

// DT_UINT64
repeated uint64 uint64_val = 17 [packed = true];
};

// Protocol buffer representing the serialization format of DT_VARIANT tensors.
message VariantTensorDataProto {
// Name of the type of objects being serialized.
string type_name = 1;
// Portions of the object that are not Tensors.
bytes metadata = 2;
// Tensors contained within objects being serialized.
repeated TensorProto tensors = 3;
}

+ 0
- 53
ge/common/proto/tensorflow/tensor_shape.proto View File

@@ -1,53 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

// Protocol buffer representing the shape of tensors.

syntax = "proto3";
option cc_enable_arenas = true;
option java_outer_classname = "TensorShapeProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

package domi.tensorflow;

// Dimensions of a tensor.
message TensorShapeProto {
// One dimension of the tensor.
message Dim {
// Size of the tensor in that dimension.
// This value must be >= -1, but values of -1 are reserved for "unknown"
// shapes (values of -1 mean "unknown" dimension). Certain wrappers
// that work with TensorShapeProto may fail at runtime when deserializing
// a TensorShapeProto containing a dim value of -1.
int64 size = 1;

// Optional name of the tensor dimension.
string name = 2;
};

// Dimensions of the tensor, such as {"input", 30}, {"output", 40}
// for a 30 x 40 2D tensor. If an entry has size -1, this
// corresponds to a dimension of unknown size. The names are
// optional.
//
// The order of entries in "dim" matters: It indicates the layout of the
// values in the tensor in-memory representation.
//
// The first entry in "dim" is the outermost dimension used to layout the
// values, the last entry is the innermost dimension. This matches the
// in-memory layout of RowMajor Eigen tensors.
//
// If "dim.size()" > 0, "unknown_rank" must be false.
repeated Dim dim = 2;

// If true, the number of dimensions in the shape is unknown.
//
// If true, "dim.size()" must be 0.
bool unknown_rank = 3;
};

+ 0
- 82
ge/common/proto/tensorflow/types.proto View File

@@ -1,82 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TypesProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// LINT.IfChange
enum DataType {
// Not a legal value for DataType. Used to indicate a DataType field
// has not been set.
DT_INVALID = 0;

// Data types that all computation devices are expected to be
// capable to support.
DT_FLOAT = 1;
DT_DOUBLE = 2;
DT_INT32 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_INT8 = 6;
DT_STRING = 7;
DT_COMPLEX64 = 8; // Single-precision complex
DT_INT64 = 9;
DT_BOOL = 10;
DT_QINT8 = 11; // Quantized int8
DT_QUINT8 = 12; // Quantized uint8
DT_QINT32 = 13; // Quantized int32
DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
DT_QINT16 = 15; // Quantized int16
DT_QUINT16 = 16; // Quantized uint16
DT_UINT16 = 17;
DT_COMPLEX128 = 18; // Double-precision complex
DT_HALF = 19;
DT_RESOURCE = 20;
DT_VARIANT = 21; // Arbitrary C++ data types
DT_UINT32 = 22;
DT_UINT64 = 23;

// Do not use! These are only for parameters. Every enum above
// should have a corresponding value below (verified by types_test).
DT_FLOAT_REF = 101;
DT_DOUBLE_REF = 102;
DT_INT32_REF = 103;
DT_UINT8_REF = 104;
DT_INT16_REF = 105;
DT_INT8_REF = 106;
DT_STRING_REF = 107;
DT_COMPLEX64_REF = 108;
DT_INT64_REF = 109;
DT_BOOL_REF = 110;
DT_QINT8_REF = 111;
DT_QUINT8_REF = 112;
DT_QINT32_REF = 113;
DT_BFLOAT16_REF = 114;
DT_QINT16_REF = 115;
DT_QUINT16_REF = 116;
DT_UINT16_REF = 117;
DT_COMPLEX128_REF = 118;
DT_HALF_REF = 119;
DT_RESOURCE_REF = 120;
DT_VARIANT_REF = 121;
DT_UINT32_REF = 122;
DT_UINT64_REF = 123;
}
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/c/c_api.h,
// https://www.tensorflow.org/code/tensorflow/go/tensor.go,
// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.h,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py)

+ 0
- 39
ge/common/proto/tensorflow/versions.proto View File

@@ -1,39 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "VersionsProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// Version information for a piece of serialized data
//
// There are different types of versions for each type of data
// (GraphDef, etc.), but they all have the same common shape
// described here.
//
// Each consumer has "consumer" and "min_producer" versions (specified
// elsewhere). A consumer is allowed to consume this data if
//
// producer >= min_producer
// consumer >= min_consumer
// consumer not in bad_consumers
//
message VersionDef {
// The version of the code that produced this data.
int32 producer = 1;

// Any consumer below this version is not allowed to consume this data.
int32 min_consumer = 2;

// Specific consumer versions which are disallowed (e.g. due to bugs).
repeated int32 bad_consumers = 3;
};

+ 1
- 0
ge/executor/CMakeLists.txt View File

@@ -37,6 +37,7 @@ set(SRC_LIST
"../graph/load/model_manager/task_info/task_info.cc"
"../graph/load/model_manager/task_info/event_record_task_info.cc"
"../graph/load/model_manager/task_info/event_wait_task_info.cc"
"../graph/load/model_manager/task_info/ffts_task_info.cc"
"../graph/load/model_manager/task_info/fusion_start_task_info.cc"
"../graph/load/model_manager/task_info/fusion_stop_task_info.cc"
"../graph/load/model_manager/task_info/kernel_ex_task_info.cc"


+ 0
- 113
ge/executor/proto/dump_task.proto View File

@@ -1,113 +0,0 @@
syntax = "proto3";
package toolkit.dump;

enum OutputDataType {
DT_UNDEFINED = 0;
DT_FLOAT = 1;
DT_FLOAT16 = 2;
DT_INT8 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_UINT16 = 6;
DT_INT32 = 7;
DT_INT64 = 8;
DT_UINT32 = 9;
DT_UINT64 = 10;
DT_BOOL = 11;
DT_DOUBLE = 12;
DT_STRING = 13;
DT_DUAL_SUB_INT8 = 14;
DT_DUAL_SUB_UINT8 = 15;
DT_COMPLEX64 = 16;
DT_COMPLEX128 = 17;
DT_QINT8 = 18;
DT_QINT16 = 19;
DT_QINT32 = 20;
DT_QUINT8 = 21;
DT_QUINT16 = 22;
DT_RESOURCE = 23;
DT_STRING_REF = 24;
DT_DUAL = 25;
DT_VARIANT = 26;
}

enum OutputFormat {
FORMAT_NCHW = 0;
FORMAT_NHWC = 1;
FORMAT_ND = 2;
FORMAT_NC1HWC0 = 3;
FORMAT_FRACTAL_Z = 4;
FORMAT_NC1C0HWPAD = 5;
FORMAT_NHWC1C0 = 6;
FORMAT_FSR_NCHW = 7;
FORMAT_FRACTAL_DECONV = 8;
FORMAT_C1HWNC0 = 9;
FORMAT_FRACTAL_DECONV_TRANSPOSE = 10;
FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11;
FORMAT_NC1HWC0_C04 = 12;
FORMAT_FRACTAL_Z_C04 = 13;
FORMAT_CHWN = 14;
FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15;
FORMAT_HWCN = 16;
FORMAT_NC1KHKWHWC0 = 17;
FORMAT_BN_WEIGHT = 18;
FORMAT_FILTER_HWCK = 19;
FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20;
FORMAT_HASHTABLE_LOOKUP_KEYS = 21;
FORMAT_HASHTABLE_LOOKUP_VALUE = 22;
FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23;
FORMAT_HASHTABLE_LOOKUP_HITS=24;
FORMAT_C1HWNCoC0 = 25;
FORMAT_MD = 26;
FORMAT_NDHWC = 27;
FORMAT_FRACTAL_ZZ = 28;
FORMAT_FRACTAL_NZ = 29;
FORMAT_RESERVED = 30;
}

message OriginalOp {
string name = 1;
uint32 output_index = 2;
OutputDataType data_type = 3;
OutputFormat format = 4;
}

message Shape {
repeated uint64 dim = 1;
}

message OpOutput {
OutputDataType data_type = 1;
OutputFormat format = 2;
Shape shape = 3;
OriginalOp original_op = 4; // the original op corresponding to the output
bytes data = 5;
uint64 size = 6;
}

message OpInput {
OutputDataType data_type = 1;
OutputFormat format = 2;
Shape shape = 3;
bytes data = 4;
uint64 size = 5;
}

enum BufferType {
L1 = 0;
}

message OpBuffer {
BufferType buffer_type = 1;
bytes data = 2;
uint64 size = 3;
}

message DumpData{
string version = 1;
uint64 dump_time = 2;
repeated OpOutput output = 3;
repeated OpInput input = 4;
repeated OpBuffer buffer = 5;
string op_name = 6;
}

+ 0
- 193
ge/executor/proto/ge_ir.proto View File

@@ -1,193 +0,0 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
DT_VARIANT = 26; // variant type
DT_BF16 = 27; // bf16 type
DT_INT4 = 28; // int4 type
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 0
- 140
ge/executor/proto/insert_op.proto View File

@@ -1,140 +0,0 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
float padding_value = 72;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 0
- 396
ge/executor/proto/om.proto View File

@@ -1,396 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 0
- 75
ge/executor/proto/op_mapping.proto View File

@@ -1,75 +0,0 @@
syntax = "proto3";
package toolkit.aicpu.dump;

message Shape {
repeated uint64 dim = 1;
}

message Output {
int32 data_type = 1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
string original_name = 5;
int32 original_output_index = 6;
int32 original_output_data_type = 7;
int32 original_output_format = 8;
uint64 size = 9;
Shape origin_shape = 10;
}

message Input {
int32 data_type =1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
uint64 size = 5;
Shape origin_shape = 6;
}

enum BufferType {
L1 = 0;
}

message OpBuffer {
BufferType buffer_type = 1;
uint64 address = 2;
uint64 size = 3;
}

message Op {
string op_name = 1;
string op_type = 2;
}

message Task {
uint32 task_id = 1;
uint32 stream_id = 2;
Op op = 3;
repeated Output output = 4;
bool end_graph = 5;
repeated Input input = 6;
repeated OpBuffer buffer = 7;
}

message OpMappingInfo {
string dump_path = 1;
oneof model_name_param {
string model_name = 2;
}
oneof model_id_param {
uint32 model_id = 3;
}
oneof step_id {
uint64 step_id_addr = 4;
}
oneof iterations_per_loop {
uint64 iterations_per_loop_addr = 5;
}
oneof loop_cond {
uint64 loop_cond_addr = 6;
}
uint32 flag = 7; // 0x01 load, 0x00 unload
repeated Task task = 8;
string dump_step = 9;
}

+ 0
- 179
ge/executor/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 4
- 7
ge/ge_local_engine/engine/host_cpu_engine.cc View File

@@ -13,15 +13,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "host_cpu_engine.h"
#include "graph/common/omg_util.h"
#include "ge_local_engine/engine/host_cpu_engine.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_adapter.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/type_utils.h"
#include "register/op_kernel_registry.h"
#include "register/host_cpu_context.h"
#include "common/ge/ge_util.h"
#include "common/ge/plugin_manager.h"
#include "graph/utils/type_utils.h"
#include "common/fp16_t.h"
#include "common/math/math_util.h"

@@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) {
}

Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) {
std::string op_type;
auto status = GetOriginalType(node, op_type);
GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status);

const std::string op_type = NodeUtils::GetNodeType(node);
auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type);
if (kernel == nullptr) {
GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str());


+ 0
- 179
ge/ge_local_engine/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 58
- 0
ge/ge_opt_info/ge_opt_info.cc View File

@@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_opt_info/ge_opt_info.h"

#include <string>
#include <map>
#include "graph/ge_local_context.h"
#include "ge/ge_api_types.h"
#include "common/debug/ge_log.h"
#include "opt_info.h"

namespace ge {
Status GeOptInfo::SetOptInfo() {
std::string soc_ver;
graphStatus ret = GetThreadLocalContext().GetOption(SOC_VERSION, soc_ver);
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Get soc version failed.");
GELOGE(FAILED, "[Get][SocVersion]Get soc version failed.");
return FAILED;
}
GELOGD("Soc version:%s.", soc_ver.c_str());
std::map<std::string, std::string> opt_info;
// the first arg does not work at present.
if (gelc::GetOptInfo(gelc::kOffline, soc_ver, opt_info) != gelc::SUCCESS) {
REPORT_CALL_ERROR("E19999", "Get optional information failed, is_offline:%d, soc version:%s",
gelc::kOffline, soc_ver.c_str());
GELOGE(FAILED, "[Get][OptInfo]Get optional information failed, is_offline:%d, soc version:%s",
gelc::kOffline, soc_ver.c_str());
return FAILED;
}
// do nothing if get empty information
if (opt_info.empty()) {
GELOGI("Optional information is empty.");
return SUCCESS;
}
std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions();
for (const auto &itr : opt_info) {
graph_options.emplace(itr.first, itr.second);
GELOGI("Get optional information success, key:%s, value:%s.", itr.first.c_str(), itr.second.c_str());
}
GetThreadLocalContext().SetGraphOption(graph_options);
return SUCCESS;
}
} // namespace ge

+ 32
- 0
ge/ge_opt_info/ge_opt_info.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_OPT_INFO_GE_OPT_INFO_H_
#define GE_OPT_INFO_GE_OPT_INFO_H_

#include "ge/ge_api_error_codes.h"
#include "register/register_types.h"

namespace ge {
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeOptInfo {
public:
GeOptInfo() = default;
~GeOptInfo() = default;
static Status SetOptInfo();
};
} // namespace ge

#endif // GE_OPT_INFO_GE_OPT_INFO_H_

+ 1
- 1
ge/ge_runtime/task/label_goto_task.cc View File

@@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() {
return false;
}

rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size);
rt_ret = rtLabelListCpy(reinterpret_cast<void**>(label_list.data()), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;


+ 13
- 8
ge/generator/ge_generator.cc View File

@@ -674,6 +674,12 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
GELOGD("Current ctx is null.");
ctx = nullptr;
}
std::function<void()> callback = [&]() {
if (ctx != nullptr) {
(void)rtCtxSetCurrent(ctx);
}
};
GE_MAKE_GUARD(restore, callback);

GeRootModelPtr ge_root_model = nullptr;
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
@@ -712,11 +718,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
}
return ret;
}

if (ctx != nullptr) {
(void)rtCtxSetCurrent(ctx);
}

return SUCCESS;
}

@@ -806,7 +807,7 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor>
return SUCCESS;
}

Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) {
Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph) {
GE_CHECK_NOTNULL(op_desc);
if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) {
auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType());
@@ -830,7 +831,11 @@ Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) {
}
node_op.BreakConnect();
}
auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc);
auto comp_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(comp_graph);
auto node = comp_graph->FindNode(op_desc->GetName());
GE_CHECK_NOTNULL(node);
auto op = OpDescUtils::CreateOperatorFromNode(node);
auto ret = op_desc->CallInferFormatFunc(op);
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail",
@@ -877,7 +882,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
Graph graph;
GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph),
"[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str());
GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc));
GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc, graph));

// 2. check engine type when compile online
if (model_file_name == kFileNameSuffix) {


+ 5
- 0
ge/graph/build/label_allocator.cc View File

@@ -86,6 +86,11 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::set<Node
return false;
}

if (func_node->GetOpDesc() != nullptr && func_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) {
GELOGD("Graph[%s] is ffts subgraph, skip label allocator.", graph->GetName().c_str());
return true;
}

ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph();
if (owner_graph == nullptr) {
REPORT_INNER_ERROR("E19999", "ComputeGraph owner not set in node:%s(%s), graph:%s",


+ 5
- 0
ge/graph/build/logical_stream_allocator.cc View File

@@ -474,6 +474,11 @@ Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vector<Subgr
for (ge::NodePtr &node : graph->GetDirectNode()) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID)) {
op_desc->SetStreamId(kInvalidStream);
GELOGI("Ffts node %s of type %s reassign to invalid stream.", node->GetName().c_str(), node->GetType().c_str());
continue;
}
int64_t stream_id = op_desc->GetStreamId();
if (ops_without_label.find(op_desc) != ops_without_label.end()) {
if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) {


+ 1
- 1
ge/graph/build/model_builder.cc View File

@@ -707,7 +707,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) {
if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) {
GE_CHECK_NOTNULL(kernel_buffer.GetData());
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize());
tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data));
tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data));
GE_CHECK_NOTNULL(tbe_kernel);
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(),
node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str());


+ 10
- 1
ge/graph/build/stream_allocator.cc View File

@@ -432,7 +432,11 @@ Status StreamAllocator::SetActiveStreamsForSubgraphs() {

// Insert the send/recv event id to the graph
Status StreamAllocator::InsertSyncEvents() {
for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) {
auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) {
return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH);
};

for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag(), nullptr, ffts_filter)) {
// Take the adjacent points, then judge whether need to insert the event
for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) {
for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
@@ -531,6 +535,11 @@ Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const
Status StreamAllocator::InsertEventsForSubgraph() {
for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) {
GE_CHECK_NOTNULL(subgraph);
const auto parent_node = subgraph->GetParentNode();
if (parent_node != nullptr && parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) {
GELOGD("Skip ffts subgraph, parent node is %s.", parent_node->GetName().c_str());
continue;
}
for (const auto &node : subgraph->GetDirectNode()) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);


+ 47
- 29
ge/graph/build/task_generator.cc View File

@@ -354,7 +354,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra
};
GE_MAKE_GUARD(release, callback);

for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) {
return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH);
};
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag(), nullptr, ffts_filter)) {
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
node_index++;
@@ -380,10 +383,8 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra
GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str());
continue;
}
if (op_kernel_lib_name.empty()) {
GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str());
continue;
}
GE_CHK_BOOL_EXEC_INFO(!op_kernel_lib_name.empty(), continue,
"Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str());
auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name);
if (kernel_info_store == nullptr) {
REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s",
@@ -394,6 +395,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra
}
GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(),
type.c_str());
if (node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) {
GE_CHK_STATUS_RET(UpdateAnchorStatusForFfts(node), "[Call][UpdateAnchorStatusForFfts] node:%s(%s) failed",
name.c_str(), type.c_str());
}
// Profiling task
size_t task_list_size_before = task_def_list.size();
GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list));
@@ -571,7 +576,24 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info
return ret;
}

Status TaskGenerator::UpdateAnchorStatusForFfts(const NodePtr &node) {
GELOGD("Start UpdateAnchorStatusForFfts for %s.", node->GetName().c_str());
if (!node->GetOpDesc()->GetSubgraphInstanceNames().empty()) {
for (size_t i = 0; i < node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) {
auto sub_graph = NodeUtils::GetSubgraph(*node, i);
GE_CHECK_NOTNULL(sub_graph);
GELOGD("Start update anchor status for %s.", sub_graph->GetName().c_str());
for (auto &ffts_node : sub_graph->GetDirectNode()) {
GE_CHK_STATUS_RET(UpdateAnchorStatus(ffts_node), "[Call][UpdateAnchorStatus] node:%s(%s) failed",
ffts_node->GetName().c_str(), ffts_node->GetType().c_str());
}
}
}
return SUCCESS;
}

Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) {
GELOGD("Start UpdateAnchorStatus for %s.", node->GetName().c_str());
if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "SetAllAnchorStatus fail for op:%s(%s)",
node->GetName().c_str(), node->GetType().c_str());
@@ -771,7 +793,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
GELOGI("Start AutoFindBpOpIndex");
NodePtr bp_node = nullptr;
uint32_t current_idx = 0;
uint32_t netoutput_idx = 0;
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
@@ -789,7 +810,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) {
if (bp_node == nullptr) {
bp_node = node;
netoutput_idx = current_idx - 1;
}
}
if (graph->GetNeedIteration()) {
@@ -814,34 +834,30 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
if (bp_node == nullptr) {
GELOGW("not find bp_node.");
return SUCCESS;
} else if (bp_node->GetName() == NODE_NAME_NET_OUTPUT) {
profiling_point.bp_index = netoutput_idx;
GELOGI("First bp name %s, idx %u", bp_node->GetName().c_str(), netoutput_idx);
} else {
profiling_point.bp_index = FindLastBpFromBpNode(graph, bp_node);
}

return SUCCESS;
return FindLastBpFromBpNode(graph, bp_node, profiling_point.bp_index);
}

uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const {
uint32_t last_bp = 0;
Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &target_node,
uint32_t &bp_index) const {
bp_index = 0;
auto target_desc = target_node->GetOpDesc();
GE_CHECK_NOTNULL(target_desc);
OpDescPtr bp_op_desc = nullptr;
for (auto &in_anchor : bp_node->GetAllInDataAnchors()) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) {
continue;
for (auto &in_node : target_node->GetInAllNodes()) {
GE_CHECK_NOTNULL(in_node);
auto in_node_desc = in_node->GetOpDesc();
GE_CHECK_NOTNULL(in_node_desc);
if ((bp_op_desc == nullptr || (in_node_desc->GetId() > bp_op_desc->GetId())) &&
(in_node_desc->GetStreamId() == target_desc->GetStreamId())){
bp_op_desc = in_node_desc;
}
auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc();
GE_CHECK_NOTNULL(out_node_desc);
if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) {
bp_op_desc = out_node_desc;
}
GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId());
}

if (bp_op_desc == nullptr) {
return last_bp;
GELOGI("Did not find bp node.");
return SUCCESS;
}
uint32_t current_idx = 0;
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
@@ -849,12 +865,14 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const
GE_CHECK_NOTNULL(op_desc);
current_idx++;
if (op_desc->GetName() == bp_op_desc->GetName()) {
last_bp = current_idx;
GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp);
bp_index = current_idx;
GELOGI("Find bp name %s, idx %u", op_desc->GetName().c_str(), bp_index);
break;
}
}
return last_bp;
GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(),
bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId());
return SUCCESS;
}

Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,


+ 2
- 1
ge/graph/build/task_generator.h View File

@@ -80,6 +80,7 @@ class TaskGenerator {
Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
std::vector<uint32_t> &all_reduce_nodes);
private:
Status UpdateAnchorStatusForFfts(const NodePtr &node);
Status UpdateAnchorStatus(const NodePtr &node);

Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id);
@@ -115,7 +116,7 @@ class TaskGenerator {
Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const;
Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
vector<uint32_t> &all_reduce_nodes) const;
uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const;
Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const;

Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,
ProfilingPoint &profiling_point) const;


+ 0
- 15
ge/graph/common/omg_util.cc View File

@@ -275,21 +275,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) {
}

///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) {
if (!force_unknown) {
return;
}

SetControlFlowGroup(node, group_index);
}

///
/// @brief Set Op _control_flow_group flag
/// @param [in] node
/// @param [in] group, condition group index of node.


+ 0
- 9
ge/graph/common/omg_util.h View File

@@ -126,15 +126,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size);
bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc);

///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index);

///
/// @brief Set Op _control_flow_group flag
/// @param [in] node
/// @param [in] group, condition group index of node.


+ 175
- 58
ge/graph/load/model_manager/davinci_model.cc View File

@@ -99,6 +99,9 @@ const uint32_t kEndOfSequenceNew = 507005;
const int32_t kModelAbortNormal = 0x0704000e;
const int32_t kModelAbortNormalNew = 507024;
const uint32_t kInteval = 2;
const uint32_t kFftsTbeHandleElementSize = 2;
const uint32_t kNonTailBlock = 0;
const uint32_t kTailBlock = 1;
const char *const kModelName = "model_name";
const char *const kModeleId = "model_id";
const char *const kLoadStartTime = "load_start_time";
@@ -116,14 +119,15 @@ const char *const kWorkSpaceSize = "workspace_size";
const char *const kTotalSize = "total_size";
const char *const kTaskCount = "task_count";
const char *const kTaskId = "task_id";
const char* const kRequestId = "request_id";
const char* const kThreadId = "thread_id";
const char* const kInputBeginTime = "input_begin_time";
const char* const kInputEndTime = "input_end_time";
const char* const kInferBeginTime = "infer_begin_time";
const char* const kInferEndTime = "infer_end_time";
const char* const kOutputBeginTime = "output_start_time";
const char* const kOutputEndTime = "output_end_time";
const char *const kRequestId = "request_id";
const char *const kThreadId = "thread_id";
const char *const kInputBeginTime = "input_begin_time";
const char *const kInputEndTime = "input_end_time";
const char *const kInferBeginTime = "infer_begin_time";
const char *const kInferEndTime = "infer_end_time";
const char *const kOutputBeginTime = "output_start_time";
const char *const kOutputEndTime = "output_end_time";
const char *const kStubFuncName = "_register_stub_func";
const uint32_t kStringHeadElems = 2;
const uint32_t kPlacementHostData = 0;
const size_t kAlignment = 64;
@@ -902,10 +906,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) {
SetLabelForDynamic(node);
auto it = op_desc_handle.find(op_desc->GetType());
if (it != op_desc_handle.end()) {
if ((this->*it->second)(op_desc) != SUCCESS) {
GELOGE(PARAM_INVALID, "[Init][Node] failed, Name:%s", op_desc->GetName().c_str());
return PARAM_INVALID;
}
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((this->*it->second)(op_desc) != SUCCESS, return PARAM_INVALID,
"[Init][Node] failed, Name:%s", op_desc->GetName().c_str());
continue;
}

@@ -935,7 +937,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) {

GE_TIMESTAMP_RESTART(InitTbeHandle);
if (IsTbeTask(op_desc)) {
Status status = InitTbeHandle(op_desc);
Status status =
op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID) ? InitTbeHandleWithFfts(op_desc) : InitTbeHandle(op_desc);
if (status != SUCCESS) {
GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str());
return status;
@@ -1477,6 +1480,11 @@ Status DavinciModel::GetLabelGotoAddr(uint32_t label_index, rtMemType_t mem_type
return SUCCESS;
}

void DavinciModel::SetGlobalStep(void *global_step, uint64_t global_step_size) {
global_step_addr_ = global_step;
global_step_size_ = global_step_size;
}

/// @ingroup ge
/// @brief LabelSet Op Initialize.
/// @param [in] op_desc: LabelSet Op descriptor.
@@ -1539,14 +1547,16 @@ Status DavinciModel::InitLabelSet(const OpDescPtr &op_desc) {
}

Status DavinciModel::InitVariable(const OpDescPtr &op_desc, map<string, OpDescPtr> &variable_by_name) {
if (op_desc->GetName() == NODE_NAME_GLOBAL_STEP) {
const auto output_sizes = ModelUtils::GetOutputSize(op_desc);
if (!output_sizes.empty()) {
global_step_size_ = output_sizes[0];
}
const auto output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc);
if (!output_addrs.empty()) {
global_step_addr_ = output_addrs[0];
if (!known_node_) {
if (op_desc->GetName() == NODE_NAME_GLOBAL_STEP) {
const auto output_sizes = ModelUtils::GetOutputSize(op_desc);
if (!output_sizes.empty()) {
global_step_size_ = output_sizes[0];
}
const auto output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc);
if (!output_addrs.empty()) {
global_step_addr_ = output_addrs[0];
}
}
}

@@ -3673,6 +3683,7 @@ Status DavinciModel::InitConstant(const OpDescPtr &op_desc) {
elem_num = 1;
}
uint64_t *buff = reinterpret_cast<uint64_t *>(tensor->MutableData().data());
GE_CHECK_NOTNULL(buff);
if (ge::CheckInt64Uint32MulOverflow(elem_num, kBytes * kStringHeadElems) != SUCCESS) {
GELOGE(FAILED, "[Call][CheckInt64Uint32MulOverflow] Shape size:%ld is invalid", elem_num);
return FAILED;
@@ -3700,6 +3711,7 @@ Status DavinciModel::InitConstant(const OpDescPtr &op_desc) {
/// @return Status
///
Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) {
string bin_file = op_desc->GetName();
auto kernel = ge_model_->GetTBEKernelStore().FindKernel(op_desc->GetName());
auto tbe_kernel = (kernel != nullptr) ? kernel : op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
if (tbe_kernel == nullptr) {
@@ -3708,12 +3720,61 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) {
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str());
return INTERNAL_ERROR;
}
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file, tbe_kernel, false), "Function register of bin file: %s failed",
bin_file.c_str());
return SUCCESS;
}

std::string session_graph_model_id;
GetUniqueId(op_desc, session_graph_model_id);
const char *bin_file_key = GetRegisterStub(op_desc->GetName(), session_graph_model_id); // from set, always valid.
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance();
Status DavinciModel::InitTbeHandleWithFfts(const OpDescPtr &op_desc) {
std::vector<OpKernelBinPtr> tbe_kernel;
tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_THREAD_TBE_KERNEL, tbe_kernel);
GELOGD("Kernel bin ptr vec size is %zu.", tbe_kernel.size());
if (tbe_kernel.size() != kFftsTbeHandleElementSize) {
REPORT_INNER_ERROR("E19999", "Get tbe_kernel for op:%s(%s) fail, model_id:%u",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file, size is %zu when ffts",
op_desc->GetName().c_str(), tbe_kernel.size());
return INTERNAL_ERROR;
}
if (tbe_kernel[0] == nullptr || tbe_kernel[1] == nullptr) {
REPORT_INNER_ERROR("E19999", "Tbe kernel for op:%s is nullptr.", op_desc->GetName().c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: tvm bin file of %s is nullptr when ffts.", op_desc->GetName().c_str());
return INTERNAL_ERROR;
}
vector<string> bin_file_keys;
(void)AttrUtils::GetListStr(op_desc, kStubFuncName, bin_file_keys);
if (bin_file_keys.size() != kFftsTbeHandleElementSize) {
REPORT_INNER_ERROR("E19999", "Get bin_file for op:%s(%s) fail.", op_desc->GetName().c_str(),
op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find bin file keys, size is %zu when ffts",
op_desc->GetName().c_str(), bin_file_keys.size());
return INTERNAL_ERROR;
}
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kNonTailBlock], tbe_kernel[kNonTailBlock], true,
kNonTailBlock),
"Function register of first bin file %s failed.", bin_file_keys[kNonTailBlock].c_str());
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kTailBlock], tbe_kernel[kTailBlock], true, kTailBlock),
"Function register of second bin file %s failed.", bin_file_keys[kTailBlock].c_str());
return SUCCESS;
}

Status DavinciModel::FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel,
bool is_ffts, size_t thread_index) {
if (thread_index > 1) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Thread index: %zu should less than 1.", thread_index);
return INTERNAL_ERROR;
}
const char *bin_file_key;
if (is_ffts) {
bin_file_key = GetRegisterStub(bin_file, "");
GELOGI("Node:%s inherit func name:%s directly.", op_desc->GetName().c_str(), bin_file_key);
} else {
std::string session_graph_model_id;
GetUniqueId(op_desc, session_graph_model_id);
bin_file_key = GetRegisterStub(bin_file, session_graph_model_id); // from set, always valid.
}

TBEHandleStore &kernel_store = TBEHandleStore::GetInstance();
std::lock_guard<std::mutex> lock(tvm_bin_mutex_);
if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) {
void *bin_handle = nullptr;
@@ -3721,59 +3782,115 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) {
GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key);

rtDevBinary_t binary;
std::string json_string;
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_MAGIC, json_string),
GELOGD("Get original type of session_graph_id."));
if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICPU") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICPU;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE;
} else {
REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
return PARAM_INVALID;
}

GE_CHK_STATUS_RET(InitBinaryMagic(op_desc, is_ffts, thread_index, binary), "Init binary magic of %s failed.",
op_desc->GetName().c_str());
binary.version = 0;
binary.data = tbe_kernel->GetBinData();
binary.length = tbe_kernel->GetBinDataSize();

GELOGD("TBE: binary.length: %lu", binary.length);
GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle));

std::string meta_data;
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_METADATA, meta_data),
GELOGI("Get original type of json_string"));
GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str());
GE_IF_BOOL_EXEC(!meta_data.empty(), GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str())));

GE_CHK_STATUS_RET(InitMetaData(op_desc, is_ffts, thread_index, bin_handle), "Init tvm meta data of %s failed.",
op_desc->GetName().c_str());
kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel);
} else {
GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key);
kernel_store.ReferTBEHandle(bin_file_key);
}

std::string kernel_name;
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name),
GELOGD("Get original type of kernel_name"));
GE_CHK_STATUS_RET(InitKernelName(op_desc, is_ffts, thread_index, kernel_name), "Init kernel name of %s failed.",
op_desc->GetName().c_str());
GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0));
used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1.
return SUCCESS;
}

// Kernel registed, Increase used num in store.
StoreTbeHandle(bin_file_key);
return SUCCESS;
}

Status DavinciModel::InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index,
rtDevBinary_t &binary) {
string json_string;
const string &tvm_magic = is_ffts ? TVM_ATTR_NAME_THREAD_MAGIC : TVM_ATTR_NAME_MAGIC;
const static std::map<std::string, uint32_t> binary_magics = {
{"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU},
{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF},
{"RT_DEV_BINARY_MAGIC_ELF_AIVEC", RT_DEV_BINARY_MAGIC_ELF_AIVEC},
{"RT_DEV_BINARY_MAGIC_ELF_AICUBE", RT_DEV_BINARY_MAGIC_ELF_AICUBE}
};
if (is_ffts) {
vector<string> json_list;
(void)AttrUtils::GetListStr(op_desc, tvm_magic, json_list);
if (json_list.size() != kFftsTbeHandleElementSize) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Attr is %s, thread index is %zu, json list size is %zu.",
tvm_magic.c_str(), thread_index, json_list.size());
return INTERNAL_ERROR;
}
json_string = json_list[thread_index];
} else {
(void)AttrUtils::GetStr(op_desc, tvm_magic, json_string);
}
auto iter = binary_magics.find(json_string);
if (iter == binary_magics.end()) {
REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
tvm_magic.c_str(), json_string.c_str(), op_desc->GetName().c_str(),
op_desc->GetType().c_str(), model_id_);
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
return PARAM_INVALID;
}
binary.magic = iter->second;
return SUCCESS;
}

Status DavinciModel::InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle) {
string meta_data;
const string &tvm_metadata = is_ffts ? TVM_ATTR_NAME_THREAD_METADATA : TVM_ATTR_NAME_METADATA;
if (is_ffts) {
vector<string> meta_data_list;
(void)AttrUtils::GetListStr(op_desc, tvm_metadata, meta_data_list);
if (meta_data_list.size() != kFftsTbeHandleElementSize) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, meta data list size is %zu.",
tvm_metadata.c_str(), thread_index, meta_data_list.size());
return INTERNAL_ERROR;
}
meta_data = meta_data_list[thread_index];
} else {
(void)AttrUtils::GetStr(op_desc, tvm_metadata, meta_data);
}
GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str());
if (!meta_data.empty()) {
GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str()));
}
return SUCCESS;
}

Status DavinciModel::InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name) {
if (is_ffts) {
// delete prefix, eg: *sgt_graph_nodes*/loss_scale/gradient/fp32_vals/Mean_grad/Tile
vector<string> kernel_name_list;
auto pos = op_desc->GetName().find("/");
if (pos == std::string::npos) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, subgraph node name: %s.", op_desc->GetName().c_str());
return INTERNAL_ERROR;
}
string attr_kernel_name = op_desc->GetName().substr(pos + 1) + "_thread_kernelname";
(void)AttrUtils::GetListStr(op_desc, attr_kernel_name, kernel_name_list);
if (kernel_name_list.size() != kFftsTbeHandleElementSize) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, kernel name list size is %zu.",
attr_kernel_name.c_str(), thread_index, kernel_name_list.size());
return INTERNAL_ERROR;
}
kernel_name = kernel_name_list[thread_index];
} else {
string attr_kernel_name = op_desc->GetName() + "_kernelname";
(void)AttrUtils::GetStr(op_desc, attr_kernel_name, kernel_name);
}
return SUCCESS;
}

void DavinciModel::StoreTbeHandle(const std::string &handle_key) {
// Online mode FE may call rtFunctionRegister.
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance();
@@ -4256,7 +4373,7 @@ void DavinciModel::SetDataDumperArgs(const ComputeGraphPtr &graph, const map<str
data_dumper_.SetDeviceId(device_id);

if (known_node_) {
data_dumper_.SetLoopAddr(known_shape_global_step_, nullptr, nullptr);
data_dumper_.SetLoopAddr(global_step_addr_, nullptr, nullptr);
} else {
// set loop count addr
auto get_var_addr = [&](const string &name) -> void *{


+ 7
- 7
ge/graph/load/model_manager/davinci_model.h View File

@@ -300,6 +300,7 @@ class DavinciModel {
return op_list_.at(index);
}

void SetGlobalStep(void *global_step, uint64_t global_step_size);
void *GetGlobalStep() const { return global_step_addr_; }

// get task info for profiling
@@ -498,10 +499,6 @@ class DavinciModel {
return exception_dumper_.DumpExceptionInfo(exception_infos);
}

void SetKnownShapeGlobalStep(void *global_step) {
known_shape_global_step_ = global_step;
}

void DumperShrink() {
data_dumper_.DumpShrink();
}
@@ -771,6 +768,12 @@ class DavinciModel {
/// @return Status
///
Status InitTbeHandle(const OpDescPtr &op_desc);
Status InitTbeHandleWithFfts(const OpDescPtr &op_desc);
Status FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, bool is_ffts,
size_t thread_index = 0);
Status InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, rtDevBinary_t &binary);
Status InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle);
Status InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name);

void StoreTbeHandle(const string &handle_key);
void CleanTbeHandle();
@@ -1102,9 +1105,6 @@ class DavinciModel {
vector<InputOutputDescInfo> output_descs_;
vector<uint32_t> output_formats_;

// known shape node for dump
void *known_shape_global_step_;

// op name to attrs mapping
std::map<std::string, std::map<std::string, std::vector<std::string>>> op_name_to_attrs_;
};


+ 17
- 10
ge/graph/load/model_manager/model_manager.cc View File

@@ -570,6 +570,7 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<ge::Te
uint32_t length = static_cast<uint32_t>(cur_dynamic_dims.size() * sizeof(int32_t));
GE_CHK_BOOL_EXEC(memcpy_s(data.data, length, cur_dynamic_dims.data(), length) == EOK,
REPORT_CALL_ERROR("E19999", "memcpy data failed, size:%u", length);
delete[] reinterpret_cast<int32_t *>(data.data);
return INTERNAL_ERROR, "[Memcpy][Data] failed, size:%u.", length);
data.length = length;
input_data.blobs.push_back(data);
@@ -1378,7 +1379,9 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_
Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) {
GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str());
std::lock_guard<std::mutex> lock(cust_aicpu_mutex_);
if (cust_aicpu_so_.size() == 0) return SUCCESS;
if (cust_aicpu_so_.empty()) {
return SUCCESS;
}
// get current context
rtContext_t rt_cur_ctx = nullptr;
auto rt_error = rtCtxGetCurrent(&rt_cur_ctx);
@@ -1394,9 +1397,19 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) {
return SUCCESS;
}

rtStream_t stream = nullptr;
vector<void *> allocated_mem;
std::function<void()> callback = [&]() {
for (auto mem : allocated_mem) {
GE_CHK_RT(rtFree(mem));
}
if (stream != nullptr) {
GE_CHK_RT(rtStreamDestroy(stream));
}
};
GE_MAKE_GUARD(release, callback);

rtError_t status;
rtStream_t stream = nullptr;
vector<CustAicpuSoBuf> v_cust_so;
void *args = nullptr;

@@ -1471,13 +1484,6 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) {
GELOGE(RT_FAILED, "[Call][RtStreamSynchronize] fail, ret = 0x%X", status);
return RT_ERROR_TO_GE_STATUS(status);
}
std::function<void()> callback = [&]() {
for (auto mem : allocated_mem) {
GE_CHK_RT(rtFree(mem));
}
GE_CHK_RT(rtStreamDestroy(stream));
};
GE_MAKE_GUARD(release, callback);
GELOGI("Cpu kernel launch task success.");
return SUCCESS;
}
@@ -1786,7 +1792,8 @@ Status ModelManager::LaunchKernelCheckAicpuOp(std::vector<std::string> &aicpu_op
std::vector<char> op_name;
op_name.clear();
op_name.resize(kOpNameMaxSize);
GE_CHK_RT(rtMemcpy(op_name.data(), aicpu_info.opLen, reinterpret_cast<void *>(aicpu_info.opType),
GE_CHK_RT(rtMemcpy(op_name.data(), aicpu_info.opLen,
reinterpret_cast<void *>(static_cast<uintptr_t>(aicpu_info.opType)),
aicpu_info.opLen, RT_MEMCPY_DEVICE_TO_HOST));
std::string kernel_type =
(static_cast<OpKernelType>(aicpu_info.kernelsType) == TF_KERNEL) ? "TF_KERNEL" : "CPU_KERNEL";


+ 393
- 0
ge/graph/load/model_manager/task_info/ffts_task_info.cc View File

@@ -0,0 +1,393 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/load/model_manager/task_info/ffts_task_info.h"

#include <vector>

#include "graph/load/model_manager/davinci_model.h"

namespace {
constexpr uint32_t kAddrLen = sizeof(void *);
}
namespace ge {
FftsTaskInfo::~FftsTaskInfo() {
GE_FREE_RT_LOG(args_);
}

Status FftsTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) {
GELOGI("FftsTaskInfo Init Start.");
GE_CHECK_NOTNULL(davinci_model);
davinci_model_ = davinci_model;
GE_CHK_STATUS_RET_NOLOG(SetStream(task_def.stream_id(), davinci_model_->GetStreamList()));

const domi::FftsTaskDef &ffts_task_def = task_def.ffts_task();
OpDescPtr op_desc = davinci_model_->GetOpByIndex(ffts_task_def.op_index());
GE_CHECK_NOTNULL(op_desc);

if ((ffts_task_def.sub_task_size() > static_cast<int>(RT_FFTS_MAX_SUB_TASK_NUM)) ||
(ffts_task_def.ticket_cache_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_NUM))) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Node: %s, sub task desc size: %d, ticket cache size: %d",
op_desc->GetName().c_str(), ffts_task_def.sub_task_size(), ffts_task_def.ticket_cache_size());
return INTERNAL_ERROR;
}
args_size_ = kAddrLen * ffts_task_def.addr_size();
GE_CHK_RT_RET(rtMalloc(&args_, args_size_, RT_MEMORY_HBM));
InitFftsDescInfo(ffts_task_def.ffts_desc(), sub_task_info_.fftsDesc);

sub_task_info_.fftsType = static_cast<rtFftsType_t>(ffts_task_def.ffts_type());
sub_task_info_.subTaskNum = ffts_task_def.sub_task_size();
for (int idx = 0; idx < ffts_task_def.sub_task_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitSubTaskInfo(ffts_task_def.sub_task(idx), sub_task_info_.subTask[idx]));
}

sub_task_info_.tickCacheNum = ffts_task_def.ticket_cache_size();
for (int idx = 0; idx < ffts_task_def.ticket_cache_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitTicketCache(ffts_task_def.ticket_cache(idx), sub_task_info_.ticketCache[idx]));
}

size_t data_size = kAddrLen * io_addrs_.size();
GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs_.data(), data_size, RT_MEMCPY_HOST_TO_DEVICE));
GELOGI("FftsTaskInfo::Init Success. Node: %s, input/output size: %zu", op_desc->GetName().c_str(), io_addrs_.size());
return SUCCESS;
}

void FftsTaskInfo::InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc) {
ffts_desc.tm = static_cast<uint8_t>(ffts_desc_def.tm());
ffts_desc.di = static_cast<uint8_t>(ffts_desc_def.di());
ffts_desc.dw = static_cast<uint8_t>(ffts_desc_def.dw());
ffts_desc.df = static_cast<uint8_t>(ffts_desc_def.df());
ffts_desc.dataSplitUnit = static_cast<uint8_t>(ffts_desc_def.data_split_unit());
ffts_desc.prefetchOstNum = static_cast<uint8_t>(ffts_desc_def.prefetch_ost_num());
ffts_desc.cacheMaintainOstNum = static_cast<uint8_t>(ffts_desc_def.cache_maintain_ost_num());
ffts_desc.aicPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_upper());
ffts_desc.aicPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_lower());
ffts_desc.aivPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_upper());
ffts_desc.aivPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_lower());
}

Status FftsTaskInfo::InitSubTaskInfo(const domi::FftsSubTaskDef &sub_task_def, rtFftsSubTaskInfo_t &sub_task_desc) {
if ((sub_task_def.dst_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) ||
(sub_task_def.src_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) {
GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, dst tick cache id size: %d, src tick cache id size: %d",
sub_task_def.dst_tick_cache_id_size(), sub_task_def.src_tick_cache_id_size());
return FAILED;
}

if (sub_task_def.has_auto_thread_aic_aiv() == sub_task_def.has_manual_thread_aic_aiv()) {
GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, auto thread aic/aiv: %d, manual thread aic/aiv: %d",
sub_task_def.has_auto_thread_aic_aiv(), sub_task_def.has_manual_thread_aic_aiv());
return FAILED;
}

thread_dim_ = sub_task_def.thread_dim();
GE_CHK_BOOL_RET_STATUS(thread_dim_ != 0, FAILED, "[Get][thread_dim] failed, Invalid thread dim: %u!", thread_dim_);

sub_task_desc.subTaskType = static_cast<rtFftsSubTaskType_t>(sub_task_def.sub_task_type());
sub_task_desc.threadDim = sub_task_def.thread_dim();

sub_task_desc.dstTickCacheVldBitmap = sub_task_def.dst_tick_cache_vld_bitmap();
sub_task_desc.srcTickCacheVldBitmap = sub_task_def.src_tick_cache_vld_bitmap();
sub_task_desc.srcDataOutOfSubGraphBitmap = sub_task_def.src_data_out_of_subgraph_bitmap();

for (int idx = 0; idx < sub_task_def.dst_tick_cache_id_size(); ++idx) {
sub_task_desc.dstTickCacheID[idx] = sub_task_def.dst_tick_cache_id(idx);
}

for (int idx = 0; idx < sub_task_def.src_tick_cache_id_size(); ++idx) {
sub_task_desc.srcTickCacheID[idx] = sub_task_def.src_tick_cache_id(idx);
}

if (sub_task_def.has_auto_thread_aic_aiv()) {
GE_CHK_STATUS_RET_NOLOG(InitAutoAicAiv(sub_task_def.auto_thread_aic_aiv(), sub_task_desc.custom.autoThreadAicAiv));
}

if (sub_task_def.has_manual_thread_aic_aiv()) {
GE_CHK_STATUS_RET_NOLOG(
InitManualAicAiv(sub_task_def.manual_thread_aic_aiv(), sub_task_desc.custom.manualThreadAicAiv));
}

if (sub_task_def.has_manual_thread_nop()) {
GE_CHK_STATUS_RET_NOLOG(InitManualNop(sub_task_def.manual_thread_nop(), sub_task_desc.custom.manualThreadNop));
}

return SUCCESS;
}

Status FftsTaskInfo::InitTicketCache(const domi::TicketCacheDef &ticket_cache_def, rtTicketCache_t &ticket_cache) {
if (ticket_cache_def.has_auto_thread_cache() == ticket_cache_def.has_manual_thread_cache()) {
GELOGE(FAILED, "[Check][Param] Invalid TicketCacheDef, has auto thread cache: %d, has manual thread cache: %d",
ticket_cache_def.has_auto_thread_cache(), ticket_cache_def.has_manual_thread_cache());
return FAILED;
}

ticket_cache.cacheOption = static_cast<rtCacheOp_t>(ticket_cache_def.cache_option());
ticket_cache.ticketCacheWindow = ticket_cache_def.ticket_cache_window();

if (ticket_cache_def.has_auto_thread_cache()) {
InitAutoCacheInfo(ticket_cache_def.auto_thread_cache(), ticket_cache.custom.autoThreadCache);
}
if (ticket_cache_def.has_manual_thread_cache()) {
GE_CHK_STATUS_RET_NOLOG(
InitManualCacheInfo(ticket_cache_def.manual_thread_cache(), ticket_cache.custom.manualThreadCache));
}

return SUCCESS;
}

// task_addr = {0,200,700,1000,2000, 3500}
// task_addr_offset = {20,40,2,100,200}
template <typename T>
Status FftsTaskInfo::InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim,
uint32_t addr_count) {
for (uint32_t i = 0; i < addr_count; ++i) {
uintptr_t logic_addr = aic_aiv_def.task_addr(i) + thread_dim * aic_aiv_def.task_addr_offset(i);
uint8_t *io_addr = nullptr;
if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress]GetRtAddress failed.");
return INTERNAL_ERROR;
}
GELOGD("aic_aiv_def task base addr is %ld, offset is %ld, thread is %d, logic addrs is 0x%lx, io addr is %p",
aic_aiv_def.task_addr(i), aic_aiv_def.task_addr_offset(i), thread_dim, logic_addr, io_addr);
io_addrs_.emplace_back(io_addr);
}
return SUCCESS;
}

Status FftsTaskInfo::InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv) {
if (aic_aiv_def.src_prefetch_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) {
GELOGE(FAILED, "[Check][Param] Invalid AutoThreadAicAivInfo, prefetch size: %d", aic_aiv_def.src_prefetch_size());
return FAILED;
}

aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size();
GELOGD("AutoThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr);
const auto &rts_param = davinci_model_->GetRuntimeParam();
for (uint32_t i = 0; i < thread_dim_ - 1; ++i) {
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i,
static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size())));
}
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count()));
int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size();
for (int k = 0; k < last_thread_workspace_size; ++k) {
uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k);
uint8_t *io_addr = nullptr;
GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr));
GELOGD("logic addr is 0x%lx, io addr is %p.", logic_addr, io_addr);
io_addrs_.emplace_back(io_addr);
}

aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset();
GELOGD("args_: %p, io_addrs size: %zu, task param offset: %u.", args_, io_addrs_.size(), aic_aiv.taskParamOffset);
aic_aiv.satMode = aic_aiv_def.sat_mode();
aic_aiv.scheduleMode = aic_aiv_def.schedule_mode();
aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt();

aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap();
aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap();

aic_aiv.tailBlkDim = aic_aiv_def.tail_blk_dim();
aic_aiv.nonTailBlkDim = aic_aiv_def.non_tail_blk_dim();

aic_aiv.nonTailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.non_tail_task_func_stub(), "");
aic_aiv.tailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.tail_task_func_stub(), "");

GELOGI("Set func name[%s][%s] succ.", aic_aiv.nonTailTaskFuncStub, aic_aiv.tailTaskFuncStub);
for (int idx = 0; idx < aic_aiv_def.src_prefetch_size(); ++idx) {
InitAutoPrefetch(aic_aiv_def.src_prefetch(idx), aic_aiv.srcPrefetch[idx]);
}

return SUCCESS;
}

void FftsTaskInfo::InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache) {
cache.dataAddr = cache_def.data_addr();
cache.dataAddrOffset = cache_def.data_addr_offset();
cache.nonTailDataLen = cache_def.non_tail_data_len();
cache.tailDataLen = cache_def.tail_data_len();
cache.ticketCacheRefCnt = cache_def.ticket_cache_ref_cnt();
}

void FftsTaskInfo::InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch) {
prefetch.dataAddr = prefetch_def.data_addr();
prefetch.dataAddrOffset = prefetch_def.data_addr_offset();
prefetch.nonTailDataLen = prefetch_def.non_tail_data_len();
prefetch.tailDataLen = prefetch_def.tail_data_len();
}

Status FftsTaskInfo::InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def,
rtManualThreadAicAivInfo_t &aic_aiv) {
if ((aic_aiv_def.thread_prefetch_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(aic_aiv_def.thread_blk_dim_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(aic_aiv_def.thread_task_func_stub_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(aic_aiv_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadAicAivInfo, thread prefetch dmu desc size: %d, "
"thread blk dim size: %d, thread task func stub size: %d, src dep tbl size: %d",
aic_aiv_def.thread_prefetch_dmu_idx_size(), aic_aiv_def.thread_blk_dim_size(),
aic_aiv_def.thread_task_func_stub_size(), aic_aiv_def.src_dep_tbl_size());
return FAILED;
}
aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size();
GELOGD("ManualThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr);
const auto &rts_param = davinci_model_->GetRuntimeParam();
for (uint32_t i = 0; i < thread_dim_ - 1; ++i) {
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i,
static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size())));
}
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count()));
int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size();
for (int k = 0; k < last_thread_workspace_size; ++k) {
uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k);
uint8_t *io_addr = nullptr;
GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr));
io_addrs_.emplace_back(io_addr);
}
aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset();

aic_aiv.satMode = aic_aiv_def.sat_mode();
aic_aiv.scheduleMode = aic_aiv_def.schedule_mode();
aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt();

aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap(); // 8 bit bitmap 1 0 1 0
aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap(); // 8 bit bitmap 1 0 1 0
aic_aiv.prefetchOnceDmuNum = aic_aiv_def.prefetch_once_dmu_num();

for (int idx = 0; idx < aic_aiv_def.thread_prefetch_dmu_idx_size(); ++idx) {
aic_aiv.threadPrefetchDmuIdx[idx] = aic_aiv_def.thread_prefetch_dmu_idx(idx);
}
for (int idx = 0; idx < aic_aiv_def.thread_blk_dim_size(); ++idx) {
aic_aiv.threadBlkDim[idx] = aic_aiv_def.thread_blk_dim(idx);
}
for (int idx = 0; idx < aic_aiv_def.thread_task_func_stub_size(); ++idx) {
aic_aiv.threadTaskFuncStub[idx] = aic_aiv_def.thread_task_func_stub(idx).c_str();
}

InitManualDmuInfo(aic_aiv_def, aic_aiv.prefetchList);
for (int idx = 0; idx < aic_aiv_def.src_dep_tbl_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitManualDependency(aic_aiv_def.src_dep_tbl(idx), aic_aiv.srcDepTbl[idx]));
}

return SUCCESS;
}

Status FftsTaskInfo::InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def,
rtManualThreadCacheInfo_t &cache_info) {
if ((cache_def.slice_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(cache_def.ticket_cache_ref_cnt_tbl_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM))) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadCacheInfo slice dum desc index %d, ticket cache ref cnt %d",
cache_def.slice_dmu_idx_size(), cache_def.ticket_cache_ref_cnt_tbl_size());
return FAILED;
}

InitManualDmuInfo(cache_def, cache_info.dmuList);
for (int idx = 0; idx < cache_def.slice_dmu_idx_size(); ++idx) {
cache_info.sliceDmuIdx[idx] = cache_def.slice_dmu_idx(idx);
}

for (int idx = 0; idx < cache_def.ticket_cache_ref_cnt_tbl_size(); ++idx) {
cache_info.ticketCacheRefCntTbl[idx] = cache_def.ticket_cache_ref_cnt_tbl(idx);
}

return SUCCESS;
}

Status FftsTaskInfo::InitManualDependency(const domi::ManualThreadDependencyDef &dependency_def,
rtManualThreadDependency_t &dependency) {
if (dependency_def.dependency_size() > static_cast<int>(RT_FFTS_MANUAL_SRC_DEPEND_TBL_LEN)) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadDependency size: %d", dependency_def.dependency_size());
return FAILED;
}

for (int idx = 0; idx < dependency_def.dependency_size(); ++idx) {
dependency.dependency[idx] = dependency_def.dependency(idx);
}

return SUCCESS;
}

Status FftsTaskInfo::InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop_info) {
if (nop_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadNopInfo, src dep tbl size: %d", nop_def.src_dep_tbl_size());
return FAILED;
}

for (int idx = 0; idx < nop_def.src_dep_tbl_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitManualDependency(nop_def.src_dep_tbl(idx), nop_info.srcDepTbl[idx]));
}

return SUCCESS;
}

void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu) {
if (aic_aiv_def.prefetch_list().empty()) {
return;
}

std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * aic_aiv_def.prefetch_list_size());
dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data());
for (int idx = 0; idx < aic_aiv_def.prefetch_list_size(); ++idx) {
InitManualDmuInfo(aic_aiv_def.prefetch_list(idx), dmu[idx]);
}
}

void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu) {
if (cache_def.dmu_list().empty()) {
return;
}

std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * cache_def.dmu_list_size());
dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data());
for (int idx = 0; idx < cache_def.dmu_list_size(); ++idx) {
InitManualDmuInfo(cache_def.dmu_list(idx), dmu[idx]);
}
}

void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu) {
dmu.dataAddr = dmu_def.data_addr();
dmu.numOuter = dmu_def.num_outer();
dmu.numInner = dmu_def.num_inner();
dmu.strideOuter = dmu_def.stride_outer();
dmu.lenInner = dmu_def.len_inner();
dmu.strideInner = dmu_def.stride_inner();
}

Status FftsTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) {
return SUCCESS;
}

Status FftsTaskInfo::UpdateArgs() {
GE_CHECK_NOTNULL(davinci_model_);
std::vector<void *> io_addrs = io_addrs_;
davinci_model_->UpdateKnownZeroCopyAddr(io_addrs);
auto addr_size = kAddrLen * io_addrs.size();
GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs.data(), addr_size, RT_MEMCPY_HOST_TO_DEVICE));
return SUCCESS;
}

Status FftsTaskInfo::Distribute() {
GELOGI("FftsTaskInfo Distribute Start.");
rtError_t rt_ret = rtFftsTaskLaunch(&sub_task_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "[Check][RT_ret] Call rtFftsTaskLaunch failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

GELOGI("FftsTaskInfo Distribute Success.");
return SUCCESS;
}

REGISTER_TASK_INFO(RT_MODEL_TASK_FFTS_TASK, FftsTaskInfo);
} // namespace ge

+ 66
- 0
ge/graph/load/model_manager/task_info/ffts_task_info.h View File

@@ -0,0 +1,66 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_

#include "graph/load/model_manager/task_info/task_info.h"
#include "graph/op_desc.h"

namespace ge {
class FftsTaskInfo : public TaskInfo {
public:
FftsTaskInfo() = default;
~FftsTaskInfo() override;

Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override;

Status Distribute() override;

Status UpdateArgs() override;

Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override;

private:
void InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc);
Status InitSubTaskInfo(const domi::FftsSubTaskDef &task_def, rtFftsSubTaskInfo_t &task);
Status InitTicketCache(const domi::TicketCacheDef &cache_def, rtTicketCache_t &cache);

Status InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv);
void InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache);
void InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch);

Status InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadAicAivInfo_t &aic_aiv);
Status InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadCacheInfo_t &cache);
Status InitManualDependency(const domi::ManualThreadDependencyDef &depend_def, rtManualThreadDependency_t &depend);
Status InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop);

void InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu);
void InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu);
void InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu);

template<typename T>
Status InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim, uint32_t addr_count);

DavinciModel *davinci_model_{nullptr};
rtFftsTaskInfo_t sub_task_info_;
std::vector<void *> io_addrs_;
uint32_t thread_dim_{0};
void *args_{nullptr}; // runtime args memory
uint32_t args_size_{0}; // runtime args memory length
};
} // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_

+ 1
- 1
ge/graph/load/model_manager/task_info/hccl_task_info.cc View File

@@ -329,7 +329,7 @@ void HcclTaskInfo::GetPrivateDefByTaskDef(const domi::TaskDef &task) {
// Get privateDef and opsKernelStorePtr from taskDef and save them in taskInfo
GELOGI("get custom info in modelTaskDef.");
ops_kernel_store_ = nullptr;
void *ops_kernel_store_name_temp = reinterpret_cast<void *>(task.ops_kernel_store_ptr());
void *ops_kernel_store_name_temp = reinterpret_cast<void *>(static_cast<uintptr_t>(task.ops_kernel_store_ptr()));
if (ops_kernel_store_name_temp != nullptr) {
ops_kernel_store_ = std::move(ops_kernel_store_name_temp);
std::string private_def_temp = task.private_def();


+ 2
- 0
ge/graph/load/model_manager/task_info/kernel_task_info.cc View File

@@ -645,6 +645,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne
GE_CHECK_NOTNULL(op_desc);

args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]);
GE_CHECK_NOTNULL(args_addr);
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_);
if (sec_ret != EOK) {
REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret);
@@ -1000,6 +1001,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k

// copy args to new host memory
args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]);
GE_CHECK_NOTNULL(args_addr);
GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_)
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_);
if (sec_ret != EOK) {


+ 1
- 1
ge/graph/load/model_manager/task_info/memcpy_async_task_info.h View File

@@ -47,7 +47,7 @@ class MemcpyAsyncTaskInfo : public TaskInfo {
uint64_t count_;
uint32_t kind_;
vector<void *> io_addrs_;
int64_t fixed_addr_offset_;
int64_t fixed_addr_offset_ = 0;
DavinciModel *davinci_model_ = nullptr;
uint32_t args_offset_ = 0;
};


+ 4
- 2
ge/graph/load/model_manager/zero_copy_offset.cc View File

@@ -62,7 +62,8 @@ Status ZeroCopyOffset::InitInputDataInfo(int64_t output_size, void *virtual_addr
for (size_t index = 0; index < zero_copy_basic_offset_.size(); ++index) {
if (zero_copy_basic_offset_.at(index) == virtual_addr_offset) {
out_count++;
uint64_t out_offset = reinterpret_cast<uint64_t>(virtual_addr) + zero_copy_relative_offset_.at(index);
uint64_t out_offset = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(virtual_addr)) +
zero_copy_relative_offset_.at(index);
data_info_.emplace_back(output_size, reinterpret_cast<void *>(static_cast<uintptr_t>(out_offset)));
relative_offset_.emplace_back(zero_copy_relative_offset_.at(index));
GELOGI("[ZCPY] virtual_addr: %p has been l2-fusion to %lu, need copy data_size is %ld.", basic_addr_,
@@ -117,7 +118,8 @@ Status ZeroCopyOffset::InitOutputDataInfo(const vector<int64_t> &input_size_list
for (size_t index = 0; index < zero_copy_basic_offset_.size(); ++index) {
if (zero_copy_basic_offset_.at(index) == virtual_addr_offset) {
in_count++;
uint64_t in_offset = reinterpret_cast<uint64_t>(virtual_addr_list[idx]) + zero_copy_relative_offset_.at(index);
uint64_t in_offset = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(virtual_addr_list[idx])) +
zero_copy_relative_offset_.at(index);
int64_t real_data_size = ModelUtils::GetInputSize(op_desc).at(idx);
data_info_.emplace_back(real_data_size, reinterpret_cast<void *>(static_cast<uintptr_t>(in_offset)));
relative_offset_.emplace_back(zero_copy_relative_offset_.at(index));


+ 20
- 26
ge/graph/manager/graph_manager.cc View File

@@ -27,6 +27,7 @@
#include "common/math/math_util.h"
#include "common/thread_pool.h"
#include "common/dump/dump_manager.h"
#include "ge_opt_info/ge_opt_info.h"
#include "analyzer/analyzer.h"
#include "graph/common/ge_call_wrapper.h"
#include "graph/common/local_context.h"
@@ -120,7 +121,6 @@ const char *const kCheckPointForGetVar = "CheckPointGraphForGetVar";
const char *const kCheckPointGraph = "checkpoint_graph";
const char *const kVectorEngine = "VectorEngine";
const char *const kAIcoreEngine = "AIcoreEngine";
const char *const kRunFlagOffline = "0";
const int32_t kDynamicDimsTypeIsGetNext = 0;
const int32_t kDynamicDimsTypeIsData = 1;
const char *const kGetNextName = "IteratorV2";
@@ -950,7 +950,7 @@ Status GraphManager::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint

rtError_t rt_ret = rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtCtxCreate faileded, session_id:%lu, graph_id:%u, mode:%d",
REPORT_CALL_ERROR("E19999", "Call rtCtxCreate failed, session_id:%lu, graph_id:%u, mode:%d",
session_id, graph_id, mode);
GELOGE(FAILED, "[Call][RtCtxCreate] faileded, session_id:%lu, graph_id:%u, mode:%d", session_id, graph_id, mode);
return FAILED;
@@ -1002,6 +1002,12 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<Ge
return ret;
}

ret = GeOptInfo::SetOptInfo();
if (ret != SUCCESS) {
GELOGE(ret, "[Set][OptInfo] Set optional information failed.");
return ret;
}

/// 1. BUILD_MODE_TUNING with BUILD_STEP_AFTER_UB_MATCH no need PreRunOptimizeOriginalGraph;
/// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeOriginalGraph.
/// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunOptimizeOriginalGraph.
@@ -1789,8 +1795,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti
return GE_GRAPH_OPTIONS_INVALID);

// ge.graphType
ret =
ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag);
ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag);
GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid");
return GE_GRAPH_OPTIONS_INVALID);
@@ -2436,6 +2441,8 @@ Status GraphManager::RemoveIsolatedConstInThisGraph(ge::ComputeGraphPtr &compute
continue;
}
if (n->GetOpDesc()->GetType() == CONSTANT || n->GetOpDesc()->GetType() == CONSTANTOP) {
// reset const type depend on train_flag
options_.train_graph_flag ? n->GetOpDesc()->SetType(CONSTANTOP) : n->GetOpDesc()->SetType(CONSTANT);
if (n->GetOutAllNodes().empty() && n->GetInAllNodes().empty()) {
// it is an isolated constant, just remove it
if (GraphUtils::RemoveJustNode(compute_graph, n) != GRAPH_SUCCESS) {
@@ -2762,35 +2769,22 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) {
"Please pay attention to it.");
}

GE_CHK_STATUS_RET(ChangeConstType(compute_graph));
ChangeConstTypeWhenTraining(compute_graph);

GELOGI("End optimize after merge sub graph.");
return SUCCESS;
}

Status GraphManager::ChangeConstType(const ComputeGraphPtr &compute_graph) {
// run_flag off means offline, on means online
string run_flag;
(void)ge::GetContext().GetOption(ge::RUN_FLAG, run_flag);
// The constant for online is CONSTANTOP, and is CONSTANT for offline. They will be unified in future.
if (run_flag == kRunFlagOffline) {
GELOGI("Offline mode, change all Constant to Const.");
} else {
GELOGI("Online mode, change all Const to Constant.");
}
for (NodePtr &n : compute_graph->GetAllNodes()) {
GE_CHECK_NOTNULL(n);
if (n->GetType() == CONSTANT || n->GetType() == CONSTANTOP) {
auto op_desc = n->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (run_flag == kRunFlagOffline) {
op_desc->SetType(CONSTANT);
} else {
op_desc->SetType(CONSTANTOP);
void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph) {
// The constant for train is CONSTANTOP, and is CONSTANT for inference. They will be unified in future.
if (options_.train_graph_flag) {
for (NodePtr &n : compute_graph->GetAllNodes()) {
// This can ensure that n is not a null pointer
if (n->GetOpDesc()->GetType() == CONSTANT) {
n->GetOpDesc()->SetType(CONSTANTOP);
}
}
}
return SUCCESS;
}

Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) {
@@ -3145,10 +3139,10 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) {
}
// Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency
if (count > 1 && graph_node->GetBuildFlag()) {
graph_node->Lock();
GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id);
// In online inference concurrency senario, graph_node is allowed to be locked for 'count' times
graph_node->SetSemSize(count);
graph_node->Lock();
graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context,
args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback }));
GELOGI("[PreRunThread] Loop end. Start to run with cached build model.");


+ 1
- 1
ge/graph/manager/graph_manager.h View File

@@ -375,7 +375,7 @@ class GraphManager {
static void ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_node, RunAsyncCallback callback,
Status ret, const string &log);

Status ChangeConstType(const ComputeGraphPtr &compute_graph);
void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph);

Status PreRunOptimizeOriginalGraph(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs,
ge::ComputeGraphPtr &compute_graph, uint64_t session_id);


+ 49
- 11
ge/graph/manager/graph_var_manager.cc View File

@@ -20,6 +20,7 @@
#include "graph/manager/graph_mem_manager.h"
#include "graph/manager/trans_var_data_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/ge_context.h"

using std::map;
using std::string;
@@ -816,25 +817,60 @@ Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &grap
return var_resource_->GetChangedGraphId(var_name, graph_id);
}

Status VarManager::GetTotalMemorySize(size_t &total_mem_size) {
rtError_t rt_ret = rtSetDevice(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X",
GetContext().DeviceId(), rt_ret);
GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
return RT_FAILED;
}
size_t free_mem = 0;
rt_ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free_mem, &total_mem_size);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret);
GELOGE(RT_FAILED, "[Call][RtMemGetInfo] failed, ret:0x%X", rt_ret);
return RT_FAILED;
}
if (total_mem_size == 0) {
rt_ret = rtMemGetInfoEx(RT_MEMORYINFO_DDR, &free_mem, &total_mem_size);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret);
GELOGE(RT_FAILED, "[Call][RtMemGetInfo] failed, ret:0x%X", rt_ret);
return RT_FAILED;
}
}
rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X",
GetContext().DeviceId(), rt_ret);
GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
return RT_FAILED;
}
return SUCCESS;
}

Status VarManager::SetMemoryMallocSize(const map<string, string> &options) {
auto it = options.find(GRAPH_MEMORY_MAX_SIZE);
if (it == options.end()) {
graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize;
} else {
string graph_memory_manager_malloc_max_size = it->second;
size_t total_mem_size = 0;
GE_CHK_STATUS_RET_NOLOG(VarManager::GetTotalMemorySize(total_mem_size));
GEEVENT("Total memory size is %zu", total_mem_size);

graph_mem_max_size_ = floor(total_mem_size * kGraphMemoryManagerMallocRatio);
var_mem_max_size_ = floor(total_mem_size * kVarMemoryManagerMallocRatio);

auto it1 = options.find(GRAPH_MEMORY_MAX_SIZE);
if (it1 != options.end()) {
string graph_memory_manager_malloc_max_size = it1->second;
ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_);
if (ret != SUCCESS) {
GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_);
return ge::GE_GRAPH_OPTIONS_INVALID;
}
GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_);
}

it = options.find(VARIABLE_MEMORY_MAX_SIZE);
if (it == options.end()) {
var_mem_max_size_ = kMemoryVarManagerMallocSize;
} else {
string memory_var_manager_malloc_size = it->second;
auto it2 = options.find(VARIABLE_MEMORY_MAX_SIZE);
if (it2 != options.end()) {
string memory_var_manager_malloc_size = it2->second;
ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_);
if (ret != SUCCESS) {
GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_);
@@ -842,6 +878,8 @@ Status VarManager::SetMemoryMallocSize(const map<string, string> &options) {
}
}

GEEVENT("The graph_mem_max_size is %zu and the var_mem_max_size is %zu", graph_mem_max_size_, var_mem_max_size_);

var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer;
if (var_mem_logic_base_ > kMaxMemorySize) {
REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid",


+ 3
- 0
ge/graph/manager/graph_var_manager.h View File

@@ -43,6 +43,8 @@ const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL;
const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY";
const uint64_t kSessionMemAlignSize = 512;
const size_t kSessionMemAlignUnit = 2;
const double kGraphMemoryManagerMallocRatio = 26.0 / 32.0;
const double kVarMemoryManagerMallocRatio = 5.0 / 32.0;

enum MemStatus {
NORMAL = 0,
@@ -316,6 +318,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager {
mutable std::recursive_mutex mutex_;

Status ParseMemoryMallocSize(std::string &memory_size, size_t &my_size);
Status GetTotalMemorySize(size_t &total_mem_size);
};

class VarManagerPool {


+ 0
- 2
ge/graph/optimize/graph_optimize.cc View File

@@ -336,10 +336,8 @@ Status GraphOptimize::OptimizeAfterStage1(ComputeGraphPtr &compute_graph) {
GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str());
continue;
}
#ifndef ONLY_COMPILE_OPEN_SRC
GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str());
ret = (iter->second)->OptimizeAfterStage1(*compute_graph);
#endif
if (ret != SUCCESS) {
REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, "
"graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str());


+ 29
- 7
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -284,9 +284,6 @@ Status DynamicShapePartitioner::InitClusters() {
auto cluster = MakeShared<Cluster>(rank++, type, node, this);
REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed.");
node_2_cluster_[node] = cluster;
if (cluster->IsUnknownShape()) {
ordered_cluster_.push_back(cluster);
}

int64_t group_index = -1;
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
@@ -306,7 +303,7 @@ Status DynamicShapePartitioner::InitClusters() {
return SUCCESS;
}

Status DynamicShapePartitioner::TopologicalSortClusters() {
Status DynamicShapePartitioner::TopologicalSortClusters(const OrderedFilter &ordered_filter) {
ordered_cluster_.clear();
// BFS topological sort clusters for known shape cluster
std::queue<ClusterPtr> ready_clusters;
@@ -331,7 +328,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() {
auto cluster = ready_clusters.front();
ready_clusters.pop();
cluster->UpdateRank(rank++);
if (cluster->IsKnownShape() || cluster->IsInputNode()) {
if (ordered_filter == nullptr || ordered_filter(cluster)) {
ordered_cluster_.push_back(cluster);
}
for (const auto &out_cluster : cluster->Outputs()) {
@@ -364,6 +361,7 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) {
}

void DynamicShapePartitioner::MergeClustersControlFlow() {
std::unordered_set<ClusterPtr> all_merged_clusters;
for (const auto &item : control_clusters_) {
const auto &control_cluster = item.second;
auto rit = control_cluster.rbegin();
@@ -373,12 +371,21 @@ void DynamicShapePartitioner::MergeClustersControlFlow() {
}

const auto &cluster = *rit;
if (all_merged_clusters.count(cluster) > 0) {
continue;
}

for (++rit; rit != control_cluster.rend(); ++rit) {
const auto &cluster_from = *rit;
if (all_merged_clusters.count(cluster_from) > 0) {
continue;
}

auto merged_clusters = cluster->MergeAllPathFrom(cluster_from);
GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(),
ToString(merged_clusters).c_str());
for (const auto &merged_cluster : merged_clusters) {
all_merged_clusters.emplace(merged_cluster);
for (const auto &node : merged_cluster->Nodes()) {
node_2_cluster_[node] = cluster;
}
@@ -459,9 +466,19 @@ void DynamicShapePartitioner::MergeClustersInputData() {
}

Status DynamicShapePartitioner::MergeClusters() {
const auto filter_known = [](const ClusterPtr &cluster) {
return cluster->IsKnownShape() || cluster->IsInputNode();
};
const auto filter_unknown = [](const ClusterPtr &cluster) {
return cluster->IsUnknownShape();
};

MergeClustersControlFlow();
REQUIRE_SUCCESS(TopologicalSortClusters(filter_unknown),
"[TopologicalSort][Clusters] after merge control flow clusters failed.");
MergeClustersUnknownShape();
REQUIRE_SUCCESS(TopologicalSortClusters(), "[TopologicalSort][Clusters] after merge unknown shape clusters failed.");
REQUIRE_SUCCESS(TopologicalSortClusters(filter_known),
"[TopologicalSort][Clusters] after merge unknown shape clusters failed.");
MergeClustersKnownShape();
MergeClustersInputData();
return SUCCESS;
@@ -703,7 +720,12 @@ void Cluster::Merge(ClusterPtr other) {
if (other->min_ < min_) {
min_ = other->min_;
}
};

if (!IsUnknownShape() && other->IsUnknownShape()) {
type_ = UNKNOWN_SHAPE;
}
}

bool Cluster::TryMerge(ClusterPtr other) {
std::queue<ClusterPtr> forward_reached;
forward_reached.push(other);


+ 4
- 2
ge/graph/partition/dynamic_shape_partition.h View File

@@ -111,6 +111,8 @@ class DynamicShapePartitioner {

Status Partition();

using OrderedFilter = std::function<bool(const std::shared_ptr<Cluster> &cluster)>;

private:
Status PartitionImpl();
// Collect nodes that satisfy the unknowshape rules:
@@ -138,7 +140,7 @@ class DynamicShapePartitioner {
// Merge clusters step3
void MergeClustersInputData();
// Topological sort clusters after merge unknown shape clusters.
Status TopologicalSortClusters();
Status TopologicalSortClusters(const OrderedFilter &ordered_filter);
// Deduplicate merged clusters
void PruneUniqueClusters();
// Establish the input-output anchors for each partition of the cluster and record links to other clusters
@@ -161,7 +163,7 @@ class DynamicShapePartitioner {
ge::ComputeGraphPtr root_graph_; // The original graph to partition
std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to
// V1 control flow cluster, need merge to one Graph.
std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_;
std::map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_;
// topological sorted clusters, this field will change with the splitting.
// When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters
// When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters


+ 14
- 7
ge/graph/partition/graph_partition.cc View File

@@ -179,6 +179,7 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr
GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret);
}
GE_CHECK_NOTNULL(original_compute_graph);
output_merged_compute_graph->SetName(original_compute_graph->GetName());
// partition sub graph
for (const auto &sub_graph : original_compute_graph->GetAllSubgraphs()) {
ComputeGraphPtr merged_sub_graph = nullptr;
@@ -188,8 +189,16 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr
GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret);
continue;
}
// this means subgraph added in optimize subgraph and without partitions, so just add to root graph
if (merged_sub_graph == sub_graph) {
GELOGI("Just add subgraph %s (parent node is %s) to root graph %s.", sub_graph->GetName().c_str(),
sub_graph->GetParentNode()->GetName().c_str(), output_merged_compute_graph->GetName().c_str());
sub_graph->SetParentGraph(sub_graph->GetParentNode()->GetOwnerComputeGraph());
GE_IF_BOOL_EXEC(output_merged_compute_graph->AddSubgraph(sub_graph->GetName(), merged_sub_graph) != SUCCESS,
return FAILED;)
continue;
}
// add sub graph
output_merged_compute_graph->SetName(original_compute_graph->GetName());
merged_sub_graph->SetName(sub_graph->GetName());
merged_sub_graph->SetInputSize(sub_graph->GetInputSize());
merged_sub_graph->SetOutputSize(sub_graph->GetOutputSize());
@@ -245,12 +254,9 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co
}
if ((graph_2_graph_partition_info_.find(original_compute_graph) == graph_2_graph_partition_info_.end()) ||
(graph_2_subgraph_list_.find(original_compute_graph) == graph_2_subgraph_list_.end())) {
REPORT_INNER_ERROR("E19999", "original_compute_graph:%s is not find in graph_2_graph_partition_info_.",
original_compute_graph->GetName().c_str());
GELOGE(GE_GRAPH_NULL_INPUT,
"[Check][Param] original_compute_graph:%s is not find in graph_2_graph_partition_info_.",
original_compute_graph->GetName().c_str());
return FAILED;
GELOGW("[GraphPartition]: compute_graph has not found, just return original.");
output_merged_compute_graph = original_compute_graph;
return SUCCESS;
}
GraphPartitionInfo &subgraph_info = graph_2_graph_partition_info_[original_compute_graph];
const auto &sub_graph_list = graph_2_subgraph_list_[original_compute_graph];
@@ -708,6 +714,7 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vector<ge::SubGraphInfoPtr
}
auto &engine_name = graph_info_.partitions_.at(sub_graph);
(void)AttrUtils::SetStr(sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName());
(void)sub_graph->SetExtAttr("part_src_graph", compute_graph);
GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(),
compute_graph->GetName().c_str());
GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]);


+ 107
- 70
ge/graph/passes/mark_force_unknown_for_cond_pass.cc View File

@@ -16,8 +16,6 @@

#include "mark_force_unknown_for_cond_pass.h"

#include <queue>

#include "graph/utils/node_utils.h"
#include "graph/common/omg_util.h"

@@ -26,17 +24,7 @@ namespace {
inline bool IsMergeInLoop(const NodePtr &node) {
const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION };

std::string node_type;
(void)GetOriginalType(node, node_type);
return kLoopMergeInputs.count(node_type) > 0;
}

inline bool IsSwitchInLoop(const NodePtr &node) {
const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND };

std::string node_type;
(void)GetOriginalType(node, node_type);
return kLoopSwitchInputs.count(node_type) > 0;
return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0;
}
}

@@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
GELOGD("MarkForceUnknownForCondPass Enter");
std::map<NodePtr, std::vector<NodePtr>> switch_groups;
for (const auto &node : graph->GetDirectNode()) {
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type),
"[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str());
if (kMergeOpTypes.count(node_type) == 0) {
if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) {
continue;
}

@@ -65,6 +50,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
}

///
/// @brief Deal with Switch node for LoopCond
/// @param [in] Switch node
/// @param [in] dest span
/// @param [out] Search queue
/// @return true: Switch In while loop / false: Not in while Loop.
///
bool MarkForceUnknownForCondPass::DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span,
std::queue<std::pair<NodePtr, uint32_t>> &search_queue) {
/// LoopCond --->\.
/// \.
/// Enter-----------+ \.
/// +--> Merge --> Switch --> Exit
/// NextIteration---+
const auto is_loop_op = [](const NodePtr &n) {
return NodeUtils::GetNodeType(n) == LOOPCOND;
};
const auto is_exit_op = [](const NodePtr &n) {
return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0;
};

const auto src_nodes = node->GetInAllNodes();
const auto dst_nodes = node->GetOutAllNodes();
if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) &&
std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) {
return false;
}

for (const auto &m : src_nodes) {
if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) {
for (const auto &n : m->GetInAllNodes()) {
if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) {
continue;
}

search_queue.push({n, dst_span});
GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(),
n->GetName().c_str(), dst_span);
}
}
}

return true;
}

///
/// @brief Mark force unknown shape for Switch node
/// @param [in] merge node
/// @param [out] switch group
@@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) {
// Switch --> {Switch --> Merge} --> Merge
GELOGD("Search Switch node for Merge: %s", node->GetName().c_str());
std::unordered_set<NodePtr> nodes_seen;
std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}});
while (!search_queue.empty()) {
@@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std:
const auto dst_span = search_queue.front().second;
search_queue.pop();

// Switch --> Identity --> Constant
for (const auto &in_node : dst_node->GetInControlNodes()) {
for (const auto &in_node : dst_node->GetInAllNodes()) {
if (nodes_seen.count(in_node) > 0) {
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str());
continue;
}
nodes_seen.insert(in_node);

if (in_node->GetType() == IDENTITY) {
GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(),
in_node->GetName().c_str(), dst_span);
search_queue.push({in_node, dst_span});
}
}

for (const auto &in_node : dst_node->GetInDataNodes()) {
if (nodes_seen.count(in_node) > 0) {
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str());
continue;
}
nodes_seen.insert(in_node);

std::string node_type;
(void)GetOriginalType(in_node, node_type);
const std::string node_type = NodeUtils::GetNodeType(in_node);
GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(),
in_node->GetName().c_str(), dst_span);
if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node.
if (DealAsLoopSwitch(in_node, dst_span, search_queue)) {
continue;
}

if (dst_span > 0) {
search_queue.push({in_node, dst_span - 1});
} else {
const auto &all_in_nodes = in_node->GetInDataNodes();
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) {
GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(),
in_node->GetName().c_str());
} else {
switch_group.emplace_back(in_node);
}
switch_group.emplace_back(in_node);
}
} else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node.
search_queue.push({in_node, dst_span + 1});
@@ -132,39 +145,63 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std:
/// @return
///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) {
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP);
};
// Step 0: no group assigned. such as:
// Merge1{id=0, group=} => {Switch1{id=1, group=}, Switch2{id=2, group=}}
// Merge2{id=3, group=} => {Switch1{id=1, group=}, Switch3{id=4, group=}}
// Merge3{id=5, group=} => {Switch4{id=6, group=}, Switch5{id=7, group=}}
// Merge4{id=8, group=} => {Switch1{id=1, group=}, Switch5{id=7, group=}}
std::map<int64_t, int64_t> unique_groups;
const auto get_group_index = [&unique_groups](const NodePtr &merge, const std::vector<NodePtr> &switch_group) {
int64_t group_index = merge->GetOpDesc()->GetId();
std::set<int64_t> group_ids{group_index};
for (const auto &node : switch_group) {
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
GELOGI("[%s] Get group from [%s], index[%ld]", merge->GetName().c_str(), node->GetName().c_str(), group_index);
group_ids.insert(group_index);
}
}

for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) {
const auto &op_node1 = it1->first;
const auto &op_desc1 = op_node1->GetOpDesc();
if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
const auto it = unique_groups.find(group_index);
if (it != unique_groups.end()) {
group_index = it->second;
}

if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) {
int64_t group_index = op_desc1->GetId();
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index);
MarkForceUnknownShape(op_node1, true, group_index);
for (const auto &n : it1->second) {
MarkForceUnknownShape(n, true, group_index);
}
for (auto id : group_ids) {
unique_groups[id] = group_index;
}

for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) {
const auto &op_node2 = it2->first;
const auto &op_desc2 = op_node2->GetOpDesc();
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
}
return group_index;
};

if (std::any_of(it2->second.begin(), it2->second.end(), callback)) {
MarkForceUnknownShape(op_node2, true, group_index);
for (const auto &n : it2->second) {
MarkForceUnknownShape(n, true, group_index);
}
}
}
const auto set_group_index = [](const NodePtr &merge, const std::vector<NodePtr> &switch_group, int64_t group_index) {
SetControlFlowGroup(merge, group_index);
for (const auto &node : switch_group) {
SetControlFlowGroup(node, group_index);
}
};

// Step 1: Set group index to merge, if switch already has group, use assigned group.
// Merge1{id=0, group=0} => {Switch1{id=1, group=0}, Switch2{id=2, group=0}}
// Merge2{id=3, group=0} => {Switch1{id=1, group=0}, Switch3{id=4, group=0}}
// Merge3{id=5, group=5} => {Switch4{id=6, group=5}, Switch5{id=7, group=5}}
// Merge4{id=8, group=0} => {Switch1{id=1, group=0}, Switch5{id=7, group=0}}
for (const auto group : switch_groups) {
int64_t group_index = get_group_index(group.first, group.second);
set_group_index(group.first, group.second, group_index);
}

// Step 2: Adjust crossed merge group for unique group.
// Merge1{id=0, group=0} => {Switch1{id=1, group=0}, Switch2{id=2, group=0}}
// Merge2{id=3, group=0} => {Switch1{id=1, group=0}, Switch3{id=4, group=0}}
// Merge3{id=5, group=0} => {Switch4{id=6, group=0}, Switch5{id=7, group=0}}
// Merge4{id=8, group=0} => {Switch1{id=1, group=0}, Switch5{id=7, group=0}}
for (const auto group : switch_groups) {
int64_t group_index = -1;
(void)AttrUtils::GetInt(group.first->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);

const auto it = unique_groups.find(group_index);
if (it != unique_groups.end() && it->first != it->second) {
set_group_index(group.first, group.second, it->second);
}
}
}


+ 11
- 0
ge/graph/passes/mark_force_unknown_for_cond_pass.h View File

@@ -19,6 +19,8 @@

#include "inc/graph_pass.h"

#include <queue>

namespace ge {
class MarkForceUnknownForCondPass : public GraphPass {
public:
@@ -26,6 +28,15 @@ class MarkForceUnknownForCondPass : public GraphPass {

private:
///
/// @brief Deal with Switch node for LoopCond
/// @param [in] Switch node
/// @param [in] dest span
/// @param [out] Search queue
/// @return true: Switch In while loop / false: Not in while Loop.
///
bool DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue<std::pair<NodePtr, uint32_t>> &search_queue);

///
/// @brief Mark force unknown shape for Switch node
/// @param [in] merge node
/// @param [out] switch group


+ 6
- 0
ge/graph/passes/mark_graph_unknown_status_pass.cc View File

@@ -40,6 +40,12 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) {
}
}

const auto &node = graph->GetParentNode();
if (!is_unknown_shape && node != nullptr && node->GetType() == PARTITIONEDCALL) {
GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape),
"[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str());
}

for (const auto &node : graph->GetDirectNode()) {
GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str());
(void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape);


+ 2
- 3
ge/graph/passes/merge_to_stream_merge_pass.cc View File

@@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
return FAILED, "[Check][Param] Param of pre node is nullptr.");
int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(node, force_unknown, group_index);
(void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
@@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str());
return FAILED;
}
MarkForceUnknownShape(active_node, force_unknown, group_index);
SetControlFlowGroup(active_node, group_index);
}

return SUCCESS;


+ 23
- 6
ge/graph/passes/next_iteration_pass.cc View File

@@ -24,7 +24,9 @@ using std::string;

namespace ge {
namespace {
const int64_t kLoopType = 1;
constexpr int64_t kLoopType = 1;
constexpr uint8_t kMaxTransOp = 3;
constexpr uint8_t kTransOpIoSize = 1;
}

Status NextIterationPass::Run(ComputeGraphPtr graph) {
@@ -284,13 +286,28 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
/// @return void
///
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) {
std::string node_type;
for (const auto &switch_node : loop_group.switch_nodes) {
SetControlFlowGroup(switch_node, group_index);
for (const auto &node : switch_node->GetOutDataNodes()) {
std::string node_type;
(void)GetOriginalType(node, node_type);
if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(node, group_index);
for (auto node : switch_node->GetOutDataNodes()) {
// Switch --> Exit
// Switch --> Cast --> Exit
// Switch --> TransData --> Cast --> Exit
for (uint8_t i = 0; i < kMaxTransOp; ++i) {
if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) {
break;
}

if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
SetControlFlowGroup(node, group_index);
break;
}

const auto &all_nodes = node->GetOutAllNodes();
if (all_nodes.size() != kTransOpIoSize) {
break;
}
node = all_nodes.at(0);
}
}
}


+ 39
- 19
ge/graph/passes/parallel_group_pass.cc View File

@@ -15,7 +15,7 @@
*/

#include "graph/passes/parallel_group_pass.h"
#include <queue>
#include "framework/common/debug/ge_log.h"
#include "common/ge/ge_util.h"
#include "framework/common/ge_inner_error_codes.h"
@@ -299,24 +299,19 @@ Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cu
for (const auto &switch_node : cur_itr->second.first) {
int64_t pre_id = pre_node->GetOpDesc()->GetId();
int64_t switch_id = switch_node->GetOpDesc()->GetId();
// avoid ring
if (pre_id > switch_id) {
auto merge_node = cur_itr->second.second;
if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) {
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
return FAILED;
}
} else {
if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) {
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
return FAILED;
}
NodePtr first_node = pre_node;
NodePtr second_node = switch_node;
if (pre_id > switch_id && IsIndirectConnect(switch_node, pre_node)) {
// avoid ring, merge->pre_node
first_node = cur_itr->second.second;
second_node = pre_node;
}
if (AddCtrlEdge(first_node, second_node) != SUCCESS) {
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
first_node->GetName().c_str(), second_node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
first_node->GetName().c_str(), second_node->GetName().c_str());
return FAILED;
}
}
} else {
@@ -345,4 +340,29 @@ bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) {
return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) &&
stream_switch_type == kLoopType);
}

bool ParallelGroupPass::IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b) {
if (node_a == nullptr || node_b == nullptr) {
GELOGW("node_a or node_b is nullptr.");
return false;
}
int64_t end_id = node_b->GetOpDesc()->GetId();
std::queue<NodePtr> nodes;
nodes.push(node_a);
while (!nodes.empty()) {
NodePtr tmp_node = nodes.front();
nodes.pop();
if (tmp_node == nullptr || tmp_node->GetOpDesc() == nullptr ||
tmp_node->GetOpDesc()->GetId() > end_id) {
continue;
}
if (tmp_node == node_b) {
return true;
}
for (const auto &out_node : tmp_node->GetOutAllNodes()) {
nodes.push(out_node);
}
}
return false;
}
} // namespace ge

+ 1
- 0
ge/graph/passes/parallel_group_pass.h View File

@@ -48,6 +48,7 @@ class ParallelGroupPass : public GraphPass {

bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc);
bool IsWhileStreamSwitch(OpDescPtr switch_op_desc);
bool IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H

+ 20
- 0
ge/graph/passes/replace_with_empty_const_pass.cc View File

@@ -21,7 +21,23 @@
#include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"

namespace {
const std::unordered_set<std::string> kControlFlowOps = {
ge::SWITCH,
ge::REFSWITCH,
ge::MERGE,
ge::REFMERGE,
ge::ENTER,
ge::REFENTER,
ge::NEXTITERATION,
ge::REFNEXTITERATION,
ge::EXIT,
ge::REFEXIT,
ge::LOOPCOND
};
}
namespace ge {
Status ReplaceWithEmptyConstPass::Run(NodePtr &node) {
GELOGD("ReplaceWithEmptyConstPass in.");
@@ -39,6 +55,10 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) {
GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str());
return SUCCESS;
}
if (kControlFlowOps.count(NodeUtils::GetNodeType(node)) != 0) {
GELOGI("Node %s is control flow op. Ignore current pass.", node->GetName().c_str());
return SUCCESS;
}
// Node like no op, it has no output
if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) {
GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str());


+ 9
- 8
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str());

int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(stream_switch, force_unknown, group_index);
if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
SetControlFlowGroup(stream_switch, group_index);
}
return stream_switch;
}

@@ -491,8 +492,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) {
Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) {
for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) {
for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) {
std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
std::set<NodePtr> same_cond_switch;
same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end());
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end());
@@ -524,13 +525,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) {
return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
};
bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
MarkForceUnknownShape(active_node, is_unknown_shape, group_index);
(void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
SetControlFlowGroup(active_node, group_index);

const std::string &cond_group = cond_node->GetName();
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) {
bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT);
std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
GE_IF_BOOL_EXEC(switch_list.empty(), continue);

// select first stream_switch
@@ -559,7 +560,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
"[Add][Edge] between %s and %s failed.",
cast_node->GetName().c_str(), stream_switch->GetName().c_str());

MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index);
SetControlFlowGroup(stream_switch, group_index);
for (const NodePtr &node : switch_list) {
GE_IF_BOOL_EXEC(node != stream_switch, {
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)),


+ 68
- 55
ge/graph/preprocess/graph_preprocess.cc View File

@@ -1420,9 +1420,10 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) {
return SUCCESS;
}

Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag) {
Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc) {
auto format = desc.GetFormat();
auto origin_format = desc.GetOriginFormat();
auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER);
bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag);
if (need_check_internal_format) {
bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format);
@@ -1439,6 +1440,63 @@ Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTens
return SUCCESS;
}

Status GraphPrepare::UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc) {
auto data_type = desc.GetDataType();
uint32_t length = 1;
bool type_ret = TypeUtils::GetDataTypeLength(data_type, length);
if (!type_ret) {
std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" +
std::to_string(index) + " input tensor is not support";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.",
TypeUtils::DataTypeToSerialString(data_type).c_str());
return FAILED;
}
int64_t desc_shape = desc.GetShape().GetShapeSize();
FMK_INT64_UINT32_MULCHECK(desc_shape, length);
int64_t shape_size = desc_shape * length;
GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast<int64_t>(length));
int64_t size = 0;
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS,
REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index);
GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index); return FAILED);
bool size_check = (size != 0 && shape_size != size);
if (size_check) {
std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) +
"] != shape_size[" + std::to_string(size) + "], check invalid";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size);
return FAILED;
}
ge::TensorUtils::SetSize(desc, shape_size);

auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER);
if (!tune_flag) {
graphStatus graph_ret = op->UpdateInputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
// Size will be recalculated in the build stage
ge::TensorUtils::SetSize(desc, 0);
graph_ret = op->UpdateOutputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
} else {
GELOGI("data %s skip update info in tune mode", op->GetName().c_str());
}

return SUCCESS;
}

Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input,
const std::map<string, string> &graph_option) {
// Get shape range of input in dynamic_execute mode
@@ -1471,63 +1529,18 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input,
}
GeTensorDesc desc(user_input[index].GetTensorDesc());
// data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM.
auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER);
ret = CheckInternalFormat(input_node, desc, tune_flag);
ret = CheckInternalFormat(input_node, desc);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str());
return ret;
}
auto data_type = desc.GetDataType();
uint32_t length = 1;
bool type_ret = TypeUtils::GetDataTypeLength(data_type, length);
if (!type_ret) {
std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" +
std::to_string(index) + " input tensor is not support";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.",
TypeUtils::DataTypeToSerialString(data_type).c_str());
return FAILED;
}
int64_t desc_shape = desc.GetShape().GetShapeSize();
FMK_INT64_UINT32_MULCHECK(desc_shape, length);
int64_t shape_size = desc_shape * length;
GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast<int64_t>(length));
int64_t size = 0;
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS,
REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index);
GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index);
return FAILED);
bool size_check = (size != 0 && shape_size != size);
if (size_check) {
std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) +
"] != shape_size[" + std::to_string(size) + "], check invalid";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size);
return FAILED;
}
ge::TensorUtils::SetSize(desc, shape_size);
if (!tune_flag) {
graphStatus graph_ret = op->UpdateInputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
// Size will be recalculated in the build stage
ge::TensorUtils::SetSize(desc, 0);
graph_ret = op->UpdateOutputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
} else {
GELOGI("data %s skip update info in tune mode", op->GetName().c_str());

ret = UpdateDataInputOutputDesc(index, op, desc);
if (ret != SUCCESS) {
GELOGE(FAILED, "[Update][DataInputOutputDesc] on %s failed", op->GetName().c_str());
return ret;
}

if (!dynamic_shape_range_vec.empty()) {
ret = UpdateDynamicInputShapeRange(index, dynamic_shape_range_vec, op, desc);
GE_CHK_STATUS_RET(ret, "[Update][DynamicInputShapeRange] on %s failed.", op->GetName().c_str());
@@ -1742,8 +1755,8 @@ Status GraphPrepare::CtrlFlowPreProcess() {
PassManager graph_pass;

// After InferShape Mark v1 control flow for unknown shape.
auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass;
GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::MarkForceUnknownForCondPass", mark_force_unknown_pass));
GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::MarkForceUnknownForCondPass",
new (std::nothrow) MarkForceUnknownForCondPass));

GE_CHK_STATUS_RET(graph_pass.Run(compute_graph_));
return SUCCESS;


+ 2
- 1
ge/graph/preprocess/graph_preprocess.h View File

@@ -63,7 +63,8 @@ class GraphPrepare {
Status CheckRefOp();
Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode);
Status AdjustDataOpOutput(const NodePtr &node);
Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag);
Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc);
Status UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc);
Status UpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option);
Status CheckAndUpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option);
Status CheckConstOp();


+ 1
- 0
ge/graph/preprocess/insert_op/util_insert_aipp_op.cc View File

@@ -568,6 +568,7 @@ Status InsertNewOpUtil::GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std:
}

std::unique_ptr<domi::AippOpParams> aipp_params(new (std::nothrow) domi::AippOpParams());
GE_CHECK_NOTNULL(aipp_params);
ge::GeAttrValue::NAMED_ATTRS aipp_attr;
GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST,
"[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str());


+ 1
- 1
ge/graph/preprocess/multi_batch_copy_graph.cc View File

@@ -1206,7 +1206,7 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_
auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims();
if (!IsAllDimsPositive(dims)) {
REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s",
node->GetName().c_str(), formats::ShapeToString(dims).c_str());
node->GetName().c_str(), formats::ShapeToString(dims).c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s",
node->GetName().c_str(), formats::ShapeToString(dims).c_str());
return INTERNAL_ERROR;


+ 8
- 0
ge/host_kernels/fill_kernel.cc View File

@@ -45,6 +45,7 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge
GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr.");
return PARAM_INVALID;
}
GELOGD("FillKernel in, name: %s.", op_desc_ptr->GetName().c_str());

GE_CHECK_NOTNULL(input.at(kFillDimsInputIndex));
GE_CHECK_NOTNULL(input.at(kFillDataInputIndex));
@@ -57,6 +58,13 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge
return NOT_CHANGED;
}

auto output_desc = op_desc_ptr->GetOutputDescPtr(0);
GE_CHECK_NOTNULL(output_desc);
if (output_desc->GetShape().IsUnknownShape()) {
GELOGD("Output is unknown shape, [%s] skip FillKernel.", op_desc_ptr->GetName().c_str());
return NOT_CHANGED;
}

GeTensorPtr output_ptr;
output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
if (output_ptr == nullptr) {


+ 5
- 3
ge/hybrid/executor/hybrid_model_async_executor.cc View File

@@ -297,13 +297,15 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData &current_data, Hy
}
}
tensor_desc->SetShape(shape);
args.input_desc[input_index] = tensor_desc;
GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str());
GELOGD("Update shape[%s] of input[%zu] to [%s]",
shape.ToString().c_str(), input_index, tensor_desc->MutableShape().ToString().c_str());
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size),
"[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size,"
"index = %zu, shape = [%s], model_id = %u.",
input_index, tensor_desc->GetShape().ToString().c_str(), model_id_);
GELOGD("Input tensor[%zu] size = %zu", input_index, tensor_size);
GELOGD("Input tensor[%zu] size = %ld", input_index, tensor_size);
TensorUtils::SetSize(*tensor_desc, tensor_size);
args.input_desc[input_index] = tensor_desc;
}

GE_CHECK_GE(tensor_size, 0);


+ 8
- 7
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -41,6 +41,8 @@ HybridModelExecutor::~HybridModelExecutor() {
Status HybridModelExecutor::Init() {
GELOGD("Start to init HybridGraphEngine.");
GE_CHK_STATUS_RET_NOLOG(InitExecutionContext());
root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_));
GE_CHECK_NOTNULL(root_graph_executor_);
GELOGD("HybridGraphEngine initialized successfully.");
return SUCCESS;
}
@@ -60,8 +62,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration,
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream));
}
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_);
auto ret = ExecuteGraphInternal(executor, args);
auto ret = ExecuteGraphInternal(args);
Cleanup();
RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End");
GELOGD("Model executed successfully.");
@@ -69,6 +70,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
context_.profiler->Dump(std::cout);
context_.profiler->Reset();
}
root_graph_executor_->ReleaseContext();

context_.iteration += 1;
if (ret == END_OF_SEQUENCE) {
@@ -79,8 +81,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
return SUCCESS;
}

Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
HybridModelExecutor::ExecuteArgs &args) {
Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) {
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start");
GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_));
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End");
@@ -94,7 +95,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id));
}

HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs),
HYBRID_CHK_STATUS_RET(root_graph_executor_->ExecuteAsync(args.inputs, args.input_desc, args.outputs),
"Failed to execute partitioned call.");
RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End");

@@ -103,7 +104,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
}

if (!model_->IsSingleOp()) {
Status ret = executor.Synchronize();
Status ret = root_graph_executor_->Synchronize();
if (ret != ge::SUCCESS) {
auto model_manager = ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager);
@@ -123,7 +124,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
}

args.outputs.clear();
HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs");
HYBRID_CHK_STATUS_RET(root_graph_executor_->GetOutputs(args.outputs, args.output_desc), "Failed to get outputs");
RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End");
return SUCCESS;
}


+ 2
- 1
ge/hybrid/executor/hybrid_model_executor.h View File

@@ -48,7 +48,7 @@ class HybridModelExecutor {
Status Execute(ExecuteArgs &args);

private:
Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args);
Status ExecuteGraphInternal(ExecuteArgs &args);
Status Cleanup();
Status InitExecutionContext();
static Status ResetExecutionContext(GraphExecutionContext &context);
@@ -58,6 +58,7 @@ class HybridModelExecutor {
uint32_t device_id_;
rtStream_t stream_;
GraphExecutionContext context_;
std::unique_ptr<SubgraphExecutor> root_graph_executor_;
};
} // namespace hybrid
} // namespace ge


+ 2
- 0
ge/hybrid/executor/hybrid_model_pipeline_executor.cc View File

@@ -172,6 +172,8 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin
config_.num_executors = kNumExecutors;
config_.num_stages = model_->GetRootGraphItem()->NumGroups();
config_.device_id = device_id_;
config_.iteration_end = 0;
config_.rt_context = nullptr;
}

Status StageExecutor::InitExecutionContext() {


+ 79
- 8
ge/hybrid/executor/node_state.cc View File

@@ -19,8 +19,9 @@
#include "framework/common/debug/log.h"
#include "graph/compute_graph.h"
#include "graph/utils/tensor_utils.h"
#include "hybrid_execution_context.h"
#include "subgraph_context.h"
#include "hybrid/executor/hybrid_execution_context.h"
#include "hybrid/executor/subgraph_context.h"
#include "hybrid/node_executor/task_context.h"

#define INC_ITERATION_COUNT(iteration) \
do { \
@@ -260,6 +261,16 @@ NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_contex
this->op_desc_ = node_item.node->GetOpDesc();
}

Status NodeState::Init(int group, const shared_ptr<FrameState> &frame_state) {
GE_CHECK_NOTNULL(frame_state);
group_ = group;
frame_state_ = frame_state;
auto unique_task_context = TaskContext::Create(this, subgraph_context_);
GE_CHECK_NOTNULL(unique_task_context);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());
return SUCCESS;
}

Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const {
if (node_item_->IsMergeOp()) {
GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size());
@@ -314,15 +325,75 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
return task_context_;
}

void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) {
const auto is_persist_tensor = [](const std::map<const NodeItem *, std::set<int>> &items, int idx) {
const auto is_exist = [&idx](const std::pair<const NodeItem *, std::set<int>> &items) {
return items.second.count(idx) > 0;
};
return std::any_of(items.begin(), items.end(), is_exist);
};

if (root_tensor_values_.count(input_idx) > 0) {
return;
}

if (is_persist_tensor(node_item_->root_data_, input_idx)) {
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor;
} else if (is_persist_tensor(node_item_->enter_data_, input_idx)) {
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor;
}
}

void NodeState::UpdatePersistTensor() {
const auto update_tensor = [&](const std::map<const NodeItem *, std::set<int>> &items) {
for (const auto &item : items) {
for (const auto idx : item.second) {
UpdatePersistTensor(idx);
}
}
};

if (root_tensor_values_.empty()) {
return;
}

update_tensor(node_item_->root_data_);
if (iteration_count_ > 0) {
update_tensor(node_item_->enter_data_);
}
}

void NodeState::UpdatePersistTensor(int input_idx) {
const auto it = root_tensor_values_.find(input_idx);
if (it == root_tensor_values_.end()) {
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx);
return;
}

auto tensor = task_context_->MutableInput(input_idx);
if (tensor == nullptr) {
GELOGW("[%s] Not found input tensor: %d", GetName().c_str(), input_idx);
return;
}

*tensor = it->second;
GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx);
}

void NodeState::ResetContext(uint64_t iteration) {
switch_index_ = -1;
subgraph_context_->ResetContext(node_item_->node);
if (iteration == 0) {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
} else {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size());
auto unique_task_context = TaskContext::Create(this, subgraph_context_);
GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());

data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
if (iteration > 0) {
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size());
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size());
}

iteration_count_ = iteration;


+ 12
- 8
ge/hybrid/executor/node_state.h View File

@@ -100,6 +100,8 @@ struct NodeState {
NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context);
~NodeState() = default;

Status Init(int group, const shared_ptr<FrameState> &frame_state);

OpDesc *GetOpDesc() const {
return op_desc_.get();
}
@@ -129,6 +131,9 @@ struct NodeState {
void RunStreamActive();
void RunNextIteration();

void SavePersistTensor(int input_idx, const TensorValue &tensor);
void UpdatePersistTensor();

Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const;

void SetScheduleFuture(std::future<Status> &&future);
@@ -150,18 +155,10 @@ struct NodeState {
return merge_index_;
}

void SetGroup(int group) {
group_ = group;
}

int GetGroup() const {
return group_;
}

void SetFrameState(const shared_ptr<FrameState> &frame_state) {
frame_state_ = frame_state;
}

const shared_ptr<NodeTask> &GetKernelTask() const {
return kernel_task_;
}
@@ -181,12 +178,17 @@ struct NodeState {
void SetTaskContext(std::shared_ptr<TaskContext> &task_context);
std::shared_ptr<TaskContext> GetTaskContext();

void SetSkipInferShape(bool skip_infershape) { skip_infershape_ = skip_infershape; }

bool MaySkipShapeInference() const { return skip_infershape_; }

private:
bool IsScheduleReady() const;
void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void ResetContext(uint64_t iteration);
void ScheduleContext(const NodeState &node_state);
void UpdatePersistTensor(int input_idx);

const NodeItem *node_item_ = nullptr;
std::shared_ptr<NodeTask> kernel_task_ = nullptr;
@@ -199,6 +201,7 @@ struct NodeState {

std::future<Status> schedule_future_;
std::shared_ptr<FrameState> frame_state_;
std::map<int, TensorValue> root_tensor_values_;
uint64_t active_count_ = 0;
uint64_t iteration_count_ = 0;
uint32_t ctrl_scheduled_ = 0;
@@ -206,6 +209,7 @@ struct NodeState {
int merge_index_ = -1; // Use for Execute (Reset after Executed).
int switch_index_ = -1; // Use for Schedule (Reset after Prepared).
int group_ = -1;
bool skip_infershape_ = false;
};
} // namespace hybrid
} // namespace ge


+ 19
- 8
ge/hybrid/executor/subgraph_context.cc View File

@@ -19,7 +19,7 @@

namespace ge {
namespace hybrid {
SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context)
SubgraphContext::SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context)
: graph_item_(graph_item), execution_context_(execution_context) {
}

@@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
return nullptr;
}

return CreateNodeState(node_item);
}

NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) {
GELOGD("[%s] lock for write", node_item->NodeName().c_str());
if (mmRWLockWRLock(&rw_lock_) != EN_OK) {
REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str());
GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str());
return nullptr;
}

auto &node_state = node_states_[node_item];
if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
node_state->SetFrameState(GetOrCreateFrameState(*node_item));
node_state->SetGroup(group_);
(void)guard;
}
do {
if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str());
REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str());
break;
}
(void)guard;
}
} while (0);

GELOGD("[%s] unlock for write", node_item->NodeName().c_str());
if (mmWRLockUnLock(&rw_lock_) != EN_OK) {
REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str());


+ 3
- 2
ge/hybrid/executor/subgraph_context.h View File

@@ -30,7 +30,7 @@ namespace ge {
namespace hybrid {
class SubgraphContext {
public:
explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context);
explicit SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context);
~SubgraphContext();

Status Init();
@@ -51,10 +51,11 @@ class SubgraphContext {
void NodeDone(const NodePtr &node);

private:
NodeStatePtr CreateNodeState(const NodeItem *node_item);
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock
friend class TaskContext;
const GraphItem *graph_item_;
const GraphExecutionContext *execution_context_;
GraphExecutionContext *execution_context_;
mmRWLock_t rw_lock_;
std::vector<TensorValue> all_inputs_;
std::vector<TensorValue> all_outputs_;


+ 10
- 15
ge/hybrid/executor/subgraph_executor.cc View File

@@ -103,6 +103,13 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue
auto node_state = subgraph_context_->GetOrCreateNodeState(input_node);
GE_CHECK_NOTNULL(node_state);
node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc);
auto op_desc = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex);
GE_CHECK_NOTNULL(output_desc);
output_desc->SetShape(tensor_desc->GetShape());
output_desc->SetOriginShape(tensor_desc->GetOriginShape());
node_state->SetSkipInferShape(true);
}
}

@@ -175,16 +182,12 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue
GE_CHECK_NOTNULL(node_state);
node_state->SetKernelTask(node_item->kernel_task);

known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(known_shape_task_context_);
node_state->SetTaskContext(known_shape_task_context_);

std::function<void()> callback;
GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback));
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback),
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, node_state->GetTaskContext(), *context_, callback),
"[%s] Failed to execute node [%s] for known subgraph.",
graph_item_->GetName().c_str(),
known_shape_task_context_->GetNodeName());
node_state->GetName().c_str());

GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str());
return SUCCESS;
@@ -271,16 +274,12 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) {
} else {
node_state->SetKernelTask(node_item.kernel_task);
}
auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state->GetKernelTask();
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str());
REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str());
return INTERNAL_ERROR;
}
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state));
return AfterPrepared(p_node_state);
}
@@ -480,19 +479,15 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta
} else {
node_state.SetKernelTask(node_item.kernel_task);
}
auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state.GetKernelTask();
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str());
REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str());
return INTERNAL_ERROR;
}
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state.SetTaskContext(shared_task_context);
GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context));
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start");
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*node_state.GetTaskContext())); // update op_desc before alloc ws
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end");
return SUCCESS;
}


+ 2
- 1
ge/hybrid/executor/subgraph_executor.h View File

@@ -41,6 +41,8 @@ class SubgraphExecutor {

Status PartialExecuteAsync(int task_group);

void ReleaseContext() { subgraph_context_.reset(nullptr); }

/**
* Execute subgraph async, output tensor address(not data) and output tensor descriptions are
* valid after this method returned
@@ -125,7 +127,6 @@ class SubgraphExecutor {
ThreadPool pre_run_pool_;
BlockingQueue<NodeState *> ready_queue_;
std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_;
std::shared_ptr<TaskContext> known_shape_task_context_;

std::mutex mu_; // Guard for prepare_queues_.
std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_;


+ 2
- 1
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -373,6 +373,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state,
auto executor = node_item.node_executor;
GE_CHECK_NOTNULL(executor);
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start");
node_state.UpdatePersistTensor();
GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.",
node_state.GetName().c_str());
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End");
@@ -427,7 +428,7 @@ Status ExecutionEngine::ValidateInputTensors(const NodeState &node_state, const
continue;
}

int64_t expected_size;
int64_t expected_size = 0;
(void)TensorUtils::GetSize(*tensor_desc, expected_size);
GELOGD("[%s] Input[%d] expects [%ld] bytes.", task_context.GetNodeName(), i, expected_size);
auto size_diff = expected_size - static_cast<int64_t>(input_tensor->GetSize());


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save